## A505
## Joan Gines i Ametlle

#import sys
#sys.path.append('C:/Users/Joan/EI/UPC/Q8/CDI/module')
import numpy as np
from math import *
from matplotlib.pyplot import *

# Synonyms
Id = np.eye
sqrt = np.sqrt
array = np.array
stack = np.vstack
splice= np.hstack
dot = np.dot
ls  = np.linspace
zeros=np.zeros
mat=np.matrix

# Basic constants
nd = 4
r2 = sqrt(2); r3 = sqrt(3);
a1 = (1+r3)/(4*r2); a2 = (3+r3)/(4*r2); 
a3 = (3-r3)/(4*r2); a4 = (1-r3)/(4*r2);
[b1,b2,b3,b4] = [a4,-a3,a2,-a1]


# To construct the array of D4 level r scaling vectors
# from the array V of D4 level r-1 scaling vectors
def D4V(V):
    N = len(V)
    X = a1*V[0,:] + a2*V[1,:] + a3*V[2%N,:] + a4*V[3%N,:]
    for j in range(1,N//2):
        v = a1*V[2*j,:] + a2*V[2*j+1,:] + a3*V[(2*j+2)%N,:] + a4*V[(2*j+3)%N,:] 
        X = stack([X,v])
    return X

# To construct the array of D4 level r wavelet vectors
# from the array V of D4 level r-1 scaling vectors
def D4W(V):
    N = len(V)
    Y = b1*V[0,:] + b2*V[1,:] + b3*V[2%N,:] + b4*V[3%N,:]
    for j in range(1,N//2):
        w = b1*V[2*j,:] + b2*V[2*j+1,:] + b3*V[(2*j+2)%N,:] + b4*V[(2*j+3)%N,:] 
        Y = stack([Y,w])
    return Y

# To construct the pair formed with the array V 
# of all D4 scale vectors and the array W of all
# D4 wavelet vectors.
def D4VW(N):
    V = Id(N)
    X = [V]
    Y = []
    while N > 2:
        W = D4W(V)
        V = D4V(V)
        X = X + [V]
        Y = Y + [W]
        N = len(V)
    W = D4W(V)
    V = D4V(V)
    X = X + [[V]]
    Y = Y + [[W]]
    return (X,Y)
    
    
# Plot Daub4 scaling and wavelet vectors at different levels
close('all')
lbl = [' wavelet', ' scaling']

N = 16
X = D4VW(N)

for i in [1,0]:
    fig, axes = subplots(nrows=3, ncols=1)
    fig.tight_layout()
    for j in [0,1,2]:
        subplot(3, 1, j+1, title = 'Level ' + str(j+1) + lbl[i] + ' vectors')
        plot(range(len(X[1-i][j+i][0])), X[1-i][j+i][0], lw = 2)
        for k in range(1, len(X[1-i][j+i])):
            plot([x+k*len(X[1-i][j+i][0]) for x in range(len(X[1-i][j+i][0]))], X[1-i][j+i][k], lw = 2)
            grid('on')

