/ python / pyoprf / multiplexer.py
multiplexer.py
  1  #!/usr/bin/env python
  2  
  3  import ssl, socket, select, struct, asyncio, serial, sys, time
  4  from pyoprf import noisexk
  5  from itertools import zip_longest
  6  from serial_asyncio import create_serial_connection
  7  try:
  8      from ble_serial.bluetooth.ble_client import BLE_client
  9  except ImportError:
 10      BLE_client = None
 11  try:
 12      import pyudev
 13  except ImportError:
 14      pyudev = None
 15  
 16  def split_by_n(iterable, n):
 17      return list(zip_longest(*[iter(iterable)]*n))
 18  
 19  def get_event_loop():
 20      try:
 21          return asyncio.get_running_loop()
 22      except RuntimeError:
 23          loop = asyncio.new_event_loop()
 24          asyncio.set_event_loop(loop)
 25          return loop
 26  
 27  class Peer:
 28      def __init__(self, name, addr, type = "SSL", ssl_cert=None, timeout=5, alpn_proto=None):
 29          self.name = name
 30          self.type = type    # currently only TCP or SSL over TCP, but
 31                              # could be others like dedicated NOISE_XK,
 32                              # or hybrid mceliece+x25519 over USB or
 33                              # even UART
 34          self.address = addr # Currently only TCP host:port as a tuple
 35          self.ssl_cert = ssl_cert
 36          self.timeout = timeout
 37          self.alpn_proto = alpn_proto or ["oprf/1"]
 38          self.state = "new"
 39          self.fd = None
 40  
 41      def connect(self):
 42          if self.state == "connected":
 43              raise ValueError(f"{self.name} is already connected")
 44  
 45          if self.type not in {"SSL", "TCP"}:
 46              raise ValueError(f"Unsupported peer type: {self.type}")
 47  
 48          if self.type == "SSL":
 49             ctx = ssl.create_default_context()
 50             ctx.minimum_version = ssl.TLSVersion.TLSv1_2
 51             ctx.set_alpn_protocols(self.alpn_proto)
 52             if(self.ssl_cert):
 53                 ctx.load_verify_locations(self.ssl_cert) # only for dev, production system should use proper certs!
 54                 ctx.check_hostname=False                 # only for dev, production system should use proper certs!
 55                 ctx.verify_mode=ssl.CERT_NONE            # only for dev, production system should use proper certs!
 56             else:
 57                 ctx.load_default_certs()
 58                 ctx.verify_mode = ssl.CERT_REQUIRED
 59                 ctx.check_hostname = True
 60  
 61          s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 62          s.settimeout(self.timeout)
 63          if self.type == "SSL":
 64              self.fd = ctx.wrap_socket(s, server_hostname=self.address[0])
 65          try: self.fd.connect(self.address)
 66          except: return
 67          self.state="connected"
 68  
 69      def connected(self):
 70          return self.state == "connected"
 71  
 72      async def read_async(self,size):
 73          if not self.connected():
 74              return None
 75              #raise ValueError(f"{self.name} cannot read, is not connected")
 76  
 77          res = []
 78          read = 0
 79          while read<size and (len(res)==0 or len(res[-1])!=0):
 80            res.append(self.fd.recv(size-read))
 81            read+=len(res[-1])
 82  
 83          if len(res[-1])==0 and read<size:
 84              self.state = 'disconnected'
 85              #raise ValueError(f"short read for {self.name}, only {len(b''.join(res))} instead of expected {size} bytes")
 86          return b''.join(res)
 87  
 88      def read(self, *args, **kwargs):
 89          return get_event_loop().run_until_complete(self.read_async(*args, **kwargs))
 90  
 91      def send(self, msg):
 92          if not self.connected():
 93              return
 94              #raise ValueError(f"{self.name} cannot write, is not connected")
 95          self.fd.sendall(msg)
 96  
 97      def close(self):
 98          if self.state == "closed": return
 99          if not self.connected():
