## A505

##From L507:

import numpy as np
import matplotlib.pyplot as plt

# Synonyms
Id = np.eye
sqrt = np.sqrt
array = np.array
stack = np.vstack
splice= np.hstack
dot = np.dot
ls  = np.linspace
zeros=np.zeros
mat=np.matrix


# Basic constants
nd = 4
r2=sqrt(2); r3=sqrt(3);
a1=(1+r3)/(4*r2); a2=(3+r3)/(4*r2); 
a3=(3-r3)/(4*r2); a4=(1-r3)/(4*r2);
[b1,b2,b3,b4] = [a4,-a3,a2,-a1]

# To construct the array of D4 level r scaling vectors
# from the array V of D4 level r-1 scaling vectors
def D4V(V): #analogous to HaarV(V)
    N = len(V)
    X = a1*V[0,:]+a2*V[1,:]+a3*V[2%N,:]+a4*V[3%N,:]
    for j in range(1,N//2):
        v = a1*V[2*j,:]+a2*V[2*j+1,:]+a3*V[(2*j+2)%N,:]+a4*V[(2*j+3)%N,:]
        X = stack([X,v]) 
    return X

# To construct the array of D4 level r wavelet vectors
# from the array V of D4 level r-1 scaling vectors
def D4W(V): #analogous to HaarW(V)
    N = len(V)
    Y = b1*V[0,:]+b2*V[1,:]+b3*V[2%N,:]+b4*V[3%N,:]
    for j in range(1,N//2):
        w = b1*V[2*j,:]+b2*V[2*j+1,:]+b3*V[(2*j+2)%N,:]+b4*V[(2*j+3)%N,:]
        Y = stack([Y,w]) 
    return Y
# To construct the pair formed with the array V 
# of all D4 scale vectors and the array W of all
# D4 wavelet vectors.
def D4VW(N): #analogous to HaarVW(V)
    V = Id(N)
    X = [V]
    Y = []
    while N > 2:
        W = D4W(V)
        V = D4V(V)
        X += [V]
        Y += [W]
        N = len(V)
    W = D4W(V)
    V = D4V(V)
    X += [[V]]
    Y += [[W]]
    return (X, Y)
    
    
#From cdi-graphics
def write(X,T,ha='center', va = 'center', fs='16', rot=0.0):
    return plt.text(X[0],X[1], T, horizontalalignment=ha,verticalalignment=va, fontsize=fs, rotation=rot)


'''
Draw graphics of some Daub4 scaling and wavelet vectors of different levels. You can adapt the analogous exercise done for the Haar wavelets. 
'''
from cdi import *
canvas = plt.figure
xlim = plt.xlim
plot = plt.plot

n = 10; N = 2**n
V, W = D4VW(N)


canvas("First %d scaling vectors of levels 1 to 6 "%n)
xlim(0,N)
for r in range(1,7) :
    f = plt.subplot(1,6,r)
    f.grid(True)
    f.set_xticklabels([])
    f.set_yticklabels([])
    if r == 1 : 
        for num in range(1,n+1) :
            write((0,num),'# %d'%num,'right',fs ='10')
    write((0,-0.75),'Level %d'%r,'left')
    for r2 in range(1,n+1) :
        plot(r2 + V[r][r2])
    

canvas("First %d wavelets vectors of levels 1 to 6 "%n)
xlim(0,N)
for r in range(1,7) :
    f = plt.subplot(1,6,r)
    f.grid(True)
    f.set_xticklabels([])
    f.set_yticklabels([])
    if r == 1 : 
        for num in range(1,n+1) :
            write((0,num),'# %d'%num,'right',fs ='10')
    write((0,-0.75),'Level %d'%r,'left')
    for r2 in range(1,n+1) :
        plot(r2 + W[r-1][r2])
