/ src / core / conversation.py
conversation.py
  1  """
  2  Conversation session management for Ag3ntum.
  3  
  4  Provides a ConversationSession class that wraps ClaudeSDKClient for
  5  multi-turn interactive conversations with session continuity.
  6  
  7  Usage:
  8      async with ConversationSession(options) as session:
  9          # First query
 10          response = await session.query("Analyze this codebase")
 11  
 12          # Follow-up with context preserved
 13          response = await session.query("What patterns did you find?")
 14  
 15          # Continue conversation...
 16          response = await session.query("Refactor the auth module")
 17  """
 18  import logging
 19  from dataclasses import dataclass, field
 20  from datetime import datetime
 21  from typing import Any, AsyncIterator, Callable, Optional
 22  
 23  from claude_agent_sdk import (
 24      AssistantMessage,
 25      ClaudeAgentOptions,
 26      ClaudeSDKClient,
 27      ResultMessage,
 28      SystemMessage,
 29      TextBlock,
 30      ToolResultBlock,
 31      ToolUseBlock,
 32  )
 33  
 34  from .schemas import SessionContext, TaskStatus, TokenUsage
 35  
 36  logger = logging.getLogger(__name__)
 37  
 38  
 39  @dataclass
 40  class ConversationTurn:
 41      """
 42      Represents a single turn in a conversation.
 43  
 44      A turn consists of a user prompt and the assistant's response,
 45      along with metadata about tool usage and timing.
 46      """
 47      turn_number: int
 48      prompt: str
 49      response_text: str = ""
 50      tools_used: list[str] = field(default_factory=list)
 51      started_at: datetime = field(default_factory=datetime.now)
 52      completed_at: Optional[datetime] = None
 53      duration_ms: int = 0
 54      is_error: bool = False
 55      error_message: Optional[str] = None
 56  
 57  
 58  @dataclass
 59  class ConversationMetrics:
 60      """
 61      Accumulated metrics for a conversation session.
 62  
 63      Terminology:
 64      - conversation_turns: High-level user↔agent exchanges (prompts/responses)
 65      - agent_turns: Low-level SDK turns (tool calls, reasoning loops within a response)
 66      """
 67      conversation_turns: int = 0  # User prompt → agent response count
 68      agent_turns: int = 0  # SDK num_turns (tool calls, internal loops)
 69      total_duration_ms: int = 0
 70      total_cost_usd: float = 0.0
 71      total_tokens: int = 0
 72      input_tokens: int = 0
 73      output_tokens: int = 0
 74      cache_read_tokens: int = 0
 75      cache_creation_tokens: int = 0
 76  
 77      def add_result(self, result: ResultMessage) -> None:
 78          """
 79          Add metrics from a ResultMessage.
 80  
 81          Args:
 82              result: The ResultMessage from SDK.
 83          """
 84          self.conversation_turns += 1  # Each ResultMessage = one conversation turn
 85          self.agent_turns += result.num_turns  # SDK's internal turn count
 86          self.total_duration_ms += result.duration_ms
 87          if result.total_cost_usd:
 88              self.total_cost_usd += result.total_cost_usd
 89  
 90          if result.usage:
 91              self.input_tokens += result.usage.get("input_tokens", 0)
 92              self.output_tokens += result.usage.get("output_tokens", 0)
 93              self.cache_read_tokens += result.usage.get(
 94                  "cache_read_input_tokens", 0
 95              )
 96              self.cache_creation_tokens += result.usage.get(
 97                  "cache_creation_input_tokens", 0
 98              )
 99              self.total_tokens = (
100                  self.input_tokens
101                  + self.output_tokens
102                  + self.cache_read_tokens
103                  + self.cache_creation_tokens
104              )
105  
106      def to_token_usage(self) -> TokenUsage:
107          """Convert to TokenUsage schema."""
108          return TokenUsage(
109              input_tokens=self.input_tokens,
110              output_tokens=self.output_tokens,
111              cache_read_input_tokens=self.cache_read_tokens,
112              cache_creation_input_tokens=self.cache_creation_tokens,
113          )
114  
115  
116  class ConversationSession:
117      """
118      Manages a multi-turn conversation session with Claude.
119  
120      Wraps ClaudeSDKClient to provide:
121      - Session continuity across multiple exchanges
122      - Turn tracking and metrics accumulation
123      - Interrupt support for long-running tasks
124      - Message processing with callbacks
125  
126      Usage:
127          options = ClaudeAgentOptions(
128              allowed_tools=["Read", "Write", "Bash"],
129              permission_mode="acceptEdits"
130          )
131  
132          async with ConversationSession(options) as session:
133              # First query
134              await session.query("List all Python files")
135              async for message in session.receive():
136                  print(message)
137  
138              # Follow-up with context
139              await session.query("Now count the lines in each")
140              async for message in session.receive():
141                  print(message)
142  
143      Args:
144          options: ClaudeAgentOptions for configuring the client.
145          on_message: Optional callback for each message received.
146          on_tool_start: Optional callback when a tool starts.
147          on_tool_complete: Optional callback when a tool completes.
148          on_turn_complete: Optional callback when a turn completes.
149      """
150  
151      def __init__(
152          self,
153          options: Optional[ClaudeAgentOptions] = None,
154          on_message: Optional[Callable[[Any], None]] = None,
155          on_tool_start: Optional[Callable[[str, dict, str], None]] = None,
156          on_tool_complete: Optional[Callable[[str, str, Any, bool], None]] = None,
157          on_turn_complete: Optional[Callable[[ConversationTurn], None]] = None,
158      ) -> None:
159          """
160          Initialize the conversation session.
161  
162          Args:
163              options: ClaudeAgentOptions for the SDK client.
164              on_message: Callback for each message (message) -> None.
165              on_tool_start: Callback for tool start (name, input, id) -> None.
166              on_tool_complete: Callback for tool end (name, id, result, is_error) -> None.
167              on_turn_complete: Callback when a turn completes (turn) -> None.
168          """
169          self._options = options or ClaudeAgentOptions()
170          self._client: Optional[ClaudeSDKClient] = None
171          self._connected = False
172          self._session_id: Optional[str] = None
173  
174          # Callbacks
175          self._on_message = on_message
176          self._on_tool_start = on_tool_start
177          self._on_tool_complete = on_tool_complete
178          self._on_turn_complete = on_turn_complete
179  
180          # Conversation state
181          self._turns: list[ConversationTurn] = []
182          self._current_turn: Optional[ConversationTurn] = None
183          self._metrics = ConversationMetrics()
184          self._last_result: Optional[ResultMessage] = None
185  
186          # Interrupt handling
187          self._interrupted = False
188  
189      @property
190      def session_id(self) -> Optional[str]:
191          """Get the current session ID."""
192          return self._session_id
193  
194      @property
195      def is_connected(self) -> bool:
196          """Check if the session is connected."""
197          return self._connected
198  
199      @property
200      def turn_count(self) -> int:
201          """Get the number of completed turns."""
202          return len(self._turns)
203  
204      @property
205      def metrics(self) -> ConversationMetrics:
206          """Get accumulated metrics."""
207          return self._metrics
208  
209      @property
210      def turns(self) -> list[ConversationTurn]:
211          """Get all completed turns."""
212          return self._turns.copy()
213  
214      @property
215      def last_result(self) -> Optional[ResultMessage]:
216          """Get the last ResultMessage from the SDK."""
217          return self._last_result
218  
219      @property
220      def was_interrupted(self) -> bool:
221          """Check if the session was interrupted."""
222          return self._interrupted
223  
224      async def connect(self, initial_prompt: Optional[str] = None) -> None:
225          """
226          Connect to Claude and optionally send an initial prompt.
227  
228          Args:
229              initial_prompt: Optional first prompt to send immediately.
230  
231          Raises:
232              RuntimeError: If connection fails.
233          """
234          if self._connected:
235              logger.warning("Already connected, ignoring connect() call")
236              return
237  
238          self._client = ClaudeSDKClient(options=self._options)
239          try:
240              await self._client.connect(initial_prompt)
241              self._connected = True
242              logger.info("ConversationSession connected")
243          except Exception as e:
244              # Clean up client on connection failure
245              self._client = None
246              self._connected = False
247              logger.error(f"Failed to connect: {e}")
248              raise RuntimeError(f"Failed to connect to Claude: {e}") from e
249  
250      async def disconnect(self) -> None:
251          """Disconnect from Claude and clean up."""
252          if not self._connected or not self._client:
253              return
254  
255          try:
256              await self._client.disconnect()
257          except Exception as e:
258              logger.warning(f"Error during disconnect: {e}")
259          finally:
260              self._connected = False
261              self._client = None
262              logger.info(
263                  f"ConversationSession disconnected after {self.turn_count} turns"
264              )
265  
266      async def query(self, prompt: str) -> None:
267          """
268          Send a query to Claude.
269  
270          This starts a new turn in the conversation. Use receive()
271          to get the response messages.
272  
273          Args:
274              prompt: The user prompt to send.
275  
276          Raises:
277              RuntimeError: If not connected.
278          """
279          if not self._connected or not self._client:
280              raise RuntimeError(
281                  "Not connected. Call connect() or use 'async with' context."
282              )
283  
284          # Start new turn
285          self._current_turn = ConversationTurn(
286              turn_number=len(self._turns) + 1,
287              prompt=prompt,
288              started_at=datetime.now(),
289          )
290          self._interrupted = False
291  
292          logger.debug(f"Starting turn {self._current_turn.turn_number}: {prompt[:50]}...")
293  
294          await self._client.query(prompt)
295  
296      async def receive(self) -> AsyncIterator[Any]:
297          """
298          Receive messages from the current query.
299  
300          Yields messages until a ResultMessage is received,
301          which indicates the query is complete.
302  
303          Yields:
304              Messages from Claude (AssistantMessage, SystemMessage, etc.)
305  
306          Note:
307              If an exception occurs during message processing, the current
308              turn is marked as an error and stored for debugging.
309          """
310          if not self._connected or not self._client:
311              raise RuntimeError("Not connected")
312  
313          if not self._current_turn:
314              raise RuntimeError("No active query. Call query() first.")
315  
316          response_parts: list[str] = []
317  
318          try:
319              async for message in self._client.receive_response():
320                  # Process message for turn tracking
321                  if isinstance(message, AssistantMessage):
322                      for block in message.content:
323                          if isinstance(block, TextBlock):
324                              response_parts.append(block.text)
325                          elif isinstance(block, ToolUseBlock):
326                              self._current_turn.tools_used.append(block.name)
327                              if self._on_tool_start:
328                                  self._on_tool_start(block.name, block.input, block.id)
329                          elif isinstance(block, ToolResultBlock):
330                              is_error = block.is_error or False
331                              if self._on_tool_complete:
332                                  self._on_tool_complete(
333                                      "unknown",  # Tool name not in result block
334                                      block.tool_use_id,
335                                      block.content,
336                                      is_error
337                                  )
338  
339                  elif isinstance(message, SystemMessage):
340                      # Extract session ID from init message
341                      if message.subtype == "init":
342                          self._session_id = message.data.get("session_id")
343                          logger.debug(f"Session ID: {self._session_id}")
344  
345                  elif isinstance(message, ResultMessage):
346                      # Turn complete - process and terminate
347                      self._last_result = message
348                      self._current_turn.completed_at = datetime.now()
349                      self._current_turn.duration_ms = message.duration_ms
350                      self._current_turn.response_text = "".join(response_parts)
351                      self._current_turn.is_error = message.is_error
352  
353                      if message.is_error:
354                          self._current_turn.error_message = message.result
355  
356                      # Update metrics
357                      self._metrics.add_result(message)
358  
359                      # Store completed turn
360                      self._turns.append(self._current_turn)
361  
362                      if self._on_turn_complete:
363                          self._on_turn_complete(self._current_turn)
364  
365                      self._current_turn = None
366                      logger.debug(
367                          f"Turn complete: {message.num_turns} SDK turns, "
368                          f"{message.duration_ms}ms"
369                      )
370  
371                      # Invoke callback and yield final message
372                      if self._on_message:
373                          self._on_message(message)
374                      yield message
375  
376                      # Terminate generator - ResultMessage signals query completion
377                      return
378  
379                  # For non-terminal messages: invoke callback and yield
380                  if self._on_message:
381                      self._on_message(message)
382                  yield message
383  
384          except Exception as e:
385              # Handle exception during message iteration
386              logger.error(f"Error during receive: {e}")
387  
388              if self._current_turn:
389                  # Mark turn as error and store it
390                  self._current_turn.completed_at = datetime.now()
391                  self._current_turn.is_error = True
392                  self._current_turn.error_message = str(e)
393                  self._current_turn.response_text = "".join(response_parts)
394                  self._turns.append(self._current_turn)
395                  self._current_turn = None
396  
397              raise
398  
399      async def interrupt(self) -> None:
400          """
401          Interrupt the current query.
402  
403          This sends a signal to stop Claude mid-execution.
404          Use this for long-running tasks that need to be cancelled.
405          """
406          if not self._connected or not self._client:
407              logger.warning("Cannot interrupt: not connected")
408              return
409  
410          self._interrupted = True
411          await self._client.interrupt()
412          logger.info("Sent interrupt signal")
413  
414      async def query_and_receive(self, prompt: str) -> ResultMessage:
415          """
416          Convenience method to send a query and collect all messages.
417  
418          Args:
419              prompt: The user prompt to send.
420  
421          Returns:
422              The final ResultMessage.
423  
424          Raises:
425              RuntimeError: If no ResultMessage is received.
426          """
427          await self.query(prompt)
428  
429          async for message in self.receive():
430              if isinstance(message, ResultMessage):
431                  return message
432  
433          raise RuntimeError("No ResultMessage received")
434  
435      def get_session_context(self, working_dir: str) -> SessionContext:
436          """
437          Create a SessionContext from current conversation state.
438  
439          Args:
440              working_dir: The working directory for the session.
441  
442          Returns:
443              SessionContext with accumulated metrics.
444          """
445          return SessionContext(
446              session_id=self._session_id or "unknown",
447              working_dir=working_dir,
448              claude_session_id=self._session_id,
449              cumulative_turns=self._metrics.agent_turns,
450              cumulative_duration_ms=self._metrics.total_duration_ms,
451              cumulative_cost_usd=self._metrics.total_cost_usd,
452              cumulative_input_tokens=self._metrics.input_tokens,
453              cumulative_output_tokens=self._metrics.output_tokens,
454              cumulative_cache_creation_tokens=self._metrics.cache_creation_tokens,
455              cumulative_cache_read_tokens=self._metrics.cache_read_tokens,
456          )
457  
458      async def __aenter__(self) -> "ConversationSession":
459          """Async context manager entry."""
460          await self.connect()
461          return self
462  
463      async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
464          """Async context manager exit."""
465          await self.disconnect()