## wavelets.py 
## SXD 141201, 150514 
'''
Utilities for working with Haar, Daub4 and Daub6 wavelets
'''

from cdi import *
import numpy as np
import matplotlib.pyplot as plt


# Synonyms
Id = np.eye
sqrt = np.sqrt
pi = np.pi
cos = np.cos
sin = np.sin
array = np.array
stack = np.vstack
splice= np.hstack
dot = np.dot
ls  = np.linspace
zeros=np.zeros
mat=np.matrix
transpose=np.transpose
view = plt.imshow
canvas = plt.figure

# Basic constants
r2=2**(1/2); r3=3**(1/2);


## 0. Generic functions

# Orthogonal projection in orthonormal basis
def proj(f,V):
    x = zeros(len(V[0]))
    for v in V:
        x = x + dot(f,v)*v
    return x  

# Projection coefficients
def proj_coeffs(f,V):
    return array([dot(f,v) for v in V])


## 1. Haar wavelets

# trend(f,r=1) computes the level r trend signal of a discrete signal f
def trend(f, r=1):
    if r == 0: return f
    N = len(f)
    if N % 2**r: return "trend: %d is not divisible by 2**%d " % (N, r)
    while r >= 1:
        N = N//2
        f = [(f[2*j]+f[2*j+1])/r2 for j in range(N)]
        r -= 1
    return f
#
D2trend = trend

# fluct(f) computes the level r fluctuation 
# (or difference) signal of a discrete signal f    
def fluct(f, r=1):
    if r == 0: return len(f)*[0]
    N = len(f)
    if N % 2**r: return "trend: %d is not divisible by 2**%d " % (N, r)
    a = trend(f,r-1)
    N = len(a)//2
    d = [(a[2*j]-a[2*j+1])/r2 for j in range(N)]
    return d
#
D2fluct = fluct

# haar(f,r) and HaarT(f,r) compute the Haar transform of f or level r. 
# haar(f) and HaarT(f) are equivalent to haar(f,1) and HaarT(f,1)
# haar is recursive and HaarT is iterative, so usually HaarT is preferable.
def haar(f,r=1): 
    if r == 0: return f
    N = len(f)
    if N % 2**r: return "trend: %d is not divisible by 2**%d " % (N, r)
    if r == 1: return (trend(f)+fluct(f))
    m=N//2**(r-1)
    a=haar(f,r-1); x=a[:m]
    return trend(x)+fluct(x)+a[m:]
#
def HaarT(f,r=1): 
    if r == 0: return f
    a = list(f); h = []; N = len(a)
    if N % 2**r: return "trend: %d is not divisible by 2**%d " % (N, r)
    N = N//2**r
    while len(a)>N:
        d = fluct(a)
        h = d + h
        a = trend(a)
    return a+h
D2 = HaarT

# To compute A^r(f) 
def high_filter(f,r=1):
    N=len(f); m=2**r; A=[]; 
    while N>=m:
        x=sum(f[:m])      
        A += m*[x]
        N -= m
        f = f[m:]
    return [a/m for a in A]

# To comput D^r(f) 
def low_filter(f,r=1):
    N=len(f); m=2**(r-1); D=[]; 
    while N>=2*m:
        x=sum(f[:m])
        N -= m
        f = f[m:]
        x -= sum(f[:m])
        D += m*[x]+ m*[-x]
        N -= m
        f = f[m:]
    return [d/(2*m) for d in D]

# Haar matrix of level r
def haar_matrix(N,r=1):  # for vectors of length N
    return mat([HaarT(v,r) for v in Id(N)])

# inverse of the haar transform of level r
def i_haar(h,r=1):
    H = haar_matrix(len(h),r)
    return h * transpose(H)


''' 
Haar scaling and wavelet arrays
'''

