## L421_martin

'''
Computing the Haar scaling and wavelet arrays
'''

# Utilities
from cdi import *
import numpy as np
import matplotlib.pyplot as plt

# Synonyms
Id = np.eye
#sqrt = np.sqrt
pi = np.pi
cos = np.cos
sin = np.sin
array = np.array
stack = np.vstack
splice= np.hstack
dot = np.dot
ls  = np.linspace
zeros=np.zeros
mat=np.matrix
transpose=np.transpose
view = plt.imshow
canvas = plt.figure


# Define a function that computes 
# the array of level r scaling vectors
# from the array V of level r-1 scaling vectors
def HaarV(V):
    a1 = a2 = 2**(-1/2)
    N = len(V) #num. rows
    X = a1 * V[0,:] + a2 * V[1,:]
    for j in range(1,N//2):
        x = a1 * V[2*j,:] + a2 * V[2*j+1,:]
        X = stack([X,x])
    return X
    
V0 = Id(4)
V1 = HaarV(V0)
V2 = HaarV(V1)
    
    
# Define a function that computes 
# the array of level r scaling vectors
# from the array V of level r-1 scaling vectors
def HaarW(V):
    a1 = a2 = 2**(-1/2)
    N = len(V)
    X = a1 * V[0,:] - a2 * V[1,:]
    for j in range(1,N//2):
        x = a1 * V[2*j,:] - a2 * V[2*j+1,:]
        X = stack([X,x])
    return X
    
W1 = HaarW(V0)
W2 = HaarW(V1) #V1, not W1!
    

# Define a function that computes 
# the array of all scaling vectors, that is,
# its r-th component is the matrix of level r
# scaling vectors: [V^0,V^1,...,V^(N/2**r)]
def HaarVA(N):
    V = Id(N)
    X = [V]
    while N > 2:
        V = HaarV(V)
        X += [V]
        N = len(V)
    V = HaarV(V)
    X += [[V]]
    return X
   
VA = HaarVA
VA(4)


# Define a function that computes a pair 
# formed with the array of all scaling vectors
# and the array of all wavelet vectors
def HaarVWA(N):
    V = Id(N)
    X = [V]
    Y = []
    while N > 2:
        W = HaarW(V)
        V = HaarV(V)
        X += [V]
        Y += [W]
        N = len(V)
    W = HaarW(V)
    V = HaarV(V)
    X += [[V]]
    Y += [[W]]
    return (X,Y)
    
VWA = HaarVWA    
VWA(4)
    

V0 = Id(8)
V1 = HaarV(V0)
V2 = HaarV(V1)
f = [1,2,3,4,5,6,7,8]
x = dot(f,V2[0,:])*V2[0,:] + dot(f,V2[1,:])*V2[1,:]
print(x) # sxd