LXStamper.py
1 import RNS 2 import RNS.vendor.umsgpack as msgpack 3 4 import os 5 import time 6 import math 7 import itertools 8 import multiprocessing 9 10 WORKBLOCK_EXPAND_ROUNDS = 3000 11 WORKBLOCK_EXPAND_ROUNDS_PN = 1000 12 WORKBLOCK_EXPAND_ROUNDS_PEERING = 25 13 STAMP_SIZE = RNS.Identity.HASHLENGTH//8 14 PN_VALIDATION_POOL_MIN_SIZE = 256 15 16 active_jobs = {} 17 18 def stamp_workblock(material, expand_rounds=WORKBLOCK_EXPAND_ROUNDS): 19 wb_st = time.time() 20 workblock = b"" 21 for n in range(expand_rounds): 22 workblock += RNS.Cryptography.hkdf(length=256, 23 derive_from=material, 24 salt=RNS.Identity.full_hash(material+msgpack.packb(n)), 25 context=None) 26 wb_time = time.time() - wb_st 27 # RNS.log(f"Stamp workblock size {RNS.prettysize(len(workblock))}, generated in {round(wb_time*1000,2)}ms", RNS.LOG_DEBUG) 28 29 return workblock 30 31 def stamp_value(workblock, stamp): 32 value = 0 33 bits = 256 34 material = RNS.Identity.full_hash(workblock+stamp) 35 i = int.from_bytes(material, byteorder="big") 36 while ((i & (1 << (bits - 1))) == 0): 37 i = (i << 1) 38 value += 1 39 40 return value 41 42 def stamp_valid(stamp, target_cost, workblock): 43 target = 0b1 << 256-target_cost 44 result = RNS.Identity.full_hash(workblock+stamp) 45 if int.from_bytes(result, byteorder="big") > target: return False 46 else: return True 47 48 def validate_peering_key(peering_id, peering_key, target_cost): 49 workblock = stamp_workblock(peering_id, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PEERING) 50 if not stamp_valid(peering_key, target_cost, workblock): return False 51 else: return True 52 53 def validate_pn_stamp(transient_data, target_cost): 54 from .LXMessage import LXMessage 55 if len(transient_data) <= LXMessage.LXMF_OVERHEAD+STAMP_SIZE: return None, None, None, None 56 else: 57 lxm_data = transient_data[:-STAMP_SIZE] 58 stamp = transient_data[-STAMP_SIZE:] 59 transient_id = RNS.Identity.full_hash(lxm_data) 60 workblock = stamp_workblock(transient_id, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PN) 61 62 if not stamp_valid(stamp, target_cost, workblock): return None, None, None, None 63 else: 64 value = stamp_value(workblock, stamp) 65 return transient_id, lxm_data, value, stamp 66 67 def validate_pn_stamps_job_simple(transient_list, target_cost): 68 validated_messages = [] 69 for transient_data in transient_list: 70 transient_id, lxm_data, value, stamp_data = validate_pn_stamp(transient_data, target_cost) 71 if transient_id: validated_messages.append([transient_id, lxm_data, value, stamp_data]) 72 73 return validated_messages 74 75 def validate_pn_stamps_job_multip(transient_list, target_cost): 76 cores = multiprocessing.cpu_count() 77 pool_count = min(cores, math.ceil(len(transient_list) / PN_VALIDATION_POOL_MIN_SIZE)) 78 79 RNS.log(f"Validating {len(transient_list)} stamps using {pool_count} processes...", RNS.LOG_VERBOSE) 80 with multiprocessing.get_context("spawn").Pool(pool_count) as p: 81 validated_entries = p.starmap(validate_pn_stamp, zip(transient_list, itertools.repeat(target_cost))) 82 83 RNS.log(f"Validation pool completed for {len(transient_list)} stamps", RNS.LOG_VERBOSE) 84 85 return [e for e in validated_entries if e[0] != None] 86 87 def validate_pn_stamps(transient_list, target_cost): 88 non_mp_platform = RNS.vendor.platformutils.is_android() 89 if len(transient_list) <= PN_VALIDATION_POOL_MIN_SIZE or non_mp_platform: return validate_pn_stamps_job_simple(transient_list, target_cost) 90 else: return validate_pn_stamps_job_multip(transient_list, target_cost) 91 92 def generate_stamp(message_id, stamp_cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS): 93 RNS.log(f"Generating stamp with cost {stamp_cost} for {RNS.prettyhexrep(message_id)}...", RNS.LOG_DEBUG) 94 workblock = stamp_workblock(message_id, expand_rounds=expand_rounds) 95 96 start_time = time.time() 97 stamp = None 98 rounds = 0 99 value = 0 100 101 if RNS.vendor.platformutils.is_windows() or RNS.vendor.platformutils.is_darwin(): stamp, rounds = job_simple(stamp_cost, workblock, message_id) 102 elif RNS.vendor.platformutils.is_android(): stamp, rounds = job_android(stamp_cost, workblock, message_id) 103 else: stamp, rounds = job_linux(stamp_cost, workblock, message_id) 104 105 duration = time.time() - start_time 106 speed = rounds/duration 107 if stamp != None: value = stamp_value(workblock, stamp) 108 109 RNS.log(f"Stamp with value {value} generated in {RNS.prettytime(duration)}, {rounds} rounds, {int(speed)} rounds per second", RNS.LOG_DEBUG) 110 111 return stamp, value 112 113 def cancel_work(message_id): 114 if RNS.vendor.platformutils.is_windows() or RNS.vendor.platformutils.is_darwin(): 115 try: 116 if message_id in active_jobs: 117 active_jobs[message_id] = True 118 119 except Exception as e: 120 RNS.log("Error while terminating stamp generation workers: {e}", RNS.LOG_ERROR) 121 RNS.trace_exception(e) 122 123 elif RNS.vendor.platformutils.is_android(): 124 try: 125 if message_id in active_jobs: 126 active_jobs[message_id] = True 127 128 except Exception as e: 129 RNS.log("Error while terminating stamp generation workers: {e}", RNS.LOG_ERROR) 130 RNS.trace_exception(e) 131 132 else: 133 try: 134 if message_id in active_jobs: 135 stop_event = active_jobs[message_id][0] 136 result_queue = active_jobs[message_id][1] 137 stop_event.set() 138 result_queue.put(None) 139 active_jobs.pop(message_id) 140 141 except Exception as e: 142 RNS.log("Error while terminating stamp generation workers: {e}", RNS.LOG_ERROR) 143 RNS.trace_exception(e) 144 145 def job_simple(stamp_cost, workblock, message_id): 146 # A simple, single-process stamp generator. 147 # should work on any platform, and is used 148 # as a fall-back, in case of limited multi- 149 # processing and/or acceleration support. 150 151 platform = RNS.vendor.platformutils.get_platform() 152 RNS.log(f"Running stamp generation on {platform}, work limited to single CPU core. This will be slower than ideal.", RNS.LOG_WARNING) 153 154 rounds = 0 155 pstamp = os.urandom(256//8) 156 st = time.time() 157 158 active_jobs[message_id] = False; 159 160 def sv(s, c, w): 161 target = 0b1<<256-c; m = w+s 162 result = RNS.Identity.full_hash(m) 163 if int.from_bytes(result, byteorder="big") > target: return False 164 else: return True 165 166 while not sv(pstamp, stamp_cost, workblock) and not active_jobs[message_id]: 167 pstamp = os.urandom(256//8); rounds += 1 168 if rounds % 2500 == 0: 169 speed = rounds / (time.time()-st) 170 RNS.log(f"Stamp generation running. {rounds} rounds completed so far, {int(speed)} rounds per second", RNS.LOG_DEBUG) 171 172 if active_jobs[message_id] == True: 173 pstamp = None 174 175 active_jobs.pop(message_id) 176 177 return pstamp, rounds 178 179 def job_linux(stamp_cost, workblock, message_id): 180 allow_kill = True 181 stamp = None 182 total_rounds = 0 183 cores = multiprocessing.cpu_count() 184 jobs = cores if cores <= 12 else int(cores/2) 185 stop_event = multiprocessing.Event() 186 result_queue = multiprocessing.Queue(1) 187 rounds_queue = multiprocessing.Queue() 188 189 def job(stop_event, pn, sc, wb): 190 terminated = False 191 rounds = 0 192 pstamp = os.urandom(256//8) 193 194 def sv(s, c, w): 195 target = 0b1<<256-c; m = w+s 196 result = RNS.Identity.full_hash(m) 197 if int.from_bytes(result, byteorder="big") > target: 198 return False 199 else: 200 return True 201 202 while not stop_event.is_set() and not sv(pstamp, sc, wb): 203 pstamp = os.urandom(256//8); rounds += 1 204 205 if not stop_event.is_set(): 206 stop_event.set() 207 result_queue.put(pstamp) 208 rounds_queue.put(rounds) 209 210 job_procs = [] 211 RNS.log(f"Starting {jobs} stamp generation workers", RNS.LOG_DEBUG) 212 for jpn in range(jobs): 213 process = multiprocessing.get_context("fork").Process(target=job, kwargs={"stop_event": stop_event, "pn": jpn, "sc": stamp_cost, "wb": workblock}, daemon=True) 214 job_procs.append(process) 215 process.start() 216 217 active_jobs[message_id] = [stop_event, result_queue] 218 219 stamp = result_queue.get() 220 221 # Collect any potential spurious 222 # results from worker queue. 223 try: 224 while True: result_queue.get_nowait() 225 except: pass 226 227 for j in range(jobs): 228 nrounds = 0 229 try: 230 nrounds = rounds_queue.get(timeout=2) 231 except Exception as e: 232 RNS.log(f"Failed to get round stats part {j}: {e}", RNS.LOG_ERROR) 233 total_rounds += nrounds 234 235 all_exited = False 236 exit_timeout = time.time() + 5 237 while time.time() < exit_timeout: 238 if not any(p.is_alive() for p in job_procs): 239 all_exited = True 240 break 241 time.sleep(0.1) 242 243 if not all_exited: 244 RNS.log("Stamp generation IPC timeout, possible worker deadlock. Terminating remaining processes.", RNS.LOG_ERROR) 245 if allow_kill: 246 for j in range(jobs): 247 process = job_procs[j] 248 process.kill() 249 else: 250 return None 251 252 else: 253 for j in range(jobs): 254 process = job_procs[j] 255 process.join() 256 # RNS.log(f"Joined {j} / {process}", RNS.LOG_DEBUG) # TODO: Remove 257 258 return stamp, total_rounds 259 260 def job_android(stamp_cost, workblock, message_id): 261 # Semaphore support is flaky to non-existent on 262 # Android, so we need to manually dispatch and 263 # manage workloads here, while periodically 264 # checking in on the progress. 265 266 stamp = None 267 start_time = time.time() 268 total_rounds = 0 269 rounds_per_worker = 1000 270 271 use_nacl = False 272 try: 273 import nacl.encoding 274 import nacl.hash 275 use_nacl = True 276 except: 277 pass 278 279 if use_nacl: 280 def full_hash(m): 281 return nacl.hash.sha256(m, encoder=nacl.encoding.RawEncoder) 282 else: 283 def full_hash(m): 284 return RNS.Identity.full_hash(m) 285 286 def sv(s, c, w): 287 target = 0b1<<256-c 288 m = w+s 289 result = full_hash(m) 290 if int.from_bytes(result, byteorder="big") > target: 291 return False 292 else: 293 return True 294 295 wm = multiprocessing.Manager() 296 jobs = multiprocessing.cpu_count() 297 298 def job(procnum=None, results_dict=None, wb=None, sc=None, jr=None): 299 # RNS.log(f"Worker {procnum} starting for {jr} rounds...") # TODO: Remove 300 try: 301 rounds = 0 302 found_stamp = None 303 304 while True: 305 pstamp = os.urandom(256//8) 306 rounds += 1 307 if sv(pstamp, sc, wb): 308 found_stamp = pstamp 309 break 310 311 if rounds >= jr: 312 # RNS.log(f"Worker {procnum} found no result in {rounds} rounds") # TODO: Remove 313 break 314 315 results_dict[procnum] = [found_stamp, rounds] 316 except Exception as e: 317 RNS.log(f"Stamp generation worker error: {e}", RNS.LOG_ERROR) 318 RNS.trace_exception(e) 319 320 active_jobs[message_id] = False; 321 322 RNS.log(f"Dispatching {jobs} workers for stamp generation...", RNS.LOG_DEBUG) # TODO: Remove 323 324 results_dict = wm.dict() 325 while stamp == None and active_jobs[message_id] == False: 326 job_procs = [] 327 try: 328 for pnum in range(jobs): 329 pargs = {"procnum":pnum, "results_dict": results_dict, "wb": workblock, "sc":stamp_cost, "jr":rounds_per_worker} 330 process = multiprocessing.Process(target=job, kwargs=pargs) 331 job_procs.append(process) 332 process.start() 333 334 for process in job_procs: 335 process.join() 336 337 for j in results_dict: 338 r = results_dict[j] 339 total_rounds += r[1] 340 if r[0] != None: 341 stamp = r[0] 342 343 if stamp == None: 344 elapsed = time.time() - start_time 345 speed = total_rounds/elapsed 346 RNS.log(f"Stamp generation running. {total_rounds} rounds completed so far, {int(speed)} rounds per second", RNS.LOG_DEBUG) 347 348 except Exception as e: 349 RNS.log(f"Stamp generation job error: {e}") 350 RNS.trace_exception(e) 351 352 active_jobs.pop(message_id) 353 354 return stamp, total_rounds 355 356 # def stamp_value_linear(workblock, stamp): 357 # value = 0 358 # bits = 256 359 # material = RNS.Identity.full_hash(workblock+stamp) 360 # s = int.from_bytes(material, byteorder="big") 361 # return s.bit_count() 362 363 if __name__ == "__main__": 364 import sys 365 if len(sys.argv) < 2: 366 RNS.log("No cost argument provided", RNS.LOG_ERROR) 367 exit(1) 368 else: 369 try: 370 cost = int(sys.argv[1]) 371 except Exception as e: 372 RNS.log(f"Invalid cost argument provided: {e}", RNS.LOG_ERROR) 373 exit(1) 374 375 RNS.loglevel = RNS.LOG_DEBUG 376 RNS.log("Testing LXMF stamp generation", RNS.LOG_DEBUG) 377 message_id = os.urandom(32) 378 generate_stamp(message_id, cost) 379 380 RNS.log("", RNS.LOG_DEBUG) 381 RNS.log("Testing propagation stamp generation", RNS.LOG_DEBUG) 382 message_id = os.urandom(32) 383 generate_stamp(message_id, cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PN) 384 385 RNS.log("", RNS.LOG_DEBUG) 386 RNS.log("Testing peering key generation", RNS.LOG_DEBUG) 387 message_id = os.urandom(32) 388 generate_stamp(message_id, cost, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PEERING) 389 390 transient_list = [] 391 st = time.time(); count = 10000 392 for i in range(count): transient_list.append(os.urandom(256)) 393 validate_pn_stamps(transient_list, 5) 394 dt = time.time()-st; mps = count/dt 395 RNS.log(f"Validated {count} PN stamps in {RNS.prettytime(dt)}, {round(mps,1)} m/s", RNS.LOG_DEBUG)