/ LXMF / LXStamper.py
LXStamper.py
  1  import RNS
  2  import RNS.vendor.umsgpack as msgpack
  3  
  4  import os
  5  import time
  6  import math
  7  import itertools
  8  import multiprocessing
  9  
 10  WORKBLOCK_EXPAND_ROUNDS         = 3000
 11  WORKBLOCK_EXPAND_ROUNDS_PN      = 1000
 12  WORKBLOCK_EXPAND_ROUNDS_PEERING = 25
 13  STAMP_SIZE                      = RNS.Identity.HASHLENGTH//8
 14  PN_VALIDATION_POOL_MIN_SIZE     = 256
 15  
 16  active_jobs = {}
 17  
 18  def stamp_workblock(material, expand_rounds=WORKBLOCK_EXPAND_ROUNDS):
 19      wb_st = time.time()
 20      workblock = b""
 21      for n in range(expand_rounds):
 22          workblock += RNS.Cryptography.hkdf(length=256,
 23                                             derive_from=material,
 24                                             salt=RNS.Identity.full_hash(material+msgpack.packb(n)),
 25                                             context=None)
 26      wb_time = time.time() - wb_st
 27      # RNS.log(f"Stamp workblock size {RNS.prettysize(len(workblock))}, generated in {round(wb_time*1000,2)}ms", RNS.LOG_DEBUG)
 28  
 29      return workblock
 30  
 31  def stamp_value(workblock, stamp):
 32      value = 0
 33      bits = 256
 34      material = RNS.Identity.full_hash(workblock+stamp)
 35      i = int.from_bytes(material, byteorder="big")
 36      while ((i & (1 << (bits - 1))) == 0):
 37          i = (i << 1)
 38          value += 1
 39   
 40      return value
 41  
 42  def stamp_valid(stamp, target_cost, workblock):
 43      target = 0b1 << 256-target_cost
 44      result = RNS.Identity.full_hash(workblock+stamp)
 45      if int.from_bytes(result, byteorder="big") > target: return False
 46      else: return True
 47  
 48  def validate_peering_key(peering_id, peering_key, target_cost):
 49      workblock = stamp_workblock(peering_id, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PEERING)
 50      if not stamp_valid(peering_key, target_cost, workblock): return False
 51      else: return True
 52  
 53  def validate_pn_stamp(transient_data, target_cost):
 54      from .LXMessage import LXMessage
 55      if len(transient_data) <= LXMessage.LXMF_OVERHEAD+STAMP_SIZE: return None, None, None, None
 56      else:
 57          lxm_data     = transient_data[:-STAMP_SIZE]
 58          stamp        = transient_data[-STAMP_SIZE:]
 59          transient_id = RNS.Identity.full_hash(lxm_data)
 60          workblock    = stamp_workblock(transient_id, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PN)
 61          
 62          if not stamp_valid(stamp, target_cost, workblock): return None, None, None, None
 63          else:
 64              value = stamp_value(workblock, stamp)
 65              return transient_id, lxm_data, value, stamp
 66  
 67  def validate_pn_stamps_job_simple(transient_list, target_cost):
 68      validated_messages = []
 69      for transient_data in transient_list:
 70          transient_id, lxm_data, value, stamp_data = validate_pn_stamp(transient_data, target_cost)
 71          if transient_id: validated_messages.append([transient_id, lxm_data, value, stamp_data])
 72  
 73      return validated_messages
 74  
 75  def validate_pn_stamps_job_multip(transient_list, target_cost):
 76      cores      = multiprocessing.cpu_count()
 77      pool_count = min(cores, math.ceil(len(transient_list) / PN_VALIDATION_POOL_MIN_SIZE))
 78          
 79      RNS.log(f"Validating {len(transient_list)} stamps using {pool_count} processes...", RNS.LOG_VERBOSE)
 80      with multiprocessing.get_context("spawn").Pool(pool_count) as p:
 81          validated_entries = p.starmap(validate_pn_stamp, zip(transient_list, itertools.repeat(target_cost)))
 82      
 83      RNS.log(f"Validation pool completed for {len(transient_list)} stamps", RNS.LOG_VERBOSE)
 84  
 85      return [e for e in validated_entries if e[0] != None]
 86  
 87  def validate_pn_stamps(transient_list, target_cost):
 88      non_mp_platform = RNS.vendor.platformutils.is_android()
 89      if len(transient_list) <= PN_VALIDATION_POOL_MIN_SIZE or non_mp_platform: return validate_pn_stamps_job_simple(transient_list, target_cost)
 90      else:                                                                     return validate_pn_stamps_job_multip(transient_list, target_cost)
 91  
 92  def generate_stamp(message_id, stamp_cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS):
 93      RNS.log(f"Generating stamp with cost {stamp_cost} for {RNS.prettyhexrep(message_id)}...", RNS.LOG_DEBUG)
 94      workblock = stamp_workblock(message_id, expand_rounds=expand_rounds)
 95      
 96      start_time = time.time()
 97      stamp = None
 98      rounds = 0
 99      value = 0
