server.py
1 import contextlib 2 import io 3 import os 4 from typing import TYPE_CHECKING, Any, Callable 5 6 import click 7 from click.types import BOOL, FLOAT, INT, STRING, UUID 8 9 import mlflow.deployments.cli as deployments_cli 10 import mlflow.experiments 11 import mlflow.models.cli as models_cli 12 import mlflow.runs 13 from mlflow.ai_commands.ai_command_utils import get_command_body, list_commands 14 from mlflow.cli.scorers import commands as scorers_cli 15 from mlflow.cli.traces import commands as traces_cli 16 from mlflow.mcp.decorator import get_mcp_tool_name 17 18 # Environment variable to control which tool categories are enabled 19 # Supported values: 20 # - "genai": traces, scorers, experiments, and runs tools (default) 21 # - "ml": experiments, runs, models and deployments tools 22 # - "all": all available tools 23 # - Comma-separated list: "traces,scorers,experiments,runs,models,deployments" 24 MLFLOW_MCP_TOOLS = os.environ.get("MLFLOW_MCP_TOOLS", "genai") 25 26 # Tool category mappings 27 _GENAI_TOOLS = {"traces", "scorers", "experiments", "runs"} 28 _ML_TOOLS = {"models", "deployments", "experiments", "runs"} 29 _ALL_TOOLS = _GENAI_TOOLS | _ML_TOOLS 30 31 if TYPE_CHECKING: 32 from fastmcp import FastMCP 33 from fastmcp.tools import FunctionTool 34 35 36 def param_type_to_json_schema_type(pt: click.ParamType) -> str: 37 """ 38 Converts a Click ParamType to a JSON schema type. 39 """ 40 if pt is STRING: 41 return "string" 42 if pt is BOOL: 43 return "boolean" 44 if pt is INT: 45 return "integer" 46 if pt is FLOAT: 47 return "number" 48 if pt is UUID: 49 return "string" 50 return "string" 51 52 53 def get_input_schema(params: list[click.Parameter]) -> dict[str, Any]: 54 """ 55 Converts click params to JSON schema 56 """ 57 properties: dict[str, Any] = {} 58 required: list[str] = [] 59 for p in params: 60 is_array_param = p.multiple or p.nargs == -1 61 item_schema = {"type": param_type_to_json_schema_type(p.type)} 62 if isinstance(p.type, click.Choice): 63 item_schema["enum"] = [str(choice) for choice in p.type.choices] 64 65 schema = {"type": "array", "items": item_schema} if is_array_param else item_schema 66 if ( 67 p.default is not None 68 and ( 69 # In click >= 8.3.0, the default value is set to `Sentinel.UNSET` when no default is 70 # provided. Skip setting the default in this case. 71 # See https://github.com/pallets/click/pull/3030 for more details. 72 not isinstance(p.default, str) and repr(p.default) != "Sentinel.UNSET" 73 ) 74 and not (is_array_param and p.required) 75 ): 76 schema["default"] = list(p.default) if is_array_param else p.default 77 if isinstance(p, click.Option): 78 schema["description"] = (p.help or "").strip() 79 if p.required: 80 required.append(p.name) 81 if is_array_param: 82 schema["minItems"] = 1 83 properties[p.name] = schema 84 85 return { 86 "type": "object", 87 "properties": properties, 88 "required": required, 89 } 90 91 92 def fn_wrapper(command: click.Command) -> Callable[..., str]: 93 def wrapper(**kwargs: Any) -> str: 94 click_unset = getattr(click.core, "UNSET", object()) 95 96 # Capture stdout and stderr 97 string_io = io.StringIO() 98 with ( 99 contextlib.redirect_stdout(string_io), 100 contextlib.redirect_stderr(string_io), 101 ): 102 # Fill in defaults for missing optional arguments 103 for param in command.params: 104 if param.name not in kwargs: 105 if param.multiple or param.nargs == -1: 106 if param.default in (None, click_unset): 107 kwargs[param.name] = () 108 else: 109 kwargs[param.name] = tuple(param.default) 110 elif param.default is click_unset: 111 kwargs[param.name] = None 112 else: 113 kwargs[param.name] = param.default 114 115 # Convert array parameters to the types expected by each command's callback 116 for param in command.params: 117 if ( 118 param.name in kwargs 119 and (param.multiple or param.nargs == -1) 120 and isinstance(kwargs[param.name], list) 121 ): 122 kwargs[param.name] = tuple( 123 param.type.convert(value, param, None) for value in kwargs[param.name] 124 ) 125 126 command.callback(**kwargs) # type: ignore[misc] 127 return string_io.getvalue().strip() 128 129 return wrapper 130 131 132 def cmd_to_function_tool(cmd: click.Command) -> "FunctionTool | None": 133 """ 134 Converts a Click command to a FunctionTool. 135 136 Args: 137 cmd: The Click command to convert. 138 139 Returns: 140 FunctionTool if the command has been decorated with @mlflow_mcp, 141 None if the command should be skipped (not decorated for MCP exposure). 142 """ 143 from fastmcp.tools import FunctionTool 144 145 # Get the MCP tool name from the decorator 146 tool_name = get_mcp_tool_name(cmd) 147 148 # Skip commands that don't have the @mlflow_mcp decorator 149 # This allows us to curate which commands are exposed as MCP tools 150 if tool_name is None: 151 return None 152 153 return FunctionTool( 154 fn=fn_wrapper(cmd), 155 name=tool_name, 156 description=(cmd.help or "").strip(), 157 parameters=get_input_schema(cmd.params), 158 ) 159 160 161 def register_prompts(mcp: "FastMCP") -> None: 162 """Register AI commands as MCP prompts.""" 163 from mlflow.telemetry.events import AiCommandRunEvent 164 from mlflow.telemetry.track import _record_event 165 166 for command in list_commands(): 167 # Convert slash-separated keys to underscores for MCP names 168 mcp_name = command["key"].replace("/", "_") 169 170 # Create a closure to capture the command key 171 def make_prompt(cmd_key: str): 172 @mcp.prompt(name=mcp_name, description=command["description"]) 173 def ai_command_prompt() -> str: 174 """Execute an MLflow AI command prompt.""" 175 _record_event(AiCommandRunEvent, {"command_key": cmd_key, "context": "mcp"}) 176 return get_command_body(cmd_key) 177 178 return ai_command_prompt 179 180 # Register the prompt 181 make_prompt(command["key"]) 182 183 184 def _is_tool_enabled(category: str) -> bool: 185 """Check if a tool category is enabled based on MLFLOW_MCP_TOOLS env var.""" 186 tools_config = MLFLOW_MCP_TOOLS.lower().strip() 187 188 # Handle preset categories 189 if tools_config == "all": 190 return True 191 if tools_config == "genai": 192 return category.lower() in _GENAI_TOOLS 193 if tools_config == "ml": 194 return category.lower() in _ML_TOOLS 195 196 # Handle comma-separated list of individual tools 197 enabled_tools = {t.strip().lower() for t in tools_config.split(",")} 198 return category.lower() in enabled_tools 199 200 201 def _collect_tools(commands: dict[str, click.Command]) -> list["FunctionTool"]: 202 """Collect MCP tools from commands, filtering out undecorated commands.""" 203 tools = [] 204 for cmd in commands.values(): 205 tool = cmd_to_function_tool(cmd) 206 if tool is not None: 207 tools.append(tool) 208 return tools 209 210 211 def create_mcp() -> "FastMCP": 212 from fastmcp import FastMCP 213 214 tools: list["FunctionTool"] = [] 215 216 # Traces CLI tools (genai) 217 if _is_tool_enabled("traces"): 218 tools.extend(_collect_tools(traces_cli.commands)) 219 220 # Scorers CLI tools (genai) 221 if _is_tool_enabled("scorers"): 222 tools.extend(_collect_tools(scorers_cli.commands)) 223 224 # Experiment tracking tools (genai) 225 if _is_tool_enabled("experiments"): 226 tools.extend(_collect_tools(mlflow.experiments.commands.commands)) 227 228 # Run management tools (genai) 229 if _is_tool_enabled("runs"): 230 tools.extend(_collect_tools(mlflow.runs.commands.commands)) 231 232 # Model serving tools (ml) 233 if _is_tool_enabled("models"): 234 tools.extend(_collect_tools(models_cli.commands.commands)) 235 236 # Deployment tools (ml) 237 if _is_tool_enabled("deployments"): 238 tools.extend(_collect_tools(deployments_cli.commands.commands)) 239 240 mcp = FastMCP( 241 name="Mlflow MCP", 242 tools=tools, 243 ) 244 245 register_prompts(mcp) 246 return mcp 247 248 249 def run_server() -> None: 250 mcp = create_mcp() 251 mcp.run(show_banner=False) 252 253 254 if __name__ == "__main__": 255 run_server()