/ mlflow / claude_code / hooks.py
hooks.py
  1  """Hook management for Claude Code integration with MLflow."""
  2  
  3  import json
  4  import os
  5  import sys
  6  from pathlib import Path
  7  from typing import Any
  8  
  9  from mlflow.claude_code.config import (
 10      ENVIRONMENT_FIELD,
 11      HOOK_FIELD_COMMAND,
 12      HOOK_FIELD_HOOKS,
 13      MLFLOW_EXPERIMENT_ID,
 14      MLFLOW_EXPERIMENT_NAME,
 15      MLFLOW_HOOK_IDENTIFIER,
 16      MLFLOW_LEGACY_HOOK_IDENTIFIER,
 17      MLFLOW_TRACING_ENABLED,
 18      MLFLOW_TRACKING_URI,
 19      load_claude_config,
 20      save_claude_config,
 21  )
 22  from mlflow.claude_code.tracing import (
 23      CLAUDE_TRACING_LEVEL,
 24      get_hook_response,
 25      get_logger,
 26      is_tracing_enabled,
 27      process_transcript,
 28      read_hook_input,
 29      setup_mlflow,
 30  )
 31  
 32  # ============================================================================
 33  # HOOK CONFIGURATION UTILITIES
 34  # ============================================================================
 35  
 36  
 37  def upsert_hook(config: dict[str, Any], hook_type: str, subcommand: str) -> None:
 38      """Insert or update a single MLflow hook in the configuration.
 39  
 40      Args:
 41          config: The hooks configuration dictionary to modify
 42          hook_type: The hook type (e.g., 'PostToolUse', 'Stop')
 43          subcommand: The CLI subcommand name (e.g., 'stop-hook')
 44      """
 45      if hook_type not in config[HOOK_FIELD_HOOKS]:
 46          config[HOOK_FIELD_HOOKS][hook_type] = []
 47  
 48      mlflow_cmd = "uv run mlflow" if "UV" in os.environ else "mlflow"
 49      hook_command = f"{mlflow_cmd} autolog claude {subcommand}"
 50  
 51      mlflow_hook = {"type": "command", HOOK_FIELD_COMMAND: hook_command}
 52  
 53      # Check if MLflow hook already exists and update it
 54      hook_exists = False
 55      for hook_group in config[HOOK_FIELD_HOOKS][hook_type]:
 56          if HOOK_FIELD_HOOKS in hook_group:
 57              for hook in hook_group[HOOK_FIELD_HOOKS]:
 58                  cmd = hook.get(HOOK_FIELD_COMMAND, "")
 59                  if MLFLOW_HOOK_IDENTIFIER in cmd or MLFLOW_LEGACY_HOOK_IDENTIFIER in cmd:
 60                      hook.update(mlflow_hook)
 61                      hook_exists = True
 62                      break
 63  
 64      # Add new hook if it doesn't exist
 65      if not hook_exists:
 66          config[HOOK_FIELD_HOOKS][hook_type].append({HOOK_FIELD_HOOKS: [mlflow_hook]})
 67  
 68  
 69  def setup_hooks_config(settings_path: Path) -> None:
 70      """Set up Claude Code hooks for MLflow tracing.
 71  
 72      Creates or updates Stop hook that calls MLflow tracing handler.
 73      Updates existing MLflow hooks if found, otherwise adds new ones.
 74  
 75      Args:
 76          settings_path: Path to Claude settings.json file
 77      """
 78      config = load_claude_config(settings_path)
 79  
 80      if HOOK_FIELD_HOOKS not in config:
 81          config[HOOK_FIELD_HOOKS] = {}
 82  
 83      upsert_hook(config, "Stop", "stop-hook")
 84  
 85      save_claude_config(settings_path, config)
 86  
 87  
 88  # ============================================================================
 89  # HOOK REMOVAL AND CLEANUP
 90  # ============================================================================
 91  
 92  
 93  def disable_tracing_hooks(settings_path: Path) -> bool:
 94      """Remove MLflow hooks and environment variables from Claude settings.
 95  
 96      Args:
 97          settings_path: Path to Claude settings file
 98  
 99      Returns:
100          True if hooks/config were removed, False if no configuration was found
101      """
102      if not settings_path.exists():
103          return False
104  
105      config = load_claude_config(settings_path)
106      hooks_removed = False
107      env_removed = False
108  
109      # Remove MLflow hooks
110      if "Stop" in config.get(HOOK_FIELD_HOOKS, {}):
111          hook_groups = config[HOOK_FIELD_HOOKS]["Stop"]
112          filtered_groups = []
113  
114          for group in hook_groups:
115              if HOOK_FIELD_HOOKS in group:
116                  filtered_hooks = [
117                      hook
118                      for hook in group[HOOK_FIELD_HOOKS]
119                      if MLFLOW_HOOK_IDENTIFIER not in hook.get(HOOK_FIELD_COMMAND, "")
120                      and MLFLOW_LEGACY_HOOK_IDENTIFIER not in hook.get(HOOK_FIELD_COMMAND, "")
121                  ]
122  
123                  if filtered_hooks:
124                      filtered_groups.append({HOOK_FIELD_HOOKS: filtered_hooks})
125                  else:
126                      hooks_removed = True
127              else:
128                  filtered_groups.append(group)
129  
130          if filtered_groups:
131              config[HOOK_FIELD_HOOKS]["Stop"] = filtered_groups
132          else:
133              del config[HOOK_FIELD_HOOKS]["Stop"]
134              hooks_removed = True
135  
136      # Remove config variables
137      if ENVIRONMENT_FIELD in config:
138          mlflow_vars = [
139              MLFLOW_TRACING_ENABLED,
140              MLFLOW_TRACKING_URI,
141              MLFLOW_EXPERIMENT_ID,
142              MLFLOW_EXPERIMENT_NAME,
143          ]
144          for var in mlflow_vars:
145              if var in config[ENVIRONMENT_FIELD]:
146                  del config[ENVIRONMENT_FIELD][var]
147                  env_removed = True
148  
149          if not config[ENVIRONMENT_FIELD]:
150              del config[ENVIRONMENT_FIELD]
151  
152      # Clean up empty hooks section
153      if HOOK_FIELD_HOOKS in config and not config[HOOK_FIELD_HOOKS]:
154          del config[HOOK_FIELD_HOOKS]
155  
156      # Save updated config or remove file if empty
157      if config:
158          save_claude_config(settings_path, config)
159      else:
160          settings_path.unlink()
161  
162      return hooks_removed or env_removed
163  
164  
165  # ============================================================================
166  # CLAUDE CODE HOOK HANDLERS
167  # ============================================================================
168  
169  
170  def _process_stop_hook(session_id: str | None, transcript_path: str | None) -> dict[str, Any]:
171      """Common logic for processing stop hooks.
172  
173      Args:
174          session_id: Session identifier
175          transcript_path: Path to transcript file
176  
177      Returns:
178          Hook response dictionary
179      """
180      get_logger().log(
181          CLAUDE_TRACING_LEVEL, "Stop hook: session=%s, transcript=%s", session_id, transcript_path
182      )
183  
184      # Process the transcript and create MLflow trace
185      trace = process_transcript(transcript_path, session_id)
186  
187      if trace is not None:
188          return get_hook_response()
189      return get_hook_response(
190          error=(
191              "Failed to process transcript, please check .claude/mlflow/claude_tracing.log"
192              " for more details"
193          ),
194      )
195  
196  
197  def stop_hook_handler() -> None:
198      """CLI hook handler for conversation end - processes transcript and creates trace."""
199      if not is_tracing_enabled():
200          response = get_hook_response()
201          print(json.dumps(response))  # noqa: T201
202          return
203  
204      try:
205          hook_data = read_hook_input()
206          session_id = hook_data.get("session_id")
207          transcript_path = hook_data.get("transcript_path")
208  
209          setup_mlflow()
210          response = _process_stop_hook(session_id, transcript_path)
211          print(json.dumps(response))  # noqa: T201
212  
213      except Exception as e:
214          get_logger().error("Error in Stop hook: %s", e, exc_info=True)
215          response = get_hook_response(error=str(e))
216          print(json.dumps(response))  # noqa: T201
217          sys.exit(1)
218  
219  
220  async def sdk_stop_hook_handler(
221      input_data: dict[str, Any],
222      tool_use_id: str | None,
223      context: Any,
224  ) -> dict[str, Any]:
225      """SDK hook handler for Stop event - processes transcript and creates trace.
226  
227      Args:
228          input_data: Dictionary containing session_id and transcript_path
229          tool_use_id: Tool use identifier
230          context: HookContext from the SDK
231      """
232      from mlflow.utils.autologging_utils import autologging_is_disabled
233  
234      # Check if autologging is disabled
235      if autologging_is_disabled("anthropic"):
236          return get_hook_response()
237  
238      try:
239          session_id = input_data.get("session_id")
240          transcript_path = input_data.get("transcript_path")
241  
242          return _process_stop_hook(session_id, transcript_path)
243  
244      except Exception as e:
245          get_logger().error("Error in SDK Stop hook: %s", e, exc_info=True)
246          return get_hook_response(error=str(e))