## L407
## Albert Puente Encinas


'''
Define functions SAE(M,P) and SAD(C,P) that implement
scaled arithmetic coding and decoding
'''

# Functions defined in previous labs
from cdi import *

# Cumulative deistribution of a distribution P.
# P is a list of pairs (x,p), 
# where x is a source symbol and p its probabiliy.
def accumulate(P): 
    S=[]
    s = 0   # current cumulative probability
    for (x,p) in P:
        s += p
        S += [(x,s)]
    return S

# Interval encoder
def IE(M, P):       
    S=accumulate(P)
    P=dict(P); S=dict(S)
    l=0.0; h=1.0; u=h-l
    for x in M:
        h=l+S[x]*u 
        u=P[x]*u 
        l=h-u
    return (l,h)

# Binary expression of x in [0,1)
def dec2bin(x,nb=58):
    if (x<0) | (x>1): return "dec2bin: was expecting a number in [0,1)"
    xb=''
    for j in range(nb):
        x=2*x
        if x<1:
            xb += '0'
        else:
            xb += '1'
            x -= 1
    return xb

# If xb is a string of bits, it returns the decimal number in [0,1)
# whose binary expansion is xb
def bin2dec(xb):
    x=0.0; j=1
    for b in xb:
        b=int(b)
        if b==0: pass
        else: x += b/2**j
        j=j+1
    return x

# Finding the bit code of an interval
def BE(a,b,nb=58):
    if (a<0) or (a>=b) or (b>1): 
        return "BE: was expecting 0<=a < b<=1"
    l=dec2bin(a,nb)
    h=dec2bin(b,nb)
    r=0
    while l[r]==h[r]:
        r += 1
        if r>=nb: 
            print("BE: The bit precision is too low")
            break  
    h=h[:r]+'1'
    if bin2dec(h)<b: return h
    else: 
        if bin2dec(l[:r])==a: return l[:r]
        x=l[(r+1):]
        if x[0]=='0': return l[:r]+'01'
        j=x.index('0')
        return l[:r+1+j]+'1'
 
# Arithmetic encoder of a message
def AE(M, P, nb=58):
    l, h = IE(M, P)
    return([len(M),BE(l,h,nb)])

# Arithmetic decoder
def AD(C,P):
    N = C[0]
    x=C[1]
    x=bin2dec(x)
    S=accumulate(P)
    P=dict(P) 
    l=0; h=1; M=''
    for j in range(N):
        u = h-l
        for (a,s) in S:
            if (l+s*u <= x): continue
            else: break
        M += a
        h = l + s*u; l = h-P[a]*u
    return M

## C = SAE(M,P)
# M is the message and P is the probability
# distribution, a list of pairs (s,p),
# where s is a symbol and p its probability

def SAE(M, P):
    C = ''
    S = accumulate(P)
    P = dict(P)
    S = dict(S)
    l = 0.0
    h = 1.0
    u = h - l
    for x in M:
        h = l + S[x]*u
        u = P[x]*u
        l = h - u
        while True:
            while h <= 0.5:
                C += '0'
                h = 2*h
                l = 2*l
                u = 2*u
            while l >= 0.5:
                C += '1'
                h = min(2*h - 1, 1)
                l = max(2*l - 1, 0)
                u = h - l
            if l < 0.5 < h: 
                break
    
    C += BE(max(l, 0), min(h, 1))
    return [len(M), C]


## M = SAD(C,P), C the coded message
'''
Scaled arithmetic decoder
C = [N, b], where N is the length
of the original message and b is 
the code binary string.
'''

def symbol(w, l, h, S):
    u = h - l
    y = bin2dec(w)
    for (a, s) in S:
        if l + s*u > y: break
    return a

def SAD(C, P):
    [N, x] = C
    S = accumulate(P)
    D = dict(S)
    P = dict(P)
    M = ''
    nb = 32
    l = 0.0
    h = 1.0
    u = h - l
    for j in range(1, N+1):
        w = x[:nb]
        s = symbol(w, l, h, S)
        M += s
        h = l + D[s]*u
        u = P[s]*u
        l = h - u
        while True:
            while h <= 0.5:
                x = x[1:]
                l = 2*l
                h = 2*h
                u = 2*u
            while l >= 0.5:
                x = x[1:]
                l = max(2*l - 1, 0)
                h = min(2*h - 1, 1)
                u = h - l
            if l < 0.5 < h: 
                break
    return M

## Simple test

P = [('a', 0.25), ('b', 0.4), ('c', 0.15), ('d', 0.1), ('e', 0.1)]
M = 'badbbdcbabea'*10
print ('Original message:', M)

C = SAE(M, P)
print ('Codification:', C)

D = SAD(C, P)
print ('Decodification:', D)

if M == D: print ('It works.')
else: print ('Not working.')

## Complex text test with equal probabilities

chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz,.'
P = list(zip(list(chars), [1/len(chars)]*len(chars)))

M = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.'
print ('Original message:', M)

C = SAE(M, P)
print ('Codification:', C)

D = SAD(C, P)
print ('Decodification:', D)

if M == D: print ('It works.')
else: print ('Not working.')
