channel.py
1 from __future__ import annotations 2 import threading 3 import RNS 4 from RNS.Channel import MessageState, ChannelOutletBase, Channel, MessageBase 5 import RNS.Buffer 6 from RNS.vendor import umsgpack 7 from typing import Callable 8 import contextlib 9 import typing 10 import types 11 import time 12 import uuid 13 import unittest 14 15 16 class Packet: 17 timeout = 1.0 18 19 def __init__(self, raw: bytes): 20 self.state = MessageState.MSGSTATE_NEW 21 self.raw = raw 22 self.packet_id = uuid.uuid4() 23 self.tries = 0 24 self.timeout_id = None 25 self.lock = threading.RLock() 26 self.instances = 0 27 self.timeout_callback: Callable[[Packet], None] | None = None 28 self.delivered_callback: Callable[[Packet], None] | None = None 29 30 def set_timeout(self, callback: Callable[[Packet], None] | None, timeout: float): 31 with self.lock: 32 if timeout is not None: 33 self.timeout = timeout 34 self.timeout_callback = callback 35 36 37 def send(self): 38 self.tries += 1 39 self.state = MessageState.MSGSTATE_SENT 40 41 def elapsed(timeout: float, timeout_id: uuid.uuid4): 42 with self.lock: 43 self.instances += 1 44 try: 45 time.sleep(timeout) 46 with self.lock: 47 if self.timeout_id == timeout_id: 48 self.timeout_id = None 49 self.state = MessageState.MSGSTATE_FAILED 50 if self.timeout_callback: 51 self.timeout_callback(self) 52 finally: 53 with self.lock: 54 self.instances -= 1 55 56 self.timeout_id = uuid.uuid4() 57 threading.Thread(target=elapsed, name="Packet Timeout", args=[self.timeout, self.timeout_id], 58 daemon=True).start() 59 60 def clear_timeout(self): 61 self.timeout_id = None 62 63 def set_delivered_callback(self, callback: Callable[[Packet], None]): 64 self.delivered_callback = callback 65 66 def delivered(self): 67 with self.lock: 68 self.state = MessageState.MSGSTATE_DELIVERED 69 self.timeout_id = None 70 if self.delivered_callback: 71 self.delivered_callback(self) 72 73 74 class ChannelOutletTest(ChannelOutletBase): 75 def get_packet_state(self, packet: Packet) -> MessageState: 76 return packet.state 77 78 def set_packet_timeout_callback(self, packet: Packet, callback: Callable[[Packet], None] | None, 79 timeout: float | None = None): 80 packet.set_timeout(callback, timeout) 81 82 def set_packet_delivered_callback(self, packet: Packet, callback: Callable[[Packet], None] | None): 83 packet.set_delivered_callback(callback) 84 85 def get_packet_id(self, packet: Packet) -> any: 86 return packet.packet_id 87 88 def __init__(self, mdu: int, rtt: float): 89 self.link_id = uuid.uuid4() 90 self.timeout_callbacks = 0 91 self._mdu = mdu 92 self._rtt = rtt 93 self._usable = True 94 self.packets = [] 95 self.lock = threading.RLock() 96 self.packet_callback: Callable[[ChannelOutletBase, bytes], None] | None = None 97 98 def send(self, raw: bytes) -> Packet: 99 with self.lock: 100 packet = Packet(raw) 101 packet.send() 102 self.packets.append(packet) 103 return packet 104 105 def resend(self, packet: Packet) -> Packet: 106 with self.lock: 107 packet.send() 108 return packet 109 110 @property 111 def mdu(self): 112 return self._mdu 113 114 @property 115 def rtt(self): 116 return self._rtt 117 118 @property 119 def is_usable(self): 120 return self._usable 121 122 def timed_out(self): 123 self.timeout_callbacks += 1 124 125 def __str__(self): 126 return str(self.link_id) 127 128 129 class MessageTest(MessageBase): 130 MSGTYPE = 0xabcd 131 132 def __init__(self): 133 self.id = str(uuid.uuid4()) 134 self.data = "test" 135 self.not_serialized = str(uuid.uuid4()) 136 137 def pack(self) -> bytes: 138 return umsgpack.packb((self.id, self.data)) 139 140 def unpack(self, raw): 141 self.id, self.data = umsgpack.unpackb(raw) 142 143 144 class SystemMessage(MessageBase): 145 MSGTYPE = 0xf000 146 147 def pack(self) -> bytes: 148 return bytes() 149 150 def unpack(self, raw): 151 pass 152 153 154 class ProtocolHarness(contextlib.AbstractContextManager): 155 def __init__(self, rtt: float): 156 self.outlet = ChannelOutletTest(mdu=500, rtt=rtt) 157 self.channel = Channel(self.outlet) 158 Packet.timeout = self.channel._get_packet_timeout_time(1) 159 160 def cleanup(self): 161 self.channel._shutdown() 162 163 def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, 164 __traceback: types.TracebackType) -> bool: 165 # self._log.debug(f"__exit__({__exc_type}, {__exc_value}, {__traceback})") 166 self.cleanup() 167 return False 168 169 170 class TestChannel(unittest.TestCase): 171 def setUp(self) -> None: 172 print("") 173 self.rtt = 0.01 174 self.h = ProtocolHarness(self.rtt) 175 176 def tearDown(self) -> None: 177 self.h.cleanup() 178 179 def test_send_one_retry(self): 180 print("Channel test one retry") 181 message = MessageTest() 182 183 self.assertEqual(0, len(self.h.outlet.packets)) 184 185 envelope = self.h.channel.send(message) 186 187 self.assertIsNotNone(envelope) 188 self.assertIsNotNone(envelope.raw) 189 self.assertEqual(1, len(self.h.outlet.packets)) 190 self.assertIsNotNone(envelope.packet) 191 self.assertTrue(envelope in self.h.channel._tx_ring) 192 self.assertTrue(envelope.tracked) 193 194 packet = self.h.outlet.packets[0] 195 196 self.assertEqual(envelope.packet, packet) 197 self.assertEqual(1, envelope.tries) 198 self.assertEqual(1, packet.tries) 199 self.assertEqual(1, packet.instances) 200 self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) 201 self.assertEqual(envelope.raw, packet.raw) 202 203 time.sleep(self.h.channel._get_packet_timeout_time(1) * 1.1) 204 205 self.assertEqual(1, len(self.h.outlet.packets)) 206 self.assertEqual(2, envelope.tries) 207 self.assertEqual(2, packet.tries) 208 self.assertEqual(1, packet.instances) 209 210 time.sleep(self.h.channel._get_packet_timeout_time(2) * 1.1) 211 212 self.assertEqual(1, len(self.h.outlet.packets)) 213 self.assertEqual(self.h.outlet.packets[0], packet) 214 self.assertEqual(3, envelope.tries) 215 self.assertEqual(3, packet.tries) 216 self.assertEqual(1, packet.instances) 217 self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) 218 219 packet.delivered() 220 221 self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) 222 223 time.sleep(self.h.channel._get_packet_timeout_time(3) * 1.1) 224 225 self.assertEqual(1, len(self.h.outlet.packets)) 226 self.assertEqual(3, envelope.tries) 227 self.assertEqual(3, packet.tries) 228 self.assertEqual(0, packet.instances) 229 self.assertFalse(envelope.tracked) 230 231 def test_send_timeout(self): 232 print("Channel test retry count exceeded") 233 message = MessageTest() 234 235 self.assertEqual(0, len(self.h.outlet.packets)) 236 237 envelope = self.h.channel.send(message) 238 239 self.assertIsNotNone(envelope) 240 self.assertIsNotNone(envelope.raw) 241 self.assertEqual(1, len(self.h.outlet.packets)) 242 self.assertIsNotNone(envelope.packet) 243 self.assertTrue(envelope in self.h.channel._tx_ring) 244 self.assertTrue(envelope.tracked) 245 246 packet = self.h.outlet.packets[0] 247 248 self.assertEqual(envelope.packet, packet) 249 self.assertEqual(1, envelope.tries) 250 self.assertEqual(1, packet.tries) 251 self.assertEqual(1, packet.instances) 252 self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) 253 self.assertEqual(envelope.raw, packet.raw) 254 255 time.sleep(self.h.channel._get_packet_timeout_time(1)) 256 time.sleep(self.h.channel._get_packet_timeout_time(2)) 257 time.sleep(self.h.channel._get_packet_timeout_time(3)) 258 time.sleep(self.h.channel._get_packet_timeout_time(4)) 259 time.sleep(self.h.channel._get_packet_timeout_time(5) * 1.1) 260 261 self.assertEqual(1, len(self.h.outlet.packets)) 262 self.assertEqual(5, envelope.tries) 263 self.assertEqual(5, packet.tries) 264 self.assertEqual(0, packet.instances) 265 self.assertEqual(MessageState.MSGSTATE_FAILED, packet.state) 266 self.assertFalse(envelope.tracked) 267 268 def test_multiple_handler(self): 269 print("Channel test multiple handler short circuit") 270 271 handler1_called = 0 272 handler1_return = True 273 handler2_called = 0 274 275 def handler1(msg: MessageBase): 276 nonlocal handler1_called, handler1_return 277 self.assertIsInstance(msg, MessageTest) 278 handler1_called += 1 279 return handler1_return 280 281 def handler2(msg: MessageBase): 282 nonlocal handler2_called 283 self.assertIsInstance(msg, MessageTest) 284 handler2_called += 1 285 286 message = MessageTest() 287 self.h.channel.register_message_type(MessageTest) 288 self.h.channel.add_message_handler(handler1) 289 self.h.channel.add_message_handler(handler2) 290 envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=0) 291 raw = envelope.pack() 292 self.h.channel._receive(raw) 293 294 time.sleep(0.5) 295 296 self.assertEqual(1, handler1_called) 297 self.assertEqual(0, handler2_called) 298 299 handler1_return = False 300 envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=1) 301 raw = envelope.pack() 302 self.h.channel._receive(raw) 303 304 time.sleep(0.5) 305 306 self.assertEqual(2, handler1_called) 307 self.assertEqual(1, handler2_called) 308 309 def test_system_message_check(self): 310 print("Channel test register system message") 311 with self.assertRaises(RNS.Channel.ChannelException): 312 self.h.channel.register_message_type(SystemMessage) 313 self.h.channel._register_message_type(SystemMessage, is_system_type=True) 314 315 316 def eat_own_dog_food(self, message: MessageBase, checker: typing.Callable[[MessageBase], None]): 317 decoded: [MessageBase] = [] 318 319 def handle_message(message: MessageBase): 320 decoded.append(message) 321 322 self.h.channel.register_message_type(message.__class__) 323 self.h.channel.add_message_handler(handle_message) 324 self.assertEqual(len(self.h.outlet.packets), 0) 325 326 envelope = self.h.channel.send(message) 327 time.sleep(self.h.channel._get_packet_timeout_time(1) * 0.5) 328 329 self.assertIsNotNone(envelope) 330 self.assertIsNotNone(envelope.raw) 331 self.assertEqual(1, len(self.h.outlet.packets)) 332 self.assertIsNotNone(envelope.packet) 333 self.assertTrue(envelope in self.h.channel._tx_ring) 334 self.assertTrue(envelope.tracked) 335 336 packet = self.h.outlet.packets[0] 337 338 self.assertEqual(envelope.packet, packet) 339 self.assertEqual(1, envelope.tries) 340 self.assertEqual(1, packet.tries) 341 self.assertEqual(1, packet.instances) 342 self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) 343 self.assertEqual(envelope.raw, packet.raw) 344 345 packet.delivered() 346 347 self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) 348 349 time.sleep(self.h.channel._get_packet_timeout_time(1)) 350 351 self.assertEqual(1, len(self.h.outlet.packets)) 352 self.assertEqual(1, envelope.tries) 353 self.assertEqual(1, packet.tries) 354 self.assertEqual(0, packet.instances) 355 self.assertFalse(envelope.tracked) 356 357 self.assertEqual(len(self.h.outlet.packets), 1) 358 self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) 359 self.assertFalse(envelope.tracked) 360 self.assertEqual(0, len(decoded)) 361 362 self.h.channel._receive(packet.raw) 363 364 time.sleep(0.5) 365 366 self.assertEqual(1, len(decoded)) 367 368 rx_message = decoded[0] 369 370 self.assertIsNotNone(rx_message) 371 self.assertIsInstance(rx_message, message.__class__) 372 checker(rx_message) 373 374 def test_send_receive_message_test(self): 375 print("Channel test send and receive message") 376 message = MessageTest() 377 378 def check(rx_message: MessageBase): 379 self.assertIsInstance(rx_message, message.__class__) 380 self.assertEqual(message.id, rx_message.id) 381 self.assertEqual(message.data, rx_message.data) 382 self.assertNotEqual(message.not_serialized, rx_message.not_serialized) 383 384 self.eat_own_dog_food(message, check) 385 386 def test_buffer_small_bidirectional(self): 387 data = "Hello\n" 388 with RNS.Buffer.create_bidirectional_buffer(0, 0, self.h.channel) as buffer: 389 count = buffer.write(data.encode("utf-8")) 390 buffer.flush() 391 392 self.assertEqual(len(data), count) 393 self.assertEqual(1, len(self.h.outlet.packets)) 394 395 packet = self.h.outlet.packets[0] 396 self.h.channel._receive(packet.raw) 397 time.sleep(0.2) 398 result = buffer.readline() 399 400 self.assertIsNotNone(result) 401 self.assertEqual(len(result), len(data)) 402 403 decoded = result.decode("utf-8") 404 405 self.assertEqual(data, decoded) 406 407 def test_buffer_big(self): 408 writer = RNS.Buffer.create_writer(15, self.h.channel) 409 reader = RNS.Buffer.create_reader(15, self.h.channel) 410 data = "01234556789"*1024*5 # 50 KB 411 count = 0 412 write_finished = False 413 414 def write_thread(): 415 nonlocal count, write_finished 416 count = writer.write(data.encode("utf-8")) 417 writer.flush() 418 writer.close() # TODO: Workaround for https://github.com/python/cpython/issues/138720 419 write_finished = True 420 threading.Thread(target=write_thread, name="Write Thread", daemon=True).start() 421 422 while not write_finished or next(filter(lambda x: x.state != MessageState.MSGSTATE_DELIVERED, 423 self.h.outlet.packets), None) is not None: 424 with self.h.outlet.lock: 425 for packet in self.h.outlet.packets: 426 if packet.state != MessageState.MSGSTATE_DELIVERED: 427 self.h.channel._receive(packet.raw) 428 packet.delivered() 429 time.sleep(0.0001) 430 431 self.assertEqual(len(data), count) 432 433 read_finished = False 434 result = bytes() 435 436 def read_thread(): 437 nonlocal read_finished, result 438 result = reader.read() 439 read_finished = True 440 threading.Thread(target=read_thread, name="Read Thread", daemon=True).start() 441 442 timeout_at = time.time() + 7 443 while not read_finished and time.time() < timeout_at: 444 time.sleep(0.001) 445 446 self.assertTrue(read_finished) 447 self.assertEqual(len(data), len(result)) 448 449 decoded = result.decode("utf-8") 450 451 self.assertSequenceEqual(data, decoded) 452 453 def test_buffer_small_with_callback(self): 454 callbacks = 0 455 last_cb_value = None 456 457 def callback(ready: int): 458 nonlocal callbacks, last_cb_value 459 callbacks += 1 460 last_cb_value = ready 461 462 data = "Hello\n" 463 with RNS.RawChannelWriter(0, self.h.channel) as writer, RNS.RawChannelReader(0, self.h.channel) as reader: 464 reader.add_ready_callback(callback) 465 count = writer.write(data.encode("utf-8")) 466 writer.flush() 467 468 self.assertEqual(len(data), count) 469 self.assertEqual(1, len(self.h.outlet.packets)) 470 471 packet = self.h.outlet.packets[0] 472 self.h.channel._receive(packet.raw) 473 packet.delivered() 474 475 self.assertEqual(1, callbacks) 476 self.assertEqual(len(data), last_cb_value) 477 478 result = reader.readline() 479 480 self.assertIsNotNone(result) 481 self.assertEqual(len(result), len(data)) 482 483 decoded = result.decode("utf-8") 484 485 self.assertEqual(data, decoded) 486 self.assertEqual(1, len(self.h.outlet.packets)) 487 488 result = reader.read(1) 489 490 self.assertIsNone(result) 491 self.assertTrue(self.h.channel.is_ready_to_send()) 492 493 writer.close() 494 495 self.assertEqual(2, len(self.h.outlet.packets)) 496 497 packet = self.h.outlet.packets[1] 498 self.h.channel._receive(packet.raw) 499 packet.delivered() 500 501 result = reader.read(1) 502 503 self.assertIsNotNone(result) 504 self.assertTrue(len(result) == 0) 505 506 507 if __name__ == '__main__': 508 unittest.main(verbosity=2)