hooks.py
1 """Hook management for Claude Code integration with MLflow.""" 2 3 import json 4 import os 5 import sys 6 from pathlib import Path 7 from typing import Any 8 9 from mlflow.claude_code.config import ( 10 ENVIRONMENT_FIELD, 11 HOOK_FIELD_COMMAND, 12 HOOK_FIELD_HOOKS, 13 MLFLOW_EXPERIMENT_ID, 14 MLFLOW_EXPERIMENT_NAME, 15 MLFLOW_HOOK_IDENTIFIER, 16 MLFLOW_LEGACY_HOOK_IDENTIFIER, 17 MLFLOW_TRACING_ENABLED, 18 MLFLOW_TRACKING_URI, 19 load_claude_config, 20 save_claude_config, 21 ) 22 from mlflow.claude_code.tracing import ( 23 CLAUDE_TRACING_LEVEL, 24 get_hook_response, 25 get_logger, 26 is_tracing_enabled, 27 process_transcript, 28 read_hook_input, 29 setup_mlflow, 30 ) 31 32 # ============================================================================ 33 # HOOK CONFIGURATION UTILITIES 34 # ============================================================================ 35 36 37 def upsert_hook(config: dict[str, Any], hook_type: str, subcommand: str) -> None: 38 """Insert or update a single MLflow hook in the configuration. 39 40 Args: 41 config: The hooks configuration dictionary to modify 42 hook_type: The hook type (e.g., 'PostToolUse', 'Stop') 43 subcommand: The CLI subcommand name (e.g., 'stop-hook') 44 """ 45 if hook_type not in config[HOOK_FIELD_HOOKS]: 46 config[HOOK_FIELD_HOOKS][hook_type] = [] 47 48 mlflow_cmd = "uv run mlflow" if "UV" in os.environ else "mlflow" 49 hook_command = f"{mlflow_cmd} autolog claude {subcommand}" 50 51 mlflow_hook = {"type": "command", HOOK_FIELD_COMMAND: hook_command} 52 53 # Check if MLflow hook already exists and update it 54 hook_exists = False 55 for hook_group in config[HOOK_FIELD_HOOKS][hook_type]: 56 if HOOK_FIELD_HOOKS in hook_group: 57 for hook in hook_group[HOOK_FIELD_HOOKS]: 58 cmd = hook.get(HOOK_FIELD_COMMAND, "") 59 if MLFLOW_HOOK_IDENTIFIER in cmd or MLFLOW_LEGACY_HOOK_IDENTIFIER in cmd: 60 hook.update(mlflow_hook) 61 hook_exists = True 62 break 63 64 # Add new hook if it doesn't exist 65 if not hook_exists: 66 config[HOOK_FIELD_HOOKS][hook_type].append({HOOK_FIELD_HOOKS: [mlflow_hook]}) 67 68 69 def setup_hooks_config(settings_path: Path) -> None: 70 """Set up Claude Code hooks for MLflow tracing. 71 72 Creates or updates Stop hook that calls MLflow tracing handler. 73 Updates existing MLflow hooks if found, otherwise adds new ones. 74 75 Args: 76 settings_path: Path to Claude settings.json file 77 """ 78 config = load_claude_config(settings_path) 79 80 if HOOK_FIELD_HOOKS not in config: 81 config[HOOK_FIELD_HOOKS] = {} 82 83 upsert_hook(config, "Stop", "stop-hook") 84 85 save_claude_config(settings_path, config) 86 87 88 # ============================================================================ 89 # HOOK REMOVAL AND CLEANUP 90 # ============================================================================ 91 92 93 def disable_tracing_hooks(settings_path: Path) -> bool: 94 """Remove MLflow hooks and environment variables from Claude settings. 95 96 Args: 97 settings_path: Path to Claude settings file 98 99 Returns: 100 True if hooks/config were removed, False if no configuration was found 101 """ 102 if not settings_path.exists(): 103 return False 104 105 config = load_claude_config(settings_path) 106 hooks_removed = False 107 env_removed = False 108 109 # Remove MLflow hooks 110 if "Stop" in config.get(HOOK_FIELD_HOOKS, {}): 111 hook_groups = config[HOOK_FIELD_HOOKS]["Stop"] 112 filtered_groups = [] 113 114 for group in hook_groups: 115 if HOOK_FIELD_HOOKS in group: 116 filtered_hooks = [ 117 hook 118 for hook in group[HOOK_FIELD_HOOKS] 119 if MLFLOW_HOOK_IDENTIFIER not in hook.get(HOOK_FIELD_COMMAND, "") 120 and MLFLOW_LEGACY_HOOK_IDENTIFIER not in hook.get(HOOK_FIELD_COMMAND, "") 121 ] 122 123 if filtered_hooks: 124 filtered_groups.append({HOOK_FIELD_HOOKS: filtered_hooks}) 125 else: 126 hooks_removed = True 127 else: 128 filtered_groups.append(group) 129 130 if filtered_groups: 131 config[HOOK_FIELD_HOOKS]["Stop"] = filtered_groups 132 else: 133 del config[HOOK_FIELD_HOOKS]["Stop"] 134 hooks_removed = True 135 136 # Remove config variables 137 if ENVIRONMENT_FIELD in config: 138 mlflow_vars = [ 139 MLFLOW_TRACING_ENABLED, 140 MLFLOW_TRACKING_URI, 141 MLFLOW_EXPERIMENT_ID, 142 MLFLOW_EXPERIMENT_NAME, 143 ] 144 for var in mlflow_vars: 145 if var in config[ENVIRONMENT_FIELD]: 146 del config[ENVIRONMENT_FIELD][var] 147 env_removed = True 148 149 if not config[ENVIRONMENT_FIELD]: 150 del config[ENVIRONMENT_FIELD] 151 152 # Clean up empty hooks section 153 if HOOK_FIELD_HOOKS in config and not config[HOOK_FIELD_HOOKS]: 154 del config[HOOK_FIELD_HOOKS] 155 156 # Save updated config or remove file if empty 157 if config: 158 save_claude_config(settings_path, config) 159 else: 160 settings_path.unlink() 161 162 return hooks_removed or env_removed 163 164 165 # ============================================================================ 166 # CLAUDE CODE HOOK HANDLERS 167 # ============================================================================ 168 169 170 def _process_stop_hook(session_id: str | None, transcript_path: str | None) -> dict[str, Any]: 171 """Common logic for processing stop hooks. 172 173 Args: 174 session_id: Session identifier 175 transcript_path: Path to transcript file 176 177 Returns: 178 Hook response dictionary 179 """ 180 get_logger().log( 181 CLAUDE_TRACING_LEVEL, "Stop hook: session=%s, transcript=%s", session_id, transcript_path 182 ) 183 184 # Process the transcript and create MLflow trace 185 trace = process_transcript(transcript_path, session_id) 186 187 if trace is not None: 188 return get_hook_response() 189 return get_hook_response( 190 error=( 191 "Failed to process transcript, please check .claude/mlflow/claude_tracing.log" 192 " for more details" 193 ), 194 ) 195 196 197 def stop_hook_handler() -> None: 198 """CLI hook handler for conversation end - processes transcript and creates trace.""" 199 if not is_tracing_enabled(): 200 response = get_hook_response() 201 print(json.dumps(response)) # noqa: T201 202 return 203 204 try: 205 hook_data = read_hook_input() 206 session_id = hook_data.get("session_id") 207 transcript_path = hook_data.get("transcript_path") 208 209 setup_mlflow() 210 response = _process_stop_hook(session_id, transcript_path) 211 print(json.dumps(response)) # noqa: T201 212 213 except Exception as e: 214 get_logger().error("Error in Stop hook: %s", e, exc_info=True) 215 response = get_hook_response(error=str(e)) 216 print(json.dumps(response)) # noqa: T201 217 sys.exit(1) 218 219 220 async def sdk_stop_hook_handler( 221 input_data: dict[str, Any], 222 tool_use_id: str | None, 223 context: Any, 224 ) -> dict[str, Any]: 225 """SDK hook handler for Stop event - processes transcript and creates trace. 226 227 Args: 228 input_data: Dictionary containing session_id and transcript_path 229 tool_use_id: Tool use identifier 230 context: HookContext from the SDK 231 """ 232 from mlflow.utils.autologging_utils import autologging_is_disabled 233 234 # Check if autologging is disabled 235 if autologging_is_disabled("anthropic"): 236 return get_hook_response() 237 238 try: 239 session_id = input_data.get("session_id") 240 transcript_path = input_data.get("transcript_path") 241 242 return _process_stop_hook(session_id, transcript_path) 243 244 except Exception as e: 245 get_logger().error("Error in SDK Stop hook: %s", e, exc_info=True) 246 return get_hook_response(error=str(e))