## L423_Velisek
'''
The lab proper starts at line 140.
The lines till there contain the functions
defined in previous labs devoted to
Haar wavelets algorithms 
'''

# Utilities
from cdi import *
import numpy as np
import matplotlib.pyplot as plt

# Synonyms
Id = np.eye
#sqrt = np.sqrt
pi = np.pi
cos = np.cos
sin = np.sin
array = np.array
stack = np.vstack
splice= np.hstack
dot = np.dot
ls  = np.linspace
zeros=np.zeros
mat=np.matrix
transpose=np.transpose
det = np.linalg.det
inv = np.linalg.inv

canvas = plt.figure
axis=plt.axis
show = plt.show
view = plt.imshow
close = plt.close
plot = plt.plot
xlim = plt.xlim


## Core functions 

# trend(f) computes the trend signal of a discrete signal f
def trend(f):
    r = 1.4142135623730951  #sqrt(2) 
    J = range(len(f)//2)
    return [(f[2*j]+f[2*j+1])/r for j in J]

# fluct(f) computes the fluctuation 
# (or difference) signal of a discrete signal f    
def fluct(f):
    r = 1.4142135623730951  #sqrt(2) 
    J = range(len(f)//2)
    return [(f[2*j]-f[2*j+1])/r for j in J]

# haar(f,r) computes the Haar transform of f or order r. 
# haar(f) is equivalent to haar(f,1)
def haar(f,r=1): 
    if r==0: return f
    if r==1: return (trend(f)+fluct(f))
    N=len(f); m=N//2**(r-1)
    a=haar(f,r-1); x=a[:m]
    return trend(x)+fluct(x)+a[m:]

# Define a function that computes 
# the array of level r scaling vectors
# from the array V of level r-1 scaling vectors
def HaarV(V):
    a1 = a2 = 2**(-1/2)
    N = len(V)
    X = a1 * V[0,:] + a2 * V[1,:]
    for j in range(1,N//2):
        x = a1 * V[2*j,:] + a2 * V[2*j+1,:]
        X = stack([X,x])
    return X

def HaarT(f,r=1): 
    if r == 0: return f
    a = list(f); h = []; N = len(a)
    if N % 2**r: return "trend: %d is not divisible by 2**%d " % (N, r)
    N = N//2**r
    while len(a)>N:
        d = fluct(a)
        h = d + h
        a = trend(a)
    return a+h
D2 = HaarT
    
# Define a function that computes 
# the array of level r scaling vectors
# from the array V of level r-1 scaling vectors
def HaarW(V):
    a1 = a2 = 2**(-1/2)
    N = len(V)
    X = a1 * V[0,:] - a2 * V[1,:]
    for j in range(1,N//2):
        x = a1 * V[2*j,:] - a2 * V[2*j+1,:]
        X = stack([X,x])
    return X
    
    
# Define a function that computes 
# the array of all scaling vectors, that is,
# its r-th component is the matrix of level r
# scaling vectors.
def HaarVA(N):
    V = Id(N)
    X = [V]
    while N > 2:
        V = HaarV(V)
        X += [V]
        N = len(V)
    V = HaarV(V)
    X += [[V]]
    return X
VA = HaarVA

# Define a function that computes a pair 
# formed with the array of all scaling vectors
# and the array of all wavelet vectors
def HaarVWA(N):
    V = Id(N)
    X = [V]
    Y = []
    while N > 2:
        W = HaarW(V)
        V = HaarV(V)
        X += [V]
        Y += [W]
        N = len(V)
    W = HaarW(V)
    V = HaarV(V)
    X += [[V]]
    Y += [[W]]
    return (X,Y)
VWA = HaarVWA




## Tasks for L423
'''
1. Plot some scaling and wavelet vectors
'''

n = 10; N = 2**n
V, W = VWA(N)
'''
close('all')
canvas('1. Example of a scaling vectors')
xlim(0,N)

plot(2.5 + V[5][1])
plot(2 + V[5][8])
plot(1.5 + V[5][16])

plot(1 + V[6][1])
plot(0.5 + V[6][4])
plot(0 + V[6][8])

canvas('2. Example of a wavelet vectors')
xlim(0,N)

X = [1, 8, 16] # horizontal positions
Y = [2.5, 2, 1.5]
P5 = zip(X, Y)
for (x, y) in P5 :
    plot(y + W[4][x])
    
X = [1, 4, 14] # horizontal positions
Y = [1, 0.5, 0]
P6 = zip(X, Y)
for (x, y) in P6 :
    plot(y + W[5][x])
'''

'''
2. Define a function that compute the orthogonal
   projection. Include some examples.
'''
# Orthogonal projection in orthonormal basis
def proj_on(f,U):
    x = zeros(len(U[0]))
    U = array(U)
    for u in U:
        x += dot(f, u) * u
    return x

# Projection coefficients
def proj_coeffs(f,V):
    return array([dot(f,v) for v in V])
gram = proj_coeffs

def Gram(V) :
    return [gram(v, V) for v in V]

def proj(f, V) :
    w = gram(f, V)
    G = Gram(V)
    d = det(G)
    if d == 0 :
        return 'proj: rows of V non lin.indep.'
    G = mat(G)
    t = w * inv(G)
    y = t * mat(V)
    return y.A1
'''
# Examples
e0 = [1, 0, 0]
e1 = [0, 1, 0]
e2 = [0, 0, 1]
x = [3, 5, -2]

x0 = proj_on(x, [e0])
x1 = proj_on(x, [e0, e1])
x2 = proj_on(x, [e0, e1, e2])

print(x0, x1, x2)

y0 = proj_coeffs(x, [e0])
y1 = proj_coeffs(x, [e0, e1])
y2 = proj_coeffs(x, [e0, e1, e2])

print(y0, y1, y2)

z0 = proj(x, [e0])
z1 = proj(x, [e0, e1])
z2 = proj(x, [e0, e1, e2])

print(z0, z1, z2)
'''

'''
3. Define functions 
   - high_filter(f,r=1)
   - low_filter(f,r=1)
   that compute A^r(f) and D^r(f)
   Make some examples.
'''
# To compute A^r(f) 
def high_filter(f,r=1):
    N = len(f)
    m = 2**r
    A = []
    while N >= m :
        x = sum(f[:m])
        A += m*[x]
        N -= m
        f = f[m:]
    return [a/m for a in A]

# To comput D^r(f)
def low_filter(f,r=1):
    N = len(f)
    m = 2**(r-1)
    D = []
    while N >= 2*m :
        x = sum(f[:m])
        N -= m
        f = f[m:]
        x -= sum(f[:m])
        D += m*[x] + m*[-x]
        N -= m
        f = f[m:]
    return [d/(2*m) for d in D]

'''
4. Compute A^r and D^r using multiresolution
   and give examples. Compare with the filter procedures 
'''
def A(f,r=1): 
    N = len(f)
    V, W = VWA(len(f))
    A = []
    m = 2**r
    for x in range(N//(2**r) - 1) :
        y = dot(f,V[r][x])
        z = V[r][x][x*m]
        A += (m * [y * z]) 
    return A
    
    
def D(f,r):
    N = len(f)
    V, W = VWA(len(f))
    D = []
    m = 2**(r-1)
    for x in range(N//(2**r)) :
        y = dot(f,W[r-1][x])
        z = W[r-1][x][x*(2*m)]
        D += (m * [y * z]) 
        z = W[r-1][x][x*(2*m)+m]
        D += (m * [y * z]) 
    return D

# Examples

f = [ (x/N-1)**2 * (sin(x/64)) for x in range(N)]

# Tests

precision = 10

for r in range(1, 8+1) :
    print('Test r =',r,'of',8)
    Af = round(A(f, r), precision)
    Df = round(D(f, r), precision)
    Hf = round(high_filter(f, r), precision)
    Lf = round(low_filter(f, r), precision)
    for x in range(len(Af)) :
        if Af[x] != Hf[x] :
            print('Test fail: r =',r,',  Af[',x,'] != Hf[',x,']  => ',Af[x],'!=',Hf[x])
        if Df[x] != Lf[x] :
            print('Test fail: r =',r,',  Df[',x,'] != Lf[',x,']  => ',Df[x],'!=',Lf[x])
print('All tests done')

# Plots

# Utility to write text T at position X with
# four parameters with default values
def write(X,T,ha='center', va = 'center', fs='16', rot=0.0, color='#000000'):
    return plt.text(X[0],X[1], T, horizontalalignment=ha,verticalalignment=va, fontsize=fs, rotation=rot, color=color)
    
    
print('Drawing plots...')
close('all')
canvas('1. A^r and D^r using multiresolution')
xlim(0,N)
    
offset = 2

for r in range(1, 8+1) :
    plot([r*offset + y for y in f], color='#ff0000')
    plot([r*offset + y for y in A(f, r)], color='#00bb00')
    plot([r*offset + y for y in D(f, r)], color='#0000ff')
    
write((N//4, 0.1), '$f$', va = 'bottom', color = '#ff0000')    
write((N//2, 0.1), '$A$', va = 'bottom', color = '#00bb00')
write((N*3//4, 0.1), '$D$', va = 'bottom', color = '#0000ff')

    
    
canvas('2. A^r and D^r using filters')
xlim(0,N)
    
for r in range(1, 8+1) :
    plot([r*offset + y for y in f], color='#ff0000')
    plot([r*offset + y for y in high_filter(f, r)], color='#00bb00')
    plot([r*offset + y for y in low_filter(f, r)], color='#0000ff')
    
write((N//4, 0.1), '$f$', va = 'bottom', color = '#ff0000')    
write((N//2, 0.1), '$high filter$', va = 'bottom', color = '#00bb00')
write((N*3//4, 0.1), '$low filter$', va = 'bottom', color = '#0000ff')


    
print('Plots done')












