import matplotlib.pyplot as plt
import numpy as np
import cdi_haar
from math import pi, sqrt

D2transform = cdi_haar.D2



# Utility to write text T at position X with
# four parameters with default values
def write(X,T,ha='center', va = 'center', fs='16', rot=0.0):
    return plt.text(X[0],X[1], T, horizontalalignment=ha,verticalalignment=va, fontsize=fs, rotation=rot)

def energy(f):
    return sum(t**2 for t in f)


r2 = sqrt(2)
r3 = sqrt(3)
d = 4*r2

a1 = (1+r3)/d
a2 = (3+r3)/d
a3 = (3-r3)/d
a4 = (1-r3)/d

def D4trend(f):
    N = len(f)
    if N%2:
        return 'D4trend: N not divisible by 2'
    return [a1*f[2*j] + a2*f[2*j+1] + a3*f[(2*j+2) % N] + a4*f[(2*j + 3)%N]
            for j in range(N//2)]
   
def D4fluct(f): 
    N = len(f)
    if N%2:
        return 'D4fluct: N not divisible by 2' 
    return [a4*f[2*j] - a3*f[2*j+1] + a2*f[(2*j+2) % N] - a1*f[(2*j + 3)%N]
            for j in range(N//2)]

def D4transform(f):
    N = len(f)
    if N%2:
        return 'D4transform: N not divisibly by 2'
    return D4trend(f) + D4fluct(f)

def draw_graph(f, x=0.0, y=10.0):
    '''
    Draw the graph of the function f in the interval [x,y]
    Also Draws the Haar and Daubechais transforms
    Computes the energies of f, D2(f) and D4(f), showing that are equal
    '''

    plt.close('all')
    plt.figure('Daubechies and Haar transforms')
    plt.subplot(3,1,3)

    x = np.arange(0,10,0.005)
    f_x = f(x)
    f_haar = D2transform(f_x)
    f_daub = D4transform(f_x)

    # original function
    plt.subplot(3,1,1)
    plt.plot(f_x, 'r-', lw=1)

    plt.subplot(3,1,2)
    write((1500,50),'Haar Transform $D2$')
    plt.plot(f_haar, 'b-', lw=1)

    plt.subplot(3,1,3)
    write((1500,50), 'Daubechies Transform $D4$')
    plt.plot(f_daub, 'g-', lw=1)


    print('energy f(x):\t', energy(f_x))
    print('energy D2(f(x)\t', energy(f_haar))
    print('energy D4(f(x))\t', energy(f_daub))

    plt.show()




def f(x):
    return x**2 * np.sin(2*(x-2)*3*pi)

draw_graph(f)

