## L512 A512
## Albert Puente Encinas

'''
1. Complete L507 by defining D4T(f,r=1) that implements
the Daub4 = D4 Transform using D4V and D4W (goto L155)
'''

import numpy as np
from cdi_haar import *
from matplotlib.pyplot import *

# Synonyms
Id = np.eye
sqrt = np.sqrt
array = np.array
stack = np.vstack
splice= np.hstack
dot = np.dot
ls  = np.linspace
zeros=np.zeros
mat=np.matrix
power = np.power
exp = np.exp

# Basic constants
nd = 4
r2=sqrt(2); r3=sqrt(3);
a1=(1+r3)/(4*r2); a2=(3+r3)/(4*r2); 
a3=(3-r3)/(4*r2); a4=(1-r3)/(4*r2);
[b1,b2,b3,b4] = [a4,-a3,a2,-a1]

'''
I) D4trend, D4fluct, D4
'''

def D4trend_rec(f, r=1):
    N = len(f)
    f = list(f)
    if r == 0: return array(f)
    if N % 2**r: 
        return "D4trend_rec: %d is not divisible by 2**%d " % (N, r)
    if r == 1:
        f = f + f[:2]
        return \
            array([a1*f[2*j]+a2*f[2*j+1]+a3*f[2*j+2]+a4*f[2*j+3] \
              for j in range(N//2)])
    else: return D4trend_rec( D4trend_rec(f,1),r-1)
        
def D4fluct_rec(f,r=1):
    N = len(f)
    f = list(f)
    if r == 0: return zeros(N)
    if N % 2**r: 
        return "D4fluct_rec: %d is not divisible by 2**%d " % (N, r)
    if r == 1:
        f = f + f[:2]
        return \
            array([b1*f[2*j]+b2*f[2*j+1]+b3*f[2*j+2]+b4*f[2*j+3] \
            for j in range(N//2)])
    else: return D4fluct_rec(D4trend_rec(f,r-1),1)


def D4trend(f, r=1):
    N = len(f)
    if r == 0: return array(f)
    if N % 2**r: 
        return "D4trend: %d is not divisible by 2**%d " % (N, r)
    while r >= 1:
        N = len(f)
        f = array([a1*f[2*j]+a2*f[2*j+1]+ \
              a3*f[(2*j+2)%N]+a4*f[(2*j+3)%N] \
              for j in range(N//2)])
        r -= 1
    return f
        
def D4fluct(f,r=1):
    N = len(f)
    if r == 0: return zeros(N)
    if N % 2**r: 
        return "D4fluct: %d is not divisible by 2**%d " % (N, r)
    a = D4trend(f,r-1)
    N = len(a)
    d = array([b1*a[2*j]+b2*a[2*j+1]+b3*a[(2*j+2)%N]+b4*a[(2*j+3)%N] \
            for j in range(N//2)])
    return d

def D4(f,r=1):
    N = len(f)
    f = list(f)
    if r == 0: return array(f)
    if N % 2**r: return "D4: %d is not divisible by 2**%d " % (N, r)
    d = []
    while r>= 1:
        a = D4trend(f)
        d = splice([D4fluct(f),d])
        f = a
        r -=1
    return splice([f,d])


'''
II) Daub4 scaling and wavelet arrays
'''

# To construct the array of D4 level r scaling vectors
# from the array V of D4 level r-1 scaling vectors
def D4V(V):   # analogous to HaarV(V)
    N = len(V)
    X = a1*V[0,:]+a2*V[1,:]+a3*V[2%N,:]+a4*V[3%N,:]
    for j in range(1,N//2):
        v = a1*V[2*j,:]+a2*V[2*j+1,:]+ \
            a3*V[(2*j+2)%N,:]+a4*V[(2*j+3)%N,:] 
        X = stack([X,v])
    return X

# To construct the array of D4 level r wavelet vectors
# from the array V of D4 level r-1 scaling vectors
def D4W(V):  # analogous to HaarW(V)
    N = len(V)
    Y = b1*V[0,:]+b2*V[1,:]+b3*V[2%N,:]+b4*V[3%N,:]
    for j in range(1,N//2):
        w = b1*V[2*j,:]+b2*V[2*j+1,:]+ \
            b3*V[(2*j+2)%N,:]+b4*V[(2*j+3)%N,:] 
        Y = stack([Y,w])
    return Y
    
# To construct the pair formed with the array V 
# of all D4 scale vectors and the array W of all
# D4 wavelet vectors.
def D4VW(N):  # analogous to HaarVW(V)
    V = Id(N)
    X = [V]
    Y = []
    while N > 2:
        W = D4W(V)
        V = D4V(V)
        X += [V]
        Y += [W]
        N = len(V)
    W = D4W(V)
    V = D4V(V)
    X += [[V]]
    Y += [[W]]
    return (X, Y)

# Orthogonal projection in orthonormal basis
def proj(f,V):
    x = zeros(len(V[0]))
    for v in V:
        x = x + dot(f,v)*v
    return x  

# Projection coefficients
def proj_coeffs(f,V):
    return array([dot(f,v) for v in V])

# The Daub4 = D4 Transform using D4V and D4W    
def D4T(f, r = 1):
    N = len(f)
    V, W = D4VW(f,N)
    x = proj_coeffs(f, V[r])
    for j in range(r, 0, -1):
        d = proj_coffs(f, W[r-1])
        x = splice([x, d])
    return x


'''
2. Define funtions
'''

def up_sample(a):
    x = []
    for t in a:
        x += [t, 0]
    return array(x)
    
U_ = up_sample

def dual(h):
    s = 1
    hd = [] 
    for t in reversed(h):
        hd += [s*t]
        s = -s
    return hd
    
a_ = [a1, a2, a3, a4]
b_ = dual(a_)

def H4(x): return filter(a_, x)
def G4(x): return filter(b_, x)

def HF4(f, r = 1):
    if r == 0: 
        return array(f)
    a = D4trend(f, r)
    for _ in range(r):
        a = H4(U_(a))
    return array(a)

def LF4(f, r = 1):
    if r == 0:
        return zeros(len(f))
    d = G4(U_(D4fluct(f,r)))
    for _ in range(r-1):
        d = H4(U_(d))
    return array(d)
    
## A512 Albert Puente Encinas
'''
Compute A^r and D^r using multiresolution
'''

from matplotlib.pyplot import *


def A(f,r): return proj(f,V[r])
def AF(f,r): return array(HF4(f,r))
def D(f,r): return proj(f,W[r-1])
def DF(f,r): return array(LF4(f,r))

n = 2**10

V, W = D4VW(n)

X = np.arange(0.0, 1.0, 1/n)

def f(x): 
    return 4*power(x, 2)*exp(-256*power(x, 8)*power(sin(40*x), 2))

Y = f(X)

close('all')
fig, axes = subplots(nrows=3, ncols=1)
fig.set_facecolor("w")
fig.tight_layout()
subplot(2, 3, 1, title = "A(f,1)")
plot(X, A(Y,1), color = 'g', lw = 1)
grid('on')

subplot(2, 3, 2, title = "A(f,2)")
plot(X, A(Y,2), color = 'r', lw = 1)
grid('on')

subplot(2, 3, 3, title = "A(f,4)")
plot(X, A(Y,4), color = 'b', lw = 1)
grid('on')

subplot(2, 3, 4, title = "AF(f,1)")
plot(X, AF(Y,1), color = 'g', lw = 1)
grid('on')

subplot(2, 3, 5, title = "AF(f,2)")
plot(X, AF(Y,2), color = 'r', lw = 1)
grid('on')

subplot(2, 3, 6, title = "AF(f,4)")
plot(X, AF(Y,4), color = 'b', lw = 1)
grid('on')

fig.show()

fig, axes = subplots(nrows=3, ncols=1)
fig.set_facecolor("w")
fig.tight_layout()
subplot(2, 3, 1, title = "D(f,1)")
plot(X, D(Y,1), color = 'g', lw = 1)
grid('on')

subplot(2, 3, 2, title = "D(f,2)")
plot(X, D(Y,2), color = 'r', lw = 1)
grid('on')

subplot(2, 3, 3, title = "D(f,4)")
plot(X, D(Y,4), color = 'b', lw = 1)
grid('on')

subplot(2, 3, 4, title = "DF(f,1)")
plot(X, DF(Y,1), color = 'g', lw = 1)
grid('on')

subplot(2, 3, 5, title = "DF(f,2)")
plot(X, DF(Y,2), color = 'r', lw = 1)
grid('on')

subplot(2, 3, 6, title = "DF(f,4)")
plot(X, DF(Y,4), color = 'b', lw = 1)
grid('on')

# Differences appear 10e-16
# Example: maximum difference between A and AF for r = 4
print ('Maximum difference between A and AF for r = 4: ', max(abs(A(Y,4) - AF(Y,4))))

