## L407_x
'''
Define functions SAE(M,P) and SAD(C,P) that implement
scaled arithmetic coding and decoding
'''

# Functions defined in previous labs
from cdi import *

# Cumulative deistribution of a distribution P.
# P is a list of pairs (x,p), 
# where x is a source symbol and p its probabiliy.
def accumulate(P): 
    S=[]
    s = 0   # current cumulative probability
    for (x,p) in P:
        s += p
        S += [(x,s)]
    return S

# Interval encoder
def IE(M, P):       
    S=accumulate(P)
    P=dict(P); S=dict(S)
    l=0.0; h=1.0; u=h-l
    for x in M:
        h=l+S[x]*u 
        u=P[x]*u 
        l=h-u
    return (l,h)

# Binary expression of x in [0,1)
def dec2bin(x,nb=58):
    if (x<0) | (x>1): return "dec2bin: was expecting a number in [0,1)"
    xb=''
    for j in range(nb):
        x=2*x
        if x<1:
            xb += '0'
        else:
            xb += '1'
            x -= 1
    return xb

# If xb is a string of bits, it returns the decimal number in [0,1)
# whose binary expansion is xb
def bin2dec(xb):
    x=0.0; j=1
    for b in xb:
        b=int(b)
        if b==0: pass
        else: x += b/2**j
        j=j+1
    return x

# Finding the bit code of an interval
def BE(a,b,nb=58):
    if (a<0) or (a>=b) or (b>1): 
        return "BE: was expecting 0<=a < b<=1"
    l=dec2bin(a,nb)
    h=dec2bin(b,nb)
    r=0
    while l[r]==h[r]:
        r += 1
        if r>=nb: 
            print("BE: The bit precision is too low")
            break  
    h=h[:r]+'1'
    if bin2dec(h)<b: return h
    else: 
        if bin2dec(l[:r])==a: return l[:r]
        x=l[(r+1):]
        if x[0]=='0': return l[:r]+'01'
        j=x.index('0')
        return l[:r+1+j]+'1'
 
# Arithmetic encoder of a message
def AE(M, P, nb=58):
    l, h = IE(M, P)
    return([len(M),BE(l,h,nb)])

# Arithmetic decoder
def AD(C,P):
    N = C[0]
    x=C[1]
    x=bin2dec(x)
    S=accumulate(P)
    P=dict(P) 
    l=0; h=1; M=''
    for j in range(N):
        u = h-l
        for (a,s) in S:
            if (l+s*u <= x): continue
            else: break
        M += a
        h = l + s*u; l = h-P[a]*u
    return M

## C = SAE(M,P)
# M is the message and P is the probability
# distribution, a list of pairs (s,p),
# where s is a symbol and p its probability
def SAE(M, P):
    C = ''
    S = accumulate(P)
    P = dict(P)
    S = dict(S)
    l = 0.0
    h = 1.0
    u = h - l
    for x in M:
      h = l + S[x]*u
      u = P[x] * u
      l = h - u
      while True:
        while h <= 0.5:
            C += '0'
            h *= 2
            l *= 2
            u = h - l
        while l >= 0.5:
            C += '1'
            h = min(2 * h - 1, 1.0)
            l = max(2 * l - 1, 0.0)
            u = h - l
        if l < 0.5 < h:
            break;
    C += BE(max(l, 0.0), min(h, 1.0))

    return [len(M), C]

def symbol(w, l, h, S):
    u = h - l
    y = bin2dec(w)
    for (a,s) in S:
        if l + s*u <= y:
            continue
        else:
            break
    return a

## M = SAD(C,P), C the coded message
def SAD(C, P):
    [N, x] = C
    S = accumulate(P)
    P = dict(P)
    D = dict(S)
    nb = 32
    M = ''
    l = 0.0; h = 1.0
    u = h - l
    for j in range(1, N+1):    
        w = x[:nb]
        s = symbol(w,l,h,S)
        M += s
        h = l + D[s]*u
        u = P[s] * u
        l = h - u 
        while True:
            while h <= 0.5:
                x = x[1:]
                l = 2*l
                h = 2*h
                u = 2*u
            while l >= 0.5:
                x = x[1:]
                l = 2*l - 1 
                h = 2*h - 1
                u = h - l

            if l < 0.5 < h:
                break
    return M

P = [('a', 0.2), ('b', 0.3), ('c', 0.3),('d', 0.2)]
M = 'abbbdacccdababcda'*10
print('Original Message: ', M)
C = SAE(M,P)
print('Encoded Message: ', C[1])
M2 = SAD(C, P)
print('Decoded Message: ', M2)
if M == M2:
    print('Encoding/Decoding Succesful!')
else:
    print('Something went wrong at Encoding/Decoding')


