## L514 // cdi_daub6: Daub6 wavelets

'''
I) D6trend, D6fluct, D6
'''
from cdi import *
import numpy as np
import matplotlib.pyplot as plt

# Synonyms
Id = np.eye
sqrt = np.sqrt
array = np.array
stack = np.vstack
splice= np.hstack
dot = np.dot
ls  = np.linspace
zeros=np.zeros

canvas = plt.figure
axis=plt.axis
show = plt.show
view = plt.imshow
close = plt.close
plot = plt.plot
xlim = plt.xlim

# Constants
h_ = h0,h1,h2,h3,h4,h5 = (0.332670552950083, 0.806891509311092, 
                     0.459877502118491,-0.135011020010255,
                    -0.0854412738820267,0.0352262918857095)
                    
g_=[g0,g1,g2,g3,g4,g5] = dual(h_)

[a1,a2,a3,a4,a5,a6] = h_

def D6trend(f, r=1):
    N = len(f)
    #print("N =",N)
    f = list(f)
    if r == 0: return array(f)
    if N % 2**r: return "D6trend: "+str(N)+" is not divisible by "+str(2**r)
    if r == 1:
        #f += f[:4]
        return \
            array([a1*f[2*j]+a2*f[2*j+1]+a3*f[(2*j+2)%N]+a4*f[(2*j+3)%N]+a5*f[(2*j+4)%N]+a6*f[(2*j+5)%N] \
            for j in range(N//2)])
    else: return D6trend(D6trend(f),r-1)
        
def D6fluct(f,r=1):
    [b1,b2,b3,b4,b5,b6] = g_
    N = len(f)
    f = list(f)
    if N % 2**r: return "D6fluct: "+str(N)+" is not divisible by "+str(2**r)
    if r == 1:
        f += f[:4]
        return \
            array([b1*f[2*j]+b2*f[2*j+1]+b3*f[2*j+2]+b4*f[2*j+3]+b5*f[2*j+4]+b6*f[2*j+5] \
            for j in range(N//2)])
    else: return D6fluct(D6trend(f,r-1))

def D6(f,r=1):
    N = len(f)
    f = list(f)
    if r == 0: return array(f)
    if N % 2**r: return "D6: "+str(N)+" is not divisible by "+str(2**r)
    d = []
    while r>= 1:
        a = D6trend(f)
        d = np.hstack([D6fluct(f),d])
        f = a
        r -=1
    return np.hstack([f,d])
#
daub6=D6

'''
II) Daub6 scaling and wavelet arrays
'''

# To construct the array of D6 level r scaling vectors
# from the array V of D6 level r-1 scaling vectors
def D6V(V):
    N = len(V)
    X = a1*V[(0 % N),:]+a2*V[(1 % N),:]+a3*V[(2 % N),:]\
       +a4*V[(3 % N),:]+a5*V[(4 % N),:]+a6*V[(5 % N),:]
    for j in range(1,N//2):
        x = a1*V[(2*j)% N,:]+a2*V[(2*j+1)% N,:]+a3*V[(2*j+2) % N,:] \
           +a4*V[(2*j+3)% N,:]+a5*V[(2*j+4)% N,:]+a6*V[(2*j+5)% N,:]
        X = stack([X,x])
    return X

# To construct the array of D6 level r wavelet vectors
# from the array V of D6 level r-1 scaling vectors
def D6W(V):
    [b1,b2,b3,b4,b5,b6] = dual([a1,a2,a3,a4,a5,a6])
    N = len(V)
    Y = b1*V[0 % N,:]+b2*V[1 % N,:]+b3*V[2 % N,:]\
       +b4*V[3 % N,:]+b5*V[4 % N,:]+b6*V[5 % N,:]
    for j in range(1,N//2):
        y = b1*V[(2*j)% N,:]+b2*V[(2*j+1)% N,:]+b3*V[(2*j+2)% N,:] \
          +b4*V[(2*j+3)% N,:]+b5*V[(2*j+4)% N,:]+b6*V[(2*j+5)% N,:]
        Y = stack([Y,y])
    return Y

# To construct the pair formed with the array V 
# of all D6 scale vectors and the array W of all
# D6 wavelet vectors.
def D6VW(N):
    V = Id(N)
    X = [V] 
    Y = []
    while N>2:
        W = D6W(V)
        V = D6V(V)
        X = X + [V]
        Y = Y + [W]
        N = len(V)
    W = D6W(V)
    V = D6V(V)
    X = X + [[V]]
    Y = Y + [[W]]
    return (X, Y)
    
#Example of transform
F = lambda t: 1500*t**2*(1-t)**4*((t-0.35)**2+0.01)*np.cos(41*t)
n=9
N = 2**n
f = sample(F,N-1)
close('all')

fig = plt.figure("D6 transform")
fig.suptitle("D6 transform")
ax = fig.add_subplot(111)
ax.set_title("$f: 1500t^2(1-t)^4((t-0.35)^2+0.01)cos(41t)$",fontsize=11)
offset = 0
for r in range(0,5):
    T = D6(f,r)
    xlim(0,N)
    plot([x+offset for x in T])
    offset -= 8


#Testing assertions
v1 = list(h_) + [0]*4
v2 = [0]*2 + list(h_) + [0]*2
v3 = [0]*4 + list(h_)
v4 = list(h_)[4:] + [0]*4 + list(h_)[:4]
v5 = list(h_)[2:] + [0]*4 + list(h_)[:2]
w1 = list(g_) + [0]*4
w2 = [0]*2 + list(g_) + [0]*2
w3 = [0]*4 + list(g_)
w4 = list(g_)[4:] + [0]*4 + list(g_)[:4]
w5 = list(g_)[2:] + [0]*4 + list(g_)[:2]

print("> Starting assertions")
assert round(sum(h_),8) == round(sqrt(2),8)
assert round(dot(h_,h_),8) == 1
assert round(dot(h_,g_),8) == 0
assert round(dot(v1,v2),8) == 0
assert round(dot(v1,v3),8) == 0
assert round(dot(v1,v4),8) == 0
assert round(dot(v1,v5),8) == 0
assert round(dot(v2,v3),8) == 0
assert round(dot(v2,v4),8) == 0
assert round(dot(v2,v5),8) == 0
assert round(dot(v3,v4),8) == 0
assert round(dot(v3,v5),8) == 0
assert round(dot(v4,v5),8) == 0
assert round(sum(g_),8) == 0
assert round(sum([k*g_[k] for k in range(len(h_))]),8) == 0
assert round(sum([(k**2)*g_[k] for k in range(len(h_))]),8) == 0
print("> All assertions have been succesfully tested")