# To construct the array of level r scaling vectors
# from the array V of level r-1 scaling vectors
def HaarV(V):
    a1 = a2 = 1/r2
    N = len(V)
    X = a1*V[0,:] + a2*V[1,:]
    for j in range(1, N//2):
        x = a1 * V[2*j,:] + a2 * V[2*j+1,:]
        X = stack([X,x])
    return X

# To construct the array of level r wavelet vectors
# from the array V of level r-1 scaling vectors
def HaarW(V):
    a1 = a2 = 1/r2
    N = len(V)
    X = a1*V[0,:] - a2*V[1,:]
    for j in range(1, N//2):
        x = a1 * V[2*j,:] - a2 * V[2*j+1,:]
        X = stack([X,x])
    return X

# To construct the array of all scale vectors.
# The r-th component of the output is
# the matrix of level r scale vectors.
def HaarVA(N):
    V = Id(N)
    X = [V]
    while N>2:
        #print("type(X) =", type(X))
        V = HaarV(V)
        X += [V]
        N = len(V) 
    V = HaarV(V)
    X += [[V]]
    return X
VA = HaarVA   
 
# To construct the pair formed with the array 
# of all scale vectors and the array of all
# wavelet vectors.
def HaarVWA(N):
    V = Id(N)
    X = [V]
    Y = []
    while N>2:
        W = HaarW(V)
        V = HaarV(V)
        X += [V]
        Y += [W]
        N = len(V) 
    W = HaarW(V)
    V = HaarV(V)
    X += [[V]]
    Y += [[W]]
    return X, Y
VWA = HaarVWA


## 2. Daub4 wavelets

# Constants
a1=(1+r3)/(4*r2); a2=(3+r3)/(4*r2); 
a3=(3-r3)/(4*r2); a4=(1-r3)/(4*r2);
a_ = [a1,a2,a3,a4]
b_ = [b1,b2,b3,b4] = dual(a_)

'''
a) D4trend, D4fluct, D4
'''

def D4trend(f, r=1):
    N = len(f)
    if r == 0: return list(f)
    if N % 2**r: return "D4trend: %d is not divisible by 2**%d " % (N, r)
    while r >= 1:
        N = len(f)
        f=[a1*f[2*j]+a2*f[2*j+1]+a3*f[(2*j+2)%N]+a4*f[(2*j+3)%N] for j in range(N//2)]
        r -= 1
    return f
        
def D4fluct(f, r=1):
    if r == 0: return len(f)*[0]
    N = len(f)
    if N % 2**r: return "D4fluct: %d is not divisible by 2**%d " % (N, r)
    a = D4trend(f,r-1)
    #plt.plot(a)
    N = len(a)
    d = [b1*a[(2*j)]+b2*a[(2*j+1)]+b3*a[(2*j+2)%N]+b4*a[(2*j+3)%N] for j in range(N//2)]
    return d

def D4(f,r=1):
    N = len(f)
    if r == 0: return array(f)
    if N % 2**r: return "D4: %d is not divisible by 2**%d " % (N, r)
    d = []
    while r>= 1:
        a = D4trend(f)
        d = splice([D4fluct(f),d])
        f = a
        r -=1
    return splice([f,d])
#
daub4 = D4

'''
b) Filters
'''
def H4(x): return filter(a_,x)
def G4(x): return filter(b_,x)
U_ = up_sample

def HF4(f,r=1):
    if r == 0: return array(f)
    a = D4trend(f,r)
    for _ in range(r):
        a = H4(U_(a))
    return array(a)
    
def LF4(f,r=1):
    if r == 0: return zeros(len(f))
    d = G4(U_(D4fluct(f,r)))
    for _ in range(r-1):
        d = H4(U_(d))
    return array(d)


'''
c) Daub4 scaling and wavelet arrays
'''

# To construct the array of D4 level r scaling vectors
# from the array V of D4 level r-1 scaling vectors
def D4V(V):
    N = len(V)
    X = a1*V[0,:]+a2*V[1,:]+a3*V[2%N,:]+a4*V[3%N,:]
    for j in range(1,N//2):
        x = a1*V[2*j,:]+a2*V[2*j+1,:]+a3*V[(2*j+2)%N,:]+a4*V[(2*j+3)%N,:]
        X = stack([X,x])
    return X

# To construct the array of D4 level r wavelet vectors
# from the array V of D4 level r-1 scaling vectors
def D4W(V):
    N = len(V)
    Y = b1*V[0,:]+b2*V[1,:]+b3*V[2%N,:]+b4*V[3%N,:]
    for j in range(1,N//2):
        y = b1*V[2*j,:]+b2*V[2*j+1,:]+b3*V[(2*j+2)%N,:]+b4*V[(2*j+3)%N,:]
        Y = stack([Y,y])
    return Y

# To construct the pair formed with the array V 
# of all D4 scale vectors and the array W of all
# D4 wavelet vectors.
def D4VW(N):
    V = Id(N)
    X = [V] 
    Y = []
    while N>2:
        W = D4W(V)
        V = D4V(V)
        X = X + [V]
        Y = Y + [W]
        N = len(V)
    W = D4W(V)
    V = D4V(V)
    X = X + [[V]]
    Y = Y + [[W]]
    return (X, Y)

'''
d) D4T: Daub4 transform using the scaling and wavelet arrays
'''
def D4T(f,r=1):
    V, W = D4VW(len(f))
    x=proj_coeffs(f,V[r])
    # range should include 0, which is level 1
    for k in range(r-1,-1,-1):
        x = splice([x,proj_coeffs(f,W[k])])
    return x


## 3. Daub6 wavelets

# Constants
h_ = h0,h1,h2,h3,h4,h5 = \
     (0.332670552950083, 0.806891509311092, 0.459877502118491,
     -0.135011020010255,-0.0854412738820267,0.0352262918857095)
                    
g_ = [g0,g1,g2,g3,g4,g5] = dual(h_)


'''
a) D6trend, D6fluct, D6
'''

def D6trend(f, r=1):
    N = len(f)
    if r == 0: return list(f)
    if N % 2**r: return "D6trend: %d is not divisible by 2**%d "%(N,r)
    while r >= 1:
        N = len(f)
        f = [h0*f[(2*j)]+h1*f[(2*j+1)]+h2*f[(2*j+2)%N]+ \
             h3*f[(2*j+3)%N]+h4*f[(2*j+4)%N]+h5*f[(2*j+5)%N] \
             for j in range(N//2)]
        r -= 1
    return f

def D6fluct(f, r=1):
    if r == 0: return len(f)*[0]
    N = len(f)
    if N % 2**r: return "D6fluct: %d is not divisible by 2**%d "%(N,r)
    a = D6trend(f,r-1)
    N = len(a)
    d = [g0*a[(2*j)]+g1*a[(2*j+1)]+g2*a[(2*j+2)%N]+ \
         g3*a[(2*j+3)%N]+g4*a[(2*j+4)%N]+g5*a[(2*j+5)%N] \
             for j in range(N//2)]
    return d

def D6(f,r=1):
    N = len(f)
    f = list(f)
    if r == 0: return array(f)
    if N % 2**r: return "D6: "+str(N)+" is not divisible by "+str(2**r)
    d = []
    while r>= 1:
        a = D6trend(f)
        d = np.hstack([D6fluct(f),d])
        f = a
        r -=1
    return np.hstack([f,d])
#
daub6=D6

'''
b) Filters
'''
def H6(x): return filter(h_,x)
def G6(x): return filter(g_,x)

def HF6(f,r=1):
    if r == 0: return array(f)
    a = D6trend(f,r)
    for _ in range(r):
        a = H6(U_(a))
    return array(a)
    
def LF6(f,r=1):
    if r == 0: return zeros(len(f))
    d = G6(U_(D6fluct(f,r)))
    for _ in range(r-1):
        d = H6(U_(d))
    return array(d)

    
'''
c) Daub6 scaling and wavelet arrays
'''

# To construct the array of D6 level r scaling vectors
# from the array V of D6 level r-1 scaling vectors
def D6V(V):
    N = len(V)
    X = h0*V[(0 % N),:]+h1*V[(1 % N),:]+h2*V[(2 % N),:]\
       +h3*V[(3 % N),:]+h4*V[(4 % N),:]+h5*V[(5 % N),:]
    for j in range(1,N//2):
        x = h0*V[(2*j)% N,:]+h1*V[(2*j+1)% N,:]+h2*V[(2*j+2) % N,:] \
           +h3*V[(2*j+3)% N,:]+h4*V[(2*j+4)% N,:]+h5*V[(2*j+5)% N,:]
        X = stack([X,x])
    return X

# To construct the array of D6 level r wavelet vectors
# from the array V of D6 level r-1 scaling vectors
def D6W(V):
    N = len(V)
    Y = g0*V[0 % N,:]+g1*V[1 % N,:]+g2*V[2 % N,:]\
       +g3*V[3 % N,:]+g4*V[4 % N,:]+g5*V[5 % N,:]
    for j in range(1,N//2):
        y = g0*V[(2*j)% N,:]+g1*V[(2*j+1)% N,:]+g2*V[(2*j+2)% N,:] \
          +g3*V[(2*j+3)% N,:]+g4*V[(2*j+4)% N,:]+g5*V[(2*j+5)% N,:]
        Y = stack([Y,y])
    return Y

# To construct the pair formed with the array V 
# of all D6 scale vectors and the array W of all
# D6 wavelet vectors.
def D6VW(N):
    V = Id(N)
    X = [V] 
    Y = []
    while N>2:
        W = D6W(V)
        V = D6V(V)
        X = X + [V]
        Y = Y + [W]
        N = len(V)
    W = D6W(V)
    V = D6V(V)
    X = X + [[V]]
    Y = Y + [[W]]
    return (X, Y)