/ mlflow / claude_code / tracing.py
tracing.py
  1  """MLflow tracing integration for Claude Code interactions."""
  2  
  3  import dataclasses
  4  import json
  5  import logging
  6  import os
  7  import sys
  8  from datetime import datetime
  9  from pathlib import Path
 10  from typing import Any
 11  
 12  import dateutil.parser
 13  
 14  import mlflow
 15  from mlflow.claude_code.config import (
 16      MLFLOW_TRACING_ENABLED,
 17      get_env_var,
 18  )
 19  from mlflow.entities import SpanType
 20  from mlflow.environment_variables import (
 21      MLFLOW_EXPERIMENT_ID,
 22      MLFLOW_EXPERIMENT_NAME,
 23      MLFLOW_TRACKING_URI,
 24  )
 25  from mlflow.telemetry.events import AutologgingEvent
 26  from mlflow.telemetry.track import _record_event
 27  from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey
 28  from mlflow.tracing.provider import _get_trace_exporter
 29  from mlflow.tracing.trace_manager import InMemoryTraceManager
 30  
 31  # ============================================================================
 32  # CONSTANTS
 33  # ============================================================================
 34  
 35  # Used multiple times across the module
 36  NANOSECONDS_PER_MS = 1e6
 37  NANOSECONDS_PER_S = 1e9
 38  MAX_PREVIEW_LENGTH = 1000
 39  
 40  MESSAGE_TYPE_USER = "user"
 41  MESSAGE_TYPE_ASSISTANT = "assistant"
 42  CONTENT_TYPE_TEXT = "text"
 43  CONTENT_TYPE_TOOL_USE = "tool_use"
 44  CONTENT_TYPE_TOOL_RESULT = "tool_result"
 45  MESSAGE_FIELD_CONTENT = "content"
 46  MESSAGE_FIELD_TYPE = "type"
 47  MESSAGE_FIELD_MESSAGE = "message"
 48  MESSAGE_FIELD_TIMESTAMP = "timestamp"
 49  MESSAGE_FIELD_TOOL_USE_RESULT = "toolUseResult"
 50  MESSAGE_FIELD_COMMAND_NAME = "commandName"
 51  MESSAGE_TYPE_QUEUE_OPERATION = "queue-operation"
 52  QUEUE_OPERATION_ENQUEUE = "enqueue"
 53  METADATA_KEY_CLAUDE_CODE_VERSION = "mlflow.claude_code_version"
 54  
 55  # Custom logging level for Claude tracing
 56  CLAUDE_TRACING_LEVEL = logging.WARNING - 5
 57  
 58  
 59  # ============================================================================
 60  # LOGGING AND SETUP
 61  # ============================================================================
 62  
 63  
 64  def setup_logging() -> logging.Logger:
 65      """Set up logging directory and return configured logger.
 66  
 67      Creates .claude/mlflow directory structure and configures file-based logging
 68      with INFO level. Prevents log propagation to avoid duplicate messages.
 69      """
 70      # Create logging directory structure
 71      log_dir = Path(os.getcwd()) / ".claude" / "mlflow"
 72      log_dir.mkdir(parents=True, exist_ok=True)
 73  
 74      logger = logging.getLogger(__name__)
 75      logger.handlers.clear()  # Remove any existing handlers
 76  
 77      # Configure file handler with timestamp formatting
 78      log_file = log_dir / "claude_tracing.log"
 79      file_handler = logging.FileHandler(log_file)
 80      file_handler.setFormatter(
 81          logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
 82      )
 83      logger.addHandler(file_handler)
 84      logging.addLevelName(CLAUDE_TRACING_LEVEL, "CLAUDE_TRACING")
 85      logger.setLevel(CLAUDE_TRACING_LEVEL)
 86      logger.propagate = False  # Prevent duplicate log messages
 87  
 88      return logger
 89  
 90  
 91  _MODULE_LOGGER: logging.Logger | None = None
 92  
 93  
 94  def get_logger() -> logging.Logger:
 95      """Get the configured module logger."""
 96      global _MODULE_LOGGER
 97  
 98      if _MODULE_LOGGER is None:
 99          _MODULE_LOGGER = setup_logging()
