/ tests / channel.py
channel.py
  1  from __future__ import annotations
  2  import threading
  3  import RNS
  4  from RNS.Channel import MessageState, ChannelOutletBase, Channel, MessageBase
  5  import RNS.Buffer
  6  from RNS.vendor import umsgpack
  7  from typing import Callable
  8  import contextlib
  9  import typing
 10  import types
 11  import time
 12  import uuid
 13  import unittest
 14  
 15  
 16  class Packet:
 17      timeout = 1.0
 18  
 19      def __init__(self, raw: bytes):
 20          self.state = MessageState.MSGSTATE_NEW
 21          self.raw = raw
 22          self.packet_id = uuid.uuid4()
 23          self.tries = 0
 24          self.timeout_id = None
 25          self.lock = threading.RLock()
 26          self.instances = 0
 27          self.timeout_callback: Callable[[Packet], None] | None = None
 28          self.delivered_callback: Callable[[Packet], None] | None = None
 29  
 30      def set_timeout(self, callback: Callable[[Packet], None] | None, timeout: float):
 31          with self.lock:
 32              if timeout is not None:
 33                  self.timeout = timeout
 34              self.timeout_callback = callback
 35  
 36  
 37      def send(self):
 38          self.tries += 1
 39          self.state = MessageState.MSGSTATE_SENT
 40  
 41          def elapsed(timeout: float, timeout_id: uuid.uuid4):
 42              with self.lock:
 43                  self.instances += 1
 44              try:
 45                  time.sleep(timeout)
 46                  with self.lock:
 47                      if self.timeout_id == timeout_id:
 48                          self.timeout_id = None
 49                          self.state = MessageState.MSGSTATE_FAILED
 50                          if self.timeout_callback:
 51                              self.timeout_callback(self)
 52              finally:
 53                  with self.lock:
 54                      self.instances -= 1
 55  
 56          self.timeout_id = uuid.uuid4()
 57          threading.Thread(target=elapsed, name="Packet Timeout", args=[self.timeout, self.timeout_id],
 58                           daemon=True).start()
 59  
 60      def clear_timeout(self):
 61          self.timeout_id = None
 62  
 63      def set_delivered_callback(self, callback: Callable[[Packet], None]):
 64          self.delivered_callback = callback
 65          
 66      def delivered(self):
 67          with self.lock:
 68              self.state = MessageState.MSGSTATE_DELIVERED
 69              self.timeout_id = None
 70          if self.delivered_callback:
 71              self.delivered_callback(self)
 72  
 73  
 74  class ChannelOutletTest(ChannelOutletBase):
 75      def get_packet_state(self, packet: Packet) -> MessageState:
 76          return packet.state
 77  
 78      def set_packet_timeout_callback(self, packet: Packet, callback: Callable[[Packet], None] | None,
 79                                      timeout: float | None = None):
 80          packet.set_timeout(callback, timeout)
 81  
 82      def set_packet_delivered_callback(self, packet: Packet, callback: Callable[[Packet], None] | None):
 83          packet.set_delivered_callback(callback)
 84  
 85      def get_packet_id(self, packet: Packet) -> any:
 86          return packet.packet_id
 87  
 88      def __init__(self, mdu: int, rtt: float):
 89          self.link_id = uuid.uuid4()
 90          self.timeout_callbacks = 0
 91          self._mdu = mdu
 92          self._rtt = rtt
 93          self._usable = True
 94          self.packets = []
 95          self.lock = threading.RLock()
 96          self.packet_callback: Callable[[ChannelOutletBase, bytes], None] | None = None
 97  
 98      def send(self, raw: bytes) -> Packet:
 99          with self.lock:
