message_handler.py
1 """ 2 Message handling for the development Solace broker simulator. 3 Manages message storage, retrieval, and lifecycle. 4 """ 5 6 import time 7 import threading 8 import uuid 9 from typing import Dict, List, Any, Optional, Callable 10 from dataclasses import dataclass, field 11 from collections import deque 12 import logging 13 14 15 @dataclass 16 class BrokerMessage: 17 """Represents a message in the broker.""" 18 id: str = field(default_factory=lambda: str(uuid.uuid4())) 19 topic: str = "" 20 payload: Any = None 21 user_properties: Dict[str, Any] = field(default_factory=dict) 22 timestamp: float = field(default_factory=time.time) 23 ttl_seconds: Optional[int] = None 24 qos: int = 1 25 26 @property 27 def is_expired(self) -> bool: 28 """Check if the message has expired.""" 29 if self.ttl_seconds is None: 30 return False 31 return time.time() > (self.timestamp + self.ttl_seconds) 32 33 def to_dict(self) -> Dict[str, Any]: 34 """Convert message to dictionary format.""" 35 return { 36 "id": self.id, 37 "topic": self.topic, 38 "payload": self.payload, 39 "user_properties": self.user_properties, 40 "timestamp": self.timestamp, 41 "ttl_seconds": self.ttl_seconds, 42 "qos": self.qos, 43 } 44 45 46 class MessageHandler: 47 """Handles message storage, retrieval, and lifecycle management.""" 48 49 def __init__(self, max_queue_size: int = 1000, default_ttl_seconds: int = 300): 50 self.max_queue_size = max_queue_size 51 self.default_ttl_seconds = default_ttl_seconds 52 53 # Message storage 54 self._messages: deque = deque(maxlen=max_queue_size) 55 self._message_index: Dict[str, BrokerMessage] = {} 56 57 # Message capture for testing 58 self._captured_messages: List[BrokerMessage] = [] 59 self._capture_enabled = True 60 61 # Threading 62 self._lock = threading.RLock() 63 self._logger = logging.getLogger(f"{__name__}.MessageHandler") 64 65 # Message listeners for testing 66 self._message_listeners: List[Callable[[BrokerMessage], None]] = [] 67 68 def store_message( 69 self, 70 topic: str, 71 payload: Any, 72 user_properties: Optional[Dict[str, Any]] = None, 73 ttl_seconds: Optional[int] = None, 74 qos: int = 1 75 ) -> BrokerMessage: 76 """ 77 Store a message in the broker. 78 79 Args: 80 topic: Message topic 81 payload: Message payload 82 user_properties: Optional user properties 83 ttl_seconds: Time to live in seconds 84 qos: Quality of service level 85 86 Returns: 87 The stored message 88 """ 89 with self._lock: 90 message = BrokerMessage( 91 topic=topic, 92 payload=payload, 93 user_properties=user_properties or {}, 94 ttl_seconds=ttl_seconds or self.default_ttl_seconds, 95 qos=qos 96 ) 97 98 # Add to storage 99 self._messages.append(message) 100 self._message_index[message.id] = message 101 102 # Capture for testing if enabled 103 if self._capture_enabled: 104 self._captured_messages.append(message) 105 106 # Notify listeners 107 for listener in self._message_listeners: 108 try: 109 listener(message) 110 except Exception as e: 111 self._logger.error(f"Error in message listener: {e}") 112 113 self._logger.debug( 114 f"Stored message {message.id} on topic '{topic}' " 115 f"(payload size: {len(str(payload))} chars)" 116 ) 117 118 return message 119 120 def get_message(self, message_id: str) -> Optional[BrokerMessage]: 121 """ 122 Retrieve a message by ID. 123 124 Args: 125 message_id: Message identifier 126 127 Returns: 128 The message if found and not expired, None otherwise 129 """ 130 with self._lock: 131 message = self._message_index.get(message_id) 132 if message and not message.is_expired: 133 return message 134 return None 135 136 def get_messages_by_topic(self, topic: str, limit: Optional[int] = None) -> List[BrokerMessage]: 137 """ 138 Get messages for a specific topic. 139 140 Args: 141 topic: Topic to search for 142 limit: Maximum number of messages to return 143 144 Returns: 145 List of messages for the topic 146 """ 147 with self._lock: 148 messages = [] 149 count = 0 150 151 for message in reversed(self._messages): 152 if message.is_expired: 153 continue 154 155 if message.topic == topic: 156 messages.append(message) 157 count += 1 158 159 if limit and count >= limit: 160 break 161 162 return messages 163 164 def get_recent_messages(self, limit: int = 10) -> List[BrokerMessage]: 165 """ 166 Get the most recent messages. 167 168 Args: 169 limit: Maximum number of messages to return 170 171 Returns: 172 List of recent messages 173 """ 174 with self._lock: 175 messages = [] 176 count = 0 177 178 for message in reversed(self._messages): 179 if not message.is_expired: 180 messages.append(message) 181 count += 1 182 183 if count >= limit: 184 break 185 186 return messages 187 188 def cleanup_expired_messages(self) -> int: 189 """ 190 Remove expired messages from storage. 191 192 Returns: 193 Number of messages removed 194 """ 195 with self._lock: 196 removed_count = 0 197 messages_to_remove = [] 198 199 for message in self._messages: 200 if message.is_expired: 201 messages_to_remove.append(message) 202 203 for message in messages_to_remove: 204 self._messages.remove(message) 205 if message.id in self._message_index: 206 del self._message_index[message.id] 207 removed_count += 1 208 209 if removed_count > 0: 210 self._logger.debug(f"Cleaned up {removed_count} expired messages") 211 212 return removed_count 213 214 def clear_all_messages(self) -> int: 215 """ 216 Clear all stored messages. 217 218 Returns: 219 Number of messages cleared 220 """ 221 with self._lock: 222 count = len(self._messages) 223 self._messages.clear() 224 self._message_index.clear() 225 226 if count > 0: 227 self._logger.info(f"Cleared all {count} messages") 228 229 return count 230 231 def get_message_count(self) -> int: 232 """Get the current number of stored messages.""" 233 with self._lock: 234 return len(self._messages) 235 236 # Testing utilities 237 238 def get_captured_messages(self) -> List[BrokerMessage]: 239 """Get all captured messages for testing.""" 240 with self._lock: 241 return self._captured_messages.copy() 242 243 def clear_captured_messages(self) -> None: 244 """Clear captured messages.""" 245 with self._lock: 246 self._captured_messages.clear() 247 self._logger.debug("Cleared captured messages") 248 249 def set_capture_enabled(self, enabled: bool) -> None: 250 """Enable or disable message capture.""" 251 with self._lock: 252 self._capture_enabled = enabled 253 self._logger.debug(f"Message capture {'enabled' if enabled else 'disabled'}") 254 255 def add_message_listener(self, listener: Callable[[BrokerMessage], None]) -> None: 256 """Add a message listener for testing.""" 257 with self._lock: 258 self._message_listeners.append(listener) 259 self._logger.debug("Added message listener") 260 261 def remove_message_listener(self, listener: Callable[[BrokerMessage], None]) -> None: 262 """Remove a message listener.""" 263 with self._lock: 264 if listener in self._message_listeners: 265 self._message_listeners.remove(listener) 266 self._logger.debug("Removed message listener") 267 268 def clear_message_listeners(self) -> None: 269 """Clear all message listeners.""" 270 with self._lock: 271 count = len(self._message_listeners) 272 self._message_listeners.clear() 273 if count > 0: 274 self._logger.debug(f"Cleared {count} message listeners") 275 276 def find_messages_by_payload(self, payload_filter: Callable[[Any], bool]) -> List[BrokerMessage]: 277 """ 278 Find messages by payload content. 279 280 Args: 281 payload_filter: Function that returns True for matching payloads 282 283 Returns: 284 List of matching messages 285 """ 286 with self._lock: 287 matching_messages = [] 288 289 for message in self._messages: 290 if message.is_expired: 291 continue 292 293 try: 294 if payload_filter(message.payload): 295 matching_messages.append(message) 296 except Exception as e: 297 self._logger.debug(f"Error in payload filter: {e}") 298 299 return matching_messages 300 301 def get_statistics(self) -> Dict[str, Any]: 302 """Get message handler statistics.""" 303 with self._lock: 304 total_messages = len(self._messages) 305 expired_count = sum(1 for msg in self._messages if msg.is_expired) 306 307 return { 308 "total_messages": total_messages, 309 "active_messages": total_messages - expired_count, 310 "expired_messages": expired_count, 311 "captured_messages": len(self._captured_messages), 312 "message_listeners": len(self._message_listeners), 313 "capture_enabled": self._capture_enabled, 314 "max_queue_size": self.max_queue_size, 315 "default_ttl_seconds": self.default_ttl_seconds, 316 }