100      return _MODULE_LOGGER
101  
102  
103  def setup_mlflow() -> None:
104      """Configure MLflow tracking URI and experiment."""
105      if not is_tracing_enabled():
106          return
107  
108      # Get tracking URI from environment/settings
109      mlflow.set_tracking_uri(get_env_var(MLFLOW_TRACKING_URI.name))
110  
111      # Set experiment if specified via environment variables
112      experiment_id = get_env_var(MLFLOW_EXPERIMENT_ID.name)
113      experiment_name = get_env_var(MLFLOW_EXPERIMENT_NAME.name)
114  
115      try:
116          if experiment_id:
117              mlflow.set_experiment(experiment_id=experiment_id)
118          elif experiment_name:
119              mlflow.set_experiment(experiment_name)
120      except Exception as e:
121          get_logger().warning("Failed to set experiment: %s", e)
122  
123      _record_event(AutologgingEvent, {"flavor": "claude_code"})
124  
125  
126  def is_tracing_enabled() -> bool:
127      """Check if MLflow Claude tracing is enabled via environment variable."""
128      return get_env_var(MLFLOW_TRACING_ENABLED).lower() in ("true", "1", "yes")
129  
130  
131  # ============================================================================
132  # INPUT/OUTPUT UTILITIES
133  # ============================================================================
134  
135  
136  def read_hook_input() -> dict[str, Any]:
137      """Read JSON input from stdin for Claude Code hook processing."""
138      try:
139          input_data = sys.stdin.read()
140          return json.loads(input_data)
141      except json.JSONDecodeError as e:
142          raise json.JSONDecodeError(f"Failed to parse hook input: {e}", input_data, 0) from e
143  
144  
145  def read_transcript(transcript_path: str) -> list[dict[str, Any]]:
146      """Read and parse a Claude Code conversation transcript from JSONL file."""
147      with open(transcript_path, encoding="utf-8") as f:
148          lines = f.readlines()
149          return [json.loads(line) for line in lines if line.strip()]
150  
151  
152  def get_hook_response(error: str | None = None, **kwargs) -> dict[str, Any]:
153      """Build hook response dictionary for Claude Code hook protocol.
154  
155      Args:
156          error: Error message if hook failed, None if successful
157          kwargs: Additional fields to include in response
158  
159      Returns:
160          Hook response dictionary
161      """
162      if error is not None:
163          return {"continue": False, "stopReason": error, **kwargs}
164      return {"continue": True, **kwargs}
165  
166  
167  # ============================================================================
168  # TIMESTAMP AND CONTENT PARSING UTILITIES
169  # ============================================================================
170  
171  
172  def parse_timestamp_to_ns(timestamp: str | int | float | None) -> int | None:
173      """Convert various timestamp formats to nanoseconds since Unix epoch.
174  
175      Args:
176          timestamp: Can be ISO string, Unix timestamp (seconds/ms), or nanoseconds
177  
178      Returns:
179          Nanoseconds since Unix epoch, or None if parsing fails
180      """
181      if not timestamp:
182          return None
183  
184      if isinstance(timestamp, str):
185          try:
186              dt = dateutil.parser.parse(timestamp)
187              return int(dt.timestamp() * NANOSECONDS_PER_S)
188          except Exception:
189              get_logger().warning("Could not parse timestamp: %s", timestamp)
190              return None
191      if isinstance(timestamp, (int, float)):
192          if timestamp < 1e10:
193              return int(timestamp * NANOSECONDS_PER_S)
194          if timestamp < 1e13:
195              return int(timestamp * NANOSECONDS_PER_MS)
196          return int(timestamp)
197  
198      return None
199  
200  
201  def extract_text_content(content: str | list[dict[str, Any]] | Any) -> str:
202      """Extract text content from Claude message content (handles both string and list formats).
203  
204      Args:
205          content: Either a string or list of content parts from Claude API
206  
207      Returns:
208          Extracted text content, empty string if none found
209      """
210      if isinstance(content, list):
211          text_parts = [
212              part.get(CONTENT_TYPE_TEXT, "")
213              for part in content
214              if isinstance(part, dict) and part.get(MESSAGE_FIELD_TYPE) == CONTENT_TYPE_TEXT
215          ]
216          return "\n".join(text_parts)
217      if isinstance(content, str):
218          return content
219      return str(content)
220  
221  
222  def find_last_user_message_index(transcript: list[dict[str, Any]]) -> int | None:
223      """Find the index of the last actual user message (ignoring tool results and empty messages).
224  
225      Args:
226          transcript: List of conversation entries from Claude Code transcript
227  
228      Returns:
229          Index of last user message, or None if not found
230      """
231      for i in range(len(transcript) - 1, -1, -1):
232          entry = transcript[i]
233          if entry.get(MESSAGE_FIELD_TYPE) == MESSAGE_TYPE_USER and not entry.get(
234              MESSAGE_FIELD_TOOL_USE_RESULT
235          ):
236              # Skip skill content injections: a user message immediately following
237              # a Skill tool result (which has toolUseResult with commandName)
238              if (
239                  i > 0
240                  and isinstance(
241                      prev_tool_result := transcript[i - 1].get(MESSAGE_FIELD_TOOL_USE_RESULT), dict
242                  )
243                  and prev_tool_result.get(MESSAGE_FIELD_COMMAND_NAME)
244              ):
245                  continue
246  
247              msg = entry.get(MESSAGE_FIELD_MESSAGE, {})
248              content = msg.get(MESSAGE_FIELD_CONTENT, "")
249  
250              if isinstance(content, list) and len(content) > 0:
251                  if (
252                      isinstance(content[0], dict)
253                      and content[0].get(MESSAGE_FIELD_TYPE) == CONTENT_TYPE_TOOL_RESULT
254                  ):
255                      continue
256  
257              if isinstance(content, str) and "<local-command-stdout>" in content:
258                  continue
259  
260              if not content or (isinstance(content, str) and content.strip() == ""):
261                  continue
262  
263              return i
264      return None
265  
266  
267  # ============================================================================
268  # TRANSCRIPT PROCESSING HELPERS
269  # ============================================================================
270  
271  
272  def _get_next_timestamp_ns(transcript: list[dict[str, Any]], current_idx: int) -> int | None:
273      """Get the timestamp of the next entry for duration calculation."""
274      for i in range(current_idx + 1, len(transcript)):
275          if timestamp := transcript[i].get(MESSAGE_FIELD_TIMESTAMP):
276              return parse_timestamp_to_ns(timestamp)
277      return None
278  
279  
280  def _extract_content_and_tools(content: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
281      """Extract text content and tool uses from assistant response content."""
282      text_content = ""
283      tool_uses = []
284  
285      if isinstance(content, list):
286          for part in content:
287              if isinstance(part, dict):
288                  if part.get(MESSAGE_FIELD_TYPE) == CONTENT_TYPE_TEXT:
289                      text_content += part.get(CONTENT_TYPE_TEXT, "")
290                  elif part.get(MESSAGE_FIELD_TYPE) == CONTENT_TYPE_TOOL_USE:
291                      tool_uses.append(part)
292  
293      return text_content, tool_uses
294  
295  
296  def _find_tool_results(transcript: list[dict[str, Any]], start_idx: int) -> dict[str, Any]:
297      """Find tool results following the current assistant response.
298  
299      Returns a mapping from tool_use_id to tool result content.
300      """
301      tool_results = {}
302  
303      # Look for tool results in subsequent entries
304      for i in range(start_idx + 1, len(transcript)):
305          entry = transcript[i]
306          if entry.get(MESSAGE_FIELD_TYPE) != MESSAGE_TYPE_USER:
307              continue
308  
309          msg = entry.get(MESSAGE_FIELD_MESSAGE, {})
310          content = msg.get(MESSAGE_FIELD_CONTENT, [])
311  
312          if isinstance(content, list):
313              for part in content:
314                  if (
315                      isinstance(part, dict)
316                      and part.get(MESSAGE_FIELD_TYPE) == CONTENT_TYPE_TOOL_RESULT
317                  ):
318                      tool_use_id = part.get("tool_use_id")
319                      result_content = part.get("content", "")
320                      if tool_use_id:
321                          tool_results[tool_use_id] = result_content
322  
323          # Stop looking once we hit the next assistant response
324          if entry.get(MESSAGE_FIELD_TYPE) == MESSAGE_TYPE_ASSISTANT:
325              break
326  
327      return tool_results
328  
329  
330  def _get_input_messages(transcript: list[dict[str, Any]], current_idx: int) -> list[dict[str, Any]]:
331      """Get all messages between the previous text-bearing assistant response and the current one.
332  
333      Claude Code emits separate transcript entries for text and tool_use content.
334      A typical sequence looks like:
335          assistant [text]        ← previous LLM boundary (stop here)
336          assistant [tool_use]    ← include
337          user [tool_result]      ← include
338          assistant [tool_use]    ← include
339          user [tool_result]      ← include
340          assistant [text]        ← current (the span we're building inputs for)
341  
342      We walk backward and collect everything, only stopping when we hit an
343      assistant entry that contains text content (which marks the previous LLM span).
344  
345      Args:
346          transcript: List of conversation entries from Claude Code transcript
347          current_idx: Index of the current assistant response
348  
349      Returns:
350          List of messages in Anthropic format
351      """
352      messages = []
353      for i in range(current_idx - 1, -1, -1):
354          entry = transcript[i]
355          msg = entry.get(MESSAGE_FIELD_MESSAGE, {})
356  
357          # Stop at a previous assistant entry that has text content (previous LLM span)
358          if entry.get(MESSAGE_FIELD_TYPE) == MESSAGE_TYPE_ASSISTANT:
359              content = msg.get(MESSAGE_FIELD_CONTENT, [])
360              has_text = False
361              if isinstance(content, str):
362                  has_text = bool(content.strip())
363              elif isinstance(content, list):
364                  has_text = any(
365                      isinstance(p, dict) and p.get(MESSAGE_FIELD_TYPE) == CONTENT_TYPE_TEXT
366                      for p in content
367                  )
368              if has_text:
369                  break
370  
371          # Include steer messages (queue-operation enqueue) as user messages
372          if (
373              entry.get(MESSAGE_FIELD_TYPE) == MESSAGE_TYPE_QUEUE_OPERATION
374              and entry.get("operation") == QUEUE_OPERATION_ENQUEUE
375              and (steer_content := entry.get(MESSAGE_FIELD_CONTENT))
376          ):
377              messages.append({"role": "user", "content": steer_content})
378              continue
379  
380          if msg.get("role") and msg.get(MESSAGE_FIELD_CONTENT):
381              messages.append(msg)
382      messages.reverse()
383      return messages
384  
385  
386  def _build_usage_dict(usage: dict[str, Any]) -> dict[str, int]:
387      """Normalize a Claude Code usage payload into the CHAT_USAGE schema.
388  
389      Stores fields as the Anthropic API reports them, matching
390      ``mlflow.anthropic.autolog``: ``input_tokens`` is the non-cached input,
391      cache tokens are exposed as separate optional keys so consumers can
392      compute cache hit rate, and ``total_tokens`` follows the
393      ``mlflow.anthropic`` convention of ``input_tokens + output_tokens``
394      (cache tokens excluded).
395      """
396      input_tokens = usage.get("input_tokens", 0)
397      output_tokens = usage.get("output_tokens", 0)
398  
399      usage_dict: dict[str, int] = {
400          TokenUsageKey.INPUT_TOKENS: input_tokens,
401          TokenUsageKey.OUTPUT_TOKENS: output_tokens,
402          TokenUsageKey.TOTAL_TOKENS: input_tokens + output_tokens,
403      }
404      if (cached := usage.get("cache_read_input_tokens")) is not None:
405          usage_dict[TokenUsageKey.CACHE_READ_INPUT_TOKENS] = cached
406      if (created := usage.get("cache_creation_input_tokens")) is not None:
407          usage_dict[TokenUsageKey.CACHE_CREATION_INPUT_TOKENS] = created
408      return usage_dict
409  
410  
411  def _set_token_usage_attribute(span, usage: dict[str, Any]) -> None:
412      """Set token usage on a span using the standardized CHAT_USAGE attribute.
413  
414      Args:
415          span: The MLflow span to set token usage on
416          usage: Dictionary containing token usage info from Claude Code transcript
417      """
418      if not usage:
419          return
420  
421      span.set_attribute(SpanAttributeKey.CHAT_USAGE, _build_usage_dict(usage))
422  
423  
424  def _create_llm_and_tool_spans(
425      parent_span, transcript: list[dict[str, Any]], start_idx: int
426  ) -> None:
427      """Create LLM and tool spans for assistant responses with proper timing."""
428      for i in range(start_idx, len(transcript)):
429          entry = transcript[i]
430          if entry.get(MESSAGE_FIELD_TYPE) != MESSAGE_TYPE_ASSISTANT:
431              continue
432  
433          timestamp_ns = parse_timestamp_to_ns(entry.get(MESSAGE_FIELD_TIMESTAMP))
434  
435          # Calculate duration based on next timestamp or use default
436          if next_timestamp_ns := _get_next_timestamp_ns(transcript, i):
437              duration_ns = next_timestamp_ns - timestamp_ns
438          else:
439              duration_ns = int(1000 * NANOSECONDS_PER_MS)  # 1 second default
440  
441          msg = entry.get(MESSAGE_FIELD_MESSAGE, {})
442          content = msg.get(MESSAGE_FIELD_CONTENT, [])
443          usage = msg.get("usage", {})
444  
445          # First check if we have meaningful content to create a span for
446          text_content, tool_uses = _extract_content_and_tools(content)
447  
448          # Only create LLM span if there's text content (no tools)
449          llm_span = None
450          if text_content and text_content.strip() and not tool_uses:
451              messages = _get_input_messages(transcript, i)
452  
453              llm_span = mlflow.start_span_no_context(
454                  name="llm",
455                  parent_span=parent_span,
456                  span_type=SpanType.LLM,
457                  start_time_ns=timestamp_ns,
458                  inputs={
459                      "model": msg.get("model", "unknown"),
460                      "messages": messages,
461                  },
462                  attributes={
463                      "model": msg.get("model", "unknown"),
464                      SpanAttributeKey.MESSAGE_FORMAT: "anthropic",
465                  },
466              )
467  
468              # Set token usage using the standardized CHAT_USAGE attribute
469              _set_token_usage_attribute(llm_span, usage)
470  
471              # Output in Anthropic response format for Chat UI rendering
472              llm_span.set_outputs({
473                  "type": "message",
474                  "role": "assistant",
475                  "content": content,
476              })
477              llm_span.end(end_time_ns=timestamp_ns + duration_ns)
478  
479          # Create tool spans with proportional timing and actual results
480          if tool_uses:
481              tool_results = _find_tool_results(transcript, i)
482              tool_duration_ns = duration_ns // len(tool_uses)
483  
484              for idx, tool_use in enumerate(tool_uses):
485                  tool_start_ns = timestamp_ns + (idx * tool_duration_ns)
486                  tool_use_id = tool_use.get("id", "")
487                  tool_result = tool_results.get(tool_use_id, "No result found")
488  
489                  tool_span = mlflow.start_span_no_context(
490                      name=f"tool_{tool_use.get('name', 'unknown')}",
491                      parent_span=parent_span,
492                      span_type=SpanType.TOOL,
493                      start_time_ns=tool_start_ns,
494                      inputs=tool_use.get("input", {}),
495                      attributes={
496                          "tool_name": tool_use.get("name", "unknown"),
497                          "tool_id": tool_use_id,
498                      },
499                  )
500  
501                  tool_span.set_outputs({"result": tool_result})
502                  tool_span.end(end_time_ns=tool_start_ns + tool_duration_ns)
503  
504  
505  def _finalize_trace(
506      parent_span,
507      user_prompt: str,
508      final_response: str | None,
509      session_id: str | None,
510      end_time_ns: int | None = None,
511      usage: dict[str, Any] | None = None,
512      claude_code_version: str | None = None,
513  ) -> mlflow.entities.Trace:
514      try:
515          # Set trace previews and metadata for UI display
516          with InMemoryTraceManager.get_instance().get_trace(parent_span.trace_id) as in_memory_trace:
517              if user_prompt:
518                  in_memory_trace.info.request_preview = user_prompt[:MAX_PREVIEW_LENGTH]
519              if final_response:
520                  in_memory_trace.info.response_preview = final_response[:MAX_PREVIEW_LENGTH]
521  
522              metadata = {
523                  TraceMetadataKey.TRACE_USER: os.environ.get("USER", ""),
524                  "mlflow.trace.working_directory": os.getcwd(),
525              }
526              if session_id:
527                  metadata[TraceMetadataKey.TRACE_SESSION] = session_id
528              if claude_code_version:
529                  metadata[METADATA_KEY_CLAUDE_CODE_VERSION] = claude_code_version
530  
531              # Set token usage directly on trace metadata so it survives
532              # even if span-level aggregation doesn't pick it up
533              if usage:
534                  metadata[TraceMetadataKey.TOKEN_USAGE] = json.dumps(_build_usage_dict(usage))
535  
536              in_memory_trace.info.trace_metadata = {
537                  **in_memory_trace.info.trace_metadata,
538                  **metadata,
539              }
540      except Exception as e:
541          get_logger().warning("Failed to update trace metadata and previews: %s", e)
542  
543      outputs = {"status": "completed"}
544      if final_response:
545          outputs["response"] = final_response
546      parent_span.set_outputs(outputs)
547      parent_span.end(end_time_ns=end_time_ns)
548      _flush_trace_async_logging()
549      get_logger().log(CLAUDE_TRACING_LEVEL, "Created MLflow trace: %s", parent_span.trace_id)
550      return mlflow.get_trace(parent_span.trace_id)
551  
552  
553  def _flush_trace_async_logging() -> None:
554      try:
555          if hasattr(_get_trace_exporter(), "_async_queue"):
556              mlflow.flush_trace_async_logging()
557      except Exception as e:
558          get_logger().debug("Failed to flush trace async logging: %s", e)
559  
560  
561  def find_final_assistant_response(transcript: list[dict[str, Any]], start_idx: int) -> str | None:
562      """Find the final text response from the assistant for trace preview.
563  
564      Args:
565          transcript: List of conversation entries from Claude Code transcript
566          start_idx: Index to start searching from (typically after last user message)
567  
568      Returns:
569          Final assistant response text or None
570      """
571      final_response = None
572  
573      for i in range(start_idx, len(transcript)):
574          entry = transcript[i]
575          if entry.get(MESSAGE_FIELD_TYPE) != MESSAGE_TYPE_ASSISTANT:
576              continue
577  
578          msg = entry.get(MESSAGE_FIELD_MESSAGE, {})
579          content = msg.get(MESSAGE_FIELD_CONTENT, [])
580  
581          if isinstance(content, list):
582              for part in content:
583                  if isinstance(part, dict) and part.get(MESSAGE_FIELD_TYPE) == CONTENT_TYPE_TEXT:
584                      text = part.get(CONTENT_TYPE_TEXT, "")
585                      if text.strip():
586                          final_response = text
587  
588      return final_response
589  
590  
591  # ============================================================================
592  # MAIN TRANSCRIPT PROCESSING
593  # ============================================================================
594  
595  
596  def process_transcript(
597      transcript_path: str, session_id: str | None = None
598  ) -> mlflow.entities.Trace | None:
599      """Process a Claude conversation transcript and create an MLflow trace with spans.
600  
601      Args:
602          transcript_path: Path to the Claude Code transcript.jsonl file
603          session_id: Optional session identifier, defaults to timestamp-based ID
604  
605      Returns:
606          MLflow trace object if successful, None if processing fails
607      """
608      try:
609          transcript = read_transcript(transcript_path)
610          if not transcript:
611              get_logger().warning("Empty transcript, skipping")
612              return None
613  
614          last_user_idx = find_last_user_message_index(transcript)
615          if last_user_idx is None:
616              get_logger().warning("No user message found in transcript")
617              return None
618  
619          last_user_entry = transcript[last_user_idx]
620          last_user_prompt = last_user_entry.get(MESSAGE_FIELD_MESSAGE, {}).get(
621              MESSAGE_FIELD_CONTENT, ""
622          )
623  
624          if not session_id:
625              session_id = f"claude-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
626  
627          get_logger().log(CLAUDE_TRACING_LEVEL, "Creating MLflow trace for session: %s", session_id)
628  
629          conv_start_ns = parse_timestamp_to_ns(last_user_entry.get(MESSAGE_FIELD_TIMESTAMP))
630  
631          parent_span = mlflow.start_span_no_context(
632              name="claude_code_conversation",
633              inputs={"prompt": extract_text_content(last_user_prompt)},
634              start_time_ns=conv_start_ns,
635              span_type=SpanType.AGENT,
636          )
637  
638          # Create spans for all assistant responses and tool uses
639          _create_llm_and_tool_spans(parent_span, transcript, last_user_idx + 1)
640  
641          # Update trace with preview content and end timing
642          final_response = find_final_assistant_response(transcript, last_user_idx + 1)
643          user_prompt_text = extract_text_content(last_user_prompt)
644  
645          # Calculate end time based on last entry or use default duration
646          last_entry = transcript[-1] if transcript else last_user_entry
647          conv_end_ns = parse_timestamp_to_ns(last_entry.get(MESSAGE_FIELD_TIMESTAMP))
648          if not conv_end_ns or conv_end_ns <= conv_start_ns:
649              conv_end_ns = conv_start_ns + int(10 * NANOSECONDS_PER_S)
650  
651          # Extract Claude Code version from transcript entries (CLI-only)
652          claude_code_version = next(
653              (ver for entry in transcript if (ver := entry.get("version"))), None
654          )
655  
656          return _finalize_trace(
657              parent_span,
658              user_prompt_text,
659              final_response,
660              session_id,
661              conv_end_ns,
662              claude_code_version=claude_code_version,
663          )
664  
665      except Exception as e:
666          get_logger().error("Error processing transcript: %s", e, exc_info=True)
667          return None
668  
669  
670  # ============================================================================
671  # SDK MESSAGE PROCESSING
672  # ============================================================================
673  
674  
675  def _find_sdk_user_prompt(messages: list[Any]) -> str | None:
676      from claude_agent_sdk.types import TextBlock, UserMessage
677  
678      for msg in messages:
679          if not isinstance(msg, UserMessage) or msg.tool_use_result is not None:
680              continue
681          content = msg.content
682          if isinstance(content, str):
683              text = content
684          elif isinstance(content, list):
685              text = "\n".join(block.text for block in content if isinstance(block, TextBlock))
686          else:
687              continue
688          if text and text.strip():
689              return text
690      return None
691  
692  
693  def _build_tool_result_map(messages: list[Any]) -> dict[str, str]:
694      """Map tool_use_id to its result content so tool spans can show outputs."""
695      from claude_agent_sdk.types import ToolResultBlock, UserMessage
696  
697      tool_result_map: dict[str, str] = {}
698      for msg in messages:
699          if isinstance(msg, UserMessage) and isinstance(msg.content, list):
700              for block in msg.content:
701                  if isinstance(block, ToolResultBlock):
702                      result = block.content
703                      if isinstance(result, list):
704                          result = str(result)
705                      tool_result_map[block.tool_use_id] = result or ""
706      return tool_result_map
707  
708  
709  # Maps SDK dataclass names to Anthropic API "type" discriminators.
710  # dataclasses.asdict() gives us the fields but not the type tag that
711  # the Anthropic message format requires on every content block.
712  _CONTENT_BLOCK_TYPES = {
713      "TextBlock": "text",
714      "ToolUseBlock": "tool_use",
715      "ToolResultBlock": "tool_result",
716  }
717  
718  
719  def _serialize_content_block(block) -> dict[str, Any] | None:
720      block_type = _CONTENT_BLOCK_TYPES.get(type(block).__name__)
721      if not block_type:
722          return None
723      fields = {key: value for key, value in dataclasses.asdict(block).items() if value is not None}
724      fields["type"] = block_type
725      return fields
726  
727  
728  def _serialize_sdk_message(msg) -> dict[str, Any] | None:
729      from claude_agent_sdk.types import AssistantMessage, UserMessage
730  
731      if isinstance(msg, UserMessage):
732          content = msg.content
733          if isinstance(content, str):
734              return {"role": "user", "content": content} if content.strip() else None
735          elif isinstance(content, list):
736              if parts := [
737                  serialized for block in content if (serialized := _serialize_content_block(block))
738              ]:
739                  return {"role": "user", "content": parts}
740      elif isinstance(msg, AssistantMessage) and msg.content:
741          if parts := [
742              serialized for block in msg.content if (serialized := _serialize_content_block(block))
743          ]:
744              return {"role": "assistant", "content": parts}
745      return None
746  
747  
748  def _create_sdk_child_spans(
749      messages: list[Any],
750      parent_span,
751      tool_result_map: dict[str, str],
752  ) -> str | None:
753      """Create LLM and tool child spans under ``parent_span`` from SDK messages."""
754      from claude_agent_sdk.types import AssistantMessage, TextBlock, ToolUseBlock
755  
756      final_response = None
757      pending_messages: list[dict[str, Any]] = []
758  
759      for msg in messages:
760          if isinstance(msg, AssistantMessage) and msg.content:
761              text_blocks = [block for block in msg.content if isinstance(block, TextBlock)]
762              tool_blocks = [block for block in msg.content if isinstance(block, ToolUseBlock)]
763  
764              if text_blocks and not tool_blocks:
765                  text = "\n".join(block.text for block in text_blocks)
766                  if text.strip():
767                      final_response = text
768  
769                  llm_span = mlflow.start_span_no_context(
770                      name="llm",
771                      parent_span=parent_span,
772                      span_type=SpanType.LLM,
773                      inputs={
774                          "model": getattr(msg, "model", "unknown"),
775                          "messages": pending_messages,
776                      },
777                      attributes={
778                          "model": getattr(msg, "model", "unknown"),
779                          SpanAttributeKey.MESSAGE_FORMAT: "anthropic",
780                      },
781                  )
782                  llm_span.set_outputs({
783                      "type": "message",
784                      "role": "assistant",
785                      "content": [{"type": "text", "text": block.text} for block in text_blocks],
786                  })
787                  llm_span.end()
788                  pending_messages = []
789                  continue
790  
791              for tool_block in tool_blocks:
792                  tool_span = mlflow.start_span_no_context(
793                      name=f"tool_{tool_block.name}",
794                      parent_span=parent_span,
795                      span_type=SpanType.TOOL,
796                      inputs=tool_block.input,
797                      attributes={"tool_name": tool_block.name, "tool_id": tool_block.id},
798                  )
799                  tool_span.set_outputs({"result": tool_result_map.get(tool_block.id, "")})
800                  tool_span.end()
801  
802          if anthropic_msg := _serialize_sdk_message(msg):
803              pending_messages.append(anthropic_msg)
804  
805      return final_response
806  
807  
808  def process_sdk_messages(
809      messages: list[Any], session_id: str | None = None
810  ) -> mlflow.entities.Trace | None:
811      """
812      Build an MLflow trace from Claude Agent SDK message objects.
813  
814      Args:
815          messages: List of SDK message objects (UserMessage, AssistantMessage,
816              ResultMessage, etc.) captured during a conversation.
817          session_id: Optional session identifier for grouping traces.
818  
819      Returns:
820          MLflow Trace if successful, None if no user prompt is found or processing fails.
821      """
822      from claude_agent_sdk.types import ResultMessage
823  
824      try:
825          if not messages:
826              get_logger().warning("Empty messages list, skipping")
827              return None
828  
829          user_prompt = _find_sdk_user_prompt(messages)
830          if user_prompt is None:
831              get_logger().warning("No user prompt found in SDK messages")
832              return None
833  
834          result_msg = next((msg for msg in messages if isinstance(msg, ResultMessage)), None)
835  
836          # Prefer the SDK's own session_id, fall back to caller arg
837          session_id = (result_msg.session_id if result_msg else None) or session_id
838  
839          get_logger().log(
840              CLAUDE_TRACING_LEVEL,
841              "Creating MLflow trace for session: %s",
842              session_id,
843          )
844  
845          tool_result_map = _build_tool_result_map(messages)
846  
847          if duration_ms := (getattr(result_msg, "duration_ms", None) if result_msg else None):
848              duration_ns = int(duration_ms * NANOSECONDS_PER_MS)
849              now_ns = int(datetime.now().timestamp() * NANOSECONDS_PER_S)
850              start_time_ns = now_ns - duration_ns
851              end_time_ns = now_ns
852          else:
853              start_time_ns = None
854              end_time_ns = None
855  
856          parent_span = mlflow.start_span_no_context(
857              name="claude_code_conversation",
858              inputs={"prompt": user_prompt},
859              span_type=SpanType.AGENT,
860              start_time_ns=start_time_ns,
861          )
862  
863          final_response = _create_sdk_child_spans(messages, parent_span, tool_result_map)
864  
865          # Set token usage on the root span so it aggregates into trace-level usage
866          usage = getattr(result_msg, "usage", None) if result_msg else None
867          if usage:
868              _set_token_usage_attribute(parent_span, usage)
869  
870          return _finalize_trace(
871              parent_span,
872              user_prompt,
873              final_response,
874              session_id,
875              end_time_ns=end_time_ns,
876              usage=usage,
877          )
878  
879      except Exception as e:
880          get_logger().error("Error processing SDK messages: %s", e, exc_info=True)
881          return None