## A512_dominguez

# Utilities
from cdi import *
import numpy as np
import matplotlib.pyplot as plt

# Synonyms
Id = np.eye
#sqrt = np.sqrt
pi = np.pi
cos = np.cos
sin = np.sin
sqrt = np.sqrt
array = np.array
stack = np.vstack
splice= np.hstack
dot = np.dot
ls  = np.linspace
zeros=np.zeros
mat=np.matrix
transpose=np.transpose
det = np.linalg.det
inv = np.linalg.inv

canvas = plt.figure
axis=plt.axis
show = plt.show
view = plt.imshow
close = plt.close
plot = plt.plot
xlim = plt.xlim


## Functions for Ar and Dr

# To construct the array of D4 level r scaling vectors
# from the array V of D4 level r-1 scaling vectors
def D4V(V):   # analogous to HaarV(V)
    N = len(V)
    X = a1*V[0,:]+a2*V[1,:]+a3*V[2%N,:]+a4*V[3%N,:]
    for j in range(1,N//2):
        v = a1*V[2*j,:]+a2*V[2*j+1,:]+ \
            a3*V[(2*j+2)%N,:]+a4*V[(2*j+3)%N,:] 
        X = stack([X,v])
    return X

# To construct the array of D4 level r wavelet vectors
# from the array V of D4 level r-1 scaling vectors
def D4W(V):  # analogous to HaarW(V)
    N = len(V)
    Y = b1*V[0,:]+b2*V[1,:]+b3*V[2%N,:]+b4*V[3%N,:]
    for j in range(1,N//2):
        w = b1*V[2*j,:]+b2*V[2*j+1,:]+ \
            b3*V[(2*j+2)%N,:]+b4*V[(2*j+3)%N,:] 
        Y = stack([Y,w])
    return Y

# To construct the pair formed with the array V 
# of all D4 scale vectors and the array W of all
# D4 wavelet vectors.
def D4VW(N):  # analogous to HaarVW(V)
    V = Id(N)
    X = [V]
    Y = []
    while N > 2:
        W = D4W(V)
        V = D4V(V)
        X += [V]
        Y += [W]
        N = len(V)
    W = D4W(V)
    V = D4V(V)
    X += [[V]]
    Y += [[W]]
    return (X, Y)
VWA = D4VW

# To compute A^r(f)     
def Ar(f,r):
    N = len(f); m = N//2**r-1
    V, W = VWA(N)
    A = zeros(N)
    for i in range(0, m + 1):
        A += dot(dot(f, V[r][i]), V[r][i])
    return A
    
# To compute D^r(f) 
def Dr(f,r):
    N = len(f); m = N//2**r-1
    V, W = VWA(N)
    D = zeros(N)
    for i in range(0, m + 1):
        D += dot(dot(f, W[r - 1][i]), W[r - 1][i])
    return D
    
## Functions for HF4 and LF4

# Basic constants
nd = 4
r2=sqrt(2); r3=sqrt(3);
a1=(1+r3)/(4*r2); a2=(3+r3)/(4*r2); 
a3=(3-r3)/(4*r2); a4=(1-r3)/(4*r2);
[b1,b2,b3,b4] = [a4,-a3,a2,-a1]

def D4trend(f, r=1):
    N = len(f)
    if r == 0: return array(f)
    if N % 2**r: 
        return "D4trend: %d is not divisible by 2**%d " % (N, r)
    while r >= 1:
        N = len(f)
        f = array([a1*f[2*j]+a2*f[2*j+1]+ \
              a3*f[(2*j+2)%N]+a4*f[(2*j+3)%N] \
              for j in range(N//2)])
        r -= 1
    return f
        
def D4fluct(f,r=1):
    N = len(f)
    if r == 0: return zeros(N)
    if N % 2**r: 
        return "D4fluct: %d is not divisible by 2**%d " % (N, r)
    a = D4trend(f,r-1)
    N = len(a)
    d = array([b1*a[2*j]+b2*a[2*j+1]+b3*a[(2*j+2)%N]+b4*a[(2*j+3)%N] \
            for j in range(N//2)])
    return d

# auxiliary functions
def up_sample(a):
    x = []
    for t in a:
        x += [t, 0]
    return array(x)
U_ = up_sample

def dual(h):
    s = 1
    hd = []
    for t in reversed(h):
        hd += [s*t]
        s = -s
    return hd
    
a_ = [a1, a2, a3, a4]
b_ = dual(a_)

def filter(h,x):
    m = len(h); n = len(x)
    y = []
    for k in range(m + n - 1):
        a = max(0, k - n + 1); b = min(m, k + 1)
        s = sum(h[j]*x[k - j] for j in range(a, b))
        y += [s]
    for i in range(n, n + m - 1):
        y[i - n] += y[i]
    return y[:n]
    
def H4(x): return filter(a_, x)
def G4(x): return filter(b_, x)

def HF4(f, r=1):
    if r == 0: return array(f)
    a = D4trend(f, r)
    for _ in range(r):
        a = H4(U_(a))
    return array(a)
    
def LF4(f, r=1):
    if r == 0: return zeros(len(f))
    d = G4(U_(D4fluct(f, r)))
    for _ in range(r - 1):
        d = H4(U_(d))
    return array(d)
    
## ASSIGNMENT

F = lambda x: 15**x**2*(1 - x)**4*cos(9*pi*x)
n = 5; N = 2**n
f = sample(F, N - 1)

close('all')

fig, axarr = plt.subplots(n)
fig.suptitle("Original function (blue), Ar (green) and Dr (red) from level 1 (top) to %d (bottom)"%r)
for r in range(1, n + 1):
    axarr[r-1].plot(f, c='b'); axarr[r-1].plot(Ar(f, r), c='g'); axarr[r-1].plot(Dr(f, r), c='r')
    
fig, axarr = plt.subplots(n)
fig.suptitle("Original function (blue), HF4 (green) and LF4 (red) from level 1 (top) to %d (bottom)"%r)
for r in range(1, n + 1):
    axarr[r-1].plot(f, c='b'); axarr[r-1].plot(HF4(f, r), c='g'); axarr[r-1].plot(LF4(f, r), c='r')

for r in range(1, n + 1):
    print("Ar(f, %d) == HF4(f, %d)?"%(r, r))
    print(round(Ar(f, r), 10) == round(HF4(f, r), 10))
    print("Dr(f, %d) == LF4(f, %d)?"%(r, r))
    print(round(Dr(f, r), 10) == round(LF4(f, r), 10))
