tracing.py
1 """MLflow tracing integration for Claude Code interactions.""" 2 3 import dataclasses 4 import json 5 import logging 6 import os 7 import sys 8 from datetime import datetime 9 from pathlib import Path 10 from typing import Any 11 12 import dateutil.parser 13 14 import mlflow 15 from mlflow.claude_code.config import ( 16 MLFLOW_TRACING_ENABLED, 17 get_env_var, 18 ) 19 from mlflow.entities import SpanType 20 from mlflow.environment_variables import ( 21 MLFLOW_EXPERIMENT_ID, 22 MLFLOW_EXPERIMENT_NAME, 23 MLFLOW_TRACKING_URI, 24 ) 25 from mlflow.telemetry.events import AutologgingEvent 26 from mlflow.telemetry.track import _record_event 27 from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey 28 from mlflow.tracing.provider import _get_trace_exporter 29 from mlflow.tracing.trace_manager import InMemoryTraceManager 30 31 # ============================================================================ 32 # CONSTANTS 33 # ============================================================================ 34 35 # Used multiple times across the module 36 NANOSECONDS_PER_MS = 1e6 37 NANOSECONDS_PER_S = 1e9 38 MAX_PREVIEW_LENGTH = 1000 39 40 MESSAGE_TYPE_USER = "user" 41 MESSAGE_TYPE_ASSISTANT = "assistant" 42 CONTENT_TYPE_TEXT = "text" 43 CONTENT_TYPE_TOOL_USE = "tool_use" 44 CONTENT_TYPE_TOOL_RESULT = "tool_result" 45 MESSAGE_FIELD_CONTENT = "content" 46 MESSAGE_FIELD_TYPE = "type" 47 MESSAGE_FIELD_MESSAGE = "message" 48 MESSAGE_FIELD_TIMESTAMP = "timestamp" 49 MESSAGE_FIELD_TOOL_USE_RESULT = "toolUseResult" 50 MESSAGE_FIELD_COMMAND_NAME = "commandName" 51 MESSAGE_TYPE_QUEUE_OPERATION = "queue-operation" 52 QUEUE_OPERATION_ENQUEUE = "enqueue" 53 METADATA_KEY_CLAUDE_CODE_VERSION = "mlflow.claude_code_version" 54 55 # Custom logging level for Claude tracing 56 CLAUDE_TRACING_LEVEL = logging.WARNING - 5 57 58 59 # ============================================================================ 60 # LOGGING AND SETUP 61 # ============================================================================ 62 63 64 def setup_logging() -> logging.Logger: 65 """Set up logging directory and return configured logger. 66 67 Creates .claude/mlflow directory structure and configures file-based logging 68 with INFO level. Prevents log propagation to avoid duplicate messages. 69 """ 70 # Create logging directory structure 71 log_dir = Path(os.getcwd()) / ".claude" / "mlflow" 72 log_dir.mkdir(parents=True, exist_ok=True) 73 74 logger = logging.getLogger(__name__) 75 logger.handlers.clear() # Remove any existing handlers 76 77 # Configure file handler with timestamp formatting 78 log_file = log_dir / "claude_tracing.log" 79 file_handler = logging.FileHandler(log_file) 80 file_handler.setFormatter( 81 logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 82 ) 83 logger.addHandler(file_handler) 84 logging.addLevelName(CLAUDE_TRACING_LEVEL, "CLAUDE_TRACING") 85 logger.setLevel(CLAUDE_TRACING_LEVEL) 86 logger.propagate = False # Prevent duplicate log messages 87 88 return logger 89 90 91 _MODULE_LOGGER: logging.Logger | None = None 92 93 94 def get_logger() -> logging.Logger: 95 """Get the configured module logger.""" 96 global _MODULE_LOGGER 97 98 if _MODULE_LOGGER is None: 99 _MODULE_LOGGER = setup_logging() 100 return _MODULE_LOGGER 101 102 103 def setup_mlflow() -> None: 104 """Configure MLflow tracking URI and experiment.""" 105 if not is_tracing_enabled(): 106 return 107 108 # Get tracking URI from environment/settings 109 mlflow.set_tracking_uri(get_env_var(MLFLOW_TRACKING_URI.name)) 110 111 # Set experiment if specified via environment variables 112 experiment_id = get_env_var(MLFLOW_EXPERIMENT_ID.name) 113 experiment_name = get_env_var(MLFLOW_EXPERIMENT_NAME.name) 114 115 try: 116 if experiment_id: 117 mlflow.set_experiment(experiment_id=experiment_id) 118 elif experiment_name: 119 mlflow.set_experiment(experiment_name) 120 except Exception as e: 121 get_logger().warning("Failed to set experiment: %s", e) 122 123 _record_event(AutologgingEvent, {"flavor": "claude_code"}) 124 125 126 def is_tracing_enabled() -> bool: 127 """Check if MLflow Claude tracing is enabled via environment variable.""" 128 return get_env_var(MLFLOW_TRACING_ENABLED).lower() in ("true", "1", "yes") 129 130 131 # ============================================================================ 132 # INPUT/OUTPUT UTILITIES 133 # ============================================================================ 134 135 136 def read_hook_input() -> dict[str, Any]: 137 """Read JSON input from stdin for Claude Code hook processing.""" 138 try: 139 input_data = sys.stdin.read() 140 return json.loads(input_data) 141 except json.JSONDecodeError as e: 142 raise json.JSONDecodeError(f"Failed to parse hook input: {e}", input_data, 0) from e 143 144 145 def read_transcript(transcript_path: str) -> list[dict[str, Any]]: 146 """Read and parse a Claude Code conversation transcript from JSONL file.""" 147 with open(transcript_path, encoding="utf-8") as f: 148 lines = f.readlines() 149 return [json.loads(line) for line in lines if line.strip()] 150 151 152 def get_hook_response(error: str | None = None, **kwargs) -> dict[str, Any]: 153 """Build hook response dictionary for Claude Code hook protocol. 154 155 Args: 156 error: Error message if hook failed, None if successful 157 kwargs: Additional fields to include in response 158 159 Returns: 160 Hook response dictionary 161 """ 162 if error is not None: 163 return {"continue": False, "stopReason": error, **kwargs} 164 return {"continue": True, **kwargs} 165 166 167 # ============================================================================ 168 # TIMESTAMP AND CONTENT PARSING UTILITIES 169 # ============================================================================ 170 171 172 def parse_timestamp_to_ns(timestamp: str | int | float | None) -> int | None: 173 """Convert various timestamp formats to nanoseconds since Unix epoch. 174 175 Args: 176 timestamp: Can be ISO string, Unix timestamp (seconds/ms), or nanoseconds 177 178 Returns: 179 Nanoseconds since Unix epoch, or None if parsing fails 180 """ 181 if not timestamp: 182 return None 183 184 if isinstance(timestamp, str): 185 try: 186 dt = dateutil.parser.parse(timestamp) 187 return int(dt.timestamp() * NANOSECONDS_PER_S) 188 except Exception: 189 get_logger().warning("Could not parse timestamp: %s", timestamp) 190 return None 191 if isinstance(timestamp, (int, float)): 192 if timestamp < 1e10: 193 return int(timestamp * NANOSECONDS_PER_S) 194 if timestamp < 1e13: 195 return int(timestamp * NANOSECONDS_PER_MS) 196 return int(timestamp) 197 198 return None 199 200 201 def extract_text_content(content: str | list[dict[str, Any]] | Any) -> str: 202 """Extract text content from Claude message content (handles both string and list formats). 203 204 Args: 205 content: Either a string or list of content parts from Claude API 206 207 Returns: 208 Extracted text content, empty string if none found 209 """ 210 if isinstance(content, list): 211 text_parts = [ 212 part.get(CONTENT_TYPE_TEXT, "") 213 for part in content 214 if isinstance(part, dict) and part.get(MESSAGE_FIELD_TYPE) == CONTENT_TYPE_TEXT 215 ] 216 return "\n".join(text_parts) 217 if isinstance(content, str): 218 return content 219 return str(content) 220 221 222 def find_last_user_message_index(transcript: list[dict[str, Any]]) -> int | None: 223 """Find the index of the last actual user message (ignoring tool results and empty messages). 224 225 Args: 226 transcript: List of conversation entries from Claude Code transcript 227 228 Returns: 229 Index of last user message, or None if not found 230 """ 231 for i in range(len(transcript) - 1, -1, -1): 232 entry = transcript[i] 233 if entry.get(MESSAGE_FIELD_TYPE) == MESSAGE_TYPE_USER and not entry.get( 234 MESSAGE_FIELD_TOOL_USE_RESULT 235 ): 236 # Skip skill content injections: a user message immediately following 237 # a Skill tool result (which has toolUseResult with commandName) 238 if ( 239 i > 0 240 and isinstance( 241 prev_tool_result := transcript[i - 1].get(MESSAGE_FIELD_TOOL_USE_RESULT), dict 242 ) 243 and prev_tool_result.get(MESSAGE_FIELD_COMMAND_NAME) 244 ): 245 continue 246 247 msg = entry.get(MESSAGE_FIELD_MESSAGE, {}) 248 content = msg.get(MESSAGE_FIELD_CONTENT, "") 249 250 if isinstance(content, list) and len(content) > 0: 251 if ( 252 isinstance(content[0], dict) 253 and content[0].get(MESSAGE_FIELD_TYPE) == CONTENT_TYPE_TOOL_RESULT 254 ): 255 continue 256 257 if isinstance(content, str) and "<local-command-stdout>" in content: 258 continue 259 260 if not content or (isinstance(content, str) and content.strip() == ""): 261 continue 262 263 return i 264 return None 265 266 267 # ============================================================================ 268 # TRANSCRIPT PROCESSING HELPERS 269 # ============================================================================ 270 271 272 def _get_next_timestamp_ns(transcript: list[dict[str, Any]], current_idx: int) -> int | None: 273 """Get the timestamp of the next entry for duration calculation.""" 274 for i in range(current_idx + 1, len(transcript)): 275 if timestamp := transcript[i].get(MESSAGE_FIELD_TIMESTAMP): 276 return parse_timestamp_to_ns(timestamp) 277 return None 278 279 280 def _extract_content_and_tools(content: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]: 281 """Extract text content and tool uses from assistant response content.""" 282 text_content = "" 283 tool_uses = [] 284 285 if isinstance(content, list): 286 for part in content: 287 if isinstance(part, dict): 288 if part.get(MESSAGE_FIELD_TYPE) == CONTENT_TYPE_TEXT: 289 text_content += part.get(CONTENT_TYPE_TEXT, "") 290 elif part.get(MESSAGE_FIELD_TYPE) == CONTENT_TYPE_TOOL_USE: 291 tool_uses.append(part) 292 293 return text_content, tool_uses 294 295 296 def _find_tool_results(transcript: list[dict[str, Any]], start_idx: int) -> dict[str, Any]: 297 """Find tool results following the current assistant response. 298 299 Returns a mapping from tool_use_id to tool result content. 300 """ 301 tool_results = {} 302 303 # Look for tool results in subsequent entries 304 for i in range(start_idx + 1, len(transcript)): 305 entry = transcript[i] 306 if entry.get(MESSAGE_FIELD_TYPE) != MESSAGE_TYPE_USER: 307 continue 308 309 msg = entry.get(MESSAGE_FIELD_MESSAGE, {}) 310 content = msg.get(MESSAGE_FIELD_CONTENT, []) 311 312 if isinstance(content, list): 313 for part in content: 314 if ( 315 isinstance(part, dict) 316 and part.get(MESSAGE_FIELD_TYPE) == CONTENT_TYPE_TOOL_RESULT 317 ): 318 tool_use_id = part.get("tool_use_id") 319 result_content = part.get("content", "") 320 if tool_use_id: 321 tool_results[tool_use_id] = result_content 322 323 # Stop looking once we hit the next assistant response 324 if entry.get(MESSAGE_FIELD_TYPE) == MESSAGE_TYPE_ASSISTANT: 325 break 326 327 return tool_results 328 329 330 def _get_input_messages(transcript: list[dict[str, Any]], current_idx: int) -> list[dict[str, Any]]: 331 """Get all messages between the previous text-bearing assistant response and the current one. 332 333 Claude Code emits separate transcript entries for text and tool_use content. 334 A typical sequence looks like: 335 assistant [text] ← previous LLM boundary (stop here) 336 assistant [tool_use] ← include 337 user [tool_result] ← include 338 assistant [tool_use] ← include 339 user [tool_result] ← include 340 assistant [text] ← current (the span we're building inputs for) 341 342 We walk backward and collect everything, only stopping when we hit an 343 assistant entry that contains text content (which marks the previous LLM span). 344 345 Args: 346 transcript: List of conversation entries from Claude Code transcript 347 current_idx: Index of the current assistant response 348 349 Returns: 350 List of messages in Anthropic format 351 """ 352 messages = [] 353 for i in range(current_idx - 1, -1, -1): 354 entry = transcript[i] 355 msg = entry.get(MESSAGE_FIELD_MESSAGE, {}) 356 357 # Stop at a previous assistant entry that has text content (previous LLM span) 358 if entry.get(MESSAGE_FIELD_TYPE) == MESSAGE_TYPE_ASSISTANT: 359 content = msg.get(MESSAGE_FIELD_CONTENT, []) 360 has_text = False 361 if isinstance(content, str): 362 has_text = bool(content.strip()) 363 elif isinstance(content, list): 364 has_text = any( 365 isinstance(p, dict) and p.get(MESSAGE_FIELD_TYPE) == CONTENT_TYPE_TEXT 366 for p in content 367 ) 368 if has_text: 369 break 370 371 # Include steer messages (queue-operation enqueue) as user messages 372 if ( 373 entry.get(MESSAGE_FIELD_TYPE) == MESSAGE_TYPE_QUEUE_OPERATION 374 and entry.get("operation") == QUEUE_OPERATION_ENQUEUE 375 and (steer_content := entry.get(MESSAGE_FIELD_CONTENT)) 376 ): 377 messages.append({"role": "user", "content": steer_content}) 378 continue 379 380 if msg.get("role") and msg.get(MESSAGE_FIELD_CONTENT): 381 messages.append(msg) 382 messages.reverse() 383 return messages 384 385 386 def _build_usage_dict(usage: dict[str, Any]) -> dict[str, int]: 387 """Normalize a Claude Code usage payload into the CHAT_USAGE schema. 388 389 Stores fields as the Anthropic API reports them, matching 390 ``mlflow.anthropic.autolog``: ``input_tokens`` is the non-cached input, 391 cache tokens are exposed as separate optional keys so consumers can 392 compute cache hit rate, and ``total_tokens`` follows the 393 ``mlflow.anthropic`` convention of ``input_tokens + output_tokens`` 394 (cache tokens excluded). 395 """ 396 input_tokens = usage.get("input_tokens", 0) 397 output_tokens = usage.get("output_tokens", 0) 398 399 usage_dict: dict[str, int] = { 400 TokenUsageKey.INPUT_TOKENS: input_tokens, 401 TokenUsageKey.OUTPUT_TOKENS: output_tokens, 402 TokenUsageKey.TOTAL_TOKENS: input_tokens + output_tokens, 403 } 404 if (cached := usage.get("cache_read_input_tokens")) is not None: 405 usage_dict[TokenUsageKey.CACHE_READ_INPUT_TOKENS] = cached 406 if (created := usage.get("cache_creation_input_tokens")) is not None: 407 usage_dict[TokenUsageKey.CACHE_CREATION_INPUT_TOKENS] = created 408 return usage_dict 409 410 411 def _set_token_usage_attribute(span, usage: dict[str, Any]) -> None: 412 """Set token usage on a span using the standardized CHAT_USAGE attribute. 413 414 Args: 415 span: The MLflow span to set token usage on 416 usage: Dictionary containing token usage info from Claude Code transcript 417 """ 418 if not usage: 419 return 420 421 span.set_attribute(SpanAttributeKey.CHAT_USAGE, _build_usage_dict(usage)) 422 423 424 def _create_llm_and_tool_spans( 425 parent_span, transcript: list[dict[str, Any]], start_idx: int 426 ) -> None: 427 """Create LLM and tool spans for assistant responses with proper timing.""" 428 for i in range(start_idx, len(transcript)): 429 entry = transcript[i] 430 if entry.get(MESSAGE_FIELD_TYPE) != MESSAGE_TYPE_ASSISTANT: 431 continue 432 433 timestamp_ns = parse_timestamp_to_ns(entry.get(MESSAGE_FIELD_TIMESTAMP)) 434 435 # Calculate duration based on next timestamp or use default 436 if next_timestamp_ns := _get_next_timestamp_ns(transcript, i): 437 duration_ns = next_timestamp_ns - timestamp_ns 438 else: 439 duration_ns = int(1000 * NANOSECONDS_PER_MS) # 1 second default 440 441 msg = entry.get(MESSAGE_FIELD_MESSAGE, {}) 442 content = msg.get(MESSAGE_FIELD_CONTENT, []) 443 usage = msg.get("usage", {}) 444 445 # First check if we have meaningful content to create a span for 446 text_content, tool_uses = _extract_content_and_tools(content) 447 448 # Only create LLM span if there's text content (no tools) 449 llm_span = None 450 if text_content and text_content.strip() and not tool_uses: 451 messages = _get_input_messages(transcript, i) 452 453 llm_span = mlflow.start_span_no_context( 454 name="llm", 455 parent_span=parent_span, 456 span_type=SpanType.LLM, 457 start_time_ns=timestamp_ns, 458 inputs={ 459 "model": msg.get("model", "unknown"), 460 "messages": messages, 461 }, 462 attributes={ 463 "model": msg.get("model", "unknown"), 464 SpanAttributeKey.MESSAGE_FORMAT: "anthropic", 465 }, 466 ) 467 468 # Set token usage using the standardized CHAT_USAGE attribute 469 _set_token_usage_attribute(llm_span, usage) 470 471 # Output in Anthropic response format for Chat UI rendering 472 llm_span.set_outputs({ 473 "type": "message", 474 "role": "assistant", 475 "content": content, 476 }) 477 llm_span.end(end_time_ns=timestamp_ns + duration_ns) 478 479 # Create tool spans with proportional timing and actual results 480 if tool_uses: 481 tool_results = _find_tool_results(transcript, i) 482 tool_duration_ns = duration_ns // len(tool_uses) 483 484 for idx, tool_use in enumerate(tool_uses): 485 tool_start_ns = timestamp_ns + (idx * tool_duration_ns) 486 tool_use_id = tool_use.get("id", "") 487 tool_result = tool_results.get(tool_use_id, "No result found") 488 489 tool_span = mlflow.start_span_no_context( 490 name=f"tool_{tool_use.get('name', 'unknown')}", 491 parent_span=parent_span, 492 span_type=SpanType.TOOL, 493 start_time_ns=tool_start_ns, 494 inputs=tool_use.get("input", {}), 495 attributes={ 496 "tool_name": tool_use.get("name", "unknown"), 497 "tool_id": tool_use_id, 498 }, 499 ) 500 501 tool_span.set_outputs({"result": tool_result}) 502 tool_span.end(end_time_ns=tool_start_ns + tool_duration_ns) 503 504 505 def _finalize_trace( 506 parent_span, 507 user_prompt: str, 508 final_response: str | None, 509 session_id: str | None, 510 end_time_ns: int | None = None, 511 usage: dict[str, Any] | None = None, 512 claude_code_version: str | None = None, 513 ) -> mlflow.entities.Trace: 514 try: 515 # Set trace previews and metadata for UI display 516 with InMemoryTraceManager.get_instance().get_trace(parent_span.trace_id) as in_memory_trace: 517 if user_prompt: 518 in_memory_trace.info.request_preview = user_prompt[:MAX_PREVIEW_LENGTH] 519 if final_response: 520 in_memory_trace.info.response_preview = final_response[:MAX_PREVIEW_LENGTH] 521 522 metadata = { 523 TraceMetadataKey.TRACE_USER: os.environ.get("USER", ""), 524 "mlflow.trace.working_directory": os.getcwd(), 525 } 526 if session_id: 527 metadata[TraceMetadataKey.TRACE_SESSION] = session_id 528 if claude_code_version: 529 metadata[METADATA_KEY_CLAUDE_CODE_VERSION] = claude_code_version 530 531 # Set token usage directly on trace metadata so it survives 532 # even if span-level aggregation doesn't pick it up 533 if usage: 534 metadata[TraceMetadataKey.TOKEN_USAGE] = json.dumps(_build_usage_dict(usage)) 535 536 in_memory_trace.info.trace_metadata = { 537 **in_memory_trace.info.trace_metadata, 538 **metadata, 539 } 540 except Exception as e: 541 get_logger().warning("Failed to update trace metadata and previews: %s", e) 542 543 outputs = {"status": "completed"} 544 if final_response: 545 outputs["response"] = final_response 546 parent_span.set_outputs(outputs) 547 parent_span.end(end_time_ns=end_time_ns) 548 _flush_trace_async_logging() 549 get_logger().log(CLAUDE_TRACING_LEVEL, "Created MLflow trace: %s", parent_span.trace_id) 550 return mlflow.get_trace(parent_span.trace_id) 551 552 553 def _flush_trace_async_logging() -> None: 554 try: 555 if hasattr(_get_trace_exporter(), "_async_queue"): 556 mlflow.flush_trace_async_logging() 557 except Exception as e: 558 get_logger().debug("Failed to flush trace async logging: %s", e) 559 560 561 def find_final_assistant_response(transcript: list[dict[str, Any]], start_idx: int) -> str | None: 562 """Find the final text response from the assistant for trace preview. 563 564 Args: 565 transcript: List of conversation entries from Claude Code transcript 566 start_idx: Index to start searching from (typically after last user message) 567 568 Returns: 569 Final assistant response text or None 570 """ 571 final_response = None 572 573 for i in range(start_idx, len(transcript)): 574 entry = transcript[i] 575 if entry.get(MESSAGE_FIELD_TYPE) != MESSAGE_TYPE_ASSISTANT: 576 continue 577 578 msg = entry.get(MESSAGE_FIELD_MESSAGE, {}) 579 content = msg.get(MESSAGE_FIELD_CONTENT, []) 580 581 if isinstance(content, list): 582 for part in content: 583 if isinstance(part, dict) and part.get(MESSAGE_FIELD_TYPE) == CONTENT_TYPE_TEXT: 584 text = part.get(CONTENT_TYPE_TEXT, "") 585 if text.strip(): 586 final_response = text 587 588 return final_response 589 590 591 # ============================================================================ 592 # MAIN TRANSCRIPT PROCESSING 593 # ============================================================================ 594 595 596 def process_transcript( 597 transcript_path: str, session_id: str | None = None 598 ) -> mlflow.entities.Trace | None: 599 """Process a Claude conversation transcript and create an MLflow trace with spans. 600 601 Args: 602 transcript_path: Path to the Claude Code transcript.jsonl file 603 session_id: Optional session identifier, defaults to timestamp-based ID 604 605 Returns: 606 MLflow trace object if successful, None if processing fails 607 """ 608 try: 609 transcript = read_transcript(transcript_path) 610 if not transcript: 611 get_logger().warning("Empty transcript, skipping") 612 return None 613 614 last_user_idx = find_last_user_message_index(transcript) 615 if last_user_idx is None: 616 get_logger().warning("No user message found in transcript") 617 return None 618 619 last_user_entry = transcript[last_user_idx] 620 last_user_prompt = last_user_entry.get(MESSAGE_FIELD_MESSAGE, {}).get( 621 MESSAGE_FIELD_CONTENT, "" 622 ) 623 624 if not session_id: 625 session_id = f"claude-{datetime.now().strftime('%Y%m%d_%H%M%S')}" 626 627 get_logger().log(CLAUDE_TRACING_LEVEL, "Creating MLflow trace for session: %s", session_id) 628 629 conv_start_ns = parse_timestamp_to_ns(last_user_entry.get(MESSAGE_FIELD_TIMESTAMP)) 630 631 parent_span = mlflow.start_span_no_context( 632 name="claude_code_conversation", 633 inputs={"prompt": extract_text_content(last_user_prompt)}, 634 start_time_ns=conv_start_ns, 635 span_type=SpanType.AGENT, 636 ) 637 638 # Create spans for all assistant responses and tool uses 639 _create_llm_and_tool_spans(parent_span, transcript, last_user_idx + 1) 640 641 # Update trace with preview content and end timing 642 final_response = find_final_assistant_response(transcript, last_user_idx + 1) 643 user_prompt_text = extract_text_content(last_user_prompt) 644 645 # Calculate end time based on last entry or use default duration 646 last_entry = transcript[-1] if transcript else last_user_entry 647 conv_end_ns = parse_timestamp_to_ns(last_entry.get(MESSAGE_FIELD_TIMESTAMP)) 648 if not conv_end_ns or conv_end_ns <= conv_start_ns: 649 conv_end_ns = conv_start_ns + int(10 * NANOSECONDS_PER_S) 650 651 # Extract Claude Code version from transcript entries (CLI-only) 652 claude_code_version = next( 653 (ver for entry in transcript if (ver := entry.get("version"))), None 654 ) 655 656 return _finalize_trace( 657 parent_span, 658 user_prompt_text, 659 final_response, 660 session_id, 661 conv_end_ns, 662 claude_code_version=claude_code_version, 663 ) 664 665 except Exception as e: 666 get_logger().error("Error processing transcript: %s", e, exc_info=True) 667 return None 668 669 670 # ============================================================================ 671 # SDK MESSAGE PROCESSING 672 # ============================================================================ 673 674 675 def _find_sdk_user_prompt(messages: list[Any]) -> str | None: 676 from claude_agent_sdk.types import TextBlock, UserMessage 677 678 for msg in messages: 679 if not isinstance(msg, UserMessage) or msg.tool_use_result is not None: 680 continue 681 content = msg.content 682 if isinstance(content, str): 683 text = content 684 elif isinstance(content, list): 685 text = "\n".join(block.text for block in content if isinstance(block, TextBlock)) 686 else: 687 continue 688 if text and text.strip(): 689 return text 690 return None 691 692 693 def _build_tool_result_map(messages: list[Any]) -> dict[str, str]: 694 """Map tool_use_id to its result content so tool spans can show outputs.""" 695 from claude_agent_sdk.types import ToolResultBlock, UserMessage 696 697 tool_result_map: dict[str, str] = {} 698 for msg in messages: 699 if isinstance(msg, UserMessage) and isinstance(msg.content, list): 700 for block in msg.content: 701 if isinstance(block, ToolResultBlock): 702 result = block.content 703 if isinstance(result, list): 704 result = str(result) 705 tool_result_map[block.tool_use_id] = result or "" 706 return tool_result_map 707 708 709 # Maps SDK dataclass names to Anthropic API "type" discriminators. 710 # dataclasses.asdict() gives us the fields but not the type tag that 711 # the Anthropic message format requires on every content block. 712 _CONTENT_BLOCK_TYPES = { 713 "TextBlock": "text", 714 "ToolUseBlock": "tool_use", 715 "ToolResultBlock": "tool_result", 716 } 717 718 719 def _serialize_content_block(block) -> dict[str, Any] | None: 720 block_type = _CONTENT_BLOCK_TYPES.get(type(block).__name__) 721 if not block_type: 722 return None 723 fields = {key: value for key, value in dataclasses.asdict(block).items() if value is not None} 724 fields["type"] = block_type 725 return fields 726 727 728 def _serialize_sdk_message(msg) -> dict[str, Any] | None: 729 from claude_agent_sdk.types import AssistantMessage, UserMessage 730 731 if isinstance(msg, UserMessage): 732 content = msg.content 733 if isinstance(content, str): 734 return {"role": "user", "content": content} if content.strip() else None 735 elif isinstance(content, list): 736 if parts := [ 737 serialized for block in content if (serialized := _serialize_content_block(block)) 738 ]: 739 return {"role": "user", "content": parts} 740 elif isinstance(msg, AssistantMessage) and msg.content: 741 if parts := [ 742 serialized for block in msg.content if (serialized := _serialize_content_block(block)) 743 ]: 744 return {"role": "assistant", "content": parts} 745 return None 746 747 748 def _create_sdk_child_spans( 749 messages: list[Any], 750 parent_span, 751 tool_result_map: dict[str, str], 752 ) -> str | None: 753 """Create LLM and tool child spans under ``parent_span`` from SDK messages.""" 754 from claude_agent_sdk.types import AssistantMessage, TextBlock, ToolUseBlock 755 756 final_response = None 757 pending_messages: list[dict[str, Any]] = [] 758 759 for msg in messages: 760 if isinstance(msg, AssistantMessage) and msg.content: 761 text_blocks = [block for block in msg.content if isinstance(block, TextBlock)] 762 tool_blocks = [block for block in msg.content if isinstance(block, ToolUseBlock)] 763 764 if text_blocks and not tool_blocks: 765 text = "\n".join(block.text for block in text_blocks) 766 if text.strip(): 767 final_response = text 768 769 llm_span = mlflow.start_span_no_context( 770 name="llm", 771 parent_span=parent_span, 772 span_type=SpanType.LLM, 773 inputs={ 774 "model": getattr(msg, "model", "unknown"), 775 "messages": pending_messages, 776 }, 777 attributes={ 778 "model": getattr(msg, "model", "unknown"), 779 SpanAttributeKey.MESSAGE_FORMAT: "anthropic", 780 }, 781 ) 782 llm_span.set_outputs({ 783 "type": "message", 784 "role": "assistant", 785 "content": [{"type": "text", "text": block.text} for block in text_blocks], 786 }) 787 llm_span.end() 788 pending_messages = [] 789 continue 790 791 for tool_block in tool_blocks: 792 tool_span = mlflow.start_span_no_context( 793 name=f"tool_{tool_block.name}", 794 parent_span=parent_span, 795 span_type=SpanType.TOOL, 796 inputs=tool_block.input, 797 attributes={"tool_name": tool_block.name, "tool_id": tool_block.id}, 798 ) 799 tool_span.set_outputs({"result": tool_result_map.get(tool_block.id, "")}) 800 tool_span.end() 801 802 if anthropic_msg := _serialize_sdk_message(msg): 803 pending_messages.append(anthropic_msg) 804 805 return final_response 806 807 808 def process_sdk_messages( 809 messages: list[Any], session_id: str | None = None 810 ) -> mlflow.entities.Trace | None: 811 """ 812 Build an MLflow trace from Claude Agent SDK message objects. 813 814 Args: 815 messages: List of SDK message objects (UserMessage, AssistantMessage, 816 ResultMessage, etc.) captured during a conversation. 817 session_id: Optional session identifier for grouping traces. 818 819 Returns: 820 MLflow Trace if successful, None if no user prompt is found or processing fails. 821 """ 822 from claude_agent_sdk.types import ResultMessage 823 824 try: 825 if not messages: 826 get_logger().warning("Empty messages list, skipping") 827 return None 828 829 user_prompt = _find_sdk_user_prompt(messages) 830 if user_prompt is None: 831 get_logger().warning("No user prompt found in SDK messages") 832 return None 833 834 result_msg = next((msg for msg in messages if isinstance(msg, ResultMessage)), None) 835 836 # Prefer the SDK's own session_id, fall back to caller arg 837 session_id = (result_msg.session_id if result_msg else None) or session_id 838 839 get_logger().log( 840 CLAUDE_TRACING_LEVEL, 841 "Creating MLflow trace for session: %s", 842 session_id, 843 ) 844 845 tool_result_map = _build_tool_result_map(messages) 846 847 if duration_ms := (getattr(result_msg, "duration_ms", None) if result_msg else None): 848 duration_ns = int(duration_ms * NANOSECONDS_PER_MS) 849 now_ns = int(datetime.now().timestamp() * NANOSECONDS_PER_S) 850 start_time_ns = now_ns - duration_ns 851 end_time_ns = now_ns 852 else: 853 start_time_ns = None 854 end_time_ns = None 855 856 parent_span = mlflow.start_span_no_context( 857 name="claude_code_conversation", 858 inputs={"prompt": user_prompt}, 859 span_type=SpanType.AGENT, 860 start_time_ns=start_time_ns, 861 ) 862 863 final_response = _create_sdk_child_spans(messages, parent_span, tool_result_map) 864 865 # Set token usage on the root span so it aggregates into trace-level usage 866 usage = getattr(result_msg, "usage", None) if result_msg else None 867 if usage: 868 _set_token_usage_attribute(parent_span, usage) 869 870 return _finalize_trace( 871 parent_span, 872 user_prompt, 873 final_response, 874 session_id, 875 end_time_ns=end_time_ns, 876 usage=usage, 877 ) 878 879 except Exception as e: 880 get_logger().error("Error processing SDK messages: %s", e, exc_info=True) 881 return None