import matplotlib.pyplot as plt
import numpy as np
import sympy as s

Eq = s.Eq
r2 = s.sqrt(2)
r3 = s.sqrt(3)
d = 4*r2

a1 = (1+r3)/d
a2 = (3+r3)/d
a3 = (3-r3)/d
a4 = (1-r3)/d

def eq(a,b):
	print ("%s == %s ? \n"%(a,b), Eq(eval(a),b))

# eq("a1**2+a2**2+a3**2+a4**2",1)
# eq("a1+a2+a3+a4",r2)
# eq("a1*a3 + a2*a4",0)
# eq("a4-a3 + a2-a1",0)
# eq("0*a4-1*a3+2*a2-3*a1",0)

# trend(f) computes the trend signal of a discrete signal f
def D2trend(f):
    r = 1.4142135623730951  #sqrt(2) 
    J = range(len(f)//2)
    return [(f[2*j]+f[2*j+1])/r for j in J]

# fluct(f) computes the fluctuation 
# (or difference) signal of a discrete signal f    
def D2fluct(f):
    r = 1.4142135623730951  #sqrt(2) 
    J = range(len(f)//2)
    return [(f[2*j]-f[2*j+1])/r for j in J]

def D2(f,r=1): 
    if r == 0: return f
    a = list(f); h = []; N = len(a)
    if N % 2**r: return "D2trend: %d is not divisible by 2**%d " % (N, r)
    N = N//2**r
    while len(a)>N:
        d = D2fluct(a)
        h = d + h
        a = D2trend(a)
    return a+h

# D4trend
def D4trend(f):
	N = len(f)
	if N%2: return "D4trend: %d is not divisible by 2"%N
	return [a1*f[2*j]+a2*f[2*j+1]+a3*f[(2*j+2)%N]+a4*f[(2*j+3)%N] for j in range(N//2)]


# D4fluct
def D4fluct(f):
	N = len(f)
	if N%2: return "D4fluct: %d is not divisible by 2"%N
	return [a4*f[2*j]-a3*f[2*j+1]+a2*f[(2*j+2)%N]-a1*f[(2*j+3)%N] for j in range(N//2)]

# D4 transform
def D4(f):
	N = len(f)
	if N%2: return "D4: %d is not divisible by 2"%N
	return D4trend(f) + D4fluct(f)



F = lambda x: np.cos(x)*np.sin(x**2)
a = np.arange(0.0,10,0.005)

plt.close('all')

plt.figure("Function f, D2(f), D4(f)")
# Function
plt.subplot(3,1,1)
plt.plot(F(a),'r')
# D2(f)
plt.subplot(3,1,2)
plt.plot(D2(F(a)),'g')
# D4(f)
plt.subplot(3,1,3)
plt.plot(D4(F(a)),'b')

plt.show()
