## cdi_haar.py, a tool module for Haar wavelets (T6)
## SXD 428


# Utilities
from cdi import *
import numpy as np
import matplotlib.pyplot as plt

# Synonyms
Id = np.eye
#sqrt = np.sqrt
pi = np.pi
cos = np.cos
sin = np.sin
array = np.array
stack = np.vstack
splice= np.hstack
dot = np.dot
ls  = np.linspace
zeros=np.zeros
mat=np.matrix
transpose=np.transpose
det = np.linalg.det
inv = np.linalg.inv

canvas = plt.figure
axis=plt.axis
show = plt.show
view = plt.imshow
close = plt.close
plot = plt.plot
xlim = plt.xlim


## Core functions 

# trend(f) computes the trend signal of a discrete signal f
def trend(f):
    r = 1.4142135623730951  #sqrt(2) 
    J = range(len(f)//2)
    return [(f[2*j]+f[2*j+1])/r for j in J]

# fluct(f) computes the fluctuation 
# (or difference) signal of a discrete signal f    
def fluct(f):
    r = 1.4142135623730951  #sqrt(2) 
    J = range(len(f)//2)
    return [(f[2*j]-f[2*j+1])/r for j in J]

# haar(f,r) computes the Haar transform of f or order r. 
# haar(f) is equivalent to haar(f,1)
def haar(f,r=1): 
    if r==0: return f
    if r==1: return (trend(f)+fluct(f))
    N=len(f); m=N//2**(r-1)
    a=haar(f,r-1); x=a[:m]
    return trend(x)+fluct(x)+a[m:]

# Define a function that computes 
# the array of level r scaling vectors
# from the array V of level r-1 scaling vectors
def HaarV(V):
    a1 = a2 = 2**(-1/2)
    N = len(V)
    X = a1 * V[0,:] + a2 * V[1,:]
    for j in range(1,N//2):
        x = a1 * V[2*j,:] + a2 * V[2*j+1,:]
        X = stack([X,x])
    return X
#
def HaarT(f,r=1): 
    if r == 0: return f
    a = list(f); h = []; N = len(a)
    if N % 2**r: return "trend: %d is not divisible by 2**%d " % (N, r)
    N = N//2**r
    while len(a)>N:
        d = fluct(a)
        h = d + h
        a = trend(a)
    return a+h
D2 = HaarT
    
# Define a function that computes 
# the array of level r scaling vectors
# from the array V of level r-1 scaling vectors
def HaarW(V):
    a1 = a2 = 2**(-1/2)
    N = len(V)
    X = a1 * V[0,:] - a2 * V[1,:]
    for j in range(1,N//2):
        x = a1 * V[2*j,:] - a2 * V[2*j+1,:]
        X = stack([X,x])
    return X
    
    
# Define a function that computes 
# the array of all scaling vectors, that is,
# its r-th component is the matrix of level r
# scaling vectors.
def HaarVA(N):
    V = Id(N)
    X = [V]
    while N > 2:
        V = HaarV(V)
        X += [V]
        N = len(V)
    V = HaarV(V)
    X += [[V]]
    return X
VA = HaarVA

# Define a function that computes a pair 
# formed with the array of all scaling vectors
# and the array of all wavelet vectors
def HaarVWA(N):
    V = Id(N)
    X = [V]
    Y = []
    while N > 2:
        W = HaarW(V)
        V = HaarV(V)
        X += [V]
        Y += [W]
        N = len(V)
    W = HaarW(V)
    V = HaarV(V)
    X += [[V]]
    Y += [[W]]
    return (X,Y)
VWA = HaarVWA
