from cdi_wavelets import *
import time
import numpy as np
import matplotlib.pyplot as plt

def f(x):
	return np.sin(x) * np.cos(x) ** 2

r = 3
x = np.arange(0, 2 ** 10, 0.01)

a = D4trend(f(x),r)
a = a[0:256]

print ('a = D4trend(f(x),3),length a: %d' % (len(a)))

L = len(a)
N = L * 2**r

## decompresion


'''
For the linear combination, it will be handy to have a function that we can apply to the three cases
'''
def lc(a,V):
    N = len(V[0])
    A = zeros(N)
    for (t, v) in zip(a, V):
        A += t * v
    return A
    
'''
Now we can compute the three decompressions
'''
D2VW = HaarVWA

def timing(func):
	def calc_time(a,r, N, e):
		time_start = time.clock()
		ret = func(a,r,N,e)
		time_end = time.clock()
		print ('Decompressing with D%d lasts: %fs' %(e, (time_end-time_start)))
		return ret
	return calc_time

@timing
def decompress(a, r, N, e):
	V = eval('D%dVW(N)[0][r]' % e)
	f = lc(a,V)
	return f


'''
VD6 = D6VW(N)[0][r] # Extract de V level r
f6 = lc(a, VD6)
'''
#decompress = timing(decompress)
f6 = decompress(a, r, N, 6)
f4 = decompress(a, r, N, 4)
f2 = decompress(a, r, N, 2)



def H2(x): return filter([1/sqrt(2), 1/sqrt(2)], x)
   
def HFDx(a,r=1,x=2):
	'''Generic high pass filter '''
	for _ in range(r):
		a = eval('H%d(U_(a))'%x)
	return array(a)


h6 = HFDx(a, r, 6)
h4 = HFDx(a, r, 4)
h2 = HFDx(a, r, 2)

'''
Checking for agreemenet of the two solutions
'''
agreement = True
nd = 15
eps = 10**(-nd)

m6 = max(abs(f6 - h6))
m4 = max(abs(f4 - h4))
m2 = max(abs(f2 - h2))
m = max(m6, m4, m2)

if m > eps: agreement = False

if agreement: 
	print('The two methods agree up to %d decimal places'%nd)
else: 
	print('The two methods disagree before or on the %d the decimal places'%nd)


conclussions =  '''
As we can observe D2 has a lot more loss with respect D4 and D6.
If we compare the running time of the decompressions, D6 (~17s) is slower than D4 and D2 (~12s).
'''

a = a + (1024 - len(a))*[0]
print (conclussions)
plt.close('all')
plt.figure("Comparing decompressions")
plt.subplot(4,1,4)

plt.subplot(4,1,1)
plt.axis([0, 1024, min(a), max(a)])
plt.plot(a)

plt.subplot(4,1,2)
plt.axis([0, 1024, min(f6), max(f6)])
plt.plot(f6, 'r')

plt.subplot(4,1,3)
plt.axis([0, 1024, min(f4), max(f4)])
plt.plot(f4, 'g')

plt.subplot(4,1,4)
plt.axis([0, 1024, min(f2), max(f2)])
plt.plot(f2, 'y')

plt.show()
