multiplexer.py
1 #!/usr/bin/env python 2 3 import ssl, socket, select, struct, asyncio, serial, sys, time 4 from pyoprf import noisexk 5 from itertools import zip_longest 6 from serial_asyncio import create_serial_connection 7 try: 8 from ble_serial.bluetooth.ble_client import BLE_client 9 except ImportError: 10 BLE_client = None 11 try: 12 import pyudev 13 except ImportError: 14 pyudev = None 15 16 def split_by_n(iterable, n): 17 return list(zip_longest(*[iter(iterable)]*n)) 18 19 def get_event_loop(): 20 try: 21 return asyncio.get_running_loop() 22 except RuntimeError: 23 loop = asyncio.new_event_loop() 24 asyncio.set_event_loop(loop) 25 return loop 26 27 class Peer: 28 def __init__(self, name, addr, type = "SSL", ssl_cert=None, timeout=5, alpn_proto=None): 29 self.name = name 30 self.type = type # currently only TCP or SSL over TCP, but 31 # could be others like dedicated NOISE_XK, 32 # or hybrid mceliece+x25519 over USB or 33 # even UART 34 self.address = addr # Currently only TCP host:port as a tuple 35 self.ssl_cert = ssl_cert 36 self.timeout = timeout 37 self.alpn_proto = alpn_proto or ["oprf/1"] 38 self.state = "new" 39 self.fd = None 40 41 def connect(self): 42 if self.state == "connected": 43 raise ValueError(f"{self.name} is already connected") 44 45 if self.type not in {"SSL", "TCP"}: 46 raise ValueError(f"Unsupported peer type: {self.type}") 47 48 if self.type == "SSL": 49 ctx = ssl.create_default_context() 50 ctx.minimum_version = ssl.TLSVersion.TLSv1_2 51 ctx.set_alpn_protocols(self.alpn_proto) 52 if(self.ssl_cert): 53 ctx.load_verify_locations(self.ssl_cert) # only for dev, production system should use proper certs! 54 ctx.check_hostname=False # only for dev, production system should use proper certs! 55 ctx.verify_mode=ssl.CERT_NONE # only for dev, production system should use proper certs! 56 else: 57 ctx.load_default_certs() 58 ctx.verify_mode = ssl.CERT_REQUIRED 59 ctx.check_hostname = True 60 61 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 62 s.settimeout(self.timeout) 63 if self.type == "SSL": 64 self.fd = ctx.wrap_socket(s, server_hostname=self.address[0]) 65 try: self.fd.connect(self.address) 66 except: return 67 self.state="connected" 68 69 def connected(self): 70 return self.state == "connected" 71 72 async def read_async(self,size): 73 if not self.connected(): 74 return None 75 #raise ValueError(f"{self.name} cannot read, is not connected") 76 77 res = [] 78 read = 0 79 while read<size and (len(res)==0 or len(res[-1])!=0): 80 res.append(self.fd.recv(size-read)) 81 read+=len(res[-1]) 82 83 if len(res[-1])==0 and read<size: 84 self.state = 'disconnected' 85 #raise ValueError(f"short read for {self.name}, only {len(b''.join(res))} instead of expected {size} bytes") 86 return b''.join(res) 87 88 def read(self, *args, **kwargs): 89 return get_event_loop().run_until_complete(self.read_async(*args, **kwargs)) 90 91 def send(self, msg): 92 if not self.connected(): 93 return 94 #raise ValueError(f"{self.name} cannot write, is not connected") 95 self.fd.sendall(msg) 96 97 def close(self): 98 if self.state == "closed": return 99 if not self.connected(): 100 return 101 #raise ValueError(f"{self.name} cannot close, is not connected") 102 if self.fd and self.fd.fileno() != -1: 103 try: self.fd.shutdown(socket.SHUT_RDWR) 104 except OSError: pass 105 self.fd.close() 106 self.state = "closed" 107 else: 108 # closed by other end. 109 self.state = "closed" 110 111 class BLEPeer: 112 def __init__(self, name, addr, server_pk, client_sk, device="hci0", timeout=5): 113 self.name = name 114 self.address = addr # the MAC address of the device 115 self.server_pk = server_pk 116 self.client_sk = client_sk 117 self.timeout = timeout 118 self.state = "new" 119 self.ble = BLE_client(device, None) 120 self.rx_buffer = [] 121 self.rx_len = 0 122 self.rx_pexted = 0 123 self.rx_available = asyncio.Event() 124 125 def receive_callback(self, value: bytes): 126 #print("Received:", len(value), self.rx_pected, self.rx_len, value.hex(), file=sys.stderr) 127 #if self.rx_pected == 0: 128 # raise Exception(f"unexpected input received: {value.hex()}") 129 #self.rx_queue.put(value) 130 self.rx_buffer.append(value) 131 self.rx_len += len(value) 132 #if(self.rx_pected < 0): 133 # exit(23) 134 # raise Exception("rx buffer overflow") 135 if self.rx_pected > 0 and self.rx_len >= self.rx_pected: 136 self.rx_available.set() 137 138 async def read_raw(self,size): 139 while(self.rx_available.is_set()): 140 await asyncio.sleep(0.001) 141 self.rx_pected = size 142 if(self.rx_len < self.rx_pected): 143 #print(f"{self.rx_len} < {self.rx_pected}", file=sys.stderr) 144 await self.rx_available.wait() 145 rsize = 0; 146 ret = [] 147 while(rsize<self.rx_pected): 148 ret.append(self.rx_buffer.pop(0)) 149 rsize+=len(ret[-1]) 150 ret = b''.join(ret) 151 self.rx_pected = 0 152 self.rx_len -= rsize 153 self.rx_available.clear() 154 return ret 155 156 async def _send(self, msg, mtu=20): 157 #def _send(self, msg, mtu=20): 158 #print(f"_sending {msg.hex()}", file=sys.stderr) 159 for frag in split_by_n(msg, mtu): 160 frag=bytes(c for c in frag if c is not None) 161 #print("sending frag", repr(frag)) 162 #self.ble.queue_send(frag) 163 await self.ble.dev.write_gatt_char(self.ble.write_char, frag, self.ble.write_response_required) 164 #print(f"sent: {frag.hex()}", file=sys.stderr) 165 166 async def _connect(self): 167 if self.state == "connected": 168 raise ValueError(f"{self.name} is already connected") 169 170 self.ble.set_receiver(self.receive_callback) 171 172 await self.ble.connect(self.address, "public", None, 10.0) 173 await self.ble.setup_chars(None, None, "rw", False) 174 175 self.session, msg = noisexk.initiator_session(self.client_sk, self.server_pk, dst=b"klutshnik ble tle") 176 await self._send(msg) 177 resp = await self.read_raw(48) 178 noisexk.initiator_session_complete(self.session, resp) 179 ct = noisexk.send_msg(self.session, "") 180 await self._send(ct) 181 182 self.state="connected" 183 184 async def _disconnect(self): 185 await self.ble.disconnect() 186 self.state == "disconnected" 187 188 def connect(self): 189 #get_event_loop().set_debug(True) 190 get_event_loop().run_until_complete(self._connect()) 191 while not self.connected(): time.sleep(0.001) 192 193 def connected(self): 194 return self.state == "connected" 195 196 def read(self,size): 197 if not self.connected(): 198 return None 199 #raise ValueError(f"{self.name} cannot read, is not connected") 200 ct = get_event_loop().run_until_complete(self.read_raw(size+16)) 201 return noisexk.read_msg(self.session, ct) 202 203 async def read_async(self, size): 204 if not self.connected(): 205 return None 206 resp = await self.read_raw(size+16) 207 return noisexk.read_msg(self.session, resp) 208 209 def send(self, msg): 210 #print("sending msg", msg.hex(), file=sys.stderr) 211 if not self.connected(): 212 return 213 #raise ValueError(f"{self.name} cannot write, is not connected") 214 ct = noisexk.send_msg(self.session, msg) 215 header = struct.pack(">H",len(ct)) 216 get_event_loop().run_until_complete(self._send(header+ct)) 217 218 def close(self): 219 if self.state == "closed": return 220 if not self.connected(): 221 return 222 #raise ValueError(f"{self.name} cannot close, is not connected") 223 get_event_loop().run_until_complete(self._disconnect()) 224 225 class Serial(asyncio.Protocol): 226 def __init__(self, *args, **kwargs): 227 self.rx_buffer = [] 228 self.rx_len = 0 229 self.rx_pexted = 0 230 self.rx_available = asyncio.Event() 231 super().__init__(*args, **kwargs) 232 233 def connection_made(self, transport): 234 #print('port opened', transport, file=sys.stderr) 235 self.transport = transport 236 transport.serial.dtr = True 237 #transport.serial.rts = False 238 #transport.write(b'hello world\n') 239 240 def data_received(self, data): 241 #print('data received', len(data), data.hex(), file=sys.stderr) 242 #print('data received', len(data), file=sys.stderr) 243 self.rx_buffer.append(data) 244 self.rx_len += len(data) 245 if self.rx_pected > 0 and self.rx_len >= self.rx_pected: 246 self.rx_available.set() 247 248 def connection_lost(self, exc): 249 print('port closed', file=sys.stderr) 250 self.rx_available.set() 251 #get_event_loop().stop() 252 253 async def read_raw(self,size): 254 #print(f"read_raw({size})",file=sys.stderr) 255 #while(self.rx_available.is_set()): pass 256 self.rx_pected = size 257 while(self.rx_len < self.rx_pected 258 and not self.rx_available.is_set()): 259 await asyncio.sleep(0.001) 260 #if(self.rx_len < self.rx_pected): 261 #while(self.rx_available.is_set()): pass 262 #print(f"{self.rx_len} < {self.rx_pected}", file=sys.stderr) 263 # await self.rx_available.wait() 264 rsize = 0; 265 ret = [] 266 while(rsize<self.rx_pected): 267 if self.rx_buffer == []: break 268 ret.append(self.rx_buffer.pop(0)) 269 rsize+=len(ret[-1]) 270 ret = b''.join(ret) 271 #if(size<len(ret)): 272 # can happen when "OK" is expected but b'\x00\x04fail" is sent 18 vs 22 273 #print(f"XXXX expected size {size} < read size: {len(ret)}", file=sys.stderr) 274 self.rx_pected = 0 275 self.rx_len -= rsize 276 self.rx_available.clear() 277 return ret 278 279 class USBPeer: 280 def __init__(self, name, serno, server_pk, client_sk, timeout=5): 281 self.name = name 282 self.serno = serno # the serial number of the usb device 283 self.server_pk = server_pk 284 self.client_sk = client_sk 285 self.timeout = timeout 286 self.state = "new" 287 288 def __getattr__(self,name): 289 if name=="address": 290 return f"usb-cdc device #{self.serno} at {self.port}" 291 292 def find_usb_port(self): 293 context = pyudev.Context() 294 idx=0 295 for device in context.list_devices(subsystem='tty'): 296 if device.get('ID_SERIAL_SHORT') == self.serno: 297 if idx==1: 298 return device.device_node 299 idx+=1 300 return None 301 302 async def _connect(self): 303 if self.state == "connected": 304 raise ValueError(f"{self.name} is already connected") 305 306 self.session, msg = noisexk.initiator_session(self.client_sk, self.server_pk, dst=b"klutshnik ble tle") 307 self.transport.serial.write(msg) 308 #print(f"sent {len(msg)}B as {msg.hex()}",file=sys.stderr) 309 #print('waiting for noise hs2 response', file=sys.stderr) 310 resp = await self.protocol.read_raw(48) 311 #print(f"received {len(resp)}B as {resp.hex()}",file=sys.stderr) 312 noisexk.initiator_session_complete(self.session, resp) 313 ct = noisexk.send_msg(self.session, "") 314 self.transport.serial.write(ct) 315 #print(f"sent {len(ct)}B as {ct.hex()}",file=sys.stderr) 316 317 self.state="connected" 318 #print(f"_connected to {self.path}", file=sys.stderr) 319 320 def connect(self): 321 self.path = self.find_usb_port() 322 #print(f"connecting to {self.path}",file=sys.stderr) 323 loop = get_event_loop() 324 #loop.set_debug(True) 325 coro = create_serial_connection(loop, Serial, self.path, baudrate=115200) 326 self.transport, self.protocol = loop.run_until_complete(coro) 327 loop.run_until_complete(self._connect()) 328 while not self.connected(): time.sleep(0.001) 329 #print(f"connected to {self.path}", file=sys.stderr) 330 331 def connected(self): 332 return self.state == "connected" 333 334 async def read_async(self, size): 335 if not self.connected(): 336 return None 337 ct = await self.protocol.read_raw(size+16) 338 if len(ct)==0 or len(ct)<size+16: 339 self.state = 'disconnected' 340 raise ValueError(f"short read for {self.name}, only {len(b''.join(ct))} instead of expected {size} bytes") 341 #print(f"read_async({size}) .. ok",file=sys.stderr) 342 return noisexk.read_msg(self.session, ct) 343 344 def read(self, *args, **kwargs): 345 return get_event_loop().run_until_complete(self.read_async(*args, **kwargs)) 346 347 def send(self, msg): 348 if not self.connected(): 349 return 350 #raise ValueError(f"{self.name} cannot write, is not connected") 351 ct = noisexk.send_msg(self.session, msg) 352 header = struct.pack(">H",len(ct)) 353 self.transport.serial.write(header+ct) 354 355 def close(self): 356 if self.state == "closed": return 357 if not self.connected(): 358 return 359 #raise ValueError(f"{self.name} cannot close, is not connected") 360 if self.transport.serial is not None: 361 self.transport.serial.dtr = False 362 self.transport.close() 363 self.state = "closed" 364 365 class Multiplexer: 366 def __init__(self, peers, alpn_proto=None): 367 if asyncio.get_event_loop_policy()._local._loop is None: 368 loop = asyncio.new_event_loop() 369 asyncio.set_event_loop(loop) 370 self.peers = [] 371 for name, p in peers.items(): 372 if 'port' in p: 373 p = Peer(name 374 ,(p['host'],p['port']) 375 ,type=p.get("type", "SSL") 376 ,ssl_cert = p.get('ssl_cert') 377 ,timeout = p.get('timeout') 378 ,alpn_proto=alpn_proto) 379 elif 'bleaddr' in p: 380 p = BLEPeer(name 381 ,p['bleaddr'] 382 ,p['device_pk'] 383 ,p['client_sk'] 384 ,timeout=p.get('timeout')) 385 elif 'usb_serial' in p: 386 p = USBPeer(name 387 ,p['usb_serial'] 388 ,p['device_pk'] 389 ,p['client_sk'] 390 ,timeout=p.get('timeout')) 391 else: 392 raise ValueError(f"cannot decide type of peer: {name}") 393 self.peers.append(p) 394 395 def __getitem__(self, idx): 396 return self.peers[idx] 397 398 def __iter__(self): 399 for p in self.peers: 400 yield p 401 402 def __len__(self): 403 return len(self.peers) 404 405 def __enter__(self): 406 return self 407 408 def __exit__(self, exception_type, exception_value, exception_traceback): 409 if exception_type is not None: 410 print("exception caught", exception_type, exception_value, exception_traceback) 411 self.close() 412 413 def connect(self): 414 for p in self.peers: 415 p.connect() 416 417 def send(self, idx, msg): 418 self.peers[idx].send(msg) 419 420 def broadcast(self, msg): 421 for p in self.peers: 422 p.send(msg) 423 424 async def gather_async(self, expected_msg_len, n=None, proc=None): 425 results = await asyncio.gather( 426 *[peer.read_async(expected_msg_len) for peer in self.peers], return_exceptions=True 427 ) 428 for i in range(len(results)): 429 if isinstance(results[i], Exception): 430 print(f"client {self.peers[i].name} returned exception: {results[i]}", file=sys.stderr) 431 results[i]=None 432 continue 433 if results[i] == b'\x00\x04fail': 434 results[i]=None 435 continue 436 tmp = results[i] if not proc else proc(results[i]) 437 if tmp is None: continue 438 results[i]=tmp 439 440 if n is None: 441 n=len(self.peers) 442 if len([1 for e in results if e is not None]) < n: 443 raise ValueError(f"not enough responses gathered: {results}") 444 return results 445 446 def gather(self, *args, **kwargs): 447 return get_event_loop().run_until_complete(self.gather_async(*args, **kwargs)) 448 449 def close(self): 450 for p in self.peers: 451 p.close()