from easy1 import *

def bitsum(n):
    return bin(n)[2:].count('1') % 2

def lin(S,i,j,x): # S: S-box, i: input mask, j: output mask, x: input
    return bitsum(i & x) ^ bitsum(j & S(x))

def linappr(S,l,m): # S: S-box, l: input length, m: output length
    T = {}
    for i in range(2**l):
        for j in range(2**m):
            T[(i,j)] = - 2**(l-1)
            for x in range(2**l):
                z = lin(S,i,j,x)
                if z==0: T[(i,j)] +=1
    return T




T = linappr(sbox,6,6)

L = T.items()

L.sort(lambda x,y: abs(y[1]) - abs(x[1]))