## A512
## Joan Gines i Ametlle

# Utilities
import sys
#sys.path.append('C:/Users/Joan/EI/UPC/Q8/CDI/module')
from cdi import *
from math import *
import numpy as np
import matplotlib.pyplot as plt

# Synonyms
Id = np.eye
sqrt = np.sqrt
array = np.array
stack = np.vstack
splice= np.hstack
dot = np.dot
ls  = np.linspace
zeros=np.zeros
mat=np.matrix
canvas = plt.figure
axis=plt.axis
show = plt.show
view = plt.imshow
close = plt.close
plot = plt.plot
xlim = plt.xlim

# Basic constants
nd = 4
r2=sqrt(2); r3=sqrt(3);
a1=(1+r3)/(4*r2); a2=(3+r3)/(4*r2); 
a3=(3-r3)/(4*r2); a4=(1-r3)/(4*r2);
[b1,b2,b3,b4] = [a4,-a3,a2,-a1]

'''
D4trend, D4fluct, D4
'''
def D4trend(f, r=1):
    N = len(f)
    if r == 0: return array(f)
    if N % 2**r: 
        return "D4trend: %d is not divisible by 2**%d " % (N, r)
    while r >= 1:
        N = len(f)
        f = array([a1*f[2*j]+a2*f[2*j+1]+ \
              a3*f[(2*j+2)%N]+a4*f[(2*j+3)%N] \
              for j in range(N//2)])
        r -= 1
    return f
        
def D4fluct(f,r=1):
    N = len(f)
    if r == 0: return zeros(N)
    if N % 2**r: 
        return "D4fluct: %d is not divisible by 2**%d " % (N, r)
    a = D4trend(f,r-1)
    N = len(a)
    d = array([b1*a[2*j]+b2*a[2*j+1]+b3*a[(2*j+2)%N]+b4*a[(2*j+3)%N] \
            for j in range(N//2)])
    return d

def D4(f,r=1):
    N = len(f)
    f = list(f)
    if r == 0: return array(f)
    if N % 2**r: return "D4: %d is not divisible by 2**%d " % (N, r)
    d = []
    while r>= 1:
        a = D4trend(f)
        d = splice([D4fluct(f),d])
        f = a
        r -=1
    return splice([f,d])

# To construct the array of D4 level r scaling vectors
# from the array V of D4 level r-1 scaling vectors
def D4V(V):
    N = len(V)
    X = a1*V[0,:]+a2*V[1,:]+a3*V[2%N,:]+a4*V[3%N,:]
    for j in range(1,N//2):
        v = a1*V[2*j,:] + a2*V[2*j+1,:] + a3*V[(2*j+2)%N,:] + a4*V[(2*j+3)%N,:] 
        X = stack([X,v])
    return X

# To construct the array of D4 level r wavelet vectors
# from the array V of D4 level r-1 scaling vectors
def D4W(V):
    N = len(V)
    Y = b1*V[0,:]+b2*V[1,:]+b3*V[2%N,:]+b4*V[3%N,:]
    for j in range(1,N//2):
        w = b1*V[2*j,:] + b2*V[2*j+1,:] + b3*V[(2*j+2)%N,:] + b4*V[(2*j+3)%N,:] 
        Y = stack([Y,w])
    return Y
    
'''
Define funtions that impement the filters 
'''
a_ = [a1,a2,a3,a4]
b_ = [b1,b2,b3,b4] = dual(a_)
U_ = up_sample
    
def H4(x): return filter(a_,x)
def G4(x): return filter(b_,x)

def HF4(f,r=1):
    if r == 0: return array(f)
    a = D4trend(f,r)
    for _ in range(r):
        a = H4(U_(a))
    return array(a)
    
def LF4(f,r=1):
    if r == 0: return zeros(len(f))
    d = G4(U_(D4fluct(f,r)))
    for _ in range(r-1):
        d = H4(U_(d))
    return array(d)
    
# compute Ar and Dr using the scaling and wavelet vectors
def A(f,r=1):
    N = len(f)
    a = list(f); V = Id(N)
    while r > 0:
        a = D4trend(a)
        V = D4V(V)
        r -= 1
    return dot(a,V)

def D(f,r=1):
    N = len(f)
    a = list(f); V = Id(N)
    while r > 0:
        d = D4fluct(a)
        a = D4trend(a)
        W = D4W(V)
        V = D4V(V)
        r -= 1
    return dot(d,W)
    
# Plot Ar and Dr    
# utility to write text T at position X
def write(X, T, ha='center', va='center', fs='16', rot=0.0, **kwargs):
    return plt.text(X[0], X[1], T, horizontalalignment=ha, verticalalignment=va, fontsize=fs, rotation=rot, **kwargs)
    
# Example

# get sample from signal f
n = 9; N = 2**n
F = lambda x: cos(32*atan(1)*x)*sin(8*atan(1)*x**2)
f = sample(F,N-1)

# create new canvas (A)
close('all')
canvas('A^r(f) for r = 2, 4, 6, 8')
xlim(0,N)

# plot original signal for reference
plot(f)
myargs = {'color':'b', 'size':'x-large'}
write((50, 0.75), r'$f(x) = cos(8\pi x)sin(2\pi x^2)$', **myargs)
cls = ['g','r','c','m']

# plot A^r for different values of r
for r in [2,4,6,8]:
    X = HF4(f,r)
    offset = 1.25*r
    plot([x+offset for x in X])
    
    myargs['color'] = cls[r//2-1]; 
    write((50, 0.75+offset), r'$A^'+str(r)+'(f)$', **myargs)
    # it verifies the results are the same as those obtained with A(f,r)
    Y = A(f,r)
    if any([abs(x-y) > 1e-12 for x,y in zip(X,Y)]):
        print('A(f,r) and HF4(f,r) give different results for r=', r)


# create new canvas (D)
canvas('D^r(f) for r = 2, 4, 6, 8')
xlim(0,N)

# plot original signal for reference
plot(f)
myargs = {'color':'b', 'size':'x-large'}
write((50, 0.75), r'$f(x) = cos(8\pi x)sin(2\pi x^2)$', **myargs)

# plot D^r for different values of r
for r in [2,4,6,8]:
    X = LF4(f,r)
    offset = 1.25*r
    plot([x+offset for x in X])
    
    myargs['color'] = cls[r//2-1]; 
    write((50, 0.75+offset), r'$D^'+str(r)+'(f)$', **myargs)
    
    # it verifies the results are the same as those obtained with D(f,r)
    Y = D(f,r)
    if any([abs(x-y) > 1e-12 for x,y in zip(X,Y)]):
        print('D(f,r) and LF4(f,r) give different results for r=', r)
        
    
