message_handler.py
  1  """
  2  Message handling for the development Solace broker simulator.
  3  Manages message storage, retrieval, and lifecycle.
  4  """
  5  
  6  import time
  7  import threading
  8  import uuid
  9  from typing import Dict, List, Any, Optional, Callable
 10  from dataclasses import dataclass, field
 11  from collections import deque
 12  import logging
 13  
 14  
 15  @dataclass
 16  class BrokerMessage:
 17      """Represents a message in the broker."""
 18      id: str = field(default_factory=lambda: str(uuid.uuid4()))
 19      topic: str = ""
 20      payload: Any = None
 21      user_properties: Dict[str, Any] = field(default_factory=dict)
 22      timestamp: float = field(default_factory=time.time)
 23      ttl_seconds: Optional[int] = None
 24      qos: int = 1
 25      
 26      @property
 27      def is_expired(self) -> bool:
 28          """Check if the message has expired."""
 29          if self.ttl_seconds is None:
 30              return False
 31          return time.time() > (self.timestamp + self.ttl_seconds)
 32      
 33      def to_dict(self) -> Dict[str, Any]:
 34          """Convert message to dictionary format."""
 35          return {
 36              "id": self.id,
 37              "topic": self.topic,
 38              "payload": self.payload,
 39              "user_properties": self.user_properties,
 40              "timestamp": self.timestamp,
 41              "ttl_seconds": self.ttl_seconds,
 42              "qos": self.qos,
 43          }
 44  
 45  
 46  class MessageHandler:
 47      """Handles message storage, retrieval, and lifecycle management."""
 48      
 49      def __init__(self, max_queue_size: int = 1000, default_ttl_seconds: int = 300):
 50          self.max_queue_size = max_queue_size
 51          self.default_ttl_seconds = default_ttl_seconds
 52          
 53          # Message storage
 54          self._messages: deque = deque(maxlen=max_queue_size)
 55          self._message_index: Dict[str, BrokerMessage] = {}
 56          
 57          # Message capture for testing
 58          self._captured_messages: List[BrokerMessage] = []
 59          self._capture_enabled = True
 60          
 61          # Threading
 62          self._lock = threading.RLock()
 63          self._logger = logging.getLogger(f"{__name__}.MessageHandler")
 64          
 65          # Message listeners for testing
 66          self._message_listeners: List[Callable[[BrokerMessage], None]] = []
 67      
 68      def store_message(
 69          self, 
 70          topic: str, 
 71          payload: Any, 
 72          user_properties: Optional[Dict[str, Any]] = None,
 73          ttl_seconds: Optional[int] = None,
 74          qos: int = 1
 75      ) -> BrokerMessage:
 76          """
 77          Store a message in the broker.
 78          
 79          Args:
 80              topic: Message topic
 81              payload: Message payload
 82              user_properties: Optional user properties
 83              ttl_seconds: Time to live in seconds
 84              qos: Quality of service level
 85              
 86          Returns:
 87              The stored message
 88          """
 89          with self._lock:
 90              message = BrokerMessage(
 91                  topic=topic,
 92                  payload=payload,
 93                  user_properties=user_properties or {},
 94                  ttl_seconds=ttl_seconds or self.default_ttl_seconds,
 95                  qos=qos
 96              )
 97              
 98              # Add to storage
 99              self._messages.append(message)
