## L505_Velisek
'''
Modify the functions D4trend(f) and D4fluct(f) to
functions D4trend(f,r=1) and D4fluct(f,r=1)
that compute the trend and fluctuation for any
level r. Ditto for D4 (iterative version).
Include examples for r=2 and r=3.
''' 

import numpy as np
sqrt = np.sqrt
array = np.array
pi = np.pi
ls = np.linspace
splice = np.hstack
stack = np.vstack

# Daubechies transform (D4)
r2 = sqrt(2); r3 = sqrt(3)
d = 4*r2

a1 = (1+r3)/d; a2 = (3+r3)/d
a3 = (3-r3)/d; a4 = (1-r3)/d
b1, b2, b3, b4 = (a4, -a3, a2, -a1)

def D4trend(f, r=1):
    N = len(f)
    f = list(f)
    if r == 0:
        return array(f)
    if N%2**r: 
        return 'D4trend: %d is not divisible by 2**(%d)'%(N,r)
    if r == 1:
        f += f[:2]
        return array([a1*f[2*j]+a2*f[2*j+1]+a3*f[(2*j+2)%N]+a4*f[(2*j+3)%N] \
            for j in range (N//2)])
    else :
        return D4trend(D4trend(f), r-1)

def D4fluct(f, r=1):
    N = len(f)
    f = list(f)
    if r == 0:
        return zeros(N)
    if N%2**r: 
        return 'D4fluct: %d is not divisible by 2**(%d)'%(N,r)
    if r == 1:
        f += f[:2]
        return array([b1*f[2*j]+b2*f[2*j+1]+b3*f[(2*j+2)%N]+b4*f[(2*j+3)%N] \
            for j in range (N//2)])
    else :
        return D4fluct(D4trend(f), r-1)
        
def D4(f, r=1):
    N = len(f)
    if r==0 :
        return array(f)
    if N%2**r: 
        return 'D4: %d is not divisible by 2**(%d)'%(N,r)
    a = f
    d = []
    while r >= 1:
        a = D4trend(f)
        d = splice([D4fluct(f),d])
        f = a
        r -= 1
    return splice([a,d])
    
# Examples 


import matplotlib.pyplot as plt

N = 1000
f = [ (x/N-1)**2 * (np.sin(x/16)) for x in range(N)]

# Rename
canvas = plt.figure
axis=plt.axis
show = plt.show
view = plt.imshow
close = plt.close
plot = plt.plot
xlim = plt.xlim


def hline(a,b,h,lw=1,dash='k-', color='#000000'):
    x = np.arange(a,b+0.0001,b-a)
    return plt.plot(x,h+0*x,dash, lw=lw, color=color)

def vline(a,b,d,lw=1, dash='k-', color='#000000'):
    y = np.arange(a,b+0.0001,b-a)
    return plt.plot(d+0*y,y,dash, lw=lw, color=color)
    
def write(X,T,ha='center', va = 'center', fs='16', rot=0.0):
    return plt.text(X[0],X[1], T, horizontalalignment=ha,verticalalignment=va, fontsize=fs, rotation=rot)

def drawLvl(r, lw=1, dash='--', size=1) :
    hline(xmin,xmax,0, color='#bbbbbb')
    for x in range(1,r+1) :
        vline(ymin, 0, 1/2**x, lw=lw, dash=dash, color='#bbbbbb')
        write((3/2/2**x, ymin+0.1), '$d^' + str(x) + '$', va='bottom')
    write((1/2**(r+1), ymin+0.1), '$a^' + str(r) + '$', va='bottom')
    
t1 = np.arange(0.0, 1.0, 0.001)
t2 = np.arange(0.0, 1.0, 0.001)


xmin = 0; xmax = 1
ymin = -3; ymax = 3

plt.close()
plt.figure("1. Higher (2 and 3) level D4")

r = 2
plt.subplot(2,1,1)
drawLvl(r)
plt.plot(t1, D4(f,r), '-')
plt.xlim(xmin,xmax)
plt.ylim(ymin,ymax)

r = 3
plt.subplot(2,1,2)
drawLvl(r)
plt.plot(t2, D4(f,r), '-')
plt.xlim(xmin,xmax)
plt.ylim(ymin,ymax)

# We can see small jumps at the end of diferences. It is because of wraparound.

plt.show()
