ws.py
1 """Websocket functionality.""" 2 3 import logging 4 5 from fastapi import WebSocket 6 from starlette.websockets import WebSocketState 7 8 from spoolman.api.v1.models import Event 9 10 logger = logging.getLogger(__name__) 11 12 13 class SubscriptionTree: 14 """Subscription tree. 15 16 This is a tree structure that allows us to efficiently send messages to 17 all websockets that are subscribed to a certain pool of events. 18 19 You can subscribe to different levels of the tree, for example: 20 - ("vendor", "1") will subscribe to events for vendor 1 21 - ("vendor") will subscribe to events for all vendors 22 - () will subscribe to events for all vendors, filaments and spools 23 """ 24 25 def __init__(self) -> None: 26 """Initialize.""" 27 self.children: dict[str, SubscriptionTree] = {} 28 self.subscribers: set[WebSocket] = set() 29 30 def add(self, path: tuple[str, ...], websocket: WebSocket) -> None: 31 """Add a websocket to the subscription tree.""" 32 if len(path) == 0: 33 self.subscribers.add(websocket) 34 else: 35 if path[0] not in self.children: 36 self.children[path[0]] = SubscriptionTree() 37 self.children[path[0]].add(path[1:], websocket) 38 39 def remove(self, path: tuple[str, ...], websocket: WebSocket) -> None: 40 """Remove a websocket from the subscription tree.""" 41 if len(path) == 0: 42 self.subscribers.remove(websocket) 43 elif path[0] in self.children: 44 self.children[path[0]].remove(path[1:], websocket) 45 46 async def send(self, path: tuple[str, ...], evt: Event) -> None: 47 """Send a message to all websockets in this branch of the tree.""" 48 # Broadcast to all subscribers on this level 49 for websocket in self.subscribers: 50 if ( 51 websocket.client_state == WebSocketState.DISCONNECTED # noqa: PLR1714 52 or websocket.application_state == WebSocketState.DISCONNECTED 53 ): 54 # A bad disconnection may have occurred 55 self.remove(path, websocket) 56 logger.info( 57 "Forcing disconnection of client %s on pool %s", 58 websocket.client.host if websocket.client else "?", 59 ",".join(path), 60 ) 61 elif ( 62 websocket.client_state == WebSocketState.CONNECTED 63 and websocket.application_state == WebSocketState.CONNECTED 64 ): 65 await websocket.send_text(evt.json()) 66 67 # Send the message further down the tree 68 if len(path) > 0 and path[0] in self.children: 69 await self.children[path[0]].send(path[1:], evt) 70 71 72 class WebsocketManager: 73 """Websocket manager.""" 74 75 def __init__(self) -> None: 76 """Initialize.""" 77 self.tree = SubscriptionTree() 78 79 def connect(self, pool: tuple[str, ...], websocket: WebSocket) -> None: 80 """Connect a websocket.""" 81 self.tree.add(pool, websocket) 82 logger.info( 83 "Client %s is now listening on pool %s", 84 websocket.client.host if websocket.client else "?", 85 ",".join(pool), 86 ) 87 88 def disconnect(self, pool: tuple[str, ...], websocket: WebSocket) -> None: 89 """Disconnect a websocket.""" 90 self.tree.remove(pool, websocket) 91 logger.info( 92 "Client %s has stopped listening on pool %s", 93 websocket.client.host if websocket.client else "?", 94 ",".join(pool), 95 ) 96 97 async def send(self, pool: tuple[str, ...], evt: Event) -> None: 98 """Send a message to all websockets in a pool.""" 99 await self.tree.send(pool, evt) 100 101 102 websocket_manager = WebsocketManager()