/ attacks / attack.py
attack.py
  1  #!/usr/bin/env python
  2  
  3  import angr, claripy, IPython, sys, time, datetime, logging
  4  from claripy import BVS, BVV, Solver, Or, RotateLeft, RotateRight, And, If
  5  from angr.block import CapstoneInsn
  6  #from angr.procedures.libc.memset import memset
  7  #from angr.procedures.libc.memcpy import memcpy
  8  from binascii import unhexlify
  9  from inspect import getframeinfo, stack
 10  
 11  l = logging.getLogger("angr")
 12  
 13  # silence some annoying logs
 14  l.setLevel("ERROR")
 15  
 16  from itertools import zip_longest
 17  def split_by_n(iterable, n):
 18      return zip_longest(*[iter(iterable)]*n, fillvalue='')
 19  
 20  SBOX = [0xF4, 0xAF, 0x8A, 0xD1, 0x3B, 0x02, 0xE8, 0x20,   0xCD, 0x65, 0x96, 0x1C, 0x47, 0xB3, 0x79, 0x5E,
 21          0x18, 0x8B, 0xE3, 0xAE, 0x7D, 0x4A, 0x94, 0xDF,   0x69, 0x30, 0xBC, 0x56, 0xF5, 0x07, 0x21, 0xC2,
 22          0xBD, 0x72, 0x9C, 0x59, 0xAE, 0x17, 0xF3, 0x61,   0x24, 0xC8, 0x40, 0xDF, 0xE6, 0x8A, 0x35, 0x0B,
 23          0x27, 0x4D, 0x56, 0xC8, 0x91, 0xB3, 0x70, 0x84,   0xF5, 0xEF, 0xD2, 0xAE, 0x3A, 0x1C, 0x0B, 0x69,
 24          0x47, 0x9F, 0x80, 0x5C, 0x0A, 0x68, 0xA1, 0xEB,   0xB9, 0x2D, 0x75, 0xF3, 0x1E, 0x32, 0xD6, 0xC4,
 25          0xB3, 0xAE, 0xED, 0x09, 0x91, 0xD4, 0x38, 0x26,   0x6A, 0xC0, 0xFB, 0x75, 0x82, 0x5F, 0x4C, 0x17,
 26          0x59, 0x27, 0x16, 0x4D, 0xDB, 0xEF, 0x04, 0x9C,   0xF0, 0xB8, 0x62, 0xCE, 0x3A, 0xA1, 0x73, 0x85,
 27          0x18, 0x5D, 0x47, 0x6E, 0xC5, 0xA0, 0x9B, 0xFA,   0x32, 0xE3, 0x8C, 0x01, 0xDF, 0x74, 0x29, 0xB6]
 28  
 29  lupb = False
 30  
 31  threadid=None
 32  if len(sys.argv) == 2:
 33      if 0<= int(sys.argv[1])< 4:
 34          threadid=int(sys.argv[1])
 35      else:
 36          print("thread id must be between 0 & 3")
 37          sys.exit(1)
 38  
 39  asserted = set()
 40  def assert_once(v):
 41      caller = getframeinfo(stack()[1][0])
 42      if caller.lineno in asserted: return
 43      asserted.add(caller.lineno)
 44      assert v
 45  
 46  def concat(l):
 47      l = list(l)
 48      return l[0].concat(*l[1:])
 49  
 50  def set_byte(array, pos, val):
 51      if pos < ((len(array) // 8) - 1):
 52          return concat((array.get_bytes(0,pos), val, array.get_bytes(pos+1,7-pos)))
 53      else:
 54          return concat((array.get_bytes(0,pos), val))
 55  
 56  def get_sboxes(sbox):
 57      return [{'hi': tuple((x>>4) for x in t),
 58              'lo': tuple((x&0xf) for x in t)}
 59              for t in split_by_n(sbox,16)]
 60  
 61  def moebius(f,n):
 62      blocksize=1
 63      for step in range(1,n+1):
 64          source=0
 65          while(source < (1<<n)):
 66              target = source + blocksize
 67              for i in range(blocksize):
 68                  f[target+i]^=f[source+i]
 69              source+=2*blocksize
 70          blocksize*=2
 71  
 72  def get_anfs():
 73     mappings = get_sboxes(SBOX)
 74     res = {}
 75  
 76     fs=[{'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 77         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 78         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 79         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 80         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 81         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 82         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 83         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]}]
 84  
 85     ms=[{'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 86         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 87         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 88         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 89         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 90         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 91         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]},
 92         {'hi': [[0]*16,[0]*16,[0]*16,[0]*16], 'lo': [[0]*16,[0]*16,[0]*16,[0]*16]}]
 93  
 94     for p, m in enumerate(mappings):
 95        #print("position", p)
 96        for nib in ('hi','lo'):
 97           #print("nibble", nib)
 98           for bit in range(4):
 99               #print("bit", bit)
100               ones=0
101               for i in range(16):
102                   b = (m[nib][i] >> bit) & 1
103                   fs[p][nib][bit][i]=b
104                   ms[p][nib][bit][i]=b
105                   ones+=b
106               #print(''.join(f"{b}" for b in fs[p][nib][bit]), ones)
107               moebius(ms[p][nib][bit],4)
108               #print(''.join(f"{b}" for b in ms[p][nib][bit]), ms[p][nib][bit].count(1))
109               #print()
110  
111     # these have odd number of 1 constant terms in their polinomial
112     odd_const_terms = [(0, 'hi', 0), (0, 'hi', 1), (0, 'hi', 2), (0, 'hi', 3), (0, 'lo', 2),
113                        (1, 'hi', 0), (1, 'lo', 3), (2, 'hi', 0), (2, 'hi', 1), (2, 'hi', 3),
114                        (2, 'lo', 0), (2, 'lo', 2), (2, 'lo', 2), (2, 'lo', 3), (3, 'hi', 1),
115                        (3, 'lo', 0), (3, 'lo', 1), (3, 'lo', 2), (4, 'hi', 2), (4, 'lo', 0),
116                        (4, 'lo', 1), (4, 'lo', 2), (5, 'hi', 0), (5, 'hi', 1), (5, 'hi', 3),
117                        (5, 'lo', 0), (5, 'lo', 1), (6, 'hi', 0), (6, 'hi', 2), (6, 'lo', 0),
118                        (6, 'lo', 3), (7, 'hi', 0), (7, 'lo', 3)]
119     for p in range(8):
120        for nib in ('hi','lo'):
121           for j, (f, m) in enumerate(zip(fs[p][nib],ms[p][nib])):
122               f_ = ' ^ '.join(c for c in ['&'.join(f"x{i}" for i, x in enumerate(reversed(f'{a:04b}')) if x =='1') for a in range(16) if m[a]==1] if c)
123               # these have odd number of 1 constant terms in their polinomial
124               if (p,nib,j) in odd_const_terms:
125                   f_ = '1 ^ ' + f_
126               #print((p,nib,j), f_)
127               evaled =''.join(str(eval(f_, {f"x{i}":int(x) for i, x in enumerate(reversed(f'{a:04b}'))})) for a in range(16))
128               #print('joined', ''.join(f"{b}" for b in f))
129               #print('evaled', evaled)
130               assert evaled==''.join(f"{b}" for b in f)
131  
132           f_ = []
133           for j, (f, m) in enumerate(zip(fs[p][nib],ms[p][nib])):
134               f_.append(' ^ '.join('('+c+')' for c in ['&'.join(f"x[{i}]" for i, x in enumerate(reversed(f'{a:04b}')) if x =='1') for a in range(16) if m[a]==1] if c))
135               if (p,nib,j) in odd_const_terms:
136                   f_[-1] = '1 ^ ' + f_[-1]
137           #print((p,nib), ', '.join(reversed(f_)))
138           res[(p,nib)] = ', '.join(reversed(f_))
139     return res
140  
141  def map4x4(anfs, pos, nib, x):
142      #return eval(f"concat(({anfs[(pos,nib)]}))", {"x": x, "concat": lambda l: list(l)[0].concat(*l[1:])})
143      return eval(anfs[(pos,nib)], {"x": x})
144  
145  def boxperm(i, k, tmp):
146     if(i==0):
147        if (k&0xc != 0):
148            return (k-4)&0xf
149        k1 = (k + (tmp & 3)) % 0x100  # 1f6f
150        return (((((tmp << 6) | (tmp >> 2)) + k1) & 3) | 0xc) % 0x100
151     elif(i==1):
152        if (k&3!=0):
153            return (k-1)&0xf
154        k1 = ((k << 6) | (k >> 2)) & 0xff # 1fa2 .. 1fa6
155        k1 += tmp & 3                               # 1fb7, 1fbb
156        tmp1 = ((((tmp << 6) | (tmp >> 2)) & 0xff) + k1) & 3
157        return (((tmp1 >> 6 | (tmp1 << 2)) & 0xff) | 3) & 0xff    # 1fbd .. 1fc8
158     elif(i==2):
159        if (k&0xc != 0xc):
160            return (k+4)&0xf
161        k1 = k & 3         # 1fe3 .. 1fe7
162        k1 = (k1 + (tmp & 3)) % 0x100  # 1ff8, 1fbb
163        return ((((tmp << 6) | (tmp >> 2)) % 0x100) + k1) & 3 # 1ffe .. 2005
164     else:
165        if (k&0x3 != 3):
166            return (k+1)&0xf
167        k1 = (k >> 2) & 3  # 202c
168        k1 = (k1 + (tmp & 3)) % 0x100  # 203d, 2041
169        tmp1 = ((((tmp << 6) | (tmp >> 2)) % 0x100) + k1) & 3  # 2043 .. 2048
170        return ((tmp1 << 2) | (tmp1 >> 6)) % 0x100 # 204a .. 204c
171  
172  def get_boxperm_anfs():
173     res = {}
174  
175     fs=[[0]*1024,[0]*1024,[0]*1024,[0]*1024]
176     ms=[[0]*1024,[0]*1024,[0]*1024,[0]*1024]
177  
178     for bit in range(4):
179         #print("bit", bit)
180         ones=0
181         for i in range(1024):
182             b = (bpsbox[i] >> bit) & 1
183             fs[bit][i]=b
184             ms[bit][i]=b
185             ones+=b
186         #print(''.join(f"{b}" for b in fs[bit]), ones)
187         moebius(ms[bit],10)
188         #print(''.join(f"{b}" for b in ms[bit]), ms[bit].count(1))
189         #print()
190  
191     # these have odd number of 1 constant terms in their polinomial
192     odd_const_terms = [2,3]
193     for j, (f, m) in enumerate(zip(fs,ms)):
194         f_ = ' ^ '.join(c for c in ['&'.join(f"x{i}" for i, x in enumerate(reversed(f'{a:010b}')) if x =='1') for a in range(1024) if m[a]==1] if c)
195         # these have odd number of 1 constant terms in their polinomial
196         if j in odd_const_terms:
197             f_ = '1 ^ ' + f_
198         #print(j, f_)
199         evaled =''.join(str(eval(f_, {f"x{i}":int(x) for i, x in enumerate(reversed(f'{a:010b}'))})) for a in range(1024))
200         #print('joined', ''.join(f"{b}" for b in f))
201         #print('evaled', evaled)
202         assert evaled==''.join(f"{b}" for b in f)
203  
204     f_ = []
205     for j, (f, m) in enumerate(zip(fs,ms)):
206         f_.append(' ^ '.join('('+c+')' for c in ['&'.join(f"x[{i}]" for i, x in enumerate(reversed(f'{a:010b}')) if x =='1') for a in range(1024) if m[a]==1] if c))
207         if j in odd_const_terms:
208             f_[-1] = '1 ^ ' + f_[-1]
209     #print((p,nib), ', '.join(reversed(f_)))
210     return ', '.join(reversed(f_))
211  
212  def map_boxperm(anf, x):
213      return eval(anf, {"x": x})
214  
215  def getFuncAddress(cfg, funcName, plt=None ):
216      found = [
217          addr for addr,func in cfg.kb.functions.items()
218          if funcName == func.name and (plt is None or func.is_plt == plt)
219      ]
220      if len( found ) > 0:
221          l.info("Found "+funcName+"'s address at "+hex(found[0])+"!")
222          return found[0]
223      else:
224          raise Exception("No address found for function : "+funcName)
225  
226  def getRetAddr(proj, fn):
227      # let's disasm with capstone to search targets
228      insn_bytes = proj.loader.memory.load(fn, 1000)
229      for cs_insn in proj.arch.capstone.disasm(insn_bytes, fn):
230          ins = CapstoneInsn(cs_insn)
231          if ins.mnemonic == "ret":
232              l.info(f"Found lfsr's return address at 0x{ins.address:x}!")
233              return ins.address
234      raise ValueError("failed to find ret op in {fn}")
235  
236  
237  def boxperm0(nibble_swap_index, cipherblock, thingy):
238      tmp = cipherblock.get_byte((nibble_swap_index ^ 8) >> 1)
239      if not (nibble_swap_index & 1):
240          tmp = tmp.LShR(4)
241      tmp &= 0xf
242      thingy += tmp & 3 # 1f6f
243      #return concat((claripy.BVV(3,6), thingy[1:0]^tmp[3:2]))
244      return ((((tmp << 6) | (tmp.LShR(2))) + thingy) & 3) | 0xc; # 1f78 .. 1f7e
245  
246  def boxperm1(nibble_swap_index, cipherblock, thingy):
247      tmp = cipherblock.get_byte((nibble_swap_index ^ 2) >> 1)
248      if not (nibble_swap_index & 1):
249          tmp = tmp.LShR(4)       # 1fb3
250      tmp &= 0xf          # 1fb4
251  
252      thingy = ((thingy << 6) | (thingy.LShR(2))) & 0xff # 1fa2 .. 1fa6
253      thingy += tmp & 3                               # 1fb7, 1fbb
254      tmp = (((tmp << 6) | (tmp.LShR(2))) + thingy) & 3
255      return ((tmp.LShR(6) | (tmp << 2)) & 0xff) | 3   # 1fbd .. 1fc8
256  
257  def boxperm2(nibble_swap_index, cipherblock, thingy):
258      tmp = cipherblock.get_byte((nibble_swap_index ^ 4) >> 1) # 1fe9 .. 1ff1
259      if not (nibble_swap_index & 1):
260          tmp = tmp.LShR(4) # 1ff4
261      tmp &= 0xf          # 1ff5
262  
263      thingy &= 3         # 1fe3 .. 1fe7
264      thingy += tmp & 3   # 1ff8, 1fbb
265      #return concat((claripy.BVV(0,6), thingy[1:0]^tmp[3:2]))
266      return ((((tmp << 6) | (tmp.LShR(2))) + thingy) & 3) & 0xff # 1ffe .. 2005
267  
268  def boxperm3(nibble_swap_index, cipherblock, thingy):
269      tmp = cipherblock.get_byte((nibble_swap_index ^ 1) >> 1)  # .. 202e..2036
270      if nibble_swap_index & 1:
271          tmp = tmp.LShR(4) #2039
272      tmp &= 0xf          # 203a
273  
274      thingy = (thingy.LShR(2)) & 3  # 202c
275      thingy += tmp & 3           # 203d, 2041
276      tmp = ((((tmp << 6) | (tmp.LShR(2))) + thingy) & 3) # 2043 .. 2048
277      return ((tmp << 2) | (tmp.LShR(6))) & 0xff # 204a .. 204c
278  
279  def sbt(s, ctrlvar, state, lupb=False):
280      global prev
281      cur = time.time()
282      delta = datetime.timedelta(seconds=cur - prev)
283      prev = cur
284      print(f"[s] {delta} running core sbt")
285  
286      cipherblock = concat(state.get_byte(i) for i in range(8))
287      for i in range(8):
288          tmp = ((cipherblock.get_byte(0) ^ cipherblock.get_byte(4)).LShR(1)) | ((cipherblock.get_byte(1) ^ cipherblock.get_byte(5)) << 7)
289          cipherblock = concat((cipherblock.get_bytes(1,7), tmp))
290      # 1d3c
291      #assert_once(s.solution(cipherblock, 0x86820280c3c181c0, extra_constraints=[key == b'\x01'*15]))
292      state = cipherblock
293  
294      # step 2 fixed bit permutation
295      cipherblock = concat((
296          cipherblock.get_byte(by)[bi]
297          for by, bi in ((4,4),(1,5),(3,1),(7,6),(0,5),(6,0),(5,7),(2,3),
298                         (3,2),(0,7),(7,1),(1,0),(6,3),(4,5),(5,4),(2,0),
299                         (7,5),(3,3),(1,3),(5,1),(0,3),(2,4),(6,2),(4,1),
300                         (5,5),(3,0),(0,1),(4,3),(1,6),(6,7),(2,2),(7,3),
301                         (2,5),(6,4),(4,7),(0,6),(5,6),(7,7),(3,5),(1,2),
302                         (6,5),(4,2),(3,6),(5,2),(1,7),(2,6),(7,4),(0,2),
303                         (5,0),(7,2),(1,4),(3,4),(4,6),(6,1),(0,0),(2,1),
304                         (6,6),(0,4),(1,1),(7,0),(3,7),(2,7),(4,0),(5,3))
305      ))
306      #assert_once(s.solution(cipherblock, 0x164001242c09892a, extra_constraints=[key == b'\x01'*15]))
307  
308      ctrlvar_iter_ctrl = 0x9c
309      for round in range(8):
310          cur = time.time()
311          delta = datetime.timedelta(seconds=cur - prev)
312          prev = cur
313          print(f"[{round}] {delta} running sbt round")
314  
315          # see updctrlvar.py
316          if (ctrlvar_iter_ctrl & 0x80): # 5
317              ctrlvar = concat((ctrlvar[32:28], ctrlvar[55:33], ctrlvar[4:0], ctrlvar[27:5]))
318          else: #2
319              ctrlvar = concat((ctrlvar[29:28],ctrlvar[55:30],ctrlvar[1:0],ctrlvar[27:2]))
320  
321          #re = s.eval(ctrlvar,1, extra_constraints=[key == b'\x01'*15])[0]
322          #print(f"cv: {re:014x}")
323  
324          ctrlvar_iter_ctrl = ((ctrlvar_iter_ctrl << 1) | (ctrlvar_iter_ctrl >> 7)) & 0xff # 1e5c .. 1e5f
325  
326          #re = s.eval(state,1, extra_constraints=[key == b'\x01'*15])[0]
327          #print(f"state0: {re:016x}")
328  
329          nibble_swap_ctrl_words = concat((state.get_byte(3-i) for i in range(4)))
330  
331          #re = s.eval(nibble_swap_ctrl_words,1, extra_constraints=[key == b'\x01'*15])[0]
332          #print(f"nscw0: {re:08x}")
333  
334          state = concat((state.get_byte(7),state.get_bytes(0,7)))
335  
336          #re = s.eval(state,1, extra_constraints=[key == b'\x01'*15])[0]
337          #print(f"state1: {re:016x}")
338  
339          nibble_swap_ctrl_words = concat((
340              (nibble_swap_ctrl_words.get_byte(0) ^ concat((ctrlvar.get_byte(by)[bi] for by,bi in
341                                                          ((1,5), (4,1), (1,2), (5,6), (2,7), (5,3), (2,4), (5,0))))),
342              (nibble_swap_ctrl_words.get_byte(1) ^ concat((ctrlvar.get_byte(by)[bi] for by,bi in
343                                                          ((2,1), (6,5), (3,6), (6,2), (0,7), (3,3), (0,4), (3,0))))),
344              (nibble_swap_ctrl_words.get_byte(2) ^ concat((ctrlvar.get_byte(by)[bi] for by,bi in
345                                                          ((0,1), (4,5), (1,6), (4,2), (1,3), (5,7), (1,0), (5,4))))),
346              (nibble_swap_ctrl_words.get_byte(3) ^ concat((ctrlvar.get_byte(by)[bi] for by,bi in
347                                                          ((2,5), (5,1), (2,2), (6,6), (3,7), (6,3), (3,4), (6,0))))) # 1f2c
348          ))
349          #re = s.eval(nibble_swap_ctrl_words,1, extra_constraints=[key == b'\x01'*15])[0]
350          #print(f"nscw: {re:08x}")
351          #assert_once(s.solution(nibble_swap_ctrl_words, 0x840a8a82, extra_constraints=[key == b'\x01'*15]))
352  
353          nibble_swap_index = 0
354          for i in range(4):
355              nibble_swap_ctrl = nibble_swap_ctrl_words.get_byte(i)
356              while(True):
357                  #re = s.eval(nibble_swap_ctrl,1, extra_constraints=[key == b'\x01'*15])[0]
358                  #print(f"nibble_swap_ctrl: {re:02x} {re>>6}")
359                  tmp = cipherblock.get_byte(nibble_swap_index >> 1)
360                  #re = s.eval(tmp,1, extra_constraints=[key == b'\x01'*15])[0]
361                  #print(f"nsi: {0x2d+(nibble_swap_index >> 1):02x} {re:02x} ")
362                  if not (nibble_swap_index & 1):
363                      tmp = tmp.LShR(4)
364                  tmp &= 0xf
365                  thingy = tmp
366  
367                  #re = s.eval(thingy,1, extra_constraints=[key == b'\x01'*15])[0]
368                  #print(f"thingy: {re:02x}")
369  
370                  #re = s.eval(tmp,1, extra_constraints=[key == b'\x01'*15])[0]
371                  #print(f"tmp: {re:02x}")
372  
373                  tmp1 = nibble_swap_ctrl.LShR(6)
374  
375                  #tmp1map = {8241}[tmp1]
376                  #tmpindex = concat(( (tmp1[0]^1) & (tmp1[1]^1),
377                  #                    (tmp1[0]^1) &  tmp1[1],
378                  #                     tmp1[0]    & (tmp1[1]^1),
379                  #                     tmp1[0]    &  tmp1[1]))
380                  #tmpindex = BVV(0,4).concat((nibble_swap_index ^ tmpindex) >> 1)
381                  #tmpindex=((64 + 7) // 8 - 1 - tmpindex)*8+7
382                  #re = s.eval(tmp1,1, extra_constraints=[key == b'\x01'*15])[0]
383                  #re1 = s.eval(tmpindex,1, extra_constraints=[key == b'\x01'*15])[0]
384                  #print(f"tmp1: {re:02x} -> {re1:02x}")
385                  #xtmp = cipherblock.Extract(tmpindex+7, tmpindex, cipherblock)
386                  #xtmp = cipherblock.LShR(tmpindex)
387                  if lupb:
388                     xtmp = If(tmp1 == 0,
389                               cipherblock.get_byte((nibble_swap_index ^ 8) >> 1).LShR(4 * (1 ^ nibble_swap_index&1)),
390                            If(tmp1 == 1,
391                               cipherblock.get_byte((nibble_swap_index ^ 2) >> 1).LShR(4 * (1 ^ nibble_swap_index&1)),
392                            If(tmp1 == 2,
393                               cipherblock.get_byte((nibble_swap_index ^ 4) >> 1).LShR(4 * (1 ^ nibble_swap_index&1)),
394                               cipherblock.get_byte((nibble_swap_index ^ 1) >> 1).LShR(4 * (nibble_swap_index&1)))))
395  
396                     #ri = s.eval(tmp1[1:0],1, extra_constraints=[key == b'\x01'*15])[0]
397                     #rtmp = s.eval(xtmp[3:0],1, extra_constraints=[key == b'\x01'*15])[0]
398                     #rthingy = s.eval(thingy[3:0],1, extra_constraints=[key == b'\x01'*15])[0]
399                     #print(f"{ri:02x} {rtmp:04b} {rthingy:04b}")
400                     thingy = BVV(0,4).concat(concat(map_boxperm(bp_anf, concat((tmp1[1:0], thingy[3:0], xtmp[3:0])))))
401                  else:
402                     thingy = If((tmp1) == 0,
403                                 If((tmp&0xc) == 0,
404                                    boxperm0(nibble_swap_index, cipherblock, thingy),
405                                    (thingy - 4) & 0xff),
406                              If((tmp1) == 1,
407                                 If((tmp&0x3) == 0,
408                                    boxperm1(nibble_swap_index, cipherblock, thingy),
409                                    (thingy - 1) & 0xff),
410                              If((tmp1) == 2,
411                                 If((tmp&0xc) == 0xc,
412                                    boxperm2(nibble_swap_index, cipherblock, thingy),
413                                    (thingy+4) & 0xff),
414                                 If((tmp&0x3) == 3,
415                                    boxperm3(nibble_swap_index, cipherblock, thingy),
416                                    (thingy+1) & 0xff)
417                                 )))
418  
419                  #re = s.eval(thingy,1, extra_constraints=[key == b'\x01'*15])[0]
420                  #print(f"thingy: {re:02x}")
421  
422                  r1 = nibble_swap_index >> 1
423                  tmp = cipherblock.get_byte(r1) # 2058
424                  if not (nibble_swap_index & 1):
425                    tmp &= 0xf
426                    cipherblock = set_byte(cipherblock, r1, tmp)  # 205d
427                    tmp = thingy
428                    tmp = ((tmp << 4) | (tmp.LShR(4))) & 0xff   # 2060
429                  else:
430                    tmp &= 0xf0       # 2063 ..
431                    cipherblock = set_byte(cipherblock, r1, tmp)
432                    tmp = thingy      # 2066
433  
434                  cipherblock = set_byte(cipherblock, r1, cipherblock.get_byte(r1) | tmp)
435  
436                  #re = s.eval(cipherblock,1, extra_constraints=[key == b'\x01'*15])[0]
437                  #print(f"cb: {re:016x}")
438                  nibble_swap_ctrl = ((nibble_swap_ctrl << 2) | (nibble_swap_ctrl.LShR(6))) & 0xff # 2072 .. 2076
439                  nibble_swap_index+=1
440                  if((nibble_swap_index & 3) == 0): break
441          #print("eobp ", end='')
442          #re = s.eval(cipherblock,1, extra_constraints=[key == b'\x01'*15])[0]
443          #print(f"cb: {re:016x}")
444          #assert_once(s.solution(cipherblock, 0x52ffec68684dc5de, extra_constraints=[key == b'\x01'*15]))
445          # step 6 fix byte permutation
446          cipherblock = concat((
447              cipherblock.get_byte(3),
448              cipherblock.get_byte(5),
449              cipherblock.get_byte(1),
450              cipherblock.get_byte(4),
451              cipherblock.get_byte(6),
452              cipherblock.get_byte(0),
453              cipherblock.get_byte(7),
454              cipherblock.get_byte(2)))
455  
456          #print("fbp ", end='')
457          #re = s.eval(cipherblock,1, extra_constraints=[key == b'\x01'*15])[0]
458          #print(f"cb: {re:016x}")
459  
460          #step 8 nibble switch
461          tmp = concat((ctrlvar.get_byte(by)[bi] for by, bi in ((3, 5), (6, 1), (0, 6), (3, 2), (0, 3), (4, 7), (0, 0), (4, 4))))
462          cf = ctrlvar.get_byte(6)[1]
463          for i in range(8):
464              tbit = tmp[0];
465              tmp = concat((cf, tmp[7:1]))
466              cf = tbit
467              tmp1 = cipherblock.get_byte(i)
468              cipherblock = set_byte(cipherblock,i,tmp)
469              tmp = tmp1
470              tmp = If(cf == 1,
471                       ((tmp & 0xf) << 4) | ((tmp & 0xf0).LShR(4)),
472                       tmp)
473              tmp1 = cipherblock.get_byte(i)
474              cipherblock = set_byte(cipherblock,i,tmp)
475              tmp = tmp1
476  
477          #re = s.eval(cipherblock,1, extra_constraints=[key == b'\x01'*15])[0]
478          #print(f"ns cb: {re:016x}")
479  
480          # step 10 SBOXes
481          for i in range(8):
482              tmp = cipherblock.get_byte(i)
483              cipherblock = set_byte(cipherblock,i,concat((map4x4(sbox_anfs,i,"hi",tmp[7:4]) + map4x4(sbox_anfs,i,"lo",tmp[3:0]))))
484  
485          #assert_once(s.solution(cipherblock, 0xed770b7518dda32f, extra_constraints=[key == b'\x01'*15]))
486          #if round == 3:
487          #    assert_once(s.solution(cipherblock, 0x50c7f3128eac6912, extra_constraints=[key == b'\x01'*15]))
488  
489          #re = s.eval(cipherblock,1, extra_constraints=[key == b'\x01'*15])[0]
490          #print(f"rnd cb: {re:016x}")
491  
492      #re = s.eval(cipherblock,1, extra_constraints=[key == b'\x01'*15])[0]
493      #print(f"cb: {re:016x}")
494  
495      return cipherblock, state
496  
497  def unscramble(ct):
498     CY = 0
499     output = []
500     buf = [0,0,0]
501     for grp in ct[:-1].split(b' ')[1:]:
502       r0 = 0
503       for tmp in grp:
504          tmp-=1
505          tmp = ((tmp >> 4) | (tmp << 4)) & 0xff
506          CY = 1
507          for j in range(4):
508            # rlc a
509            tbit = CY
510            CY = tmp & 0x80
511            tmp = ((tmp << 1) & 0xff) |  tbit
512            # xchg a,intmemabc[r0]
513            tmp1 = tmp
514            tmp = buf[r0]
515            buf[r0] = tmp1
516            # rrc a
517            tbit = CY
518            CY = tmp & 1
519            tmp = tbit | (tmp >> 1)
520            # xchg a,intmemabc[r0]
521            tmp1 = tmp
522            tmp = buf[r0]
523            buf[r0] = tmp1
524            r0 +=1
525            CY = 1 if r0<3 else 0
526            if(r0==3): r0=0
527       for r0 in range(3):
528           tmp = buf[r0]
529           CY = 1 if r0<2 else 0
530           if(r0==2):
531             # rrc a
532             tbit = CY
533             CY = tmp & 1
534             tmp = (tbit << 7) | (tmp >> 1)
535           # rrc a
536           tbit = CY
537           CY = tmp & 1
538           tmp = ((tbit << 7) | (tmp >> 1)) & 0x3f
539           output.append(tmp)
540     return output
541  
542  prev = start = time.time()
543  print("[S] 0:00:00.000000 calculating algebraic normal forms for SBOXes...")
544  sbox_anfs = get_anfs()
545  #for k, v in sbox_anfs.items():
546  #    print(k, v)
547  
548  cur = time.time()
549  delta = datetime.timedelta(seconds=cur - prev)
550  prev = cur
551  print(f"[A] {delta} verifying correctness of ANF equations (comparing solutions of the equations with the lookup from the sbox table)...")
552  for p, m in enumerate(get_sboxes(SBOX)):
553      for nib in ('hi', 'lo'):
554          for x in range(16):
555              r = tuple(int(c) for c in (f'{m[nib][x]:04b}'))
556              x = tuple(int(c) for c in reversed(f'{x:04b}'))
557              assert map4x4(sbox_anfs, p, nib, x) == r
558  
559  cur = time.time()
560  delta = datetime.timedelta(seconds=cur - prev)
561  prev = cur
562  print(f"[B] {delta} calculating algebraic normal forms for boxperm...")
563  bpsbox = [boxperm(i,k,tmp) for i in range(4) for k in range(16) for tmp in range(16)]
564  #print(''.join(f"{x:02x}" for x in bpsbox))
565  bp_anf = get_boxperm_anfs()
566  #for k, v in sbox_anfs.items():
567  #    print(k, v)
568  
569  cur = time.time()
570  delta = datetime.timedelta(seconds=cur - prev)
571  prev = cur
572  print(f"[C] {delta} verifying correctness of ANF equations (comparing solutions of the equations with the lookup from the sbox table)...")
573  for x in range(1024):
574      r = tuple(int(c) for c in (f'{bpsbox[x]:04b}'))
575      x = tuple(int(c) for c in reversed(f'{x:010b}'))
576      assert map_boxperm(bp_anf, x) == r
577  
578  cur = time.time()
579  delta = datetime.timedelta(seconds=cur - prev)
580  prev = cur
581  print(f"[M] {delta} building model (lookup/ifs: {lupb})")
582  
583  s=Solver(timeout=1728000000, max_memory=110 * 1024**3)
584  
585  ciphertext="GDMIG BGAMF HIGIK ACDGJ GGJOL\xfe"
586  testvec_cv0 = 0x6e3897f14b9b20
587  testvec_cbs = [0x2440e35e2fc88a7f, 0x367a2ec0ab0789cb]
588  next_nonce = "EPNDD"
589  # prevcb0[5]={0xe6, 0x3c, 0x32, 0x65, 0x9a, 0x26, 0x65, 0x76};
590  #cb0 2440e35e2fc88a7f
591  # plaintext = "2PADDING..__"
592  
593  #ciphertext = "BAMLK GAFFJ EOBCP EEIMP CDAFP\xfe"
594  #testvec_cv0 = 0x19bea3035951f5
595  #testvec_cbs = [0xa074c5c8ec1a477e, 0xf4e8703e0308a769]
596  #ciphertext="FBFDN MLBPE IJCFF DKDJB LFNNP DMCOM LBMKM GGOKL IJHCO JOEDE LFCCK PIHDP MMHCB DJGNE CGOCL NCAAB JJAIC GPJCH GHBNP FDKNA MGJHP NICPF\xfe"
597  #testvec_cv0 = 0x06157a7cbc4555
598  #testvec_cbs = [0xac18ad815975afd9, 0x69e9fb2c9994938f, 0x2c7c4f499eda0eab, 0xa022bfd6af9bcbed,
599  #              0x2c07677a2e16a042, 0x63b79ece15f3abed, 0x64d45fc6c5998530, 0x3557dcce84846f57]
600  
601  print(f"[i] ciphertext: {ciphertext[:-1]}")
602  ciphertext = bytes([(ord(c) & 0x3f) if c!='\xfe' else 0x1f
603                          for c in ciphertext])
604  nonce = ciphertext[:5]
605  ciphertext = unscramble(ciphertext)
606  
607  # key schedule from password key
608  ################################
609  
610  ###### input_key is symbolic, and 1st and only param to the tgt fn
611  #####key = BVS("k",15*8, explicit_name="k")
612  #####
613  #####for b in key.chop(8):
614  #####    s.add(b>0)
615  #####    s.add(Or(b <= 26,
616  #####             b == b' ',
617  #####             b == b'.',
618  #####             b == b'-',
619  #####             b == b',',
620  #####             And(b'9' >= b, b >= b'0')))
621  #####
622  ###### parametric_entry(nonce,key)
623  ###### sbt_init(nonce, key)
624  #####state = concat([nonce[i] ^ key.get_byte(i) if i<3 else key.get_byte(i) for i in range(8)])
625  #####ctrlvar = key.get_bytes(8,7)
626  #####
627  ######assert(s.solution(cipherblock, 0x0d04040101010101, extra_constraints=[key == b'\x01'*15]))
628  #####
629  ###### sbt
630  #####cipherblock, state = sbt(s, ctrlvar, state)
631  #####ctrlvar = cipherblock.get_bytes(0,7)
632  
633  # keystream generation
634  ######################
635  
636  ctrlvar = BVS('cv', 7*8, explicit_name='cv')
637  state    = BVV(bytes([0xf5, 0xc0, 0x7a, 0x10, 0x8a, 0xaf, 0x17, 0xcf]))
638  # sbt again
639  cipherblock, state = sbt(s, ctrlvar, state)
640  #assert s.solution(cipherblock, testvec_cbs[0], extra_constraints=[ctrlvar == testvec_cv0])
641  state = state.get_bytes(0,5).concat(nonce[:3])
642  # sbt_init, parametric_entry end.
643  
644  cipherblocks=[cipherblock]
645  for i in range(1, len(ciphertext) // 8 + 1):
646      cipherblock, state = sbt(s,ctrlvar,state)
647      #re = s.eval(cipherblock,1, extra_constraints=[ctrlvar == testvec_cv0])[0]
648      #print(f"{re:016x} == {testvec_cbs[i]:016x}")
649      #assert re == testvec_cbs[i]
650      #print(f"[?] block {i} ok")
651      cipherblocks.append(cipherblock)
652  
653  cur = time.time()
654  delta = datetime.timedelta(seconds=cur - prev)
655  prev = cur
656  print(f"[?] {delta} verifying testvector")
657  #re = s.eval(cipherblock,1, extra_constraints=[ctrlvar == 0x19bea3035951f5])[0]
658  #assert f"{re:016x}" == 'a074c5c8ec1a477e'
659  #re = s.eval(cipherblock1,1, extra_constraints=[ctrlvar == 0x19bea3035951f5])[0]
660  #print(f"cb1: {re:016x}")
661  #assert f"{re:016x}" == 'f4e8703e0308a769'
662  # the following times out? the previous is almost instant
663  for j, cb in enumerate(testvec_cbs):
664     cb = cb.to_bytes(8,byteorder="big")
665     for i in range( min(8, len(ciphertext)-8*j) ):
666         p = (cb[i] ^ ciphertext[i+j*8]) & 0x3f
667         print(f"{chr( p|(0x40 if p <= 26 else 0))}", end=' ')
668  print()
669  for j, cb in enumerate(testvec_cbs):
670     cb = cb.to_bytes(8,byteorder="big")
671     for i in range( min(8, len(ciphertext)-8*j) ):
672         p = cb[i] & 0x3f
673         print(f"{p:02x}", end='')
674  print()
675  
676  cur = time.time()
677  delta = datetime.timedelta(seconds=cur - prev)
678  prev = cur
679  print(f"[c] {delta} constraining to ciphertext")
680  
681  plaintext_blocks = []
682  for j, cb in enumerate(cipherblocks):
683      pt_block = concat(( (cb.get_byte(i) ^ ciphertext[i+j*8]) & 0x3f for i in range( min(8, len(ciphertext)-8*j) )))
684      #print(f"len(pt[{j}]) = {len(pt_block)}")
685      for b in pt_block.chop(8):
686          s.add(b!=0)
687          if j == len(cipherblocks) -1:
688              s.add(Or(b <= 26,
689                       b == b' ',
690                       b == b'.',
691                       b == b'-',
692                       b == b',',
693                       b == 0x1f,
694                       And(b'9' >= b, b >= b'0')))
695          else:
696              s.add(Or(b <= 26,
697                       b == b' ',
698                       b == b'.',
699                       b == b'-',
700                       b == b',',
701                       And(b'9' >= b, b >= b'0')))
702      plaintext_blocks.append(pt_block)
703  
704  # test for known padding double _
705  #s.add(plaintext_blocks[1].chop(8)[3] == 0x1f)
706  #s.add(plaintext_blocks[1].chop(8)[2] == 0x1f)
707  # the following 2 constraints are equivalen to the above, but more "exact"
708  #s.add(cipherblocks[1].get_byte(2)[5:0] == 0x2e)
709  #s.add(cipherblocks[1].get_byte(3)[5:0] == 0x00)
710  
711  s.add(cipherblocks[0].get_byte(0)[5:0] == 0x24)
712  s.add(cipherblocks[0].get_byte(1)[5:0] == 0x00)
713  s.add(cipherblocks[0].get_byte(2)[5:0] == 0x23)
714  s.add(cipherblocks[0].get_byte(3)[5:0] == 0x1e)
715  s.add(cipherblocks[0].get_byte(4)[5:0] == 0x2f)
716  s.add(cipherblocks[0].get_byte(5)[5:0] == 0x08)
717  s.add(cipherblocks[0].get_byte(6)[5:0] == 0x0a)
718  s.add(cipherblocks[0].get_byte(7)[5:0] == 0x3f)
719  
720  # test for known succeeding Nonce with known delta for TL0
721  # for offset in range(16): print([hex((c ^ (offset + i*15)) & 0xf) for i, c in enumerate(nonce)])
722  # one of these rows has the correct cb0[:5]
723  next_nonce = [(ord(x) & 0x3f) - 1 for x in next_nonce]
724  cbn = cipherblocks[0].chop(4)
725  cbnc = [[(c ^ (offset + i*15)) & 0xf for i, c in enumerate(next_nonce)] for offset in range(16)]
726  s.add(Or(
727      And(cbn[1]==cbnc[0][0],  cbn[3]==cbnc[0][1],  cbn[5]==cbnc[0][2],  cbn[7]==cbnc[0][3],  cbn[9]==cbnc[0][4]),
728      And(cbn[1]==cbnc[1][0],  cbn[3]==cbnc[1][1],  cbn[5]==cbnc[1][2],  cbn[7]==cbnc[1][3],  cbn[9]==cbnc[1][4]),
729      And(cbn[1]==cbnc[2][0],  cbn[3]==cbnc[2][1],  cbn[5]==cbnc[2][2],  cbn[7]==cbnc[2][3],  cbn[9]==cbnc[2][4]),
730      And(cbn[1]==cbnc[3][0],  cbn[3]==cbnc[3][1],  cbn[5]==cbnc[3][2],  cbn[7]==cbnc[3][3],  cbn[9]==cbnc[3][4]),
731      And(cbn[1]==cbnc[4][0],  cbn[3]==cbnc[4][1],  cbn[5]==cbnc[4][2],  cbn[7]==cbnc[4][3],  cbn[9]==cbnc[4][4]),
732      And(cbn[1]==cbnc[5][0],  cbn[3]==cbnc[5][1],  cbn[5]==cbnc[5][2],  cbn[7]==cbnc[5][3],  cbn[9]==cbnc[5][4]),
733      And(cbn[1]==cbnc[6][0],  cbn[3]==cbnc[6][1],  cbn[5]==cbnc[6][2],  cbn[7]==cbnc[6][3],  cbn[9]==cbnc[6][4]),
734      And(cbn[1]==cbnc[7][0],  cbn[3]==cbnc[7][1],  cbn[5]==cbnc[7][2],  cbn[7]==cbnc[7][3],  cbn[9]==cbnc[7][4]),
735      And(cbn[1]==cbnc[8][0],  cbn[3]==cbnc[8][1],  cbn[5]==cbnc[8][2],  cbn[7]==cbnc[8][3],  cbn[9]==cbnc[8][4]),
736      And(cbn[1]==cbnc[9][0],  cbn[3]==cbnc[9][1],  cbn[5]==cbnc[9][2],  cbn[7]==cbnc[9][3],  cbn[9]==cbnc[9][4]),
737      And(cbn[1]==cbnc[10][0], cbn[3]==cbnc[10][1], cbn[5]==cbnc[10][2], cbn[7]==cbnc[10][3], cbn[9]==cbnc[10][4]),
738      And(cbn[1]==cbnc[11][0], cbn[3]==cbnc[11][1], cbn[5]==cbnc[11][2], cbn[7]==cbnc[11][3], cbn[9]==cbnc[11][4]),
739      And(cbn[1]==cbnc[12][0], cbn[3]==cbnc[12][1], cbn[5]==cbnc[12][2], cbn[7]==cbnc[12][3], cbn[9]==cbnc[12][4]),
740      And(cbn[1]==cbnc[13][0], cbn[3]==cbnc[13][1], cbn[5]==cbnc[13][2], cbn[7]==cbnc[13][3], cbn[9]==cbnc[13][4]),
741      And(cbn[1]==cbnc[14][0], cbn[3]==cbnc[14][1], cbn[5]==cbnc[14][2], cbn[7]==cbnc[14][3], cbn[9]==cbnc[14][4]),
742      And(cbn[1]==cbnc[15][0], cbn[3]==cbnc[15][1], cbn[5]==cbnc[15][2], cbn[7]==cbnc[15][3], cbn[9]==cbnc[15][4]),
743      ))
744  
745  #print(s.constraints)
746  cur = time.time()
747  delta = datetime.timedelta(seconds=cur - prev)
748  prev = cur
749  print(f"[?] {delta} testing decryption")
750  
751  re = s.eval(plaintext_blocks[0], 1, extra_constraints=[cipherblocks[0] == testvec_cbs[0]])[0]
752  #re = s.eval(plaintext_blocks[0], 1, extra_constraints=[ctrlvar == testvec_cv0])[0]
753  #re = s.eval(plaintext_blocks[0], 1, extra_constraints=[ctrlvar == testvec_cv0, cipherblocks[0] == testvec_cbs[0]])[0]
754  print(f"{re:016x}")
755  print(bytes(c|0x40 if c<27 else c for c in re.to_bytes(8, byteorder='big')))
756  
757  if threadid is not None:
758      print(f"[t]                constraining ctrlvar to threadid: {threadid}")
759      s.add(ctrlvar[2:1] == threadid)
760  
761  cur = time.time()
762  delta = datetime.timedelta(seconds=cur - prev)
763  prev = cur
764  print(f"[S] {delta} solving for ctrlvar")
765  
766  i=0
767  while s.satisfiable():
768     cur = time.time()
769     delta = datetime.timedelta(seconds=cur - prev)
770     prev = cur
771     print(f'[!] {delta} #{i} is sat')
772     try:
773        sol = s.eval(ctrlvar, 1)
774     except claripy.errors.ClaripyZ3Error:
775        cur = time.time()
776        delta = datetime.timedelta(seconds=cur - prev)
777        prev = cur
778        print(f"[.] {delta} z3 exception total running time: {str(datetime.timedelta(seconds=cur - start))}")
779        raise
780     cur = time.time()
781     delta = datetime.timedelta(seconds=cur - prev)
782     prev = cur
783     print(f'[!] {delta} #{i} we have a solution for the ctrlvar: {sol[0]:014x}')
784     pt=[]
785     for j, pt_block in enumerate(plaintext_blocks):
786         block = s.eval(pt_block, 1, extra_constraints=[ctrlvar == sol[0]])[0]
787         cur = time.time()
788         delta = datetime.timedelta(seconds=cur - prev)
789         prev = cur
790         print(f"{delta} solved block {j} {block:016x}")
791         for c in block.to_bytes(min(8, len(ciphertext)-8*j), byteorder='big'):
792             pt.append(c|0x40 if c<27 else c)
793     pt=bytes(pt)
794  
795     cur = time.time()
796     delta = datetime.timedelta(seconds=cur - prev)
797     prev = cur
798     print(f"                   plaintext is:", repr(pt))
799     s.add(ctrlvar!=sol[0])
800     i+=1
801  
802  cur = time.time()
803  delta = datetime.timedelta(seconds=cur - prev)
804  print(f"[.] {delta} total running time: {str(datetime.timedelta(seconds=cur - start))}")