## A512_martin
## David Martín Alaminos

from cdi import *
import numpy as np
import matplotlib.pyplot as plt
from sys import stdout


#Synonyms
Id = np.eye
sqrt = np.sqrt
array = np.array
stack = np.vstack
splice = np.hstack
dot = np.dot
ls = np.linspace
zeros = np.zeros
mat = np.matrix

#Basic constants
nd = 4
r2 = sqrt(2); r3=sqrt(3);
a1 = (1+r3)/(4*r2); a2=(3+r3)/(4*r2);
a3 = (3-r3)/(4*r2); a4=(1-r3)/(4*r2);
[b1,b2,b3,b4] = [a4,-a3,a2,-a1]



'''
I) D4trend_rec, D4fluct_rec, D4
'''

def D4trend_rec(f, r=1):
    N = len(f)
    f = list(f)
    if r == 0: return array(f)
    if N % 2**r:
        return "D4trend_rec: %d is not divisible by 2**%d " % (N, r)
    if r == 1:
        f = f + f[:2]
        return \
            array([a1*f[2*j]+a2*f[2*j+1]+a3*f[2*j+2]+a4*f[2*j+3] for j in range(N//2)])
    else: return D4trend_rec(D4trend_rec(f),r-1)

def D4fluct_rec(f,r=1):
    N = len(f)
    f = list(f)
    if r == 0: return zeros(N)
    if N % 2**r:
        return "D4fluct_rec: %d is not divisible by 2**%d " % (N, r)
    if r == 1:
        f = f + f[:2]
        return \
            array([b1*f[2*j]+b2*f[2*j+1]+b3*f[2*j+2]+b4*f[2*j+3] \
            for j in range(N//2)])
    else: return D4fluct_rec(D4trend_rec(f,r-1))

## Iterative versions
def D4trend(f, r=1):
    N = len(f)
    if r == 0: return array(f)
    if N % 2**r:  return "D4trend: %d is not divisible by 2**%d " % (N, r)

    while r >= 1:
        N = len(f)
        f = array([a1*f[2*j]+a2*f[2*j+1]+a3*f[(2*j+2)%N]+a4*f[(2*j+3)%N] for j in range(N//2)])
        r -= 1
    return f

def D4fluct(f,r=1):
    N = len(f)
    if r == 0: return zeros(N)
    if N % 2**r:
        return "D4fluct: %d is not divisible by 2**%d " % (N, r)

    a = D4trend(f,r-1)
    N = len(a)
    d = array([b1*a[2*j]+b2*a[2*j+1]+b3*a[(2*j+2)%N]+b4*a[(2*j+3)%N] for j in range(N//2)])
    return d

def D4(f,r=1):
    N = len(f)
    f = list(f)
    if r == 0: return array(f)
    if N % 2**r: return "D4: %d is not divisible by 2**%d " % (N, r)
    d = []
    while r >= 1:
        a = D4trend(f)
        d = splice([D4fluct(f),d])
        f = a
        r -=1
    return splice([f,d])



'''
II) Daub4 scaling and wavelet arrays
'''

# To construct the array of D4 level r scaling vectors
# from the array V of D4 level r-1 scaling vectors
def D4V(V):
    N = len(V)
    X = a1*V[0,:] + a2*V[1,:] + a3*V[2%N,:] + a4*V[3%N,:]
    for j in range(1,N//2):
        v = a1*V[2*j,:] + a2*V[2*j+1,:] + a3*V[(2*j+2)%N,:] + a4*V[(2*j+3)%N,:]
        X = stack([X,v])
    return X

# To construct the array of D4 level r wavelet vectors
# from the array V of D4 level r-1 scaling vectors
def D4W(V):
    N = len(V)
    Y = b1*V[0,:] + b2*V[1,:] + b3*V[2%N,:] + b4*V[3%N,:]
    for j in range(1,N//2):
        w = b1*V[2*j,:] + b2*V[2*j+1,:] + b3*V[(2*j+2)%N,:] + b4*V[(2*j+3)%N,:]
        Y = stack([Y,w])
    return Y

# To construct the pair formed with the array V
# of all D4 scale vectors and the array W of all
# D4 wavelet vectors.
def D4VW(N):
    V = Id(N)
    X = [V]
    Y = []

    while N > 2:
        W = D4W(V)
        V = D4V(V)
        X += [V]
        Y += [W]
        N = len(V)
    W = D4W(V)
    V = D4V(V)
    X += [[V]]
    Y += [[W]]

    return (X, Y)


# Orthogonal projection in orthonormal basis
def proj(f,V):
    x = zeros(len(V[0]))
    for v in V:
        x = x + dot(f,v)*v
    return x

# Projection coefficients
def proj_coeffs(f,V):
    return array([dot(f,v) for v in V])

## The Daub4 = D4 Transform using D4V and D4W
def D4T(f,r=1):
    N = len(f)
    V,W = D4VW(N)

    x = proj_coeffs(f,V[r]) #a^r
    for j in range(r,0,-1):
        d = proj_coeffs(f,W[j-1])
        x = splice([x,d])
    return x

## Filters and auxiliary functions
def up_sample(a):
    x = []
    for t in a:
        x += [t,0]
    return array(x)

U_ = up_sample

def dual(h):
    s = 1
    hd = []
    for t in reversed(h):
        hd += [s*t]
        s = -s
    return hd

a_ = [a1,a2,a3,a4]
b_ = dual(a_)

def filter(h,x):
    m = len(h)
    n = len(x)
    y = zeros(m+n-1)

    for l in range(m+n-1):
        a = max(0,l-m+1)
        b = min(l+1,n)
        s = sum(h[l-j]*x[j] for j in range(a,b))
        y[l] = s
        if l >= n: y[l-n] += s

    return y[:n]


def H4(x): return filter(a_,x)
def G4(x): return filter(b_,x)

def HF4(f,r=1):
    if r == 0: return array(f)
    a = D4trend(f,r)
    for _ in range(r):
        a = H4(U_(a))
    return array(a)

def LF4(f,r=1):
    if r == 0: return zeros(len(f))
    d = G4(U_(D4fluct(f,r)))
    for _ in range(r-1):
        d = H4(U_(d))
    return array(d)



'''
Assignment tasks begin here.

By analogy with L423, choose your favorite signal f (to be useful it
should be not too wild) and compute the arrays A^r(f,r) and D^r(f,r) by
using the scaling and wavelet vectors and by using the method of filters.
Present the results in suitable graphical representations.
'''

## Multiresolution procedures

#Computation using orthogonal projection of r-th order scaling array
#(for average signal) and wavelet array (for detail signal)

def A(f,r): return proj(f,V[r])

def D(f,r): return proj(f,W[r-1])


## Graphics

n = 10
N = 2**n
M = n
V,W = D4VW(N)

F = lambda x: 17*x**2 * ((1-x)**5)*np.cos(20*np.pi*x)
f = sample(F,N-1)


print("Comparing results between HF4 and A, LF4 and D... ", end="")
stdout.flush() #otherwise it won't print the last line in time

results_eq = True
for i in range(1,M+1):
    if not round(HF4(f,i),6) == round(A(f,i),6):
        results_eq = False
        print("Different results for HF4 and A (level %d)" % i)
    if not round(LF4(f,i),6) == round(D(f,i),6):
        results_eq = False
        print("Different results for LF4 and D (level %d)" % i)

if results_eq: print("same results!")


print("Generating graphics... ", end="")
stdout.flush()

plt.close('all')

#Plot average signals using A(f,r)
plt.figure("1-%d level average signals" % M)
plt.title("Average signal $A(f,r)$. $f(x) = 17x^2 (1-x^{5}) cos(20 \pi x)$, $r \in 1..%d$, $N=2^{%d}$" % (M,n), color='black')
plt.xlim(0,N)

for i in range(0,M+1):
    plt.plot(i + A(f,i))


#Plot detail signals using D(f,r)
plt.figure("1-%d level detail signals" % M)
plt.title("Detail signal $D(f,r)$. $f(x) = 17x^2 (1-x^{5}) cos(20 \pi x)$, $r \in 1..%d$, $N=2^{%d}$" % (M,n), color='black')
plt.xlim(0,N)

for i in range(1,M+1):
    plt.plot(i + D(f,i))


print("done!")

plt.show()
