## A421_martin
## David Martín Alaminos

from cdi import *
import numpy as np
import matplotlib.pyplot as plt


#Synonyms
Id = np.eye
pi = np.pi
cos = np.cos
sin = np.sin
array = np.array
stack = np.vstack
splice = np.hstack
dot = np.dot
ls = np.linspace
zeros = np.zeros
mat = np.matrix
transpose = np.transpose
det = np.linalg.det
inv = np.linalg.inv

canvas = plt.figure
axis = plt.axis
show = plt.show
view = plt.imshow
close = plt.close
plot = plt.plot
xlim = plt.xlim



'''
Core functions for Haar wavelets algorithms
'''

# trend(f) computes the trend signal of a discrete signal f
def trend(f):
    r = 1.4142135623730951  #sqrt(2)
    J = range(len(f)//2)
    return [(f[2*j]+f[2*j+1])/r for j in J]

# fluct(f) computes the fluctuation
# (or difference) signal of a discrete signal f
def fluct(f):
    r = 1.4142135623730951  #sqrt(2)
    J = range(len(f)//2)
    return [(f[2*j]-f[2*j+1])/r for j in J]

# haar(f,r) computes the Haar transform of f or order r.
# haar(f) is equivalent to haar(f,1)
def haar(f,r=1):
    if r==0: return f
    if r==1: return (trend(f)+fluct(f))
    N=len(f); m=N//2**(r-1)
    a=haar(f,r-1); x=a[:m]
    return trend(x)+fluct(x)+a[m:]

def HaarT(f,r=1):
    if r == 0: return f
    a = list(f); h = []; N = len(a)
    if N % 2**r: return "trend: %d is not divisible by 2**%d " % (N, r)
    N = N//2**r
    while len(a)>N:
        d = fluct(a)
        h = d + h
        a = trend(a)
    return a+h

# Define a function that computes
# the array of level r scaling vectors
# from the array V of level r-1 scaling vectors
def HaarV(V):
    a1 = a2 = 2**(-1/2)
    N = len(V)
    X = a1 * V[0,:] + a2 * V[1,:]
    for j in range(1,N//2):
        x = a1 * V[2*j,:] + a2 * V[2*j+1,:]
        X = stack([X,x])
    return X

# Define a function that computes
# the array of level r wavelet vectors
# from the array V of level r-1 wavelet vectors
def HaarW(V):
    a1 = a2 = 2**(-1/2)
    N = len(V)
    X = a1 * V[0,:] - a2 * V[1,:]
    for j in range(1,N//2):
        x = a1 * V[2*j,:] - a2 * V[2*j+1,:]
        X = stack([X,x])
    return X

# Define a function that computes
# the array of all scaling vectors, that is,
# its r-th component is the matrix of level r
# scaling vectors.
def HaarVA(N):
    V = Id(N)
    X = [V]
    while N > 2:
        V = HaarV(V)
        X += [V]
        N = len(V)
    V = HaarV(V)
    X += [[V]]
    return X
VA = HaarVA

# Define a function that computes a pair
# formed with the array of all scaling vectors
# and the array of all wavelet vectors
def HaarVWA(N):
    V = Id(N)
    X = [V]
    Y = []
    while N > 2:
        W = HaarW(V)
        V = HaarV(V)
        X += [V]
        Y += [W]
        N = len(V)
    W = HaarW(V)
    V = HaarV(V)
    X += [[V]]
    Y += [[W]]
    return (X,Y)
VWA = HaarVWA



'''
Functions from previous lab L423
'''

#Projection coefficients (for any basis)
def proj_coeffs(f,V):
    return array([dot(f,v) for v in V])

#Gram matrix
def Gram(V):
    return [proj_coeffs(v,V) for v in V]

#Orthogonal projection
def proj(f,V):
    w = proj_coeffs(f,V) #indep. terms
    G = Gram(V)
    d = det(G)
    if d == 0: return 'proj: rows of V non lin. indep.'

    G = mat(G)
    t = w*inv(G)
    y = t*mat(V)
    return y.A1

#To compute A^r(f)
def high_filter(f,r=1):
    N = len(f)
    m = 2**r
    A = []
    while N >= m:
        x = sum(f[:m])
        A += m*[x]
        N -= m
        f = f[m:]
    return [a/m for a in A]

#To compute D^r(f)
def low_filter(f,r=1):
    N = len(f)
    m = 2**(r-1)
    D = []
    while N >= 2*m:
        x = sum(f[:m])
        N -= m
        f = f[m:]
        x -= sum(f[:m])
        D += m*[x] + m*[-x]
        N -= m
        f = f[m:]
    return [d/(2*m) for d in D]



'''
Assignment tasks begin here.

Use the expressions on T6.14 for A^{r}(f,r) and D^{r}(f,r), and the
scaling and wavelet arrays, to define funtions that compute them.
'''

#Computation using orthogonal projection of r-th order scaling array
#(for average signal) and wavelet array (for detail signal)

#Functions A and D require access to proper arrays V and W

def A(f,r): return proj(f,V[r])

def D(f,r): return proj(f,W[r-1])



'''
Pick your favorite signal f (to be useful, it should be reasonably tame)
and use the technique of vertical offsets to represent the A^{r}(f, r),
for a suitable range of r, in a single canvas. Ditto for D^{r}(f, r).
'''

n = 10
N = 2**n
M = 9

F = lambda x: 20*x**2 * (1-x)**4 * cos(12*pi*x) #from T6.5
f = sample(F,N-1)
V, W = VWA(N)


#Plot average signals using A(f,r)
close('all')
canvas("1-%d level average signals of function 20*x**2 * (1-x)**4 * cos(12*pi*x)" % M)
plt.title("Average signal $A(f,r)$. $f(x) = 20x^2 (1-x^{4}) cos(12 \pi x)$, $r \in 1..%d$" % M, color='black')
xlim(0,N)

for i in range(1,M+1):
    plot(i + A(f,i))


#Plot detail signals using D(f,r)
canvas("1-%d level detail signals of function 20*x**2 * (1-x)**4 * cos(12*pi*x)" % M)
plt.title("Detail signal $D(f,r)$. $f(x) = 20x^2 (1-x^{4}) cos(12 \pi x)$, $r \in 1..%d$" % M, color='black')
xlim(0,N)

for i in range(1,M+1):
    plot(i + D(f,i))


show()
#Opened windows need to be closed before moving on to the last section



'''
Compare the results with similar computations using high_filter and low_filter.
'''

#This test may take a few seconds
results_eq = True
for i in range(1,M+1):
    if not round(high_filter(f,i),6) == round(A(f,i),6):
        results_eq = False
        print("Different results for high_filter and A (level %d)" % i)
    if not round(low_filter(f,i),6) == round(D(f,i),6):
        results_eq = False
        print("Different results for low_filter and D (level %d)" % i)

if results_eq: print("Success! Same results for high_filter and A, low_filter and D")

