## L407 - Andrés Mingorance
'''
Define functions SAE(M,P) and SAD(C,P) that implement
scaled arithmetic coding and decoding
'''

# Functions defined in previous labs
from cdi import *

# Cumulative distribution of a distribution P.
# P is a list of pairs (x,p), 
# where x is a source symbol and p its probability.
def accumulate(P):
    S=[]
    s = 0   # current cumulative probability
    for (x,p) in P:
        s += p
        S += [(x,s)]
    return S

# Interval encoder
def IE(M, P):       
    S=accumulate(P)
    P=dict(P); S=dict(S)
    l=0.0; h=1.0; u=h-l
    for x in M:
        h=l+S[x]*u 
        u=P[x]*u 
        l=h-u
    return (l,h)

# Binary expression of x in [0,1)
def dec2bin(x,nb=58):
    if (x<0) | (x>1): return "dec2bin: was expecting a number in [0,1)"
    xb=''
    for j in range(nb):
        x=2*x
        if x<1:
            xb += '0'
        else:
            xb += '1'
            x -= 1
    return xb

# If xb is a string of bits, it returns the decimal number in [0,1)
# whose binary expansion is xb
def bin2dec(xb):
    x=0.0; j=1
    for b in xb:
        b=int(b)
        if b==0: pass
        else: x += b/2**j
        j=j+1
    return x

# Finding the bit code of an interval
def BE(a,b,nb=58):
    if (a<0) or (a>=b) or (b>1): 
        return "BE: was expecting 0<=a < b<=1"
    l=dec2bin(a,nb)
    h=dec2bin(b,nb)
    r=0
    while l[r]==h[r]:
        r += 1
        if r>=nb: 
            print("BE: The bit precision is too low")
            break  
    h=h[:r]+'1'
    if bin2dec(h)<b: return h
    else: 
        if bin2dec(l[:r])==a: return l[:r]
        x=l[(r+1):]
        if x[0]=='0': return l[:r]+'01'
        j=x.index('0')
        return l[:r+1+j]+'1'
 
# Arithmetic encoder of a message
def AE(M, P, nb=58):
    l, h = IE(M, P)
    return([len(M),BE(l,h,nb)])

# Arithmetic decoder
def AD(C,P):
    N = C[0]
    x=C[1]
    x=bin2dec(x)
    S=accumulate(P)
    P=dict(P) 
    l=0; h=1; M=''
    for j in range(N):
        u = h-l
        for (a,s) in S:
            if (l+s*u <= x): continue
            else: break
        M += a
        h = l + s*u; l = h-P[a]*u
    return M

## C = SAE(M,P)
# M is the message and P is the probability
# distribution, a list of pairs (s,p),
# where s is a symbol and p its probability

def SAE(M, P) :
    C = ''
    S = accumulate(P)
    P = dict(P); S = dict(S)
    l = 0.0; h = 1.0; u = h - l
    for x in M :
        h = l + S[x] * u
        u *= P[x]
        l = h - u
        while h <= 0.5 or l >= 0.5 :
            if h <= 0.5 :
                C += '0'
                h *= 2; l *= 2; u *= 2
            else : # l >= 0.5
                C += '1'
                h = min(h*2-1, 1); l = max(l*2-1, 0); u = h - l
    C += BE(max(l, 0), min(h, 1))
    return [len(M), C]

## M = SAD(C,P), C the coded message
# Scaled arithmetic decoder
# C = [N,b], where N is the length
# of the original message, and b its
# binary string

def SAD(C, P) :
    [N, x] = C
    S = accumulate(P)
    P = dict(P)
    D = dict(S)
    nb = 32
    M = ''
    l = 0.0; h = 1.0; u = h - l
    for _ in range(N) :
        w = x[:nb]
        s = symbol(w, l, h, S)
        M += s
        h = l + D[s]*u; u = P[s]*u; l = h - u
        while h <= 0.5 or l >= 0.5 :
            x = x[1:]
            if h <= 0.5 : h *= 2; l *= 2; u *= 2
            else : h = min(h*2-1, 1); l = max(l*2-1, 0); u = h - l
    return M


def symbol(w, l, h, S) :
    u = h - l
    y = bin2dec(w)
    c = ''
    for c,s in S :
        if l + s*u > y: break
    return c


P = [('a', 0.25), ('b', 0.4), ('c', 0.15), ('d', 0.1), ('e', 0.1)]

def test(P, times=1000, length = 100) :
    print("Testing", times, "encoding/decoding of messages from P:", P, "of length", length, "...")
    for step in range(times) :
        msg = ''
        n = len(P)
        # build a message of length 'length'
        for _ in range(length) : msg += P[rd_int(0, n-1)][0]
        if msg != SAD(SAE(msg,P), P) : raise Exception("Decoding went wrong on step", step)
    print("All tests passed")

test(P) #1000 tests of length 100
test(P, 100, 1000)
test(P, 5, 33333)

abc = 'abcdefghijklmnopqrstuvwxyz '
test(list(zip(abc, [1/len(abc)]*len(abc))))
