## L416_martin

'''
Haar wavelets algorithms: continued.
'''

from math import *
from numpy import * 

def trend(f):
    N = len(f)
    J = range(N//2)
    return [(f[2*j]+f[2*j+1])/2**(1/2) for j in J]
    
def fluct(f):
    N = len(f)
    J = range(N//2)
    return [(f[2*j]-f[2*j+1])/2**(1/2) for j in J]
    
def haar(f,r=1):
    if r==0: return f
    if r==1: return(trend(f)+fluct(f))
    N = len(f); m = N // 2**(r-1)
    h = haar(f,r-1); a = h[:m]
    return trend(a) + fluct(a) + h[m:]
    
def energy(f):
    return sum(t**2 for t in f)

# High level trend and fluctuation functions
def trend(f,r=1):
    if r == 0: return f
    N = len(f)
    if N % 2**r: return 'trend: %d is not divisible by 2**%d' %(N,r)
    while r >= 1:
        N = N // 2
        r2 = sqrt(2)
        f = [(f[2*j]+f[2*j+1])/2**(1/2) for j in range(N)]
        r -= 1
    return f

def fluct(f,r=1):
    if r == 0: return len(f)*[0]
    N = len(f)
    if N % 2**r: return 'fluct: %d is not divisible by 2**%d' %(N,r)
    a = trend(f,r-1)
    N = len(a) // 2
    d = [(a[2*j]-a[2*j+1])/2**(1/2) for j in range(N)]
    return d

# Iterative form of high level Haar transform
def HaarT(f,r=1):
    if r == 0: return f
    a = list(f); h = []; N = len(a)
    if N % 2**r: return 'HaarT: %d is not divisible by 2**%d' %(N,r)
    N = N // 2**r
    
    while len(a) > N:
        d = fluct(a)
        h = d + h
        a = trend(a)
        
    return a+h


#Test for levels 0-4
def f(x): return 10*x**2 * (1-x)**3 *cos(40*atan(1)*x)
X = arange(0, 1, 0.0001) 

for i in range(5):
    assert(all(haar(f(X),i) == HaarT(f(X),i)))



