conversation.py
1 """ 2 Conversation session management for Ag3ntum. 3 4 Provides a ConversationSession class that wraps ClaudeSDKClient for 5 multi-turn interactive conversations with session continuity. 6 7 Usage: 8 async with ConversationSession(options) as session: 9 # First query 10 response = await session.query("Analyze this codebase") 11 12 # Follow-up with context preserved 13 response = await session.query("What patterns did you find?") 14 15 # Continue conversation... 16 response = await session.query("Refactor the auth module") 17 """ 18 import logging 19 from dataclasses import dataclass, field 20 from datetime import datetime 21 from typing import Any, AsyncIterator, Callable, Optional 22 23 from claude_agent_sdk import ( 24 AssistantMessage, 25 ClaudeAgentOptions, 26 ClaudeSDKClient, 27 ResultMessage, 28 SystemMessage, 29 TextBlock, 30 ToolResultBlock, 31 ToolUseBlock, 32 ) 33 34 from .schemas import SessionContext, TaskStatus, TokenUsage 35 36 logger = logging.getLogger(__name__) 37 38 39 @dataclass 40 class ConversationTurn: 41 """ 42 Represents a single turn in a conversation. 43 44 A turn consists of a user prompt and the assistant's response, 45 along with metadata about tool usage and timing. 46 """ 47 turn_number: int 48 prompt: str 49 response_text: str = "" 50 tools_used: list[str] = field(default_factory=list) 51 started_at: datetime = field(default_factory=datetime.now) 52 completed_at: Optional[datetime] = None 53 duration_ms: int = 0 54 is_error: bool = False 55 error_message: Optional[str] = None 56 57 58 @dataclass 59 class ConversationMetrics: 60 """ 61 Accumulated metrics for a conversation session. 62 63 Terminology: 64 - conversation_turns: High-level user↔agent exchanges (prompts/responses) 65 - agent_turns: Low-level SDK turns (tool calls, reasoning loops within a response) 66 """ 67 conversation_turns: int = 0 # User prompt → agent response count 68 agent_turns: int = 0 # SDK num_turns (tool calls, internal loops) 69 total_duration_ms: int = 0 70 total_cost_usd: float = 0.0 71 total_tokens: int = 0 72 input_tokens: int = 0 73 output_tokens: int = 0 74 cache_read_tokens: int = 0 75 cache_creation_tokens: int = 0 76 77 def add_result(self, result: ResultMessage) -> None: 78 """ 79 Add metrics from a ResultMessage. 80 81 Args: 82 result: The ResultMessage from SDK. 83 """ 84 self.conversation_turns += 1 # Each ResultMessage = one conversation turn 85 self.agent_turns += result.num_turns # SDK's internal turn count 86 self.total_duration_ms += result.duration_ms 87 if result.total_cost_usd: 88 self.total_cost_usd += result.total_cost_usd 89 90 if result.usage: 91 self.input_tokens += result.usage.get("input_tokens", 0) 92 self.output_tokens += result.usage.get("output_tokens", 0) 93 self.cache_read_tokens += result.usage.get( 94 "cache_read_input_tokens", 0 95 ) 96 self.cache_creation_tokens += result.usage.get( 97 "cache_creation_input_tokens", 0 98 ) 99 self.total_tokens = ( 100 self.input_tokens 101 + self.output_tokens 102 + self.cache_read_tokens 103 + self.cache_creation_tokens 104 ) 105 106 def to_token_usage(self) -> TokenUsage: 107 """Convert to TokenUsage schema.""" 108 return TokenUsage( 109 input_tokens=self.input_tokens, 110 output_tokens=self.output_tokens, 111 cache_read_input_tokens=self.cache_read_tokens, 112 cache_creation_input_tokens=self.cache_creation_tokens, 113 ) 114 115 116 class ConversationSession: 117 """ 118 Manages a multi-turn conversation session with Claude. 119 120 Wraps ClaudeSDKClient to provide: 121 - Session continuity across multiple exchanges 122 - Turn tracking and metrics accumulation 123 - Interrupt support for long-running tasks 124 - Message processing with callbacks 125 126 Usage: 127 options = ClaudeAgentOptions( 128 allowed_tools=["Read", "Write", "Bash"], 129 permission_mode="acceptEdits" 130 ) 131 132 async with ConversationSession(options) as session: 133 # First query 134 await session.query("List all Python files") 135 async for message in session.receive(): 136 print(message) 137 138 # Follow-up with context 139 await session.query("Now count the lines in each") 140 async for message in session.receive(): 141 print(message) 142 143 Args: 144 options: ClaudeAgentOptions for configuring the client. 145 on_message: Optional callback for each message received. 146 on_tool_start: Optional callback when a tool starts. 147 on_tool_complete: Optional callback when a tool completes. 148 on_turn_complete: Optional callback when a turn completes. 149 """ 150 151 def __init__( 152 self, 153 options: Optional[ClaudeAgentOptions] = None, 154 on_message: Optional[Callable[[Any], None]] = None, 155 on_tool_start: Optional[Callable[[str, dict, str], None]] = None, 156 on_tool_complete: Optional[Callable[[str, str, Any, bool], None]] = None, 157 on_turn_complete: Optional[Callable[[ConversationTurn], None]] = None, 158 ) -> None: 159 """ 160 Initialize the conversation session. 161 162 Args: 163 options: ClaudeAgentOptions for the SDK client. 164 on_message: Callback for each message (message) -> None. 165 on_tool_start: Callback for tool start (name, input, id) -> None. 166 on_tool_complete: Callback for tool end (name, id, result, is_error) -> None. 167 on_turn_complete: Callback when a turn completes (turn) -> None. 168 """ 169 self._options = options or ClaudeAgentOptions() 170 self._client: Optional[ClaudeSDKClient] = None 171 self._connected = False 172 self._session_id: Optional[str] = None 173 174 # Callbacks 175 self._on_message = on_message 176 self._on_tool_start = on_tool_start 177 self._on_tool_complete = on_tool_complete 178 self._on_turn_complete = on_turn_complete 179 180 # Conversation state 181 self._turns: list[ConversationTurn] = [] 182 self._current_turn: Optional[ConversationTurn] = None 183 self._metrics = ConversationMetrics() 184 self._last_result: Optional[ResultMessage] = None 185 186 # Interrupt handling 187 self._interrupted = False 188 189 @property 190 def session_id(self) -> Optional[str]: 191 """Get the current session ID.""" 192 return self._session_id 193 194 @property 195 def is_connected(self) -> bool: 196 """Check if the session is connected.""" 197 return self._connected 198 199 @property 200 def turn_count(self) -> int: 201 """Get the number of completed turns.""" 202 return len(self._turns) 203 204 @property 205 def metrics(self) -> ConversationMetrics: 206 """Get accumulated metrics.""" 207 return self._metrics 208 209 @property 210 def turns(self) -> list[ConversationTurn]: 211 """Get all completed turns.""" 212 return self._turns.copy() 213 214 @property 215 def last_result(self) -> Optional[ResultMessage]: 216 """Get the last ResultMessage from the SDK.""" 217 return self._last_result 218 219 @property 220 def was_interrupted(self) -> bool: 221 """Check if the session was interrupted.""" 222 return self._interrupted 223 224 async def connect(self, initial_prompt: Optional[str] = None) -> None: 225 """ 226 Connect to Claude and optionally send an initial prompt. 227 228 Args: 229 initial_prompt: Optional first prompt to send immediately. 230 231 Raises: 232 RuntimeError: If connection fails. 233 """ 234 if self._connected: 235 logger.warning("Already connected, ignoring connect() call") 236 return 237 238 self._client = ClaudeSDKClient(options=self._options) 239 try: 240 await self._client.connect(initial_prompt) 241 self._connected = True 242 logger.info("ConversationSession connected") 243 except Exception as e: 244 # Clean up client on connection failure 245 self._client = None 246 self._connected = False 247 logger.error(f"Failed to connect: {e}") 248 raise RuntimeError(f"Failed to connect to Claude: {e}") from e 249 250 async def disconnect(self) -> None: 251 """Disconnect from Claude and clean up.""" 252 if not self._connected or not self._client: 253 return 254 255 try: 256 await self._client.disconnect() 257 except Exception as e: 258 logger.warning(f"Error during disconnect: {e}") 259 finally: 260 self._connected = False 261 self._client = None 262 logger.info( 263 f"ConversationSession disconnected after {self.turn_count} turns" 264 ) 265 266 async def query(self, prompt: str) -> None: 267 """ 268 Send a query to Claude. 269 270 This starts a new turn in the conversation. Use receive() 271 to get the response messages. 272 273 Args: 274 prompt: The user prompt to send. 275 276 Raises: 277 RuntimeError: If not connected. 278 """ 279 if not self._connected or not self._client: 280 raise RuntimeError( 281 "Not connected. Call connect() or use 'async with' context." 282 ) 283 284 # Start new turn 285 self._current_turn = ConversationTurn( 286 turn_number=len(self._turns) + 1, 287 prompt=prompt, 288 started_at=datetime.now(), 289 ) 290 self._interrupted = False 291 292 logger.debug(f"Starting turn {self._current_turn.turn_number}: {prompt[:50]}...") 293 294 await self._client.query(prompt) 295 296 async def receive(self) -> AsyncIterator[Any]: 297 """ 298 Receive messages from the current query. 299 300 Yields messages until a ResultMessage is received, 301 which indicates the query is complete. 302 303 Yields: 304 Messages from Claude (AssistantMessage, SystemMessage, etc.) 305 306 Note: 307 If an exception occurs during message processing, the current 308 turn is marked as an error and stored for debugging. 309 """ 310 if not self._connected or not self._client: 311 raise RuntimeError("Not connected") 312 313 if not self._current_turn: 314 raise RuntimeError("No active query. Call query() first.") 315 316 response_parts: list[str] = [] 317 318 try: 319 async for message in self._client.receive_response(): 320 # Process message for turn tracking 321 if isinstance(message, AssistantMessage): 322 for block in message.content: 323 if isinstance(block, TextBlock): 324 response_parts.append(block.text) 325 elif isinstance(block, ToolUseBlock): 326 self._current_turn.tools_used.append(block.name) 327 if self._on_tool_start: 328 self._on_tool_start(block.name, block.input, block.id) 329 elif isinstance(block, ToolResultBlock): 330 is_error = block.is_error or False 331 if self._on_tool_complete: 332 self._on_tool_complete( 333 "unknown", # Tool name not in result block 334 block.tool_use_id, 335 block.content, 336 is_error 337 ) 338 339 elif isinstance(message, SystemMessage): 340 # Extract session ID from init message 341 if message.subtype == "init": 342 self._session_id = message.data.get("session_id") 343 logger.debug(f"Session ID: {self._session_id}") 344 345 elif isinstance(message, ResultMessage): 346 # Turn complete - process and terminate 347 self._last_result = message 348 self._current_turn.completed_at = datetime.now() 349 self._current_turn.duration_ms = message.duration_ms 350 self._current_turn.response_text = "".join(response_parts) 351 self._current_turn.is_error = message.is_error 352 353 if message.is_error: 354 self._current_turn.error_message = message.result 355 356 # Update metrics 357 self._metrics.add_result(message) 358 359 # Store completed turn 360 self._turns.append(self._current_turn) 361 362 if self._on_turn_complete: 363 self._on_turn_complete(self._current_turn) 364 365 self._current_turn = None 366 logger.debug( 367 f"Turn complete: {message.num_turns} SDK turns, " 368 f"{message.duration_ms}ms" 369 ) 370 371 # Invoke callback and yield final message 372 if self._on_message: 373 self._on_message(message) 374 yield message 375 376 # Terminate generator - ResultMessage signals query completion 377 return 378 379 # For non-terminal messages: invoke callback and yield 380 if self._on_message: 381 self._on_message(message) 382 yield message 383 384 except Exception as e: 385 # Handle exception during message iteration 386 logger.error(f"Error during receive: {e}") 387 388 if self._current_turn: 389 # Mark turn as error and store it 390 self._current_turn.completed_at = datetime.now() 391 self._current_turn.is_error = True 392 self._current_turn.error_message = str(e) 393 self._current_turn.response_text = "".join(response_parts) 394 self._turns.append(self._current_turn) 395 self._current_turn = None 396 397 raise 398 399 async def interrupt(self) -> None: 400 """ 401 Interrupt the current query. 402 403 This sends a signal to stop Claude mid-execution. 404 Use this for long-running tasks that need to be cancelled. 405 """ 406 if not self._connected or not self._client: 407 logger.warning("Cannot interrupt: not connected") 408 return 409 410 self._interrupted = True 411 await self._client.interrupt() 412 logger.info("Sent interrupt signal") 413 414 async def query_and_receive(self, prompt: str) -> ResultMessage: 415 """ 416 Convenience method to send a query and collect all messages. 417 418 Args: 419 prompt: The user prompt to send. 420 421 Returns: 422 The final ResultMessage. 423 424 Raises: 425 RuntimeError: If no ResultMessage is received. 426 """ 427 await self.query(prompt) 428 429 async for message in self.receive(): 430 if isinstance(message, ResultMessage): 431 return message 432 433 raise RuntimeError("No ResultMessage received") 434 435 def get_session_context(self, working_dir: str) -> SessionContext: 436 """ 437 Create a SessionContext from current conversation state. 438 439 Args: 440 working_dir: The working directory for the session. 441 442 Returns: 443 SessionContext with accumulated metrics. 444 """ 445 return SessionContext( 446 session_id=self._session_id or "unknown", 447 working_dir=working_dir, 448 claude_session_id=self._session_id, 449 cumulative_turns=self._metrics.agent_turns, 450 cumulative_duration_ms=self._metrics.total_duration_ms, 451 cumulative_cost_usd=self._metrics.total_cost_usd, 452 cumulative_input_tokens=self._metrics.input_tokens, 453 cumulative_output_tokens=self._metrics.output_tokens, 454 cumulative_cache_creation_tokens=self._metrics.cache_creation_tokens, 455 cumulative_cache_read_tokens=self._metrics.cache_read_tokens, 456 ) 457 458 async def __aenter__(self) -> "ConversationSession": 459 """Async context manager entry.""" 460 await self.connect() 461 return self 462 463 async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: 464 """Async context manager exit.""" 465 await self.disconnect()