## L505 - Andrés Mingorance
'''
Modify the functions D4trend(f) and D4fluct(f) to
functions D4trend(f,r=1) and D4fluct(f,r=1)
that compute the trend and fluctuation for any
level r. Ditto for D4 (iterative version).
Include examples for r=2 and r=3.
''' 

from numpy import sqrt, array, pi, arange
from numpy import linspace as ls
from numpy import hstack as splice
from numpy import vstack as stack

from cdi_haar import *
from cdi import sample

# Daubechies transform (D4)
r2 = sqrt(2); r3 = sqrt(3)
d = 4*r2

a1 = (1+r3)/d; a2 = (3+r3)/d
a3 = (3-r3)/d; a4 = (1-r3)/d
b1,b2,b3,b4 = a4,-a3,a2,-a1


def D4trend(f, r=1):
    N = len(f)
    f = list(f)
    if r == 0: return array(f)
    if N%2**r: return \
        'D4trend: %d is not divisible by 2**(%d)'%(N,r)
    if r == 1:
        f += f[:2]
        return \
        array([a1*f[2*j]+a2*f[2*j+1]+a3*f[(2*j+2)%N]+a4*f[(2*j+3)%N] \
            for j in range (N//2)])
    else:
        return D4trend(D4trend(f), r-1)

def D4fluct(f, r=1):
    N = len(f)
    f = list(f)
    if r == 0: return zeros(f)
    if N%2**r: return \
        'D4trend: %d is not divisible by 2**(%d)'%(N,r)
    if r == 1:
        f += f[:2]
        return \
        array([a4*f[2*j]-a3*f[2*j+1]+a2*f[(2*j+2)%N]-a1*f[(2*j+3)%N] \
            for j in range (N//2)])
    else:
        return D4fluct(D4trend(f), r-1)
def D4(f, r=1):
    N = len(f)
    if r == 0: return array(f)
    if N%2**r:
        return 'D4: %d is not divisible by 2**(%d)'%(N,r)
    d = []
    a = []
    while r >= 1:
        a = D4trend(f)
        d = splice([D4fluct(f), d])
        f = a
        r -= 1
    return splice([a,d])

N = 1024
F = lambda x: 15**x**2 * 2*(1-x)**4 * np.cos(9*np.pi*x)
X = arange(0.0, 1.0, 1/1024)
f = F(X)

import matplotlib.pyplot as plt

plt.figure("D4(f, r)")
plt.subplot(3,1, 1)
plt.title('D4(f, 2)')
plt.plot(D4(f, 2), 'b')
plt.xlim(0,N)
plt.grid()

plt.subplot(3,1, 2)
plt.title('D4(f, 3)')
plt.plot(D4(f, 3), 'b')
plt.xlim(0,N)
plt.grid()

plt.subplot(3,1, 3)
plt.title('D4(f, 4)')
plt.plot(D4(f, 4), 'b')
plt.xlim(0,N)
plt.grid()

plt.show()
