/ attacks / attack-4way.py
attack-4way.py
  1  #!/usr/bin/env python
  2  
  3  import angr, claripy, IPython, sys, time, datetime, logging
  4  from claripy import BVS, BVV, Solver, Or, RotateLeft, RotateRight, And, If
  5  from angr.block import CapstoneInsn
  6  #from angr.procedures.libc.memset import memset
  7  #from angr.procedures.libc.memcpy import memcpy
  8  from binascii import unhexlify
  9  from inspect import getframeinfo, stack
 10  
 11  l = logging.getLogger("angr")
 12  
 13  # silence some annoying logs
 14  l.setLevel("ERROR")
 15  
 16  from itertools import zip_longest
 17  def split_by_n(iterable, n):
 18      return zip_longest(*[iter(iterable)]*n, fillvalue='')
 19  
 20  SBOX = [0xF4, 0xAF, 0x8A, 0xD1, 0x3B, 0x02, 0xE8, 0x20,   0xCD, 0x65, 0x96, 0x1C, 0x47, 0xB3, 0x79, 0x5E,
 21          0x18, 0x8B, 0xE3, 0xAE, 0x7D, 0x4A, 0x94, 0xDF,   0x69, 0x30, 0xBC, 0x56, 0xF5, 0x07, 0x21, 0xC2,
 22          0xBD, 0x72, 0x9C, 0x59, 0xAE, 0x17, 0xF3, 0x61,   0x24, 0xC8, 0x40, 0xDF, 0xE6, 0x8A, 0x35, 0x0B,
 23          0x27, 0x4D, 0x56, 0xC8, 0x91, 0xB3, 0x70, 0x84,   0xF5, 0xEF, 0xD2, 0xAE, 0x3A, 0x1C, 0x0B, 0x69,
 24          0x47, 0x9F, 0x80, 0x5C, 0x0A, 0x68, 0xA1, 0xEB,   0xB9, 0x2D, 0x75, 0xF3, 0x1E, 0x32, 0xD6, 0xC4,
 25          0xB3, 0xAE, 0xED, 0x09, 0x91, 0xD4, 0x38, 0x26,   0x6A, 0xC0, 0xFB, 0x75, 0x82, 0x5F, 0x4C, 0x17,
 26          0x59, 0x27, 0x16, 0x4D, 0xDB, 0xEF, 0x04, 0x9C,   0xF0, 0xB8, 0x62, 0xCE, 0x3A, 0xA1, 0x73, 0x85,
 27          0x18, 0x5D, 0x47, 0x6E, 0xC5, 0xA0, 0x9B, 0xFA,   0x32, 0xE3, 0x8C, 0x01, 0xDF, 0x74, 0x29, 0xB6]
 28  
 29  lupb = False
 30  
 31  threadid=None
 32  if len(sys.argv) == 2:
 33      if 0<= int(sys.argv[1])< 4:
 34          threadid=int(sys.argv[1])
 35      else:
 36          print("thread id must be between 0 & 3")
 37          sys.exit(1)
 38  
 39  asserted = set()
 40  def assert_once(v):
 41      caller = getframeinfo(stack()[1][0])
 42      if caller.lineno in asserted: return
 43      asserted.add(caller.lineno)
 44      assert v
 45  
 46  def concat(l):
 47      l = list(l)
 48      return l[0].concat(*l[1:])
 49  
 50  def set_byte(array, pos, val):
 51      if pos < ((len(array) // 8) - 1):
 52          return concat((array.get_bytes(0,pos), val, array.get_bytes(pos+1,7-pos)))
 53      else:
 54          return concat((array.get_bytes(0,pos), val))
 55  
 56  def get_sboxes(sbox):
 57      return [{'hi': tuple((x>>4) for x in t),
 58              'lo': tuple((x&0xf) for x in t)}
 59              for t in split_by_n(sbox,16)]
 60  
 61  def moebius(f,n):
 62      blocksize=1
 63      for step in range(1,n+1):
 64          source=0
 65          while(source < (1<<n)):
 66              target = source + blocksize
 67              for i in range(blocksize):
 68                  f[target+i]^=f[source+i]
 69              source+=2*blocksize
 70          blocksize*=2
 71  
 72  def get_anfs():
 73     mappings = get_sboxes(SBOX)
 74     res = {}
 75  
 76     fs=[{'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 77         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 78         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 79         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 80         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 81         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 82         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 83         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]}]
 84  
 85     ms=[{'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 86         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 87         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 88         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 89         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 90         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 91         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 92         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]}]
 93  
 94     for p, m in enumerate(mappings):
 95        #print("position", p)
 96        for nib in ('hi','lo'):
 97           #print("nibble", nib)
 98           for bit in range(4):
 99               #print("bit", bit)
100               ones=0
101               for i in range(16):
102                   b = (m[nib][i] >> bit) & 1
103                   fs[p][nib][bit][i]=b
104                   ms[p][nib][bit][i]=b
105                   ones+=b
106               #print(''.join(f"{b}" for b in fs[p][nib][bit]), ones)
107               moebius(ms[p][nib][bit],4)
108               #print(''.join(f"{b}" for b in ms[p][nib][bit]), ms[p][nib][bit].count(1))
109               #print()
110  
111     # these have odd number of 1 constant terms in their polinomial
112     odd_const_terms = [(0, 'hi', 0), (0, 'hi', 1), (0, 'hi', 2), (0, 'hi', 3), (0, 'lo', 2),
113                        (1, 'hi', 0), (1, 'lo', 3), (2, 'hi', 0), (2, 'hi', 1), (2, 'hi', 3),
114                        (2, 'lo', 0), (2, 'lo', 2), (2, 'lo', 2), (2, 'lo', 3), (3, 'hi', 1),
115                        (3, 'lo', 0), (3, 'lo', 1), (3, 'lo', 2), (4, 'hi', 2), (4, 'lo', 0),
116                        (4, 'lo', 1), (4, 'lo', 2), (5, 'hi', 0), (5, 'hi', 1), (5, 'hi', 3),
117                        (5, 'lo', 0), (5, 'lo', 1), (6, 'hi', 0), (6, 'hi', 2), (6, 'lo', 0),
118                        (6, 'lo', 3), (7, 'hi', 0), (7, 'lo', 3)]
119     for p in range(8):
120        for nib in ('hi','lo'):
121           for j, (f, m) in enumerate(zip(fs[p][nib],ms[p][nib])):
122               f_ = ' ^ '.join(c for c in ['&'.join(f"x{i}" for i, x in enumerate(reversed(f'{a:04b}')) if x =='1') for a in range(16) if m[a]==1] if c)
123               # these have odd number of 1 constant terms in their polinomial
124               if (p,nib,j) in odd_const_terms:
125                   f_ = '1 ^ ' + f_
126               #print((p,nib,j), f_)
127               evaled =''.join(str(eval(f_, {f"x{i}":int(x) for i, x in enumerate(reversed(f'{a:04b}'))})) for a in range(16))
128               #print('joined', ''.join(f"{b}" for b in f))
129               #print('evaled', evaled)
130               assert evaled==''.join(f"{b}" for b in f)
131  
132           f_ = []
133           for j, (f, m) in enumerate(zip(fs[p][nib],ms[p][nib])):
134               f_.append(' ^ '.join('('+c+')' for c in ['&'.join(f"x[{i}]" for i, x in enumerate(reversed(f'{a:04b}')) if x =='1') for a in range(16) if m[a]==1] if c))
135               if (p,nib,j) in odd_const_terms:
136                   f_[-1] = '1 ^ ' + f_[-1]
137           #print((p,nib), ', '.join(reversed(f_)))
138           res[(p,nib)] = ', '.join(reversed(f_))
139     return res
140  
141  def map4x4(anfs, pos, nib, x):
142      #return eval(f"concat(({anfs[(pos,nib)]}))", {"x": x, "concat": lambda l: list(l)[0].concat(*l[1:])})
143      return eval(anfs[(pos,nib)], {"x": x})
144  
145  def boxperm(i, k, tmp):
146     if(i==0):
147        if (k&0xc != 0):
148            return (k-4)&0xf
149        k1 = (k + (tmp & 3)) % 0x100  # 1f6f
150        return (((((tmp << 6) | (tmp >> 2)) + k1) & 3) | 0xc) % 0x100
151     elif(i==1):
152        if (k&3!=0):
153            return (k-1)&0xf
154        k1 = ((k << 6) | (k >> 2)) & 0xff # 1fa2 .. 1fa6
155        k1 += tmp & 3                               # 1fb7, 1fbb
156        tmp1 = ((((tmp << 6) | (tmp >> 2)) & 0xff) + k1) & 3
157        return (((tmp1 >> 6 | (tmp1 << 2)) & 0xff) | 3) & 0xff    # 1fbd .. 1fc8
158     elif(i==2):
159        if (k&0xc != 0xc):
160            return (k+4)&0xf
161        k1 = k & 3         # 1fe3 .. 1fe7
162        k1 = (k1 + (tmp & 3)) % 0x100  # 1ff8, 1fbb
163        return ((((tmp << 6) | (tmp >> 2)) % 0x100) + k1) & 3 # 1ffe .. 2005
164     else:
165        if (k&0x3 != 3):
166            return (k+1)&0xf
167        k1 = (k >> 2) & 3  # 202c
168        k1 = (k1 + (tmp & 3)) % 0x100  # 203d, 2041
169        tmp1 = ((((tmp << 6) | (tmp >> 2)) % 0x100) + k1) & 3  # 2043 .. 2048
170        return ((tmp1 << 2) | (tmp1 >> 6)) % 0x100 # 204a .. 204c
171  
172  def get_boxperm_anfs():
173     res = {}
174  
175     fs=[[0]*1024,[0]*1024,[0]*1024,[0]*1024]
176     ms=[[0]*1024,[0]*1024,[0]*1024,[0]*1024]
177  
178     for bit in range(4):
179         #print("bit", bit)
180         ones=0
181         for i in range(1024):
182             b = (bpsbox[i] >> bit) & 1
183             fs[bit][i]=b
184             ms[bit][i]=b
185             ones+=b
186         #print(''.join(f"{b}" for b in fs[bit]), ones)
187         moebius(ms[bit],10)
188         #print(''.join(f"{b}" for b in ms[bit]), ms[bit].count(1))
189         #print()
190  
191     # these have odd number of 1 constant terms in their polinomial
192     odd_const_terms = [2,3]
193     for j, (f, m) in enumerate(zip(fs,ms)):
194         f_ = ' ^ '.join(c for c in ['&'.join(f"x{i}" for i, x in enumerate(reversed(f'{a:010b}')) if x =='1') for a in range(1024) if m[a]==1] if c)
195         # these have odd number of 1 constant terms in their polinomial
196         if j in odd_const_terms:
197             f_ = '1 ^ ' + f_
198         #print(j, f_)
199         evaled =''.join(str(eval(f_, {f"x{i}":int(x) for i, x in enumerate(reversed(f'{a:010b}'))})) for a in range(1024))
200         #print('joined', ''.join(f"{b}" for b in f))
201         #print('evaled', evaled)
202         assert evaled==''.join(f"{b}" for b in f)
203  
204     f_ = []
205     for j, (f, m) in enumerate(zip(fs,ms)):
206         f_.append(' ^ '.join('('+c+')' for c in ['&'.join(f"x[{i}]" for i, x in enumerate(reversed(f'{a:010b}')) if x =='1') for a in range(1024) if m[a]==1] if c))
207         if j in odd_const_terms:
208             f_[-1] = '1 ^ ' + f_[-1]
209     #print((p,nib), ', '.join(reversed(f_)))
210     return ', '.join(reversed(f_))
211  
212  def map_boxperm(anf, x):
213      return eval(anf, {"x": x})
214  
215  def getFuncAddress(cfg, funcName, plt=None ):
216      found = [
217          addr for addr,func in cfg.kb.functions.items()
218          if funcName == func.name and (plt is None or func.is_plt == plt)
219      ]
220      if len( found ) > 0:
221          l.info("Found "+funcName+"'s address at "+hex(found[0])+"!")
222          return found[0]
223      else:
224          raise Exception("No address found for function : "+funcName)
225  
226  def getRetAddr(proj, fn):
227      # let's disasm with capstone to search targets
228      insn_bytes = proj.loader.memory.load(fn, 1000)
229      for cs_insn in proj.arch.capstone.disasm(insn_bytes, fn):
230          ins = CapstoneInsn(cs_insn)
231          if ins.mnemonic == "ret":
232              l.info(f"Found lfsr's return address at 0x{ins.address:x}!")
233              return ins.address
234      raise ValueError("failed to find ret op in {fn}")
235  
236  
237  def boxperm0(nibble_swap_index, cipherblock, thingy):
238      tmp = cipherblock.get_byte((nibble_swap_index ^ 8) >> 1)
239      if not (nibble_swap_index & 1):
240          tmp = tmp.LShR(4)
241      tmp &= 0xf
242      thingy += tmp & 3 # 1f6f
243      #return concat((claripy.BVV(3,6), thingy[1:0]^tmp[3:2]))
244      return ((((tmp << 6) | (tmp.LShR(2))) + thingy) & 3) | 0xc; # 1f78 .. 1f7e
245  
246  def boxperm1(nibble_swap_index, cipherblock, thingy):
247      tmp = cipherblock.get_byte((nibble_swap_index ^ 2) >> 1)
248      if not (nibble_swap_index & 1):
249          tmp = tmp.LShR(4)       # 1fb3
250      tmp &= 0xf          # 1fb4
251  
252      thingy = ((thingy << 6) | (thingy.LShR(2))) & 0xff # 1fa2 .. 1fa6
253      thingy += tmp & 3                               # 1fb7, 1fbb
254      tmp = (((tmp << 6) | (tmp.LShR(2))) + thingy) & 3
255      return ((tmp.LShR(6) | (tmp << 2)) & 0xff) | 3   # 1fbd .. 1fc8
256  
257  def boxperm2(nibble_swap_index, cipherblock, thingy):
258      tmp = cipherblock.get_byte((nibble_swap_index ^ 4) >> 1) # 1fe9 .. 1ff1
259      if not (nibble_swap_index & 1):
260          tmp = tmp.LShR(4) # 1ff4
261      tmp &= 0xf          # 1ff5
262  
263      thingy &= 3         # 1fe3 .. 1fe7
264      thingy += tmp & 3   # 1ff8, 1fbb
265      #return concat((claripy.BVV(0,6), thingy[1:0]^tmp[3:2]))
266      return ((((tmp << 6) | (tmp.LShR(2))) + thingy) & 3) & 0xff # 1ffe .. 2005
267  
268  def boxperm3(nibble_swap_index, cipherblock, thingy):
269      tmp = cipherblock.get_byte((nibble_swap_index ^ 1) >> 1)  # .. 202e..2036
270      if nibble_swap_index & 1:
271          tmp = tmp.LShR(4) #2039
272      tmp &= 0xf          # 203a
273  
274      thingy = (thingy.LShR(2)) & 3  # 202c
275      thingy += tmp & 3           # 203d, 2041
276      tmp = ((((tmp << 6) | (tmp.LShR(2))) + thingy) & 3) # 2043 .. 2048
277      return ((tmp << 2) | (tmp.LShR(6))) & 0xff # 204a .. 204c
278  
279  def sbt(s, ctrlvar, state, lupb=False):
280      global prev
281      cur = time.time()
282      delta = datetime.timedelta(seconds=cur - prev)
283      prev = cur
284      print(f"[s] {delta} running core sbt")
285  
286      cipherblock = concat(state.get_byte(i) for i in range(8))
287      for i in range(8):
288          tmp = ((cipherblock.get_byte(0) ^ cipherblock.get_byte(4)).LShR(1)) | ((cipherblock.get_byte(1) ^ cipherblock.get_byte(5)) << 7)
289          cipherblock = concat((cipherblock.get_bytes(1,7), tmp))
290      # 1d3c
291      #assert_once(s.solution(cipherblock, 0x86820280c3c181c0, extra_constraints=[key == b'\x01'*15]))
292      state = cipherblock
293  
294      # step 2 fixed bit permutation
295      cipherblock = concat((
296          cipherblock.get_byte(by)[bi]
297          for by, bi in ((4,4),(1,5),(3,1),(7,6),(0,5),(6,0),(5,7),(2,3),
298                         (3,2),(0,7),(7,1),(1,0),(6,3),(4,5),(5,4),(2,0),
299                         (7,5),(3,3),(1,3),(5,1),(0,3),(2,4),(6,2),(4,1),
300                         (5,5),(3,0),(0,1),(4,3),(1,6),(6,7),(2,2),(7,3),
301                         (2,5),(6,4),(4,7),(0,6),(5,6),(7,7),(3,5),(1,2),
302                         (6,5),(4,2),(3,6),(5,2),(1,7),(2,6),(7,4),(0,2),
303                         (5,0),(7,2),(1,4),(3,4),(4,6),(6,1),(0,0),(2,1),
304                         (6,6),(0,4),(1,1),(7,0),(3,7),(2,7),(4,0),(5,3))
305      ))
306      #assert_once(s.solution(cipherblock, 0x164001242c09892a, extra_constraints=[key == b'\x01'*15]))
307  
308      ctrlvar_iter_ctrl = 0x9c
309      for round in range(8):
310          cur = time.time()
311          delta = datetime.timedelta(seconds=cur - prev)
312          prev = cur
313          print(f"[{round}] {delta} running sbt round")
314  
315          # see updctrlvar.py
316          if (ctrlvar_iter_ctrl & 0x80): # 5
317              ctrlvar = concat((ctrlvar[32:28], ctrlvar[55:33], ctrlvar[4:0], ctrlvar[27:5]))
318          else: #2
319              ctrlvar = concat((ctrlvar[29:28],ctrlvar[55:30],ctrlvar[1:0],ctrlvar[27:2]))
320  
321          #re = s.eval(ctrlvar,1, extra_constraints=[key == b'\x01'*15])[0]
322          #print(f"cv: {re:014x}")
323  
324          ctrlvar_iter_ctrl = ((ctrlvar_iter_ctrl << 1) | (ctrlvar_iter_ctrl >> 7)) & 0xff # 1e5c .. 1e5f
325  
326          #re = s.eval(state,1, extra_constraints=[key == b'\x01'*15])[0]
327          #print(f"state0: {re:016x}")
328  
329          nibble_swap_ctrl_words = concat((state.get_byte(3-i) for i in range(4)))
330  
331          #re = s.eval(nibble_swap_ctrl_words,1, extra_constraints=[key == b'\x01'*15])[0]
332          #print(f"nscw0: {re:08x}")
333  
334          state = concat((state.get_byte(7),state.get_bytes(0,7)))
335  
336          #re = s.eval(state,1, extra_constraints=[key == b'\x01'*15])[0]
337          #print(f"state1: {re:016x}")
338  
339          nibble_swap_ctrl_words = concat((
340              (nibble_swap_ctrl_words.get_byte(0) ^ concat((ctrlvar.get_byte(by)[bi] for by,bi in
341                                                          ((1,5), (4,1), (1,2), (5,6), (2,7), (5,3), (2,4), (5,0))))),
342              (nibble_swap_ctrl_words.get_byte(1) ^ concat((ctrlvar.get_byte(by)[bi] for by,bi in
343                                                          ((2,1), (6,5), (3,6), (6,2), (0,7), (3,3), (0,4), (3,0))))),
344              (nibble_swap_ctrl_words.get_byte(2) ^ concat((ctrlvar.get_byte(by)[bi] for by,bi in
345                                                          ((0,1), (4,5), (1,6), (4,2), (1,3), (5,7), (1,0), (5,4))))),
346              (nibble_swap_ctrl_words.get_byte(3) ^ concat((ctrlvar.get_byte(by)[bi] for by,bi in
347                                                          ((2,5), (5,1), (2,2), (6,6), (3,7), (6,3), (3,4), (6,0))))) # 1f2c
348          ))
349          #re = s.eval(nibble_swap_ctrl_words,1, extra_constraints=[key == b'\x01'*15])[0]
350          #print(f"nscw: {re:08x}")
351          #assert_once(s.solution(nibble_swap_ctrl_words, 0x840a8a82, extra_constraints=[key == b'\x01'*15]))
352  
353          nibble_swap_index = 0
354          for i in range(4):
355              nibble_swap_ctrl = nibble_swap_ctrl_words.get_byte(i)
356              while(True):
357                  #re = s.eval(nibble_swap_ctrl,1, extra_constraints=[key == b'\x01'*15])[0]
358                  #print(f"nibble_swap_ctrl: {re:02x} {re>>6}")
359                  tmp = cipherblock.get_byte(nibble_swap_index >> 1)
360                  #re = s.eval(tmp,1, extra_constraints=[key == b'\x01'*15])[0]
361                  #print(f"nsi: {0x2d+(nibble_swap_index >> 1):02x} {re:02x} ")
362                  if not (nibble_swap_index & 1):
363                      tmp = tmp.LShR(4)
364                  tmp &= 0xf
365                  thingy = tmp
366  
367                  #re = s.eval(thingy,1, extra_constraints=[key == b'\x01'*15])[0]
368                  #print(f"thingy: {re:02x}")
369  
370                  #re = s.eval(tmp,1, extra_constraints=[key == b'\x01'*15])[0]
371                  #print(f"tmp: {re:02x}")
372  
373                  tmp1 = nibble_swap_ctrl.LShR(6)
374  
375                  #tmp1map = {8241}[tmp1]
376                  #tmpindex = concat(( (tmp1[0]^1) & (tmp1[1]^1),
377                  #                    (tmp1[0]^1) &  tmp1[1],
378                  #                     tmp1[0]    & (tmp1[1]^1),
379                  #                     tmp1[0]    &  tmp1[1]))
380                  #tmpindex = BVV(0,4).concat((nibble_swap_index ^ tmpindex) >> 1)
381                  #tmpindex=((64 + 7) // 8 - 1 - tmpindex)*8+7
382                  #re = s.eval(tmp1,1, extra_constraints=[key == b'\x01'*15])[0]
383                  #re1 = s.eval(tmpindex,1, extra_constraints=[key == b'\x01'*15])[0]
384                  #print(f"tmp1: {re:02x} -> {re1:02x}")
385                  #xtmp = cipherblock.Extract(tmpindex+7, tmpindex, cipherblock)
386                  #xtmp = cipherblock.LShR(tmpindex)
387                  if lupb:
388                     xtmp = If(tmp1 == 0,
389                               cipherblock.get_byte((nibble_swap_index ^ 8) >> 1).LShR(4 * (1 ^ nibble_swap_index&1)),
390                            If(tmp1 == 1,
391                               cipherblock.get_byte((nibble_swap_index ^ 2) >> 1).LShR(4 * (1 ^ nibble_swap_index&1)),
392                            If(tmp1 == 2,
393                               cipherblock.get_byte((nibble_swap_index ^ 4) >> 1).LShR(4 * (1 ^ nibble_swap_index&1)),
394                               cipherblock.get_byte((nibble_swap_index ^ 1) >> 1).LShR(4 * (nibble_swap_index&1)))))
395  
396                     #ri = s.eval(tmp1[1:0],1, extra_constraints=[key == b'\x01'*15])[0]
397                     #rtmp = s.eval(xtmp[3:0],1, extra_constraints=[key == b'\x01'*15])[0]
398                     #rthingy = s.eval(thingy[3:0],1, extra_constraints=[key == b'\x01'*15])[0]
399                     #print(f"{ri:02x} {rtmp:04b} {rthingy:04b}")
400                     thingy = BVV(0,4).concat(concat(map_boxperm(bp_anf, concat((tmp1[1:0], thingy[3:0], xtmp[3:0])))))
401                  else:
402                     thingy = If((tmp1) == 0,
403                                 If((tmp&0xc) == 0,
404                                    boxperm0(nibble_swap_index, cipherblock, thingy),
405                                    (thingy - 4) & 0xff),
406                              If((tmp1) == 1,
407                                 If((tmp&0x3) == 0,
408                                    boxperm1(nibble_swap_index, cipherblock, thingy),
409                                    (thingy - 1) & 0xff),
410                              If((tmp1) == 2,
411                                 If((tmp&0xc) == 0xc,
412                                    boxperm2(nibble_swap_index, cipherblock, thingy),
413                                    (thingy+4) & 0xff),
414                                 If((tmp&0x3) == 3,
415                                    boxperm3(nibble_swap_index, cipherblock, thingy),
416                                    (thingy+1) & 0xff)
417                                 )))
418  
419                  #re = s.eval(thingy,1, extra_constraints=[key == b'\x01'*15])[0]
420                  #print(f"thingy: {re:02x}")
421  
422                  r1 = nibble_swap_index >> 1
423                  tmp = cipherblock.get_byte(r1) # 2058
424                  if not (nibble_swap_index & 1):
425                    tmp &= 0xf
426                    cipherblock = set_byte(cipherblock, r1, tmp)  # 205d
427                    tmp = thingy
428                    tmp = ((tmp << 4) | (tmp.LShR(4))) & 0xff   # 2060
429                  else:
430                    tmp &= 0xf0       # 2063 ..
431                    cipherblock = set_byte(cipherblock, r1, tmp)
432                    tmp = thingy      # 2066
433  
434                  cipherblock = set_byte(cipherblock, r1, cipherblock.get_byte(r1) | tmp)
435  
436                  #re = s.eval(cipherblock,1, extra_constraints=[key == b'\x01'*15])[0]
437                  #print(f"cb: {re:016x}")
438                  nibble_swap_ctrl = ((nibble_swap_ctrl << 2) | (nibble_swap_ctrl.LShR(6))) & 0xff # 2072 .. 2076
439                  nibble_swap_index+=1
440                  if((nibble_swap_index & 3) == 0): break
441          #print("eobp ", end='')
442          #re = s.eval(cipherblock,1, extra_constraints=[key == b'\x01'*15])[0]
443          #print(f"cb: {re:016x}")
444          #assert_once(s.solution(cipherblock, 0x52ffec68684dc5de, extra_constraints=[key == b'\x01'*15]))
445          # step 6 fix byte permutation
446          cipherblock = concat((
447              cipherblock.get_byte(3),
448              cipherblock.get_byte(5),
449              cipherblock.get_byte(1),
450              cipherblock.get_byte(4),
451              cipherblock.get_byte(6),
452              cipherblock.get_byte(0),
453              cipherblock.get_byte(7),
454              cipherblock.get_byte(2)))
455  
456          #print("fbp ", end='')
457          #re = s.eval(cipherblock,1, extra_constraints=[key == b'\x01'*15])[0]
458          #print(f"cb: {re:016x}")
459  
460          #step 8 nibble switch
461          tmp = concat((ctrlvar.get_byte(by)[bi] for by, bi in ((3, 5), (6, 1), (0, 6), (3, 2), (0, 3), (4, 7), (0, 0), (4, 4))))
462          cf = ctrlvar.get_byte(6)[1]
463          for i in range(8):
464              tbit = tmp[0];
465              tmp = concat((cf, tmp[7:1]))
466              cf = tbit
467              tmp1 = cipherblock.get_byte(i)
468              cipherblock = set_byte(cipherblock,i,tmp)
469              tmp = tmp1
470              tmp = If(cf == 1,
471                       ((tmp & 0xf) << 4) | ((tmp & 0xf0).LShR(4)),
472                       tmp)
473              tmp1 = cipherblock.get_byte(i)
474              cipherblock = set_byte(cipherblock,i,tmp)
475              tmp = tmp1
476  
477          #re = s.eval(cipherblock,1, extra_constraints=[key == b'\x01'*15])[0]
478          #print(f"ns cb: {re:016x}")
479  
480          # step 10 SBOXes
481          for i in range(8):
482              tmp = cipherblock.get_byte(i)
483              cipherblock = set_byte(cipherblock,i,concat((map4x4(sbox_anfs,i,"hi",tmp[7:4]) + map4x4(sbox_anfs,i,"lo",tmp[3:0]))))
484  
485          #assert_once(s.solution(cipherblock, 0xed770b7518dda32f, extra_constraints=[key == b'\x01'*15]))
486          #if round == 3:
487          #    assert_once(s.solution(cipherblock, 0x50c7f3128eac6912, extra_constraints=[key == b'\x01'*15]))
488  
489          #re = s.eval(cipherblock,1, extra_constraints=[key == b'\x01'*15])[0]
490          #print(f"rnd cb: {re:016x}")
491  
492      #re = s.eval(cipherblock,1, extra_constraints=[key == b'\x01'*15])[0]
493      #print(f"cb: {re:016x}")
494  
495      return ctrlvar, cipherblock, state
496  
497  def unscramble(ct):
498     CY = 0
499     output = []
500     buf = [0,0,0]
501     for grp in ct[:-1].split(b' ')[1:]:
502       r0 = 0
503       for tmp in grp:
504          tmp-=1
505          tmp = ((tmp >> 4) | (tmp << 4)) & 0xff
506          CY = 1
507          for j in range(4):
508            # rlc a
509            tbit = CY
510            CY = tmp & 0x80
511            tmp = ((tmp << 1) & 0xff) |  tbit
512            # xchg a,intmemabc[r0]
513            tmp1 = tmp
514            tmp = buf[r0]
515            buf[r0] = tmp1
516            # rrc a
517            tbit = CY
518            CY = tmp & 1
519            tmp = tbit | (tmp >> 1)
520            # xchg a,intmemabc[r0]
521            tmp1 = tmp
522            tmp = buf[r0]
523            buf[r0] = tmp1
524            r0 +=1
525            CY = 1 if r0<3 else 0
526            if(r0==3): r0=0
527       for r0 in range(3):
528           tmp = buf[r0]
529           CY = 1 if r0<2 else 0
530           if(r0==2):
531             # rrc a
532             tbit = CY
533             CY = tmp & 1
534             tmp = (tbit << 7) | (tmp >> 1)
535           # rrc a
536           tbit = CY
537           CY = tmp & 1
538           tmp = ((tbit << 7) | (tmp >> 1)) & 0x3f
539           output.append(tmp)
540     return output
541  
542  prev = start = time.time()
543  print("[S] 0:00:00.000000 calculating algebraic normal forms for SBOXes...")
544  sbox_anfs = get_anfs()
545  #for k, v in sbox_anfs.items():
546  #    print(k, v)
547  
548  cur = time.time()
549  delta = datetime.timedelta(seconds=cur - prev)
550  prev = cur
551  print(f"[A] {delta} verifying correctness of ANF equations (comparing solutions of the equations with the lookup from the sbox table)...")
552  for p, m in enumerate(get_sboxes(SBOX)):
553      for nib in ('hi', 'lo'):
554          for x in range(16):
555              r = tuple(int(c) for c in (f'{m[nib][x]:04b}'))
556              x = tuple(int(c) for c in reversed(f'{x:04b}'))
557              assert map4x4(sbox_anfs, p, nib, x) == r
558  
559  cur = time.time()
560  delta = datetime.timedelta(seconds=cur - prev)
561  prev = cur
562  print(f"[B] {delta} calculating algebraic normal forms for boxperm...")
563  bpsbox = [boxperm(i,k,tmp) for i in range(4) for k in range(16) for tmp in range(16)]
564  bp_anf = get_boxperm_anfs()
565  #for k, v in sbox_anfs.items():
566  #    print(k, v)
567  
568  cur = time.time()
569  delta = datetime.timedelta(seconds=cur - prev)
570  prev = cur
571  print(f"[C] {delta} verifying correctness of ANF equations (comparing solutions of the equations with the lookup from the sbox table)...")
572  for x in range(1024):
573      r = tuple(int(c) for c in (f'{bpsbox[x]:04b}'))
574      x = tuple(int(c) for c in reversed(f'{x:010b}'))
575      assert map_boxperm(bp_anf, x) == r
576  
577  cur = time.time()
578  delta = datetime.timedelta(seconds=cur - prev)
579  prev = cur
580  print(f"[M] {delta} building model (lookup/ifs: {lupb})")
581  
582  s=Solver(timeout=17280000000, max_memory=110 * 1024**3)
583  
584  # key schedule from password key
585  ################################
586  
587  # input_key is symbolic, and 1st and only param to the tgt fn
588  key = BVS("k",15*8, explicit_name="k")
589  
590  #for b in key.chop(8):
591  #    s.add(Or(And(b <= 26, b>0),
592  #             b == b' ',
593  #             b == b'.',
594  #             b == b'-',
595  #             b == b',',
596  #             And(b'9' >= b, b >= b'0')))
597  
598  # parametric_entry(nonce,key)
599  # sbt_init(nonce, key)
600  #tmp = [ciphertext.get_byte(i) ^ key.get_byte(i) if i<3 else key.get_byte(i) for i in range(15)]
601  #cipherblock = concat(tmp[:8])
602  #ctrlvar = concat(tmp[8:])
603  #state = cipherblock
604  
605  #assert(s.solution(cipherblock, 0x0d04040101010101, extra_constraints=[key == b'\x01'*15]))
606  
607  # sbt
608  #ctrlvar, cipherblock, state = sbt(s, ctrlvar, state)
609  #ctrlvar = concat(cipherblock.get_byte(i) for i in range(7))
610  
611  # keystream generation
612  ######################
613  
614  #ciphertext = "BAMLK GAFFJ EOBCP EEIMP CDAFP\xfe"
615  #testvec_cv0 = 0x19bea3035951f5
616  #testvec_cbs = [0xa074c5c8ec1a477e, 0xf4e8703e0308a769]
617  ciphertext="FBFDN MLBPE IJCFF DKDJB LFNNP DMCOM LBMKM GGOKL IJHCO JOEDE LFCCK PIHDP MMHCB DJGNE CGOCL NCAAB JJAIC GPJCH GHBNP FDKNA MGJHP NICPF\xfe"
618  testvec_cv0 = 0x06157a7cbc4555
619  testvec_cbs = [0xac18ad815975afd9, 0x69e9fb2c9994938f, 0x2c7c4f499eda0eab, 0xa022bfd6af9bcbed,
620                0x2c07677a2e16a042, 0x63b79ece15f3abed, 0x64d45fc6c5998530, 0x3557dcce84846f57]
621  
622  print(f"[i] ciphertext: {ciphertext[:-1]}")
623  ciphertext = bytes([(ord(c) & 0x3f) if c!='\xfe' else 0xfe
624                          for i, c in enumerate(ciphertext)])
625  nonce = ciphertext[:5]
626  ciphertext = unscramble(ciphertext)
627  
628  ctrlvar0 = BVS('cv', 7*8, explicit_name='cv')
629  state    = BVV(bytes([0xf5, 0xc0, 0x7a, 0x10, 0x8a, 0xaf, 0x17, 0xcf]))
630  # sbt again
631  ctrlvar, cipherblock, state = sbt(s, ctrlvar0, state)
632  #assert s.solution(cipherblock, 0xa074c5c8ec1a477e, extra_constraints=[ctrlvar0 == testvec_cv0])
633  state = state.get_bytes(0,5).concat(nonce[:3])
634  # sbt_init, parametric_entry end.
635  
636  cipherblocks=[cipherblock]
637  for i in range(1, len(ciphertext) // 8 + 1):
638      ctrlvar, cipherblock, state = sbt(s,ctrlvar,state)
639      #re = s.eval(cipherblock,1, extra_constraints=[ctrlvar0 == testvec_cv0])[0]
640      #assert f"{re:016x}" == testvec_cb[i]
641      #print(f"[?] block {i} ok")
642      cipherblocks.append(cipherblock)
643  
644  cur = time.time()
645  delta = datetime.timedelta(seconds=cur - prev)
646  prev = cur
647  print(f"[?] {delta} verifing testvector")
648  #re = s.eval(cipherblock,1, extra_constraints=[ctrlvar0 == 0x19bea3035951f5])[0]
649  #assert f"{re:016x}" == 'a074c5c8ec1a477e'
650  #re = s.eval(cipherblock1,1, extra_constraints=[ctrlvar0 == 0x19bea3035951f5])[0]
651  #print(f"cb1: {re:016x}")
652  #assert f"{re:016x}" == 'f4e8703e0308a769'
653  # the following times out? the previous is almost instant
654  for j, cb in enumerate(testvec_cbs):
655     cb = cb.to_bytes(8,byteorder="big")
656     for i in range( min(8, len(ciphertext)-8*j) ):
657         p = (cb[i] ^ ciphertext[i+j*8]) & 0x3f
658         print(f"{chr( p|(0x40 if p <= 26 else 0))}", end='')
659  print()
660  
661  cur = time.time()
662  delta = datetime.timedelta(seconds=cur - prev)
663  prev = cur
664  print(f"[c] {delta} constraining to ciphertext")
665  
666  plaintext_blocks = []
667  for j, cb in enumerate(cipherblocks):
668      pt_block = concat(( (cb.get_byte(i) ^ ciphertext[i+j*8]) & 0x3f for i in range( min(8, len(ciphertext)-8*j) )))
669      #print(f"len(pt[{j}]) = {len(pt_block)}")
670      for b in pt_block.chop(8):
671          s.add(b!=0)
672          if j == len(cipherblocks) -1:
673              s.add(Or(b <= 26,
674                       b == b' ',
675                       b == b'.',
676                       b == b'-',
677                       b == b',',
678                       b == 0x1f,
679                       And(b'9' >= b, b >= b'0')))
680          else:
681              s.add(Or(b <= 26,
682                       b == b' ',
683                       b == b'.',
684                       b == b'-',
685                       b == b',',
686                       And(b'9' >= b, b >= b'0')))
687      plaintext_blocks.append(pt_block)
688  
689  #print(s.constraints)
690  #cur = time.time()
691  #delta = datetime.timedelta(seconds=cur - prev)
692  #prev = cur
693  #print(f"[?] {delta} testing decryption")
694  
695  #re = s.eval(plaintext_blocks[0], 1, extra_constraints=[ctrlvar0 == testvec_cv0, cipherblocks[0] == testvec_cbs[0]])[0]
696  #print(f"{re:016x}")
697  #print(bytes(c|0x40 if c<27 else c for c in re.to_bytes(8, byteorder='big')))
698  
699  if threadid is not None:
700      print(f"[t]                constraining ctrlvar to threadid: {threadid}")
701      s.add(ctrlvar0[2:1] == threadid)
702  
703  cur = time.time()
704  delta = datetime.timedelta(seconds=cur - prev)
705  prev = cur
706  print(f"[S] {delta} solving for ctrlvar")
707  
708  i=0
709  while s.satisfiable():
710     try:
711        sol = s.eval(ctrlvar0, 1)
712     except claripy.errors.ClaripyZ3Error:
713        delta = datetime.timedelta(seconds=time.time() - prev)
714        print(f"[.] {delta} z3 exception total running time: {str(datetime.timedelta(seconds=cur - start))}")
715        raise
716     cur = time.time()
717     delta = datetime.timedelta(seconds=cur - prev)
718     prev = cur
719     print(f'[!] {delta} #{i} we have a solution for the ctrlvar: {sol[0]:014x}')
720     pt=[]
721     for j, pt_block in enumerate(plaintext_blocks):
722         block = s.eval(pt_block, 1, extra_constraints=[ctrlvar0 == sol[0]])[0]
723         cur = time.time()
724         delta = datetime.timedelta(seconds=cur - prev)
725         prev = cur
726         print(f"{delta} solved block {j} {block:016x}")
727         for c in block.to_bytes(min(8, len(ciphertext)-8*j), byteorder='big'):
728             pt.append(c|0x40 if c<27 else c)
729     pt=bytes(pt)
730  
731     cur = time.time()
732     delta = datetime.timedelta(seconds=cur - prev)
733     prev = cur
734     print(f"                   plaintext is:", repr(pt))
735     s.add(ctrlvar0!=sol[0])
736     i+=1
737  
738  cur = time.time()
739  delta = datetime.timedelta(seconds=cur - prev)
740  print(f"[.] {delta} total running time: {str(datetime.timedelta(seconds=cur - start))}")