from easy1 import *
def bitsum(n):
return bin(n)[2:].count('1') % 2
def lin(S,i,j,x): # S: S-box, i: input mask, j: output mask, x: input
return bitsum(i & x) ^ bitsum(j & S(x))
def linappr(S,l,m): # S: S-box, l: input length, m: output length
T = {}
for i in range(2**l):
for j in range(2**m):
T[(i,j)] = - 2**(l-1)
for x in range(2**l):
z = lin(S,i,j,x)
if z==0: T[(i,j)] +=1
return T
T = linappr(sbox,6,6)
L = T.items()
L.sort(lambda x,y: abs(y[1]) - abs(x[1]))