## L428_Velisek> Checking Daub4 identities. D4trend and D4fluct

import sympy as s
from cdi import *
import matplotlib.pyplot as plt

Eq = s.Eq
simplify = s.simplify
r2 = s.sqrt(2)
r3 = s.sqrt(3)
d = 4*r2

a1 = (1 + r3)/d
a2 = (3 + r3)/d
a3 = (3 - r3)/d
a4 = (1 - r3)/d

def ev(x) :
    if isinstance(x, list) :
        return [ev(a) for a in x]
    else :
        return x.evalf()

def eq(a,b) :
    print("%s == %s ? \n"%(a,b),ev(Eq(eval(a),b)))

eq('a1**2+a2**2+a3**2+a4**2', 1)
eq('a1+a2+a3+a4', r2)
eq('a1*a3 + a2*a4', 0)
eq('a4-a3+a2-a1', 0)
eq('0*a4 - 1*a3 + 2*a2 - 3*a1', 0)

# D4trend
def D4trend(f) :
    N = len(f)
    if N % 2:
        return 'D4trend: %d is not divisible by 2'%N
    return ev([a1*f[2*j] + a2*f[2*j+1] + a3*f[(2*j+2)%N] + a4*f[(2*j+3)%N] \
        for j in range(N//2)])
    
# D4fluct
def D4fluct(f) :
    N = len(f)
    if N % 2:
        return 'D4fluct: %d is not divisible by 2'%N
    return ev([a4*f[2*j] - a3*f[2*j+1] + a2*f[(2*j+2)%N] - a1*f[(2*j+3)%N] \
        for j in range(N//2)])

# D4 transform
def D4(f) :
    N = len(f)
    if N % 2:
        return 'D4: %d is not divisible by 2'%N
    return ev(D4trend(f) + D4fluct(f))
    
# Example
'''
f = [10,11,10,12,22,21,20,21]
print('function:', f)


print('trend:', D4trend(f))
print('fluct:', D4fluct(f))
print('D4:', D4(f))
'''
    
    
# plots

# Haar (D2) computation
def trend(f):
    r2 = np.sqrt(2)
    N = len(f)
    J = range(N//2)
    return [(f[2*j]+f[2*j+1])/r2 for j in J]
    
def fluct(f):
    r2 = np.sqrt(2)
    N = len(f)
    J = range(N//2)
    return [(f[2*j]-f[2*j+1])/r2 for j in J]
    
def haar(f,r=1):
    if r==0: return f
    if r==1: return(trend(f)+fluct(f))
    N = len(f); m = N // 2**(r-1)
    h = haar(f,r-1); a = h[:m]
    return trend(a) + fluct(a) + h[m:]

# Rename
canvas = plt.figure
axis=plt.axis
show = plt.show
view = plt.imshow
close = plt.close
plot = plt.plot
xlim = plt.xlim


# Utility to write text T at position X with
# four parameters with default values
def write(X,T,ha='center', va = 'center', fs='16', rot=0.0, color='#000000'):
    return plt.text(X[0],X[1], T, horizontalalignment=ha,verticalalignment=va, fontsize=fs, rotation=rot, color=color)

# vline draws a vertical line fron (d,a) to (d,b)
# The default width and dashing style are 1 and 'k-'
def vline(a,b,d,lw=1, dash='k-', color='#000000'):
    y = np.arange(a,b+0.0001,b-a)
    return plt.plot(d+0*y,y,dash, lw=lw, color=color)
    
    
N = 256
f = [ (x/N-1)**2 * (np.sin(x/6)) for x in range(N)]

    
close('all')
canvas('1. compare D2 and D4')
xlim(0,N)
    
offset = 4

vline(-2, 6, N//2, dash='--', color='#999999')

plot([0*offset + y for y in haar(f)], color='#0000ff')
plot([1*offset + y for y in D4(f)], color='#ff0000')
    
write((N//2, 0*offset+0.1), '$D2$', va = 'bottom', color = '#0000ff', fs='20')
write((N//2, 1*offset+0.1), '$D4$', va = 'bottom', color = '#ff0000', fs='20')

write((N//2+5, 1.4*offset), '$difference$', ha = 'left', color = '#666666')
write((N//2-5, 1.4*offset), '$trend$', ha = 'right', color = '#666666')

write((N//2, 0.5*offset), 'We can see differences of D4 \n is much closer to $0$ than D2 as we expected', color = '#666666')

    



    

    
    
    
    
    
    
    