100              return
101              #raise ValueError(f"{self.name} cannot close, is not connected")
102          if self.fd and self.fd.fileno() != -1:
103              try: self.fd.shutdown(socket.SHUT_RDWR)
104              except OSError: pass
105              self.fd.close()
106              self.state = "closed"
107          else:
108              # closed by other end.
109              self.state = "closed"
110  
111  class BLEPeer:
112      def __init__(self, name, addr, server_pk, client_sk, device="hci0", timeout=5):
113          self.name = name
114          self.address = addr # the MAC address of the device
115          self.server_pk = server_pk
116          self.client_sk = client_sk
117          self.timeout = timeout
118          self.state = "new"
119          self.ble = BLE_client(device, None)
120          self.rx_buffer = []
121          self.rx_len = 0
122          self.rx_pexted = 0
123          self.rx_available = asyncio.Event()
124  
125      def receive_callback(self, value: bytes):
126           #print("Received:", len(value), self.rx_pected, self.rx_len, value.hex(), file=sys.stderr)
127           #if self.rx_pected == 0:
128           #    raise Exception(f"unexpected input received: {value.hex()}")
129           #self.rx_queue.put(value)
130           self.rx_buffer.append(value)
131           self.rx_len += len(value)
132           #if(self.rx_pected < 0):
133           #    exit(23)
134           #    raise Exception("rx buffer overflow")
135           if self.rx_pected > 0 and self.rx_len >= self.rx_pected:
136               self.rx_available.set()
137  
138      async def read_raw(self,size):
139          while(self.rx_available.is_set()):
140              await asyncio.sleep(0.001)
141          self.rx_pected = size
142          if(self.rx_len < self.rx_pected):
143              #print(f"{self.rx_len} < {self.rx_pected}", file=sys.stderr)
144              await self.rx_available.wait()
145          rsize = 0;
146          ret = []
147          while(rsize<self.rx_pected):
148             ret.append(self.rx_buffer.pop(0))
149             rsize+=len(ret[-1])
150          ret = b''.join(ret)
151          self.rx_pected = 0
152          self.rx_len -= rsize
153          self.rx_available.clear()
154          return ret
155  
156      async def _send(self, msg, mtu=20):
157      #def _send(self, msg, mtu=20):
158          #print(f"_sending {msg.hex()}", file=sys.stderr)
159          for frag in split_by_n(msg, mtu):
160              frag=bytes(c for c in frag if c is not None)
161              #print("sending frag", repr(frag))
162              #self.ble.queue_send(frag)
163              await self.ble.dev.write_gatt_char(self.ble.write_char, frag, self.ble.write_response_required)
164              #print(f"sent: {frag.hex()}", file=sys.stderr)
165  
166      async def _connect(self):
167          if self.state == "connected":
168              raise ValueError(f"{self.name} is already connected")
169  
170          self.ble.set_receiver(self.receive_callback)
171  
172          await self.ble.connect(self.address, "public", None, 10.0)
173          await self.ble.setup_chars(None, None, "rw", False)
174  
175          self.session, msg = noisexk.initiator_session(self.client_sk, self.server_pk, dst=b"klutshnik ble tle")
176          await self._send(msg)
177          resp = await self.read_raw(48)
178          noisexk.initiator_session_complete(self.session, resp)
179          ct = noisexk.send_msg(self.session, "")
180          await self._send(ct)
181  
182          self.state="connected"
183  
184      async def _disconnect(self):
185          await self.ble.disconnect()
186          self.state == "disconnected"
187  
188      def connect(self):
189          #get_event_loop().set_debug(True)
190          get_event_loop().run_until_complete(self._connect())
191          while not self.connected(): time.sleep(0.001)
192  
193      def connected(self):
194          return self.state == "connected"
195  
196      def read(self,size):
197          if not self.connected():
198              return None
199              #raise ValueError(f"{self.name} cannot read, is not connected")
200          ct = get_event_loop().run_until_complete(self.read_raw(size+16))
201          return noisexk.read_msg(self.session, ct)
202  
203      async def read_async(self, size):
204          if not self.connected():
205              return None
206          resp = await self.read_raw(size+16)
207          return noisexk.read_msg(self.session, resp)
208  
209      def send(self, msg):
210          #print("sending msg", msg.hex(), file=sys.stderr)
211          if not self.connected():
212              return
213              #raise ValueError(f"{self.name} cannot write, is not connected")
214          ct = noisexk.send_msg(self.session, msg)
215          header = struct.pack(">H",len(ct))
216          get_event_loop().run_until_complete(self._send(header+ct))
217  
218      def close(self):
219          if self.state == "closed": return
220          if not self.connected():
221              return
222              #raise ValueError(f"{self.name} cannot close, is not connected")
223          get_event_loop().run_until_complete(self._disconnect())
224  
225  class Serial(asyncio.Protocol):
226      def __init__(self, *args, **kwargs):
227          self.rx_buffer = []
228          self.rx_len = 0
229          self.rx_pexted = 0
230          self.rx_available = asyncio.Event()
231          super().__init__(*args, **kwargs)
232  
233      def connection_made(self, transport):
234          #print('port opened', transport, file=sys.stderr)
235          self.transport = transport
236          transport.serial.dtr = True
237          #transport.serial.rts = False
238          #transport.write(b'hello world\n')
239  
240      def data_received(self, data):
241          #print('data received', len(data), data.hex(), file=sys.stderr)
242          #print('data received', len(data), file=sys.stderr)
243          self.rx_buffer.append(data)
244          self.rx_len += len(data)
245          if self.rx_pected > 0 and self.rx_len >= self.rx_pected:
246              self.rx_available.set()
247  
248      def connection_lost(self, exc):
249          print('port closed', file=sys.stderr)
250          self.rx_available.set()
251          #get_event_loop().stop()
252  
253      async def read_raw(self,size):
254          #print(f"read_raw({size})",file=sys.stderr)
255          #while(self.rx_available.is_set()): pass
256          self.rx_pected = size
257          while(self.rx_len < self.rx_pected
258                and not self.rx_available.is_set()):
259              await asyncio.sleep(0.001)
260          #if(self.rx_len < self.rx_pected):
261          #while(self.rx_available.is_set()): pass
262              #print(f"{self.rx_len} < {self.rx_pected}", file=sys.stderr)
263          #    await self.rx_available.wait()
264          rsize = 0;
265          ret = []
266          while(rsize<self.rx_pected):
267             if self.rx_buffer == []: break
268             ret.append(self.rx_buffer.pop(0))
269             rsize+=len(ret[-1])
270          ret = b''.join(ret)
271          #if(size<len(ret)):
272              # can happen when "OK" is expected but b'\x00\x04fail" is sent 18 vs 22
273              #print(f"XXXX expected size {size} < read size: {len(ret)}", file=sys.stderr)
274          self.rx_pected = 0
275          self.rx_len -= rsize
276          self.rx_available.clear()
277          return ret
278  
279  class USBPeer:
280      def __init__(self, name, serno, server_pk, client_sk, timeout=5):
281          self.name = name
282          self.serno = serno # the serial number of the usb device
283          self.server_pk = server_pk
284          self.client_sk = client_sk
285          self.timeout = timeout
286          self.state = "new"
287  
288      def __getattr__(self,name):
289          if name=="address":
290              return f"usb-cdc device #{self.serno} at {self.port}"
291  
292      def find_usb_port(self):
293         context = pyudev.Context()
294         idx=0
295         for device in context.list_devices(subsystem='tty'):
296            if device.get('ID_SERIAL_SHORT') == self.serno:
297               if idx==1:
298                   return device.device_node
299               idx+=1
300         return None
301  
302      async def _connect(self):
303          if self.state == "connected":
304              raise ValueError(f"{self.name} is already connected")
305  
306          self.session, msg = noisexk.initiator_session(self.client_sk, self.server_pk, dst=b"klutshnik ble tle")
307          self.transport.serial.write(msg)
308          #print(f"sent {len(msg)}B as {msg.hex()}",file=sys.stderr)
309          #print('waiting for noise hs2 response', file=sys.stderr)
310          resp = await self.protocol.read_raw(48)
311          #print(f"received {len(resp)}B as {resp.hex()}",file=sys.stderr)
312          noisexk.initiator_session_complete(self.session, resp)
313          ct = noisexk.send_msg(self.session, "")
314          self.transport.serial.write(ct)
315          #print(f"sent {len(ct)}B as {ct.hex()}",file=sys.stderr)
316  
317          self.state="connected"
318          #print(f"_connected to {self.path}", file=sys.stderr)
319  
320      def connect(self):
321          self.path = self.find_usb_port()
322          #print(f"connecting to {self.path}",file=sys.stderr)
323          loop = get_event_loop()
324          #loop.set_debug(True)
325          coro = create_serial_connection(loop, Serial, self.path, baudrate=115200)
326          self.transport, self.protocol = loop.run_until_complete(coro)
327          loop.run_until_complete(self._connect())
328          while not self.connected(): time.sleep(0.001)
329          #print(f"connected to {self.path}", file=sys.stderr)
330  
331      def connected(self):
332          return self.state == "connected"
333  
334      async def read_async(self, size):
335          if not self.connected():
336              return None
337          ct = await self.protocol.read_raw(size+16)
338          if len(ct)==0 or len(ct)<size+16:
339              self.state = 'disconnected'
340              raise ValueError(f"short read for {self.name}, only {len(b''.join(ct))} instead of expected {size} bytes")
341          #print(f"read_async({size}) .. ok",file=sys.stderr)
342          return noisexk.read_msg(self.session, ct)
343  
344      def read(self, *args, **kwargs):
345          return get_event_loop().run_until_complete(self.read_async(*args, **kwargs))
346  
347      def send(self, msg):
348          if not self.connected():
349              return
350              #raise ValueError(f"{self.name} cannot write, is not connected")
351          ct = noisexk.send_msg(self.session, msg)
352          header = struct.pack(">H",len(ct))
353          self.transport.serial.write(header+ct)
354  
355      def close(self):
356          if self.state == "closed": return
357          if not self.connected():
358              return
359              #raise ValueError(f"{self.name} cannot close, is not connected")
360          if self.transport.serial is not None:
361              self.transport.serial.dtr = False
362          self.transport.close()
363          self.state = "closed"
364  
365  class Multiplexer:
366      def __init__(self, peers, alpn_proto=None):
367          if asyncio.get_event_loop_policy()._local._loop is None:
368              loop = asyncio.new_event_loop()
369              asyncio.set_event_loop(loop)
370          self.peers = []
371          for name, p in peers.items():
372              if 'port' in p:
373                  p = Peer(name
374                           ,(p['host'],p['port'])
375                           ,type=p.get("type", "SSL")
376                           ,ssl_cert = p.get('ssl_cert')
377                           ,timeout = p.get('timeout')
378                           ,alpn_proto=alpn_proto)
379              elif 'bleaddr' in p:
380                  p = BLEPeer(name
381                              ,p['bleaddr']
382                              ,p['device_pk']
383                              ,p['client_sk']
384                              ,timeout=p.get('timeout'))
385              elif 'usb_serial' in p:
386                  p = USBPeer(name
387                              ,p['usb_serial']
388                              ,p['device_pk']
389                              ,p['client_sk']
390                              ,timeout=p.get('timeout'))
391              else:
392                  raise ValueError(f"cannot decide type of peer: {name}")
393              self.peers.append(p)
394  
395      def __getitem__(self, idx):
396          return self.peers[idx]
397  
398      def __iter__(self):
399          for p in self.peers:
400              yield p
401  
402      def __len__(self):
403          return len(self.peers)
404  
405      def __enter__(self):
406          return self
407  
408      def __exit__(self, exception_type, exception_value, exception_traceback):
409          if exception_type is not None:
410              print("exception caught", exception_type, exception_value, exception_traceback)
411          self.close()
412  
413      def connect(self):
414         for p in self.peers:
415             p.connect()
416  
417      def send(self, idx, msg):
418          self.peers[idx].send(msg)
419  
420      def broadcast(self, msg):
421        for p in self.peers:
422          p.send(msg)
423  
424      async def gather_async(self, expected_msg_len, n=None, proc=None):
425          results = await asyncio.gather(
426              *[peer.read_async(expected_msg_len) for peer in self.peers], return_exceptions=True
427          )
428          for i in range(len(results)):
429              if isinstance(results[i], Exception):
430                  print(f"client {self.peers[i].name} returned exception: {results[i]}", file=sys.stderr)
431                  results[i]=None
432                  continue
433              if results[i] == b'\x00\x04fail':
434                  results[i]=None
435                  continue
436              tmp = results[i] if not proc else proc(results[i])
437              if tmp is None: continue
438              results[i]=tmp
439  
440          if n is None:
441              n=len(self.peers)
442          if len([1 for e in results if e is not None]) < n:
443              raise ValueError(f"not enough responses gathered: {results}")
444          return results
445  
446      def gather(self, *args, **kwargs):
447          return get_event_loop().run_until_complete(self.gather_async(*args, **kwargs))
448  
449      def close(self):
450        for p in self.peers:
451          p.close()