/ spoolman / ws.py
ws.py
  1  """Websocket functionality."""
  2  
  3  import logging
  4  
  5  from fastapi import WebSocket
  6  from starlette.websockets import WebSocketState
  7  
  8  from spoolman.api.v1.models import Event
  9  
 10  logger = logging.getLogger(__name__)
 11  
 12  
 13  class SubscriptionTree:
 14      """Subscription tree.
 15  
 16      This is a tree structure that allows us to efficiently send messages to
 17      all websockets that are subscribed to a certain pool of events.
 18  
 19      You can subscribe to different levels of the tree, for example:
 20      - ("vendor", "1") will subscribe to events for vendor 1
 21      - ("vendor") will subscribe to events for all vendors
 22      - () will subscribe to events for all vendors, filaments and spools
 23      """
 24  
 25      def __init__(self) -> None:
 26          """Initialize."""
 27          self.children: dict[str, SubscriptionTree] = {}
 28          self.subscribers: set[WebSocket] = set()
 29  
 30      def add(self, path: tuple[str, ...], websocket: WebSocket) -> None:
 31          """Add a websocket to the subscription tree."""
 32          if len(path) == 0:
 33              self.subscribers.add(websocket)
 34          else:
 35              if path[0] not in self.children:
 36                  self.children[path[0]] = SubscriptionTree()
 37              self.children[path[0]].add(path[1:], websocket)
 38  
 39      def remove(self, path: tuple[str, ...], websocket: WebSocket) -> None:
 40          """Remove a websocket from the subscription tree."""
 41          if len(path) == 0:
 42              self.subscribers.remove(websocket)
 43          elif path[0] in self.children:
 44              self.children[path[0]].remove(path[1:], websocket)
 45  
 46      async def send(self, path: tuple[str, ...], evt: Event) -> None:
 47          """Send a message to all websockets in this branch of the tree."""
 48          # Broadcast to all subscribers on this level
 49          for websocket in self.subscribers:
 50              if (
 51                  websocket.client_state == WebSocketState.DISCONNECTED  # noqa: PLR1714
 52                  or websocket.application_state == WebSocketState.DISCONNECTED
 53              ):
 54                  # A bad disconnection may have occurred
 55                  self.remove(path, websocket)
 56                  logger.info(
 57                      "Forcing disconnection of client %s on pool %s",
 58                      websocket.client.host if websocket.client else "?",
 59                      ",".join(path),
 60                  )
 61              elif (
 62                  websocket.client_state == WebSocketState.CONNECTED
 63                  and websocket.application_state == WebSocketState.CONNECTED
 64              ):
 65                  await websocket.send_text(evt.json())
 66  
 67          # Send the message further down the tree
 68          if len(path) > 0 and path[0] in self.children:
 69              await self.children[path[0]].send(path[1:], evt)
 70  
 71  
 72  class WebsocketManager:
 73      """Websocket manager."""
 74  
 75      def __init__(self) -> None:
 76          """Initialize."""
 77          self.tree = SubscriptionTree()
 78  
 79      def connect(self, pool: tuple[str, ...], websocket: WebSocket) -> None:
 80          """Connect a websocket."""
 81          self.tree.add(pool, websocket)
 82          logger.info(
 83              "Client %s is now listening on pool %s",
 84              websocket.client.host if websocket.client else "?",
 85              ",".join(pool),
 86          )
 87  
 88      def disconnect(self, pool: tuple[str, ...], websocket: WebSocket) -> None:
 89          """Disconnect a websocket."""
 90          self.tree.remove(pool, websocket)
 91          logger.info(
 92              "Client %s has stopped listening on pool %s",
 93              websocket.client.host if websocket.client else "?",
 94              ",".join(pool),
 95          )
 96  
 97      async def send(self, pool: tuple[str, ...], evt: Event) -> None:
 98          """Send a message to all websockets in a pool."""
 99          await self.tree.send(pool, evt)
100  
101  
102  websocket_manager = WebsocketManager()