#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Chiffre "EASY1" du livre de Christopher Swenson, "Modern Cryptanalysis"
# Blocs de 36 bits
# Clef de 18 bits
# en binaire pour ne pas se tromper: par exemple,
# K = 0b111100001111000011111100001111000011

# S-box 6 bits --> 6 bits, 64 valeurs de 0 à 63

s = [16, 42, 28, 3, 26, 0, 31, 46, 27, 14, 49, 62, 37, 56, 23, 6, 40, 48, 53, 8,
     20, 25, 33, 1, 2, 63, 15, 34, 55, 21, 39, 57, 54, 45, 47, 13, 7, 44, 61, 9,
     60, 32, 22, 29, 52, 19, 12, 50, 5, 51, 11, 18, 59, 41, 36, 30, 17, 38, 10, 4,
     58, 43, 35, 24]

# P-box sur 36 bits

p = [24, 5, 15, 23, 14, 32, 19, 18, 26, 17, 6, 12, 34, 9, 8, 20, 28, 0, 2, 21, 29,
     11, 33, 22, 30, 31, 1, 25, 3, 35, 16, 13, 27, 7, 10, 4]

mask = int('1'*36,2)

def sbox(x) : return s[x]

def pbox(x):
    y = 0L
    for i in range(len(p)):
        if (x & (1L << i)): y = y^(1L << p[i])
    return y

def apbox(x):
    y = 0l
    for i in range(len(p)):
        if (x & (1l << i)) != 0:
            pval = p.index(i)
            y = y ^ (1l << pval)
    return y

def asbox(x):
    return s.index(x)





# 36 bits --> 6*6 bits
# Revient à écrire x en bases 64. Les digits de poids
# faible sont au début de la liste (!)

def demux(x):
    return [(x >> 6*i) & 0b111111 for i in range(6)]

def mux(xx):
    y = 0L
    for i in range(6): y = y^(xx[i] << (6*i))
    return y

# XOR avec la clef k

def mix(xx,k):
    key = demux(k)
    return [xx[i]^key[i] for i in range(6)]


# pour les chaines

def strtoblocks(x):# x est une chaîne de 9 caractères
    y = ''.join(["%02X" % ord(c) for c in x])
    return int(y[:9],16), int(y[9:],16)

def blockstostr(y,z):
    z = ("%09X" % y) + ("%09X" % z)
    return ''.join([chr(int(z[2*i:2*i+2],16)) for i in range(9)])



class easy1:
    def __init__(self, key, Nrounds):
        self.k = key
        self.N = Nrounds

    def round(self, x, k):
        return mux(mix(demux(pbox(mux(map(sbox, demux(x))))),self.k))

    def unround(self, c, k):
        x = demux(c)
        u = mix(x, k)
        v = demux(apbox(mux(u)))
        w = []
        for s in v:
            w.append(asbox(s))
        return mux(w)


    def encrypt_block(self,x):
        for i in range(self.N): x = self.round(x, self.k)
        return x

    
    def decrypt_block(self, c):
        x = c
        for i in range(self.N): x = self.unround(x, self.k)
        return x

# longueur des chaines multiple de 9 (72=2*36 bits)
    def encrypt(self, t):
        if len(t)%9:
            raise ValueError('Input length must be multiple of 9')
        ll = [strtoblocks(t[9*i:9*i+9]) for i in range(len(t)/9)]
        ll = [blockstostr(self.encrypt_block(x[0]),
                         self.encrypt_block(x[1])) for x in ll]
        return ''.join(ll)
        
        

    def decrypt(self, c):
        if len(c)%9:
            raise ValueError('Input length must be multiple of 9')
        ll = [strtoblocks(c[9*i:9*i+9]) for i in range(len(c)/9)]
        ll = [blockstostr(self.decrypt_block(x[0]),
                         self.decrypt_block(x[1])) for x in ll]
        return ''.join(ll)

        
            
if __name__ == "__main__":
    
    K = 0b111100001111000011111100001111000011
    e=easy1(K,2)
    y = e.encrypt_block(0xffff)
    print y
    x = e.decrypt_block(y)
    print hex(x)
    z = e.encrypt_block(0x1111)
    print z
    t='abcdefghi123456789'*4
    c = e.encrypt(t)
    c
    d =e.decrypt(c)
    print d