## L505
## Albert Puente Encinas

'''
Modify the functions D4trend(f) and D4fluct(f) to
functions D4trend(f,r=1) and D4fluct(f,r=1)
that compute the trend and fluctuation for any
level r. Ditto for D4 (iterative version).
Include examples for r=2 and r=3.
''' 

import numpy as np

pi = np.pi
sqrt = np.sqrt
array = np.array
ls = np.linspace
splice = np.hstack
stack = np.vstack

from cdi_haar import *



# Daubechies transform (D4)
r2 = sqrt(2); r3 = sqrt(3)
d = 4*r2

a1 = (1+r3)/d; a2 = (3+r3)/d
a3 = (3-r3)/d; a4 = (1-r3)/d
b1,b2,b3,b4 = (a4,-a3,a2,-a1)


def D4trend(f, r=1):
    N = len(f)
    f = list(f)
    if r == 0: return array(f)
    if N%2**r: return \
        'D4trend: %d is not divisible by 2**(%d)'%(N,r)
    if r == 1: 
        f += f[:2]
        return \
        array([a1*f[2*j]+a2*f[2*j+1]+a3*f[(2*j+2)%N]+a4*f[(2*j+3)%N] \
            for j in range (N//2)])
    else:
        return D4trend(D4trend(f), r-1)
    
def D4fluct(f, r=1):
    N = len(f)
    f = list(f)
    if r == 0: return zeros(f)
    if N%2**r: return \
        'D4trend: %d is not divisible by 2**(%d)'%(N,r)
    if r == 1: 
        f += f[:2]
        return \
        array([a4*f[2*j]-a3*f[2*j+1]+a2*f[(2*j+2)%N]-a1*f[(2*j+3)%N] \
            for j in range (N//2)])
    else:
        return D4fluct(D4trend(f), r-1)

def D4(f, r=1):
    N = len(f)
    if r == 0: return array(f)
    if N%2**r: 
        return 'D4: %d is not divisible by 2**(%d)'%(N,r)
    d = []
    a = []
    while r >= 1:
        a = D4trend(f)
        d = splice([D4fluct(f), d])
        f = a
        r -= 1
    return splice([a,d])
    

# Examples 
from matplotlib.pyplot import *

prec = 0.001
X = np.arange(0.0, 1.0, prec)

power = np.power
exp = np.exp
def f(x): 
    return 4*power(x, 2)*exp(-256*power(x, 8)*power(sin(40*x), 2))

Y = f(X)

close('all')
fig, axes = subplots(nrows=3, ncols=1)
fig.set_facecolor("w")
fig.tight_layout()
subplot(3,1,1, title = "Original function")
plot(X, Y, color = 'g', lw = 1)
grid('on')
text(0.1, 2, r'$f(x) =4x^2 e^{-256x^8 \sin^2(40x)}$', color='k', fontsize=22)

subplot(3,1,2, title = "D4, r = 1")
plot(X, D4(Y,1), color = 'r', lw = 1)
grid('on')

D4Y = [eval(str(x)) for x in D4(Y)]
subplot(3,1,3, title = "D4, r = 2")
plot(X, D4(Y,2), color = 'b', lw = 1)
grid('on')

fig.show()



