## cdi_dct
## SXD 521
'''
1D Discrete Cosine Transform
'''

from cdi import * 
import numpy as np

# Synonyms
array = np.array
mat = matrix  = np.matrix
ones = np.ones
zeros = np.zeros
splice=np.hstack
stack = np.vstack
ls = np.linspace
meshgrid=np.meshgrid
sin = np.sin
pi = np.pi
cos = np.cos
sqrt = np.sqrt
uint8=np.uint8

# DCT matrix of order n
def dct1_matrix(n):
    T = ones(n)*sqrt(1/n)
    for i in range(1,n):
        t = mat([cos((2*j+1)*i*pi/2/n) for j in range(n)])*sqrt(2/n)
        T = stack([T,t])
    return T

# DCT of vector x
def dct1(x):
    n = len(x)
    #print("x =",x)
    M = dct1_matrix(n)
    y = x*M
    return y.A1  #[y[0,k] for k in range(n)]

# inverse DCT of vector y
def idct1(y):
    n = len(y)
    M = dct1_matrix(n)
    x = y*(M.T)
    return x.A1   #[x[0,k] for k in range(n)]
    


'''
2D Discrete Cosine Transform
'''

def multiplier(n):
    if n>0: return 1
    else: return 1/sqrt(2)

def dct2(X):
    n, m = X.shape
    if n != m: return 'dct2: data has to be a square matrix'
    Y = zeros([n,n])
    m = multiplier
    for r in range(n):
        for s in range(n):
            Y[r,s] = m(r)*m(s)*sum(sum(X[i,j]* \
              cos(((2*i+1)*r*pi)/(2*n))*cos(((2*j+1)*s*pi)/(2*n)) \
                for j in range(n)) for i in range(n))
    return matrix(2*Y/n)

def idct2(Y):
    n, m = Y.shape
    if n != m: return 'dct2 Error: data is not a squate matrix'
    X = zeros([n,n])
    m = multiplier
    for i in range(n):
        for j in range(n):
            X[i,j] = sum(sum( 
              m(r)*m(s)*Y[r,s]*cos(((2*i+1)*r*pi)/(2*n))* \
                cos(((2*j+1)*s*pi)/(2*n)) \
                  for s in range(n)) for r in range(n))
    return matrix(2*X/n)

''' Auxiliary functions'''

# threshing a vector or matrix
def Q(T,thresh=1, nd=0): 
    M = matrix(T)
    m, n = M.shape
    Z = zeros([m,n])
    for i in range(m):
        for j in range(n):
            Z[i,j] = round(float(M[i,j]),nd)
            if abs(Z[i,j])<thresh: Z[i,j]=0
    if m == 1: Z = Z[0,:]  #[Z[0,k] for k in range(n)]
    return Z

# Function to reverse the order of the rows of a matrix.
def reverse_rows(X): return X[::-1]
# Function to reverse the order of the columns of a matrix.
def reverse_cols(X): return (reverse_rows(X.T)).T

def solomon(X):
    Y = reverse_cols(X)
    Y = splice([X,Y])
    Z = reverse_rows(Y)  
    return stack([Y,Z])


