## L407_x
'''
Define functions SAE(M,P) and SAD(C,P) that implement
scaled arithmetic coding and decoding
'''

# Functions defined in previous labs
from cdi import *

# Cumulative deistribution of a distribution P.
# P is a list of pairs (x,p), 
# where x is a source symbol and p its probabiliy.
def accumulate(P): 
    S=[]
    s = 0   # current cumulative probability
    for (x,p) in P:
        s += p
        S += [(x,s)]
    return S

# Interval encoder
def IE(M, P):       
    S=accumulate(P)
    P=dict(P); S=dict(S)
    l=0.0; h=1.0; u=h-l
    for x in M:
        h=l+S[x]*u 
        u=P[x]*u 
        l=h-u
    return (l,h)

# Binary expression of x in [0,1)
def dec2bin(x,nb=58):
    if (x<0) | (x>1): return "dec2bin: was expecting a number in [0,1)"
    xb=''
    for j in range(nb):
        x=2*x
        if x<1:
            xb += '0'
        else:
            xb += '1'
            x -= 1
    return xb

# If xb is a string of bits, it returns the decimal number in [0,1)
# whose binary expansion is xb
def bin2dec(xb):
    x=0.0; j=1
    for b in xb:
        b=int(b)
        if b==0: pass
        else: x += b/2**j
        j=j+1
    return x

# Finding the bit code of an interval
def BE(a,b,nb=58):
    if (a<0) or (a>=b) or (b>1): 
        return "BE: was expecting 0<=a < b<=1"
    l=dec2bin(a,nb)
    h=dec2bin(b,nb)
    r=0
    while l[r]==h[r]:
        r += 1
        if r>=nb: 
            print("BE: The bit precision is too low")
            break  
    h=h[:r]+'1'
    if bin2dec(h)<b: return h
    else: 
        if bin2dec(l[:r])==a: return l[:r]
        x=l[(r+1):]
        if x[0]=='0': return l[:r]+'01'
        j=x.index('0')
        return l[:r+1+j]+'1'
 
# Arithmetic encoder of a message
def AE(M, P, nb=58):
    l, h = IE(M, P)
    return([len(M),BE(l,h,nb)])

# Arithmetic decoder
def AD(C,P):
    N = C[0]
    x=C[1]
    x=bin2dec(x)
    S=accumulate(P)
    P=dict(P) 
    l=0; h=1; M=''
    for j in range(N):
        u = h-l
        for (a,s) in S:
            if (l+s*u <= x): continue
            else: break
        M += a
        h = l + s*u; l = h-P[a]*u
    return M

## C = SAE(M,P)
# M is the message and P is the probability
# distribution, a list of pairs (s,p),
# where s is a symbol and p its probability



## M = SAD(C,P), C the coded message



