rpc.py
  1  # SPDX-FileCopyrightText: Copyright (c) 2021 Melissa LeBlanc-Williams for Adafruit Industries
  2  #
  3  # SPDX-License-Identifier: Unlicense
  4  """
  5  USB CDC Remote Procedure Call class
  6  """
  7  
  8  import time
  9  import json
 10  try:
 11      import serial
 12      import adafruit_board_toolkit.circuitpython_serial
 13      json_decode_exception = json.decoder.JSONDecodeError
 14  except ImportError:
 15      import usb_cdc as serial
 16      json_decode_exception = ValueError
 17  
 18  RESPONSE_TIMEOUT=5
 19  DATA_TIMEOUT=0.5
 20  
 21  class RpcError(Exception):
 22      """For RPC Specific Errors"""
 23      pass
 24  
 25  class _Rpc:
 26      def __init__(self):
 27          self._serial = None
 28  
 29      @staticmethod
 30      def create_response_packet(error=False, error_type="RPC", message=None, return_val=None):
 31          return {
 32              "error": error,
 33              "error_type": error_type if error else None,
 34              "message": message,
 35              "return_val": return_val
 36          }
 37  
 38      @staticmethod
 39      def create_request_packet(function, args=[], kwargs={}):
 40          return {
 41              "function": function,
 42              "args": args,
 43              "kwargs": kwargs
 44          }
 45  
 46      def _wait_for_packet(self, timeout=None):
 47          incoming_packet = b""
 48          if timeout is not None:
 49              response_start_time = time.monotonic()
 50          while True:
 51              if incoming_packet:
 52                  data_start_time = time.monotonic()
 53              while not self._serial.in_waiting:
 54                  if incoming_packet and (time.monotonic() - data_start_time) >= DATA_TIMEOUT:
 55                      incoming_packet = b""
 56                  if not incoming_packet and timeout is not None:
 57                      if (time.monotonic() - response_start_time) >= timeout:
 58                          return self.create_response_packet(error=True, message="Timed out waiting for response")
 59                  time.sleep(0.001)
 60              data = self._serial.read(self._serial.in_waiting)
 61              if data:
 62                  try:
 63                      incoming_packet += data
 64                      packet = json.loads(incoming_packet)
 65                      # json can try to be clever with missing braces, so make sure we have everything
 66                      if sorted(tuple(packet.keys())) == sorted(self._packet_format()):
 67                          return packet
 68                  except json_decode_exception:
 69                      pass # Incomplete packet
 70  
 71  class RpcClient(_Rpc):
 72      def __init__(self):
 73          super().__init__()
 74          self._serial = serial.data
 75      
 76      def _packet_format(self):
 77          return self.create_response_packet().keys()
 78  
 79      def call(self, function, *args, **kwargs):
 80          packet = self.create_request_packet(function, args, kwargs)
 81          self._serial.write(bytes(json.dumps(packet), "utf-8"))
 82          # Wait for response packet to indicate success
 83          return self._wait_for_packet(RESPONSE_TIMEOUT)
 84  
 85  class RpcServer(_Rpc):
 86      def __init__(self, handler, baudrate=9600):
 87          super().__init__()
 88          self._serial = self.init_serial(baudrate)
 89          self._handler = handler
 90  
 91      def _packet_format(self):
 92          return self.create_request_packet(None).keys()
 93  
 94      def init_serial(self, baudrate):
 95          port = self.detect_port()
 96  
 97          return serial.Serial(
 98              port,
 99              baudrate,
100              parity='N',
101              rtscts=False,
102              xonxoff=False,
103              exclusive=True,
104          )
105  
106      def detect_port(self):
107          """
108          Detect the port automatically
109          """
110          comports = adafruit_board_toolkit.circuitpython_serial.data_comports()
111          ports = [comport.device for comport in comports]
112          if len(ports) >= 1:
113              if len(ports) > 1:
114                  print("Multiple devices detected, using the first detected port.")
115              return ports[0]
116          raise RuntimeError("Unable to find any CircuitPython Devices with the CDC Data port enabled.")
117  
118      def loop(self, timeout=None):
119          packet = self._wait_for_packet(timeout)
120          if "error" not in packet:
121              response_packet = self._handler(packet)
122              self._serial.write(bytes(json.dumps(response_packet), "utf-8"))
123      
124      def close_serial(self):
125          if self._serial is not None:
126              self._serial.close()