## L423_x
'''
The lab proper starts at line 140.
The lines till there contain the functions
defined in previous labs devoted to
Haar wavelets algorithms 
'''

# Utilities
from cdi import *
import numpy as np
import matplotlib.pyplot as plt

# Synonyms
Id = np.eye
#sqrt = np.sqrt
pi = np.pi
cos = np.cos
sin = np.sin
array = np.array
stack = np.vstack
splice= np.hstack
dot = np.dot
ls  = np.linspace
zeros=np.zeros
mat=np.matrix
transpose=np.transpose
det = np.linalg.det
inv = np.linalg.inv

canvas = plt.figure
axis=plt.axis
show = plt.show
view = plt.imshow
close = plt.close
plot = plt.plot
xlim = plt.xlim


## Core functions 

# trend(f) computes the trend signal of a discrete signal f
def trend(f):
    r = 1.4142135623730951  #sqrt(2) 
    J = range(len(f)//2)
    return [(f[2*j]+f[2*j+1])/r for j in J]

# fluct(f) computes the fluctuation 
# (or difference) signal of a discrete signal f    
def fluct(f):
    r = 1.4142135623730951  #sqrt(2) 
    J = range(len(f)//2)
    return [(f[2*j]-f[2*j+1])/r for j in J]

# haar(f,r) computes the Haar transform of f or order r. 
# haar(f) is equivalent to haar(f,1)
def haar(f,r=1): 
    if r==0: return f
    if r==1: return (trend(f)+fluct(f))
    N=len(f); m=N//2**(r-1)
    a=haar(f,r-1); x=a[:m]
    return trend(x)+fluct(x)+a[m:]

# Define a function that computes 
# the array of level r scaling vectors
# from the array V of level r-1 scaling vectors
def HaarV(V):
    a1 = a2 = 2**(-1/2)
    N = len(V)
    X = a1 * V[0,:] + a2 * V[1,:]
    for j in range(1,N//2):
        x = a1 * V[2*j,:] + a2 * V[2*j+1,:]
        X = stack([X,x])
    return X
#
def HaarT(f,r=1): 
    if r == 0: return f
    a = list(f); h = []; N = len(a)
    if N % 2**r: return "trend: %d is not divisible by 2**%d " % (N, r)
    N = N//2**r
    while len(a)>N:
        d = fluct(a)
        h = d + h
        a = trend(a)
    return a+h
D2 = HaarT
    
# Define a function that computes 
# the array of level r scaling vectors
# from the array V of level r-1 scaling vectors
def HaarW(V):
    a1 = a2 = 2**(-1/2)
    N = len(V)
    X = a1 * V[0,:] - a2 * V[1,:]
    for j in range(1,N//2):
        x = a1 * V[2*j,:] - a2 * V[2*j+1,:]
        X = stack([X,x])
    return X
    
    
# Define a function that computes 
# the array of all scaling vectors, that is,
# its r-th component is the matrix of level r
# scaling vectors.
def HaarVA(N):
    V = Id(N)
    X = [V]
    while N > 2:
        V = HaarV(V)
        X += [V]
        N = len(V)
    V = HaarV(V)
    X += [[V]]
    return X
VA = HaarVA

# Define a function that computes a pair 
# formed with the array of all scaling vectors
# and the array of all wavelet vectors
def HaarVWA(N):
    V = Id(N)
    X = [V]
    Y = []
    while N > 2:
        W = HaarW(V)
        V = HaarV(V)
        X += [V]
        Y += [W]
        N = len(V)
    W = HaarW(V)
    V = HaarV(V)
    X += [[V]]
    Y += [[W]]
    return (X,Y)
VWA = HaarVWA




## Tasks for L423
'''
1. Plot some scaling and wavelet vectors
'''

n = 10; N = 2**n
V, W = VWA(N)

'''
close('all')
canvas('1. Examples of scaling vectors')
xlim(0,N)

plot(2.5+V[5][1])
plot(2+V[5][8])
plot(1.5+V[5][16])

plot(1+V[6][1])
plot(0.5+V[6][4])
plot(V[6][8])

#improve it with zip

canvas('2. Examples of wavelet vectors')
xlim(0,N)

X = [1,8,16]; Y = [2.5,2,1.5]
P5 = zip(X,Y)
for (x,y) in P5:
    plot(y+W[4][x])
    
X = [1,4,15]; Y = [1,0.5,0]
P6 = zip(X,Y)
for (x,y) in P6:
    plot(y+W[5][x])
'''

'''
2. Define a function that compute the orthogonal
   projection. Include some examples.
'''
# Orthogonal projection in orthonormal basis
def proj(f,V):
    x = zeros(len(V[0]))
    for v in V:
        x = x + dot(f,v)*v
    return x  

# Projection coefficients
def proj_coeffs(f,V):
    return array([dot(f,v) for v in V])

def gram(x,V): return [dot(x,v)for v in V]
def Gram(V): return [gram(v,V) for v in V]

def proj(x,V):
    w = gram(x,V)
    G = Gram(V)
    d = det(G)
    if d == 0: 
        return "proj: the rows of V are not lineraly independent"
    G = mat(G)
    t = w*inv(G)
    y = t*mat(V)
    return y.A1

# Examples, I
e0 = [1,0,0]; e1=[0,1,0]; e2=[0,0,1]
x = [3,5,-2]
x0 = proj(x,[e0])
print('x0 =',x0)

x1 = proj(x,[e0, e1])
print('x1 =',x1)

x2 = proj(x,[e0, e1, e2])
print('x2 =',x2)

# Examples, II
x = [1,1,-2]           # 3D vector 
V =[[4,3,2],[7,-3,-8]] # This represents a plane in 3D
x1 = proj(x,V)         # orthogonal projection of x on U 

# a check
v1 = V[0]; v2 = V[1]
d1=dot(x,v1)-dot(x1,v1)
d2=dot(x,v2)-dot(x1,v2)
print("d1 and d2 should vanish:")
print(round(d1,12), round(d2,12))


'''
3. Define functions 
   - high_filter(f,r=1)
   - low_filter(f,r=1)
   that compute A^r(f) and D^r(f)
'''
# To compute A^r(f) -- T6.13 to T6.16
def high_filter(f,r=1):
    N=len(f); m=2**r; A=[]; 
    while N>=m:
        x=sum(f[:m])      
        A += m*[x]
        N -= m
        f = f[m:]
    return [a/m for a in A]

# To comput D^r(f) -- T5.13 to T5.16
def low_filter(f,r=1):
    N=len(f); m=2**(r-1); D=[]; 
    while N>=2*m:
        x=sum(f[:m])
        N -= m
        f = f[m:]
        x -= sum(f[:m])
        D += m*[x]+ m*[-x]
        N -= m
        f = f[m:]
    return [d/(2*m) for d in D]

'''
# Examples
F = lambda x: 15*x**2*(1-x)**4*cos(9*pi*x)
f = sample(F,N-1)
close('all')
for i in range(1,n+1):
    A = high_filter(f,i); D = low_filter(f,i)
    canvas("Level %d high and low filter "%i)
    xlim(0,N)
    plot(f); plot(A); plot(D)
'''

'''
4. Compute A^r and D^r using multiresolution
   and give examples. Compare with the filter procedures 
'''
def A(f,r): return proj(f,V[r])
def A1(f,r): return array(high_filter(f,r))
def D(f,r): return proj(f,W[r-1])
def D1(f,r): return array(low_filter(f,r))

# Mildly oscillating function
F = lambda t: 1500*t**2*(1-t)**4*((t-0.35)**2+0.01)*np.cos(41*t)
t = ls(0,1,N)
f = F(t)

lo =-1; hi=15; sep=1.4 

# Examples
V, W = VWA(N)
close('all')

canvas('1. Graphs of A(f,r), r=0,1,...,%d'%n)

axis([0, N, lo, hi])

for r in range(n+1):
    plot(sep*r+A(f,r))
    
canvas('1*. Graphs of A1(f,r), r=0,1,...,%d'%n)
axis([0, N, lo, hi])

for r in range(n+1):
    plot(sep*r+A1(f,r))
    
canvas('2. Graphs of D(f,r), r=1,...,%d'%n)
axis([0, N, lo, hi])
    
# Be careful: index 0 for W is level 1
for r in range(1,n+1):
    plot(sep*r+D(f,r))
    
canvas('2*. Graphs of D1(f,r), r=1,...,%d'%n)
axis([0, N, lo, hi])
    
# Be careful: index 0 for W is level 1
for r in range(1,n+1):
    plot(sep*r+D1(f,r))

# Checking equalities A=A1 and D=D1
nd = 14
Error = False
for r in range(1,n+1):
    X = A(f,r); X1 = A1(f,r)
    Y = D(f,r); Y1 = D1(f,r)
    if any((abs(x-x1)>10**(-nd) for (x,x1) in zip(X,X1))):
        print('A(f,%d) and A1(f,%d) differ in the first %d decimal places'%(r,nd))
        error = True
    if any((abs(y-y1)>10**(-nd) for (y,y1) in zip(Y,Y1))):
        print('D(f,%d) and D1(f,%d) differ in the first %d decimal places'%(r,nd))
        error = True
if Error == False: 
    print('A(f,r) = A1(f,r) and D(f,r) = D1(f,r) for r = 1,...,%d up to %d decimal digits'%(r,nd))