## A421
## Joan Gines i Ametlle

'''
The lab properly starts at line 140.
The lines until there contain the functions
defined in previous labs covering Haar wavelets algorithms 
'''

# Utilities
#import sys
#sys.path.append('C:/Users/Joan/EI/UPC/Q8/CDI/module')
from cdi import *
from math import *
import numpy as np
import matplotlib.pyplot as plt

# Synonyms
Id = np.eye
#sqrt = np.sqrt
pi = np.pi
cos = np.cos
sin = np.sin
array = np.array
stack = np.vstack
splice= np.hstack
dot = np.dot
ls  = np.linspace
zeros=np.zeros
mat=np.matrix
transpose=np.transpose
det = np.linalg.det
inv = np.linalg.inv

canvas = plt.figure
axis=plt.axis
show = plt.show
view = plt.imshow
close = plt.close
plot = plt.plot
xlim = plt.xlim


# Core functions 

# trend(f) computes the trend signal of a discrete signal f
def trend(f):
    r = 1.4142135623730951  #sqrt(2) 
    J = range(len(f)//2)
    return [(f[2*j]+f[2*j+1])/r for j in J]

# fluct(f) computes the fluctuation 
# (or difference) signal of a discrete signal f    
def fluct(f):
    r = 1.4142135623730951  #sqrt(2) 
    J = range(len(f)//2)
    return [(f[2*j]-f[2*j+1])/r for j in J]

# haar(f,r) computes the Haar transform of f or order r. 
# haar(f) is equivalent to haar(f,1)
def haar(f,r=1): 
    if r==0: return f
    if r==1: return (trend(f)+fluct(f))
    N=len(f); m=N//2**(r-1)
    a=haar(f,r-1); x=a[:m]
    return trend(x)+fluct(x)+a[m:]

def HaarT(f,r=1): 
    if r == 0: return f
    a = list(f); h = []; N = len(a)
    if N % 2**r: return "trend: %d is not divisible by 2**%d " % (N, r)
    N = N//2**r
    while len(a)>N:
        d = fluct(a)
        h = d + h
        a = trend(a)
    return a+h
D2 = HaarT

# Define a function that computes 
# the array of level r scaling vectors
# from the array V of level r-1 scaling vectors
def HaarV(V):
    a1 = a2 = 2**(-1/2)
    N = len(V)
    X = a1 * V[0,:] + a2 * V[1,:]
    for j in range(1,N//2):
        x = a1 * V[2*j,:] + a2 * V[2*j+1,:]
        X = stack([X,x])
    return X
    
# Define a function that computes 
# the array of level r scaling vectors
# from the array V of level r-1 scaling vectors
def HaarW(V):
    a1 = a2 = 2**(-1/2)
    N = len(V)
    X = a1 * V[0,:] - a2 * V[1,:]
    for j in range(1,N//2):
        x = a1 * V[2*j,:] - a2 * V[2*j+1,:]
        X = stack([X,x])
    return X
    
# Define a function that computes 
# the array of all scaling vectors, that is,
# its r-th component is the matrix of level r
# scaling vectors.
def HaarVA(N):
    V = Id(N)
    X = [V]
    while N > 2:
        V = HaarV(V)
        X += [V]
        N = len(V)
    V = HaarV(V)
    X += [[V]]
    return X
VA = HaarVA

# Define a function that computes a pair 
# formed with the array of all scaling vectors
# and the array of all wavelet vectors
def HaarVWA(N):
    V = Id(N)
    X = [V]
    Y = []
    while N > 2:
        W = HaarW(V)
        V = HaarV(V)
        X += [V]
        Y += [W]
        N = len(V)
    W = HaarW(V)
    V = HaarV(V)
    X += [[V]]
    Y += [[W]]
    return (X,Y)
VWA = HaarVWA


# Tasks for L423
'''
1. Plot some scaling and wavelet vectors
'''

'''
n = 10; N = 2**n
V, W = VWA(N)

close('all')
canvas('1. Examples of scaling vectors')
xlim(0,N)

plot(2.5+V[5][1])
plot(2+V[5][8])
plot(1.5+V[5][16])

plot(1+V[6][1])
plot(0.5+V[6][4])
plot(V[6][8])

canvas('2. Examples of wavelet vectors')
xlim(0,N)

# level 5
X = [1,8,16]        # horizontal positions
Y = [2.5,2,1.5]
P5 = zip(X,Y)
for (x,y) in P5:
    plot(y+W[4][x])

# level 6
X = [1,4,14]
Y = [1,0.5,0]
P6 = zip(X,Y)
for (x,y) in P6:
    plot(y+W[5][x])
'''


'''
2. Define a function that computes the orthogonal
   projection. Include some examples.
'''

# Orthogonal projection in orthonormal basis
def proj_on(f,U):
    x = zeros(len(U[0]))
    U = array(U)
    for u in U:
        x = x + dot(f,u)*u
    return x

# Projection coefficients
def proj_coeffs(f,V):
    return array([dot(f,v) for v in V])
    
gram = proj_coeffs

def Gram(V):
    return [gram(v,V) for v in V]

def proj(f,V):
    w = gram(f,V)
    G = Gram(V)
    d = det(G)
    if d == 0: return 'proj: rows of V non lin indep.'
    G = mat(G)
    t = w*inv(G)
    y = t * mat(V)
    return y.A1

'''
# Examples
e0 = [1,0,0]; e1 = [0,1,0]; e2 = [0,0,1]
x = [3,5,-2]

x0 = proj_on(x,[e0])
x1 = proj_on(x,[e0,e1])
x2 = proj_on(x,[e0,e1,e2])

print(proj_coeffs(x,[e0,e1,e2]))
print(proj(x,[e0,e1,e2]))
'''


'''
3. Define functions 
   - high_filter(f,r=1)
   - low_filter(f,r=1)
   that compute A^r(f) and D^r(f)
   Make some examples.
'''

# compute A^r(f)
def high_filter(f,r=1):
    N = len(f); m = 2**r; A = []
    while N >= m:
        x = sum(f[:m])
        A += m*[x]
        N -= m
        f = f[m:]
    return [a/m for a in A]
    
    
# compute D^r(f) 
def low_filter(f,r=1):
    N = len(f); m = 2**(r-1); D = []
    while N >= 2*m:
        x = sum(f[:m])
        N -= m
        f = f[m:]
        x -= sum(f[:m])
        D += m*[x]+m*[-x]
        N -= m
        f = f[m:]
    return [d/(2*m) for d in D] 
    
    
'''
4. Compute A^r and D^r using multiresolution
   and give examples. Compare with the filter procedures 
'''

# utility to write text T at position X
def write(X, T, ha='center', va='center', fs='16', rot=0.0, **kwargs):
    return plt.text(X[0], X[1], T, horizontalalignment=ha, verticalalignment=va, fontsize=fs, rotation=rot, **kwargs)
    
def A(f,r=1):
    N = len(f)
    a = list(f); V = Id(N)
    while r > 0:
        a = trend(a)
        V = HaarV(V)
        r -= 1
    return dot(a,V)

def D(f,r=1):
    N = len(f)
    a = list(f); V = Id(N)
    while r > 0:
        d = fluct(a)
        a = trend(a)
        W = HaarW(V)
        V = HaarV(V)
        r -= 1
    return dot(d,W)
    
# Example

# get sample from signal f
n = 9; N = 2**n
F = lambda x: cos(32*atan(1)*x)*sin(8*atan(1)*x**2)
f = sample(F,N-1)

# create new canvas (A)
close('all')
canvas('A^r(f) for r = 2, 4, 6, 8')
xlim(0,N)

# plot original signal for reference
plot(f)
myargs = {'color':'b', 'size':'x-large'}
write((50, 0.75), r'$f(x) = cos(8\pi x)sin(2\pi x^2)$', **myargs)
cls = ['g','r','c','m']

# plot A^r for different values of r
for r in [2,4,6,8]:
    X = A(f,r)
    offset = 1.25*r
    plot([x+offset for x in X])
    
    myargs['color'] = cls[r//2-1]; 
    write((50, 0.75+offset), r'$A^'+str(r)+'(f)$', **myargs)
    
    # it verifies the results are the same as those obtained with high_filter(f,r)
    Y = high_filter(f,r)
    if any([abs(x-y) > 1e-12 for x,y in zip(X,Y)]):
        print('A(f,r) and high_filter(f,r) give different results for r=', r)


# create new canvas (D)
canvas('D^r(f) for r = 2, 4, 6, 8')
xlim(0,N)

# plot original signal for reference
plot(f)
myargs = {'color':'b', 'size':'x-large'}
write((50, 0.75), r'$f(x) = cos(8\pi x)sin(2\pi x^2)$', **myargs)

# plot D^r for different values of r
for r in [2,4,6,8]:
    X = D(f,r)
    offset = 1.25*r
    plot([x+offset for x in X])
    
    myargs['color'] = cls[r//2-1]; 
    write((50, 0.75+offset), r'$D^'+str(r)+'(f)$', **myargs)
    
    # it verifies the results are the same as those obtained with low_filter(f,r)
    Y = low_filter(f,r)
    if any([abs(x-y) > 1e-12 for x,y in zip(X,Y)]):
        print('D(f,r) and low_filter(f,r) give different results for r=', r)
