## L507_x

import numpy as np
import matplotlib.pyplot as plt

# Synonyms
Id = np.eye
sqrt = np.sqrt
array = np.array
stack = np.vstack
splice= np.hstack
dot = np.dot
ls  = np.linspace
zeros=np.zeros
mat=np.matrix


# Basic constants
nd = 4
r2=sqrt(2); r3=sqrt(3);
a1=(1+r3)/(4*r2); a2=(3+r3)/(4*r2); 
a3=(3-r3)/(4*r2); a4=(1-r3)/(4*r2);
[b1,b2,b3,b4] = [a4,-a3,a2,-a1]

'''
I) D4trend, D4fluct, D4
'''

def D4trend_rec(f, r=1):
    N = len(f)
    f = list(f)
    if r == 0: return array(f)
    if N % 2**r: 
        return "D4trend_rec: %d is not divisible by 2**%d " % (N, r)
    if r == 1:
        f = f + f[:2]
        return \
            array([a1*f[2*j]+a2*f[2*j+1]+a3*f[2*j+2]+a4*f[2*j+3] for j in range(N//2)])
    else: return D4trend_rec(D4trend(f),r-1)
        
def D4fluct_rec(f,r=1):
    N = len(f)
    f = list(f)
    if r == 0: return zeros(N)
    if N % 2**r: 
        return "D4fluct_rec: %d is not divisible by 2**%d " % (N, r)
    if r == 1:
        f = f + f[:2]
        return \
            array([b1*f[2*j]+b2*f[2*j+1]+b3*f[2*j+2]+b4*f[2*j+3] \
            for j in range(N//2)])
    else: return D4fluct_rec(D4trend_rec(f,r-1), 1)

def D4trend(f, r=1):
    N = len(f)
    if r == 0: return array(f)
    if N % 2**r: 
        return "D4trend: %d is not divisible by 2**%d " % (N, r)
    while r >= 1:
        N = len(f)
        f = array([a1*f[2*j]+a2*f[2*j+1]+a3*f[(2*j+2)%N]+a4*f[(2*j+3)%N] for j in range(N//2)])
        r -= 1
    return f
        
def D4fluct(f,r=1):
    N = len(f)
    if r == 0: return zeros(N)
    if N % 2**r: 
        return "D4fluct: %d is not divisible by 2**%d " % (N, r)
    a = D4trend(f, r-1)
    N = len(f)
    d = array([b1*a[2*j]+b2*a[2*j+1]+b3*a[(2*j+2)%N]+b4*a[(2*j+3)%N] for j in range(N//2)])
    return d

def D4(f,r=1):
    N = len(f)
    f = list(f)
    if r == 0: return array(f)
    if N % 2**r: return "D4: %d is not divisible by 2**%d " % (N, r)
    d = []
    while r>= 1:
        a = D4trend(f)
        d = splice([D4fluct(f),d])
        f = a
        r -=1
    return splice([f,d])
    
def D4_rec(f,r=1):
    N = len(f)
    f = list(f)
    if r == 0: return array(f)
    if N % 2**r: return "D4_rec: %d is not divisible by 2**%d " % (N, r)
    d = []
    while r>= 1:
        a = D4trend_rec(f)
        d = splice([D4fluct_rec(f),d])
        f = a
        r -=1
    return splice([f,d])


'''
II) Daub4 scaling and wavelet arrays
'''

# To construct the array of D4 level r scaling vectors
# from the array V of D4 level r-1 scaling vectors
def D4V(V):
    N = len(V)
    X = a1*V[0,:]+a2*V[1,:]+a3*V[2%N,:]+a4*V[3%N,:]
    for j in range(N//2):
        v = a1*V[2*j,:]+a2*V[2*j+1,:]+a3*V[(2*j+2)%N,:]+a4*V[(2*j+3)%N,:]
        X = stack([X,v])
    return X
    
# print(D4V(Id(4)))
# print(D4V(D4V(Id(4))))

# To construct the array of D4 level r wavelet vectors
# from the array V of D4 level r-1 scaling vectors
def D4W(V):
    N = len(V)
    Y = b1*V[0,:]+b2*V[1,:]+b3*V[2%N,:]+b4*V[3%N,:]
    for j in range(N//2):
        w = b1*V[2*j,:]+b2*V[2*j+1,:]+b3*V[(2*j+2)%N,:]+b4*V[(2*j+3)%N,:]
        Y = stack([Y,w])
    return Y

# print(D4W(Id(2))) 
# 1/sqrt(2) = 0.707 so it seems that it is working

# To construct the pair formed with the array V 
# of all D4 scale vectors and the array W of all
# D4 wavelet vectors.
def D4VW(N):
    V = Id(N)
    X = [V]
    Y = []
    while N > 2:
        W = D4W(V)
        V = D4V(V)
        X += [V]
        Y += [W]
        N = len(V)
    W = D4W(V)
    V = D4V(V)
    X += [[V]]
    Y += [[W]]
    return (X, Y)
    
# print(D4VW(2))

# Orthogonal projection in orthonormal basis
def proj(f,V):
    x = zeros(len(V[0]))
    for v in V:
        x = x + dot(f,v)*v
    return x  

# Projection coefficients
def proj_coeffs(f,V):
    return array([dot(f,v) for v in V])
    
## The Daub4 = D4 Transform using D4V and D4W
def D4T(f,r=1):
    x = None
    return x
    
    
    
    
    
    
    
# Examples 


import matplotlib.pyplot as plt

# Rename
canvas = plt.figure
axis=plt.axis
show = plt.show
view = plt.imshow
close = plt.close
plot = plt.plot
xlim = plt.xlim


def hline(a,b,h,lw=1,dash='k-', color='#000000'):
    x = np.arange(a,b+0.0001,b-a)
    return plt.plot(x,h+0*x,dash, lw=lw, color=color)

def vline(a,b,d,lw=1, dash='k-', color='#000000'):
    y = np.arange(a,b+0.0001,b-a)
    return plt.plot(d+0*y,y,dash, lw=lw, color=color)
    
def write(X,T,ha='center', va = 'center', fs='12', rot=0.0):
    return plt.text(X[0],X[1], T, horizontalalignment=ha,verticalalignment=va, fontsize=fs, rotation=rot)

def drawLvl(r, lw=1, dash='--', size=1) :
    hline(xmin,xmax,0, color='#bbbbbb')
    for x in range(1,r+1) :
        vline(ymin, 0, 1/2**x, lw=lw, dash=dash, color='#bbbbbb')
        write((3/2/2**x, ymin+0.1), '$d^' + str(x) + '$', va='bottom')
    write((1/2**(r+1), ymin+0.1), '$a^' + str(r) + '$', va='bottom')
    
    

n = 10; N = 2**n
V, W = D4VW(N)


close('all')
canvas('1. Example of a scaling vectors')
xlim(0,N)


X = [1, 8, 16]
Y = [2.5, 2, 1.5]
lvl = 5
P5 = zip(X, Y)
for (x, y) in P5 :
    write((N-5, y),"lvl "+str(lvl), va="bottom", ha="right")
    plot(y + V[lvl][x])
    
    
X = [1, 4, 14]
Y = [1, 0.5, 0]
P6 = zip(X, Y)
lvl = 6
for (x, y) in P6 :
    write((N-5, y),"lvl "+str(lvl), va="bottom", ha="right")
    plot(y + V[lvl][x])

# Interesting is first value of scaling vectors. it is very high. Unfortunatly I didn't find out why.

canvas('2. Example of a wavelet vectors')
xlim(0,N)

X = [1, 8, 16]
Y = [2.5, 2, 1.5]
P5 = zip(X, Y)
lvl = 5
for (x, y) in P5 :
    write((N-5, y),"lvl "+str(lvl), va="bottom", ha="right")
    plot(y + W[lvl-1][x])
    
X = [1, 4, 14]
Y = [1, 0.5, 0]
P6 = zip(X, Y)
lvl = 6
for (x, y) in P6 :
    write((N-5, y),"lvl "+str(lvl), va="bottom", ha="right")
    plot(y + W[lvl-1][x])
    
