## A505_martin
## David Martín Alaminos

import numpy as np
import matplotlib.pyplot as plt
from cdi import *


#Synonyms
Id = np.eye
sqrt = np.sqrt
array = np.array
stack = np.vstack
splice = np.hstack
dot = np.dot
ls = np.linspace
zeros = np.zeros
mat = np.matrix

#Basic constants
nd = 4
r2 = sqrt(2); r3=sqrt(3);
a1 = (1+r3)/(4*r2); a2=(3+r3)/(4*r2);
a3 = (3-r3)/(4*r2); a4=(1-r3)/(4*r2);
[b1,b2,b3,b4] = [a4,-a3,a2,-a1]



'''
I) D4trend_rec, D4fluct_rec, D4
'''

def D4trend_rec(f, r=1):
    N = len(f)
    f = list(f)
    if r == 0: return array(f)
    if N % 2**r:
        return "D4trend_rec: %d is not divisible by 2**%d " % (N, r)
    if r == 1:
        f = f + f[:2]
        return \
            array([a1*f[2*j]+a2*f[2*j+1]+a3*f[2*j+2]+a4*f[2*j+3] for j in range(N//2)])
    else: return D4trend_rec(D4trend_rec(f),r-1)

def D4fluct_rec(f,r=1):
    N = len(f)
    f = list(f)
    if r == 0: return zeros(N)
    if N % 2**r:
        return "D4fluct_rec: %d is not divisible by 2**%d " % (N, r)
    if r == 1:
        f = f + f[:2]
        return \
            array([b1*f[2*j]+b2*f[2*j+1]+b3*f[2*j+2]+b4*f[2*j+3] \
            for j in range(N//2)])
    else: return D4fluct_rec(D4trend_rec(f,r-1))

## Iterative versions
def D4trend(f, r=1):
    N = len(f)
    if r == 0: return array(f)
    if N % 2**r:  return "D4trend: %d is not divisible by 2**%d " % (N, r)

    while r >= 1:
        N = len(f)
        f = array([a1*f[2*j]+a2*f[2*j+1]+a3*f[(2*j+2)%N]+a4*f[(2*j+3)%N] for j in range(N//2)])
        r -= 1
    return f

def D4fluct(f,r=1):
    N = len(f)
    if r == 0: return zeros(N)
    if N % 2**r:
        return "D4fluct: %d is not divisible by 2**%d " % (N, r)

    a = D4trend(f,r-1)
    N = len(a)
    d = array([b1*a[2*j]+b2*a[2*j+1]+b3*a[(2*j+2)%N]+b4*a[(2*j+3)%N] for j in range(N//2)])
    return d

def D4(f,r=1):
    N = len(f)
    f = list(f)
    if r == 0: return array(f)
    if N % 2**r: return "D4: %d is not divisible by 2**%d " % (N, r)
    d = []
    while r >= 1:
        a = D4trend(f)
        d = splice([D4fluct(f),d])
        f = a
        r -=1
    return splice([f,d])



'''
II) Daub4 scaling and wavelet arrays
'''

# To construct the array of D4 level r scaling vectors
# from the array V of D4 level r-1 scaling vectors
def D4V(V):
    N = len(V)
    X = a1*V[0,:] + a2*V[1,:] + a3*V[2%N,:] + a4*V[3%N,:]
    for j in range(1,N//2):
        v = a1*V[2*j,:] + a2*V[2*j+1,:] + a3*V[(2*j+2)%N,:] + a4*V[(2*j+3)%N,:]
        X = stack([X,v])
    return X

# To construct the array of D4 level r wavelet vectors
# from the array V of D4 level r-1 scaling vectors
def D4W(V):
    N = len(V)
    Y = b1*V[0,:] + b2*V[1,:] + b3*V[2%N,:] + b4*V[3%N,:]
    for j in range(1,N//2):
        w = b1*V[2*j,:] + b2*V[2*j+1,:] + b3*V[(2*j+2)%N,:] + b4*V[(2*j+3)%N,:]
        Y = stack([Y,w])
    return Y

# To construct the pair formed with the array V
# of all D4 scale vectors and the array W of all
# D4 wavelet vectors.
def D4VW(N):
    V = Id(N)
    X = [V]
    Y = []

    while N > 2:
        W = D4W(V)
        V = D4V(V)
        X += [V]
        Y += [W]
        N = len(V)
    W = D4W(V)
    V = D4V(V)
    X += [[V]]
    Y += [[W]]

    return (X, Y)


# Orthogonal projection in orthonormal basis
def proj(f,V):
    x = zeros(len(V[0]))
    for v in V:
        x = x + dot(f,v)*v
    return x

# Projection coefficients
def proj_coeffs(f,V):
    return array([dot(f,v) for v in V])

## The Daub4 = D4 Transform using D4V and D4W
def D4T(f,r=1):
    x = None
    return x



'''
Assignment tasks begin here.

Draw graphics of some Daub4 scaling and wavelet vectors of different
levels. You can adapt the analogous exercise done for the Haar wavelets.
'''


def plot_vectors(L, t='b-', offset=1.0, yini=0.0, lab=""):
    for i in range(len(L)):
        if i == 0: plt.plot(yini + offset*i + L[i],t,label=lab)
        else: plt.plot(yini + offset*i + L[i],t)
    return yini + offset*i

n = 10
N = 2**n
V,W = D4VW(N)

plt.close('all')

#Scaling vectors
plt.figure("Examples of scaling vectors")
plt.title("Examples of scaling vectors of 3rd, 5th and 8th level, $N=2^{%d}$"%n, color='black')
plt.xlim(0,N)

lasty = plot_vectors([V[5][0],V[5][15],V[5][31]],offset=0.5,lab="Level 5")
lasty = plot_vectors([V[6][0],V[6][7],V[6][15]],t='r-',offset=0.5,yini=0.5+lasty,lab="Level 6")
plot_vectors([V[8][0],V[8][1],V[8][2],V[8][3]],t='g-',offset=0.5,yini=0.5+lasty,lab="Level 8")
plt.legend(loc=4, ncol=3, mode="expand")


#Wavelet vectors
plt.figure("Examples of wavelet vectors")
plt.title("Examples of wavelet vectors of 3rd, 5th and 8th level, $N=2^{%d}$"%n, color='black')
plt.xlim(0,N)

lasty = plot_vectors([W[4][0],W[4][15],W[4][31]],offset=0.5,lab="Level 5")
lasty = plot_vectors([W[5][0],W[5][7],W[5][15]],t='r-',offset=0.5,yini=0.5+lasty,lab="Level 6")
plot_vectors([W[7][0],W[7][1],W[7][2],W[7][3]],t='g-',offset=0.5,yini=0.5+lasty,lab="Level 8")
plt.legend(loc=4, ncol=3, mode="expand")


plt.show()