100              packet = Packet(raw)
101              packet.send()
102              self.packets.append(packet)
103              return packet
104  
105      def resend(self, packet: Packet) -> Packet:
106          with self.lock:
107              packet.send()
108              return packet
109  
110      @property
111      def mdu(self):
112          return self._mdu
113  
114      @property
115      def rtt(self):
116          return self._rtt
117  
118      @property
119      def is_usable(self):
120          return self._usable
121  
122      def timed_out(self):
123          self.timeout_callbacks += 1
124  
125      def __str__(self):
126          return str(self.link_id)
127  
128  
129  class MessageTest(MessageBase):
130      MSGTYPE = 0xabcd
131  
132      def __init__(self):
133          self.id = str(uuid.uuid4())
134          self.data = "test"
135          self.not_serialized = str(uuid.uuid4())
136  
137      def pack(self) -> bytes:
138          return umsgpack.packb((self.id, self.data))
139  
140      def unpack(self, raw):
141          self.id, self.data = umsgpack.unpackb(raw)
142  
143  
144  class SystemMessage(MessageBase):
145      MSGTYPE = 0xf000
146  
147      def pack(self) -> bytes:
148          return bytes()
149  
150      def unpack(self, raw):
151          pass
152  
153  
154  class ProtocolHarness(contextlib.AbstractContextManager):
155      def __init__(self, rtt: float):
156          self.outlet = ChannelOutletTest(mdu=500, rtt=rtt)
157          self.channel = Channel(self.outlet)
158          Packet.timeout = self.channel._get_packet_timeout_time(1)
159  
160      def cleanup(self):
161          self.channel._shutdown()
162  
163      def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
164                   __traceback: types.TracebackType) -> bool:
165          # self._log.debug(f"__exit__({__exc_type}, {__exc_value}, {__traceback})")
166          self.cleanup()
167          return False
168  
169  
170  class TestChannel(unittest.TestCase):
171      def setUp(self) -> None:
172          print("")
173          self.rtt = 0.01
174          self.h = ProtocolHarness(self.rtt)
175  
176      def tearDown(self) -> None:
177          self.h.cleanup()
178  
179      def test_send_one_retry(self):
180          print("Channel test one retry")
181          message = MessageTest()
182  
183          self.assertEqual(0, len(self.h.outlet.packets))
184  
185          envelope = self.h.channel.send(message)
186  
187          self.assertIsNotNone(envelope)
188          self.assertIsNotNone(envelope.raw)
189          self.assertEqual(1, len(self.h.outlet.packets))
190          self.assertIsNotNone(envelope.packet)
191          self.assertTrue(envelope in self.h.channel._tx_ring)
192          self.assertTrue(envelope.tracked)
193  
194          packet = self.h.outlet.packets[0]
195  
196          self.assertEqual(envelope.packet, packet)
197          self.assertEqual(1, envelope.tries)
198          self.assertEqual(1, packet.tries)
199          self.assertEqual(1, packet.instances)
200          self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
201          self.assertEqual(envelope.raw, packet.raw)
202  
203          time.sleep(self.h.channel._get_packet_timeout_time(1) * 1.1)
204  
205          self.assertEqual(1, len(self.h.outlet.packets))
206          self.assertEqual(2, envelope.tries)
207          self.assertEqual(2, packet.tries)
208          self.assertEqual(1, packet.instances)
209  
210          time.sleep(self.h.channel._get_packet_timeout_time(2) * 1.1)
211  
212          self.assertEqual(1, len(self.h.outlet.packets))
213          self.assertEqual(self.h.outlet.packets[0], packet)
214          self.assertEqual(3, envelope.tries)
215          self.assertEqual(3, packet.tries)
216          self.assertEqual(1, packet.instances)
217          self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
218  
219          packet.delivered()
220  
221          self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state)
222  
223          time.sleep(self.h.channel._get_packet_timeout_time(3) * 1.1)
224  
225          self.assertEqual(1, len(self.h.outlet.packets))
226          self.assertEqual(3, envelope.tries)
227          self.assertEqual(3, packet.tries)
228          self.assertEqual(0, packet.instances)
229          self.assertFalse(envelope.tracked)
230  
231      def test_send_timeout(self):
232          print("Channel test retry count exceeded")
233          message = MessageTest()
234  
235          self.assertEqual(0, len(self.h.outlet.packets))
236  
237          envelope = self.h.channel.send(message)
238  
239          self.assertIsNotNone(envelope)
240          self.assertIsNotNone(envelope.raw)
241          self.assertEqual(1, len(self.h.outlet.packets))
242          self.assertIsNotNone(envelope.packet)
243          self.assertTrue(envelope in self.h.channel._tx_ring)
244          self.assertTrue(envelope.tracked)
245  
246          packet = self.h.outlet.packets[0]
247  
248          self.assertEqual(envelope.packet, packet)
249          self.assertEqual(1, envelope.tries)
250          self.assertEqual(1, packet.tries)
251          self.assertEqual(1, packet.instances)
252          self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
253          self.assertEqual(envelope.raw, packet.raw)
254  
255          time.sleep(self.h.channel._get_packet_timeout_time(1))
256          time.sleep(self.h.channel._get_packet_timeout_time(2))
257          time.sleep(self.h.channel._get_packet_timeout_time(3))
258          time.sleep(self.h.channel._get_packet_timeout_time(4))
259          time.sleep(self.h.channel._get_packet_timeout_time(5) * 1.1)
260  
261          self.assertEqual(1, len(self.h.outlet.packets))
262          self.assertEqual(5, envelope.tries)
263          self.assertEqual(5, packet.tries)
264          self.assertEqual(0, packet.instances)
265          self.assertEqual(MessageState.MSGSTATE_FAILED, packet.state)
266          self.assertFalse(envelope.tracked)
267  
268      def test_multiple_handler(self):
269          print("Channel test multiple handler short circuit")
270  
271          handler1_called = 0
272          handler1_return = True
273          handler2_called = 0
274  
275          def handler1(msg: MessageBase):
276              nonlocal handler1_called, handler1_return
277              self.assertIsInstance(msg, MessageTest)
278              handler1_called += 1
279              return handler1_return
280  
281          def handler2(msg: MessageBase):
282              nonlocal handler2_called
283              self.assertIsInstance(msg, MessageTest)
284              handler2_called += 1
285  
286          message = MessageTest()
287          self.h.channel.register_message_type(MessageTest)
288          self.h.channel.add_message_handler(handler1)
289          self.h.channel.add_message_handler(handler2)
290          envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=0)
291          raw = envelope.pack()
292          self.h.channel._receive(raw)
293  
294          time.sleep(0.5)
295  
296          self.assertEqual(1, handler1_called)
297          self.assertEqual(0, handler2_called)
298  
299          handler1_return = False
300          envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=1)
301          raw = envelope.pack()
302          self.h.channel._receive(raw)
303  
304          time.sleep(0.5)
305  
306          self.assertEqual(2, handler1_called)
307          self.assertEqual(1, handler2_called)
308  
309      def test_system_message_check(self):
310          print("Channel test register system message")
311          with self.assertRaises(RNS.Channel.ChannelException):
312              self.h.channel.register_message_type(SystemMessage)
313          self.h.channel._register_message_type(SystemMessage, is_system_type=True)
314  
315  
316      def eat_own_dog_food(self, message: MessageBase, checker: typing.Callable[[MessageBase], None]):
317          decoded: [MessageBase] = []
318  
319          def handle_message(message: MessageBase):
320              decoded.append(message)
321  
322          self.h.channel.register_message_type(message.__class__)
323          self.h.channel.add_message_handler(handle_message)
324          self.assertEqual(len(self.h.outlet.packets), 0)
325  
326          envelope = self.h.channel.send(message)
327          time.sleep(self.h.channel._get_packet_timeout_time(1) * 0.5)
328  
329          self.assertIsNotNone(envelope)
330          self.assertIsNotNone(envelope.raw)
331          self.assertEqual(1, len(self.h.outlet.packets))
332          self.assertIsNotNone(envelope.packet)
333          self.assertTrue(envelope in self.h.channel._tx_ring)
334          self.assertTrue(envelope.tracked)
335  
336          packet = self.h.outlet.packets[0]
337  
338          self.assertEqual(envelope.packet, packet)
339          self.assertEqual(1, envelope.tries)
340          self.assertEqual(1, packet.tries)
341          self.assertEqual(1, packet.instances)
342          self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
343          self.assertEqual(envelope.raw, packet.raw)
344  
345          packet.delivered()
346  
347          self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state)
348  
349          time.sleep(self.h.channel._get_packet_timeout_time(1))
350  
351          self.assertEqual(1, len(self.h.outlet.packets))
352          self.assertEqual(1, envelope.tries)
353          self.assertEqual(1, packet.tries)
354          self.assertEqual(0, packet.instances)
355          self.assertFalse(envelope.tracked)
356  
357          self.assertEqual(len(self.h.outlet.packets), 1)
358          self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state)
359          self.assertFalse(envelope.tracked)
360          self.assertEqual(0, len(decoded))
361  
362          self.h.channel._receive(packet.raw)
363  
364          time.sleep(0.5)
365  
366          self.assertEqual(1, len(decoded))
367  
368          rx_message = decoded[0]
369  
370          self.assertIsNotNone(rx_message)
371          self.assertIsInstance(rx_message, message.__class__)
372          checker(rx_message)
373  
374      def test_send_receive_message_test(self):
375          print("Channel test send and receive message")
376          message = MessageTest()
377  
378          def check(rx_message: MessageBase):
379              self.assertIsInstance(rx_message, message.__class__)
380              self.assertEqual(message.id, rx_message.id)
381              self.assertEqual(message.data, rx_message.data)
382              self.assertNotEqual(message.not_serialized, rx_message.not_serialized)
383  
384          self.eat_own_dog_food(message, check)
385  
386      def test_buffer_small_bidirectional(self):
387          data = "Hello\n"
388          with RNS.Buffer.create_bidirectional_buffer(0, 0, self.h.channel) as buffer:
389              count = buffer.write(data.encode("utf-8"))
390              buffer.flush()
391  
392              self.assertEqual(len(data), count)
393              self.assertEqual(1, len(self.h.outlet.packets))
394  
395              packet = self.h.outlet.packets[0]
396              self.h.channel._receive(packet.raw)
397              time.sleep(0.2)
398              result = buffer.readline()
399  
400              self.assertIsNotNone(result)
401              self.assertEqual(len(result), len(data))
402  
403              decoded = result.decode("utf-8")
404  
405              self.assertEqual(data, decoded)
406  
407      def test_buffer_big(self):
408          writer = RNS.Buffer.create_writer(15, self.h.channel)
409          reader = RNS.Buffer.create_reader(15, self.h.channel)
410          data = "01234556789"*1024*5  # 50 KB
411          count = 0
412          write_finished = False
413  
414          def write_thread():
415              nonlocal count, write_finished
416              count = writer.write(data.encode("utf-8"))
417              writer.flush()
418              writer.close() # TODO: Workaround for https://github.com/python/cpython/issues/138720
419              write_finished = True
420          threading.Thread(target=write_thread, name="Write Thread", daemon=True).start()
421  
422          while not write_finished or next(filter(lambda x: x.state != MessageState.MSGSTATE_DELIVERED,
423                                                  self.h.outlet.packets), None) is not None:
424              with self.h.outlet.lock:
425                  for packet in self.h.outlet.packets:
426                      if packet.state != MessageState.MSGSTATE_DELIVERED:
427                          self.h.channel._receive(packet.raw)
428                          packet.delivered()
429              time.sleep(0.0001)
430  
431          self.assertEqual(len(data), count)
432  
433          read_finished = False
434          result = bytes()
435  
436          def read_thread():
437              nonlocal read_finished, result
438              result = reader.read()
439              read_finished = True
440          threading.Thread(target=read_thread, name="Read Thread", daemon=True).start()
441  
442          timeout_at = time.time() + 7
443          while not read_finished and time.time() < timeout_at:
444              time.sleep(0.001)
445  
446          self.assertTrue(read_finished)
447          self.assertEqual(len(data), len(result))
448  
449          decoded = result.decode("utf-8")
450  
451          self.assertSequenceEqual(data, decoded)
452  
453      def test_buffer_small_with_callback(self):
454          callbacks = 0
455          last_cb_value = None
456  
457          def callback(ready: int):
458              nonlocal callbacks, last_cb_value
459              callbacks += 1
460              last_cb_value = ready
461  
462          data = "Hello\n"
463          with RNS.RawChannelWriter(0, self.h.channel) as writer, RNS.RawChannelReader(0, self.h.channel) as reader:
464              reader.add_ready_callback(callback)
465              count = writer.write(data.encode("utf-8"))
466              writer.flush()
467  
468              self.assertEqual(len(data), count)
469              self.assertEqual(1, len(self.h.outlet.packets))
470  
471              packet = self.h.outlet.packets[0]
472              self.h.channel._receive(packet.raw)
473              packet.delivered()
474  
475              self.assertEqual(1, callbacks)
476              self.assertEqual(len(data), last_cb_value)
477  
478              result = reader.readline()
479  
480              self.assertIsNotNone(result)
481              self.assertEqual(len(result), len(data))
482  
483              decoded = result.decode("utf-8")
484  
485              self.assertEqual(data, decoded)
486              self.assertEqual(1, len(self.h.outlet.packets))
487  
488              result = reader.read(1)
489  
490              self.assertIsNone(result)
491              self.assertTrue(self.h.channel.is_ready_to_send())
492  
493              writer.close()
494  
495              self.assertEqual(2, len(self.h.outlet.packets))
496  
497              packet = self.h.outlet.packets[1]
498              self.h.channel._receive(packet.raw)
499              packet.delivered()
500  
501              result = reader.read(1)
502  
503              self.assertIsNotNone(result)
504              self.assertTrue(len(result) == 0)
505  
506  
507  if __name__ == '__main__':
508      unittest.main(verbosity=2)