## L505_sanchez
'''
Modify the functions D4trend(f) and D4fluct(f) to
functions D4trend(f,r=1) and D4fluct(f,r=1)
that compute the trend and fluctuation for any
level r. Ditto for D4 (iterative version).
Include examples for r=2 and r=3.
''' 

import numpy as np
import matplotlib.pyplot as plt
sqrt = np.sqrt
array = np.array
pi = np.array
linspace = np.linspace
splice = np.hstack
stack = np.vstack
from cdi_haar import *

# Daubechies transform (D4)
r2 = sqrt(2); r3 = sqrt(3)
d = 4*r2

a1 = (1+r3)/d; a2 = (3+r3)/d
a3 = (3-r3)/d; a4 = (1-r3)/d
b1, b2, b3, b4 = (a4,-a3,a2,-a1)

def D4trend(f, r=1):
    N = len(f)
    f = list(f)
    if r == 0: return array(f)
    if N%2**r: return 'D4trend: %d is not divisible by 2**%d'%(N,r)
    if r == 1: 
        f += f[:2]
        return array([a1*f[2*j]+a2*f[2*j+1]+a3*f[(2*j+2)%N]+a4*f[(2*j+3)%N] \
           for j in range (N//2)])
    else:
        return D4trend(D4trend(f),r-1)
    
def D4fluct(f, r=1):
    N = len(f)
    f = list(f)
    if r == 0: return zeros(N)
    if N%2**r: return 'D4fluct: %d is not divisible by 2**%d'%(N,r)
    if r == 1: 
        f += f[:2]
        return array([b1*f[2*j]+b2*f[2*j+1]+b3*f[(2*j+2)%N]+b4*f[(2*j+3)%N] \
           for j in range (N//2)])
    else:
        return D4fluct(D4trend(f),r-1)
    
def D4(f, r=1):
    N = len(f)
    if r == 0: return array(f)
    if N%2**r: return 'D4: %d is not divisible by 2**%d'%(N,r)
    d = []
    while r >= 1:
        a = D4trend(f)
        d = splice([D4fluct(f),d])
        f = a
        r -= 1
    return splice([a,d])
    
# Examples
F = lambda x: cos(10*pi*x)*sin(10*pi/(x+1))
n=10
N = 2**n
f = sample(F,N-1)
plt.close('all')

fig = plt.figure("Daubechies transform levels")
fig.suptitle("Daubechies transform levels")
ax = fig.add_subplot(111)
ax.set_title("Black: $f:\/cos(10 \pi x)sin(10 \pi /(x+1))$ | Green: $D_{1}(f)$ | Blue: $D_{2}(f)$ | Red: $D_{3}(f)$",fontsize=11)
plt.xlim(0,N-1)
plt.plot(f,'k-')
plt.plot([x-5 for x in D4(f,1)],'g-')
plt.plot([x-10 for x in D4(f,2)],'b-')
plt.plot([x-15 for x in D4(f,3)],'r-')
