## A505_dominguez

import numpy as np
import matplotlib.pyplot as plt

# Synonyms
Id = np.eye
sqrt = np.sqrt
array = np.array
stack = np.vstack
splice= np.hstack
dot = np.dot
ls  = np.linspace
zeros=np.zeros
mat=np.matrix
canvas = plt.figure


# Basic constants
nd = 4
r2=sqrt(2); r3=sqrt(3);
a1=(1+r3)/(4*r2); a2=(3+r3)/(4*r2); 
a3=(3-r3)/(4*r2); a4=(1-r3)/(4*r2);
[b1,b2,b3,b4] = [a4,-a3,a2,-a1]

def D4trend(f, r=1):
    N = len(f)
    #f = list(f)
    if r == 0: return array(f)
    if N % 2**r: 
        return "D4trend: %d is not divisible by 2**%d " % (N, r)
    while r >= 1:
        N = len(f)
        array([a1*f[2*j]+a2*f[2*j+1]+a3*f[(2*j+2)%N]+a4*f[(2*j+3)%N] for j in range(N//2)])
        r -= 1
    return f
        
def D4fluct(f,r=1):
    N = len(f)
    if r == 0: return zeros(N)
    if N % 2**r: 
        return "D4fluct: %d is not divisible by 2**%d " % (N, r)
    a = D4trend(f, r - 1)
    N = len(a)
    d = array([b1*a[2*j]+b2*aa[2*j+1]+b3*a[(2*j+2)%N]+b4*a[(2*j+3)%N] \
            for j in range(N//2)])
    return d
    
def D4(f,r=1):
    N = len(f)
    f = list(f)
    if r == 0: return array(f)
    if N % 2**r: return "D4: %d is not divisible by 2**%d " % (N, r)
    d = []
    while r>= 1:
        a = D4trend(f)
        d = splice([D4fluct(f),d])
        f = a
        r -=1
    return splice([f,d])

# To construct the array of D4 level r scaling vectors
# from the array V of D4 level r-1 scaling vectors
def D4V(V):
    N = len(V)
    X = a1*V[0,:] + a2*V[1,:] + a3*V[2%N,:] + a4*V[3%N,:]
    for j in range(1, N//2):
        v = a1*V[2*j,:] + a2*V[2*j + 1,:] + a3*V[(2*j + 2)%N,:] + a4*V[(2*j + 3)%N,:] # v = a1*(row 2*j) + ...
        X = stack([X,v])
    return X

# To construct the array of D4 level r wavelet vectors
# from the array V of D4 level r-1 scaling vectors
# analogous to HaarW(V)
def D4W(V):
    N = len(V)
    Y = b1*V[0,:] + b2*V[1,:] + b3*V[2%N,:] + b4*V[3%N,:]
    for j in range(1, N//2):
        w = b1*V[2*j,:] + b2*V[2*j + 1,:] + b3*V[(2*j + 2)%N,:] + b4*V[(2*j + 3)%N,:] # v = a1*(row 2*j) + ...
        Y = stack([Y,w])
    return Y

# To construct the pair formed with the array V 
# of all D4 scale vectors and the array W of all
# D4 wavelet vectors.
def D4VW(N):
    V = Id(N)
    X = [V]
    Y = []
    while N > 2:
        W = D4W(V)
        V = D4V(V)
        X += [V]
        Y += [W]
        N = len(V)
    W = D4W(V)
    V = D4V(V)
    X += [[V]]
    Y += [[W]]
    return (X, Y)
    
## ASSIGNMENT

n = 10; N = 2**n
V, W = D4VW(N)

plt.close('all')
canvas("Examples of scaling vectors")
plt.xlim(0, N)

# level 5
X = [1, 8, 16] # horitzontal positions
Y = [2.5, 2, 1.5] # vertical positions
P5 = zip(X, Y)
for (x,y) in P5:
    plt.plot(y + V[5][x])

#level 6
X = [1, 4, 14] # horitzontal positions
Y = [1, 0.5, 0] # vertical positions
P6 = zip(X, Y)
for (x,y) in P6:
    plt.plot(y + V[6][x])

canvas("Examples of wavelets vectors")
plt.xlim(0, N)

#level 5
X = [1, 8, 16] # horitzontal positions
Y = [2.5, 2, 1.5] # vertical positions
P5 = zip(X, Y)
for (x,y) in P5:
    plt.plot(y + W[4][x])
    
# level 6
X = [1, 4, 14] # horitzontal positions
Y = [1, 0.5, 0] # vertical positions
P6 = zip(X, Y)
for (x,y) in P6:
    plt.plot(y + W[5][x])