100  
101      if RNS.vendor.platformutils.is_windows() or RNS.vendor.platformutils.is_darwin(): stamp, rounds = job_simple(stamp_cost, workblock, message_id)
102      elif RNS.vendor.platformutils.is_android(): stamp, rounds = job_android(stamp_cost, workblock, message_id)
103      else: stamp, rounds = job_linux(stamp_cost, workblock, message_id)
104      
105      duration = time.time() - start_time
106      speed = rounds/duration
107      if stamp != None: value = stamp_value(workblock, stamp)
108  
109      RNS.log(f"Stamp with value {value} generated in {RNS.prettytime(duration)}, {rounds} rounds, {int(speed)} rounds per second", RNS.LOG_DEBUG)
110  
111      return stamp, value
112  
113  def cancel_work(message_id):
114      if RNS.vendor.platformutils.is_windows() or RNS.vendor.platformutils.is_darwin():
115          try:
116              if message_id in active_jobs:
117                  active_jobs[message_id] = True
118  
119          except Exception as e:
120              RNS.log("Error while terminating stamp generation workers: {e}", RNS.LOG_ERROR)
121              RNS.trace_exception(e)
122  
123      elif RNS.vendor.platformutils.is_android():
124          try:
125              if message_id in active_jobs:
126                  active_jobs[message_id] = True
127  
128          except Exception as e:
129              RNS.log("Error while terminating stamp generation workers: {e}", RNS.LOG_ERROR)
130              RNS.trace_exception(e)
131  
132      else:
133          try:
134              if message_id in active_jobs:
135                  stop_event = active_jobs[message_id][0]
136                  result_queue = active_jobs[message_id][1]
137                  stop_event.set()
138                  result_queue.put(None)
139                  active_jobs.pop(message_id)
140  
141          except Exception as e:
142              RNS.log("Error while terminating stamp generation workers: {e}", RNS.LOG_ERROR)
143              RNS.trace_exception(e)
144  
145  def job_simple(stamp_cost, workblock, message_id):
146      # A simple, single-process stamp generator.
147      # should work on any platform, and is used
148      # as a fall-back, in case of limited multi-
149      # processing and/or acceleration support.
150  
151      platform = RNS.vendor.platformutils.get_platform()
152      RNS.log(f"Running stamp generation on {platform}, work limited to single CPU core. This will be slower than ideal.", RNS.LOG_WARNING)
153  
154      rounds = 0
155      pstamp = os.urandom(256//8)
156      st = time.time()
157  
158      active_jobs[message_id] = False;
159  
160      def sv(s, c, w):
161          target = 0b1<<256-c; m = w+s
162          result = RNS.Identity.full_hash(m)
163          if int.from_bytes(result, byteorder="big") > target: return False
164          else:                                                return True
165  
166      while not sv(pstamp, stamp_cost, workblock) and not active_jobs[message_id]:
167          pstamp = os.urandom(256//8); rounds += 1
168          if rounds % 2500 == 0:
169              speed = rounds / (time.time()-st)
170              RNS.log(f"Stamp generation running. {rounds} rounds completed so far, {int(speed)} rounds per second", RNS.LOG_DEBUG)
171  
172      if active_jobs[message_id] == True:
173          pstamp = None
174  
175      active_jobs.pop(message_id)
176      
177      return pstamp, rounds
178  
179  def job_linux(stamp_cost, workblock, message_id):
180      allow_kill = True
181      stamp = None
182      total_rounds = 0
183      cores = multiprocessing.cpu_count()
184      jobs = cores if cores <= 12 else int(cores/2)
185      stop_event   = multiprocessing.Event()
186      result_queue = multiprocessing.Queue(1)
187      rounds_queue = multiprocessing.Queue()
188  
189      def job(stop_event, pn, sc, wb):
190          terminated = False
191          rounds = 0
192          pstamp = os.urandom(256//8)
193  
194          def sv(s, c, w):
195              target = 0b1<<256-c; m = w+s
196              result = RNS.Identity.full_hash(m)
197              if int.from_bytes(result, byteorder="big") > target:
198                  return False
199              else:
200                  return True
201  
202          while not stop_event.is_set() and not sv(pstamp, sc, wb):
203              pstamp = os.urandom(256//8); rounds += 1
204  
205          if not stop_event.is_set():
206              stop_event.set()
207              result_queue.put(pstamp)
208          rounds_queue.put(rounds)
209      
210      job_procs = []
211      RNS.log(f"Starting {jobs} stamp generation workers", RNS.LOG_DEBUG)
212      for jpn in range(jobs):
213          process = multiprocessing.get_context("fork").Process(target=job, kwargs={"stop_event": stop_event, "pn": jpn, "sc": stamp_cost, "wb": workblock}, daemon=True)
214          job_procs.append(process)
215          process.start()
216  
217      active_jobs[message_id] = [stop_event, result_queue]
218  
219      stamp = result_queue.get()
220  
221      # Collect any potential spurious
222      # results from worker queue.
223      try:
224          while True: result_queue.get_nowait()
225      except: pass
226  
227      for j in range(jobs):
228          nrounds = 0
229          try:
230              nrounds = rounds_queue.get(timeout=2)
231          except Exception as e:
232              RNS.log(f"Failed to get round stats part {j}: {e}", RNS.LOG_ERROR)
233          total_rounds += nrounds
234  
235      all_exited = False
236      exit_timeout = time.time() + 5
237      while time.time() < exit_timeout:
238          if not any(p.is_alive() for p in job_procs):
239              all_exited = True
240              break
241          time.sleep(0.1)
242  
243      if not all_exited:
244          RNS.log("Stamp generation IPC timeout, possible worker deadlock. Terminating remaining processes.", RNS.LOG_ERROR)
245          if allow_kill:
246              for j in range(jobs):
247                  process = job_procs[j]
248                  process.kill()
249          else:
250              return None
251  
252      else:
253          for j in range(jobs):
254              process = job_procs[j]
255              process.join()
256              # RNS.log(f"Joined {j} / {process}", RNS.LOG_DEBUG) # TODO: Remove
257  
258      return stamp, total_rounds
259  
260  def job_android(stamp_cost, workblock, message_id):
261      # Semaphore support is flaky to non-existent on
262      # Android, so we need to manually dispatch and
263      # manage workloads here, while periodically
264      # checking in on the progress.
265      
266      stamp = None
267      start_time = time.time()
268      total_rounds = 0
269      rounds_per_worker = 1000
270      
271      use_nacl = False
272      try:
273          import nacl.encoding
274          import nacl.hash
275          use_nacl = True
276      except:
277          pass
278  
279      if use_nacl:
280          def full_hash(m):
281              return nacl.hash.sha256(m, encoder=nacl.encoding.RawEncoder)
282      else:
283          def full_hash(m):
284              return RNS.Identity.full_hash(m)
285  
286      def sv(s, c, w):
287          target = 0b1<<256-c
288          m = w+s
289          result = full_hash(m)
290          if int.from_bytes(result, byteorder="big") > target:
291              return False
292          else:
293              return True
294  
295      wm = multiprocessing.Manager()
296      jobs = multiprocessing.cpu_count()
297  
298      def job(procnum=None, results_dict=None, wb=None, sc=None, jr=None):
299          # RNS.log(f"Worker {procnum} starting for {jr} rounds...") # TODO: Remove
300          try:
301              rounds = 0
302              found_stamp = None
303  
304              while True:
305                  pstamp = os.urandom(256//8)
306                  rounds += 1
307                  if sv(pstamp, sc, wb):
308                      found_stamp = pstamp
309                      break
310  
311                  if rounds >= jr:
312                      # RNS.log(f"Worker {procnum} found no result in {rounds} rounds") # TODO: Remove
313                      break
314  
315              results_dict[procnum] = [found_stamp, rounds]
316          except Exception as e:
317              RNS.log(f"Stamp generation worker error: {e}", RNS.LOG_ERROR)
318              RNS.trace_exception(e)
319  
320      active_jobs[message_id] = False;
321  
322      RNS.log(f"Dispatching {jobs} workers for stamp generation...", RNS.LOG_DEBUG) # TODO: Remove
323  
324      results_dict = wm.dict()
325      while stamp == None and active_jobs[message_id] == False:
326          job_procs = []
327          try:
328              for pnum in range(jobs):
329                  pargs = {"procnum":pnum, "results_dict": results_dict, "wb": workblock, "sc":stamp_cost, "jr":rounds_per_worker}
330                  process = multiprocessing.Process(target=job, kwargs=pargs)
331                  job_procs.append(process)
332                  process.start()
333  
334              for process in job_procs:
335                  process.join()
336  
337              for j in results_dict:
338                  r = results_dict[j]
339                  total_rounds += r[1]
340                  if r[0] != None:
341                      stamp = r[0]
342  
343              if stamp == None:
344                  elapsed = time.time() - start_time
345                  speed = total_rounds/elapsed
346                  RNS.log(f"Stamp generation running. {total_rounds} rounds completed so far, {int(speed)} rounds per second", RNS.LOG_DEBUG)
347          
348          except Exception as e:
349              RNS.log(f"Stamp generation job error: {e}")
350              RNS.trace_exception(e)
351  
352      active_jobs.pop(message_id)
353  
354      return stamp, total_rounds
355  
356  # def stamp_value_linear(workblock, stamp):
357  #     value = 0
358  #     bits = 256
359  #     material = RNS.Identity.full_hash(workblock+stamp)
360  #     s = int.from_bytes(material, byteorder="big")
361  #     return s.bit_count()
362  
363  if __name__ == "__main__":
364      import sys
365      if len(sys.argv) < 2:
366          RNS.log("No cost argument provided", RNS.LOG_ERROR)
367          exit(1)
368      else:
369          try:
370              cost = int(sys.argv[1])
371          except Exception as e:
372              RNS.log(f"Invalid cost argument provided: {e}", RNS.LOG_ERROR)
373              exit(1)
374  
375      RNS.loglevel = RNS.LOG_DEBUG
376      RNS.log("Testing LXMF stamp generation", RNS.LOG_DEBUG)
377      message_id = os.urandom(32)
378      generate_stamp(message_id, cost)
379  
380      RNS.log("", RNS.LOG_DEBUG)
381      RNS.log("Testing propagation stamp generation", RNS.LOG_DEBUG)
382      message_id = os.urandom(32)
383      generate_stamp(message_id, cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PN)
384  
385      RNS.log("", RNS.LOG_DEBUG)
386      RNS.log("Testing peering key generation", RNS.LOG_DEBUG)
387      message_id = os.urandom(32)
388      generate_stamp(message_id, cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PEERING)
389  
390      transient_list = []
391      st = time.time(); count = 10000
392      for i in range(count): transient_list.append(os.urandom(256))
393      validate_pn_stamps(transient_list, 5)
394      dt = time.time()-st; mps = count/dt
395      RNS.log(f"Validated {count} PN stamps in {RNS.prettytime(dt)}, {round(mps,1)} m/s", RNS.LOG_DEBUG)