## L407_x
'''
Define functions SAE(M,P) and SAD(C,P) that implement
scaled arithmetic coding and decoding
'''

# Functions defined in previous labs
from cdi import *

# Cumulative deistribution of a distribution P.
# P is a list of pairs (x,p), 
# where x is a source symbol and p its probabiliy.
def accumulate(P): 
    S=[]
    s = 0   # current cumulative probability
    for (x,p) in P:
        s += p
        S += [(x,s)]
    return S

# Interval encoder
def IE(M, P):       
    S=accumulate(P)
    P=dict(P); S=dict(S)
    l=0.0; h=1.0; u=h-l
    for x in M:
        h=l+S[x]*u 
        u=P[x]*u 
        l=h-u
    return (l,h)

# Binary expression of x in [0,1)
def dec2bin(x,nb=58):
    if (x<0) | (x>1): return "dec2bin: was expecting a number in [0,1)"
    xb=''
    for j in range(nb):
        x=2*x
        if x<1:
            xb += '0'
        else:
            xb += '1'
            x -= 1
    return xb

# If xb is a string of bits, it returns the decimal number in [0,1)
# whose binary expansion is xb
def bin2dec(xb):
    x=0.0; j=1
    for b in xb:
        b = int(b)
        if b==0: pass
        else: x += b/2**j
        j=j+1
    return x

# Finding the bit code of an interval
def BE(a,b,nb=58):
    if (a<0) or (a>=b) or (b>1): 
        return "BE: was expecting 0<=a < b<=1"
    l=dec2bin(a,nb)
    h=dec2bin(b,nb)
    r=0
    while l[r]==h[r]:
        r += 1
        if r>=nb: 
            print("BE: The bit precision is too low")
            break  
    h=h[:r]+'1'
    if bin2dec(h)<b: return h
    else: 
        if bin2dec(l[:r])==a: return l[:r]
        x=l[(r+1):]
        if x[0]=='0': return l[:r]+'01'
        j=x.index('0')
        return l[:r+1+j]+'1'
 
# Arithmetic encoder of a message
def AE(M, P, nb=58):
    l, h = IE(M, P)
    return([len(M),BE(l,h,nb)])

# Arithmetic decoder
def AD(C,P):
    N = C[0]
    x=C[1]
    x=bin2dec(x)
    S=accumulate(P)
    P=dict(P) 
    l=0; h=1; M=''
    for j in range(N):
        u = h-l
        for (a,s) in S:
            if (l+s*u <= x): continue
            else: break
        M += a
        h = l + s*u; l = h-P[a]*u
    return M

## C = SAE(M,P)
# M is the message and P is the probability
# distribution, a list of pairs (s,p),
# where s is a symbol and p its probability

def SAE(M,P):
    C = ''
    S = accumulate(P)
    P = dict(P); S = dict(S)
    l = 0.0; h = 1.0; u = h-l
    for x in M:
        h = l+S[x]*u; u = P[x]*u; l = h-u
        while True:
            while h <= 0.5:
                C += '0'
                h = 2*h; l = 2*l; u = 2*u
            while l >= 0.5:
                C += '1'
                h = min(2*h-1,1); l = max(2*l-1,0); u = h-1
            if l < 0.5 < h: break
    C += BE(max(l,0),min(h,1))
    return [len(M),C]

## M = SAD(C,P), C the coded message

# Scale arothmetic decoder
# C = [N,b], where N is the length
# of the original message and x is
# the code binary string

def SAD(C,P):
    [N,x] = C
    S = accumulate(P)
    D = dict(S)
    P = dict(P)
    nb = 32
    M = ''
    l = 0.0; h = 1.0; u = h-l
    for j in range (1,N+1):
        w = x[:nb]
        s = symbol(w,l,h,S)
        M += s
        h = l+D[s]*u; u = P[s]*u; l = h-u
        while True:
            while h <= 0.5:
                x = x[1:]
                l = 2*l; h = 2*h; u = 2*u
            while l >= 0.5:
                x = x[1:]
                l = 2*l-1; h = 2*h-1; u = h-l
            if l < 0.5 < h: break
    return M

def symbol(w,l,h,S):
    u = h-l
    y = bin2dec(w)
    for (a,s) in S:
        if l+s*u <= y:
            continue
        else:
            break
    return a

## tests

P = [('a', 0.25), ('b', 0.4), ('c', 0.15), ('d', 0.1), ('e', 0.1)]
M = 'abcebabcbdbeb'