100              self._message_index[message.id] = message
101              
102              # Capture for testing if enabled
103              if self._capture_enabled:
104                  self._captured_messages.append(message)
105              
106              # Notify listeners
107              for listener in self._message_listeners:
108                  try:
109                      listener(message)
110                  except Exception as e:
111                      self._logger.error(f"Error in message listener: {e}")
112              
113              self._logger.debug(
114                  f"Stored message {message.id} on topic '{topic}' "
115                  f"(payload size: {len(str(payload))} chars)"
116              )
117              
118              return message
119      
120      def get_message(self, message_id: str) -> Optional[BrokerMessage]:
121          """
122          Retrieve a message by ID.
123          
124          Args:
125              message_id: Message identifier
126              
127          Returns:
128              The message if found and not expired, None otherwise
129          """
130          with self._lock:
131              message = self._message_index.get(message_id)
132              if message and not message.is_expired:
133                  return message
134              return None
135      
136      def get_messages_by_topic(self, topic: str, limit: Optional[int] = None) -> List[BrokerMessage]:
137          """
138          Get messages for a specific topic.
139          
140          Args:
141              topic: Topic to search for
142              limit: Maximum number of messages to return
143              
144          Returns:
145              List of messages for the topic
146          """
147          with self._lock:
148              messages = []
149              count = 0
150              
151              for message in reversed(self._messages):
152                  if message.is_expired:
153                      continue
154                      
155                  if message.topic == topic:
156                      messages.append(message)
157                      count += 1
158                      
159                      if limit and count >= limit:
160                          break
161              
162              return messages
163      
164      def get_recent_messages(self, limit: int = 10) -> List[BrokerMessage]:
165          """
166          Get the most recent messages.
167          
168          Args:
169              limit: Maximum number of messages to return
170              
171          Returns:
172              List of recent messages
173          """
174          with self._lock:
175              messages = []
176              count = 0
177              
178              for message in reversed(self._messages):
179                  if not message.is_expired:
180                      messages.append(message)
181                      count += 1
182                      
183                      if count >= limit:
184                          break
185              
186              return messages
187      
188      def cleanup_expired_messages(self) -> int:
189          """
190          Remove expired messages from storage.
191          
192          Returns:
193              Number of messages removed
194          """
195          with self._lock:
196              removed_count = 0
197              messages_to_remove = []
198              
199              for message in self._messages:
200                  if message.is_expired:
201                      messages_to_remove.append(message)
202              
203              for message in messages_to_remove:
204                  self._messages.remove(message)
205                  if message.id in self._message_index:
206                      del self._message_index[message.id]
207                  removed_count += 1
208              
209              if removed_count > 0:
210                  self._logger.debug(f"Cleaned up {removed_count} expired messages")
211              
212              return removed_count
213      
214      def clear_all_messages(self) -> int:
215          """
216          Clear all stored messages.
217          
218          Returns:
219              Number of messages cleared
220          """
221          with self._lock:
222              count = len(self._messages)
223              self._messages.clear()
224              self._message_index.clear()
225              
226              if count > 0:
227                  self._logger.info(f"Cleared all {count} messages")
228              
229              return count
230      
231      def get_message_count(self) -> int:
232          """Get the current number of stored messages."""
233          with self._lock:
234              return len(self._messages)
235      
236      # Testing utilities
237      
238      def get_captured_messages(self) -> List[BrokerMessage]:
239          """Get all captured messages for testing."""
240          with self._lock:
241              return self._captured_messages.copy()
242      
243      def clear_captured_messages(self) -> None:
244          """Clear captured messages."""
245          with self._lock:
246              self._captured_messages.clear()
247              self._logger.debug("Cleared captured messages")
248      
249      def set_capture_enabled(self, enabled: bool) -> None:
250          """Enable or disable message capture."""
251          with self._lock:
252              self._capture_enabled = enabled
253              self._logger.debug(f"Message capture {'enabled' if enabled else 'disabled'}")
254      
255      def add_message_listener(self, listener: Callable[[BrokerMessage], None]) -> None:
256          """Add a message listener for testing."""
257          with self._lock:
258              self._message_listeners.append(listener)
259              self._logger.debug("Added message listener")
260      
261      def remove_message_listener(self, listener: Callable[[BrokerMessage], None]) -> None:
262          """Remove a message listener."""
263          with self._lock:
264              if listener in self._message_listeners:
265                  self._message_listeners.remove(listener)
266                  self._logger.debug("Removed message listener")
267      
268      def clear_message_listeners(self) -> None:
269          """Clear all message listeners."""
270          with self._lock:
271              count = len(self._message_listeners)
272              self._message_listeners.clear()
273              if count > 0:
274                  self._logger.debug(f"Cleared {count} message listeners")
275      
276      def find_messages_by_payload(self, payload_filter: Callable[[Any], bool]) -> List[BrokerMessage]:
277          """
278          Find messages by payload content.
279          
280          Args:
281              payload_filter: Function that returns True for matching payloads
282              
283          Returns:
284              List of matching messages
285          """
286          with self._lock:
287              matching_messages = []
288              
289              for message in self._messages:
290                  if message.is_expired:
291                      continue
292                      
293                  try:
294                      if payload_filter(message.payload):
295                          matching_messages.append(message)
296                  except Exception as e:
297                      self._logger.debug(f"Error in payload filter: {e}")
298              
299              return matching_messages
300      
301      def get_statistics(self) -> Dict[str, Any]:
302          """Get message handler statistics."""
303          with self._lock:
304              total_messages = len(self._messages)
305              expired_count = sum(1 for msg in self._messages if msg.is_expired)
306              
307              return {
308                  "total_messages": total_messages,
309                  "active_messages": total_messages - expired_count,
310                  "expired_messages": expired_count,
311                  "captured_messages": len(self._captured_messages),
312                  "message_listeners": len(self._message_listeners),
313                  "capture_enabled": self._capture_enabled,
314                  "max_queue_size": self.max_queue_size,
315                  "default_ttl_seconds": self.default_ttl_seconds,
316              }