## L514 - Andrés Mingorance

'''
I) D6trend, D6fluct, D6
'''
from cdi import dual
import numpy as np
import matplotlib.pyplot as plt

# Synonyms
Id = np.eye
sqrt = np.sqrt
array = np.array
stack = np.vstack
splice= np.hstack
dot = np.dot
ls  = np.linspace
zeros=np.zeros

# Constants
h_ = [h0,h1,h2,h3,h4,h5] = [0.332670552950083, 0.806891509311092,
                     0.459877502118491,-0.135011020010255,
                    -0.0854412738820267,0.0352262918857095]
                    
g_=[g0,g1,g2,g3,g4,g5] = dual(h_)

[a1,a2,a3,a4,a5,a6] = h_

def D6trend(f, r=1):
    N = len(f)
    #print("N =",N)
    f = list(f)
    if r == 0: return array(f)
    if N % 2**r: return "D6trend: "+str(N)+" is not divisible by "+str(2**r)
    if r == 1:
        #f += f[:4]
        return \
            array([a1*f[2*j]+a2*f[2*j+1]+a3*f[(2*j+2)%N]+a4*f[(2*j+3)%N]+a5*f[(2*j+4)%N]+a6*f[(2*j+5)%N] \
            for j in range(N//2)])
    else: return D6trend(D6trend(f),r-1)
        
def D6fluct(f,r=1):
    [b1,b2,b3,b4,b5,b6] = g_
    N = len(f)
    f = list(f)
    if N % 2**r: return "D6fluct: "+str(N)+" is not divisible by "+str(2**r)
    if r == 1:
        f = f + f[:4]
        return \
            array([b1*f[2*j]+b2*f[2*j+1]+b3*f[2*j+2]+b4*f[2*j+3]+b5*f[2*j+4]+b6*f[2*j+5] \
            for j in range(N//2)])
    else: return D6fluct(D6trend(f,r-1))

def D6(f,r=1):
    N = len(f)
    f = list(f)
    if r == 0: return array(f)
    if N % 2**r: return "D6: "+str(N)+" is not divisible by "+str(2**r)
    d = []
    while r>= 1:
        a = D6trend(f)
        d = splice([D6fluct(f),d])
        f = a
        r -=1
    return splice([f,d])
#
daub6=D6

'''
II) Daub6 scaling and wavelet arrays
'''

# To construct the array of D6 level r scaling vectors
# from the array V of D6 level r-1 scaling vectors
def D6V(V):
    N = len(V)
    X = a1*V[(0 % N),:]+a2*V[(1 % N),:]+a3*V[(2 % N),:]\
       +a4*V[(3 % N),:]+a5*V[(4 % N),:]+a6*V[(5 % N),:]
    for j in range(1,N//2):
        x = a1*V[(2*j)% N,:]+a2*V[(2*j+1)% N,:]+a3*V[(2*j+2) % N,:] \
           +a4*V[(2*j+3)% N,:]+a5*V[(2*j+4)% N,:]+a6*V[(2*j+5)% N,:]
        X = stack([X,x])
    return X

# To construct the array of D6 level r wavelet vectors
# from the array V of D6 level r-1 scaling vectors
def D6W(V):
    [b1,b2,b3,b4,b5,b6] = dual([a1,a2,a3,a4,a5,a6])
    N = len(V)
    Y = b1*V[0 % N,:]+b2*V[1 % N,:]+b3*V[2 % N,:]\
       +b4*V[3 % N,:]+b5*V[4 % N,:]+b6*V[5 % N,:]
    for j in range(1,N//2):
        y = b1*V[(2*j)% N,:]+b2*V[(2*j+1)% N,:]+b3*V[(2*j+2)% N,:] \
          +b4*V[(2*j+3)% N,:]+b5*V[(2*j+4)% N,:]+b6*V[(2*j+5)% N,:]
        Y = stack([Y,y])
    return Y

# To construct the pair formed with the array V 
# of all D6 scale vectors and the array W of all
# D6 wavelet vectors.
def D6VW(N):
    V = Id(N)
    X = [V] 
    Y = []
    while N>2:
        W = D6W(V)
        V = D6V(V)
        X = X + [V]
        Y = Y + [W]
        N = len(V)
    W = D6W(V)
    V = D6V(V)
    X = X + [[V]]
    Y = Y + [[W]]
    return (X, Y)


##
def zer(v, n=4) :
    res = []
    for j in range(n-1):
        res += [[0,0]*j + v + [0,0]*(n-2-j)]
    return res

# check stuff
err = 1e-12

for v1, w1 in zip(zer(h_), zer(g_)) :
    if abs(dot(v1,v1) - 1.0) > err : print("error; v1v1:", dot(v1,v1))
    if abs(dot(v1,w1) - 0.0) > err  : print("error; v1w1:", dot(v1,w1))


if abs(sum(g_) - 0.0) > err : print("error: sum Bl != 0")
if abs(sum(j*g_[j] for j in range(len(g_))) - 0.0) > err : print("error: sum l*Bl != 0")
if abs(sum(j**2*g_[j] for j in range(len(g_))) - 0.0) > err : print("error: sum l^2 * Bl != 0")

print("checks finished")

from numpy import *
from cdi import sample
import matplotlib.pyplot as plt

class CDIPlotter :
    def __init__(self, title = "", rows = 1, cols = 1, grid = False, tight = True) :
        plt.figure(title)
        self.title = title
        self.nrows = rows
        self.ncols = cols
        self.grid = grid
        self.tight = tight
        self.index = 1

    def plot(self, f, title = "", color = 'b') :
        if self.index > self.nrows * self.ncols :
            plt.figure(self.title)
            self.index = 1
        plt.subplot(self.nrows, self.ncols, self.index)
        plt.title(title)
        self.index = self.index + 1
        plt.plot(f, color)
        plt.xlim(0, len(f))
        if self.grid : plt.grid()
        if self.tight : plt.tight_layout()

N = 2**10
F = lambda x: 15**x**2 * 2*(1-x)**4 * cos(9*pi*x)
f = sample(F, N-1)

plotter = CDIPlotter(rows = 4, cols = 1)
plotter.plot(D6(f, 1), "D6, R=1")
plotter.plot(D6(f, 2), "D6, R=2")
plotter.plot(D6(f, 3), "D6, R=3")
plotter.plot(D6(f, 4), "D6, R=4")
plt.show()
