/ mlflow / mcp / server.py
server.py
  1  import contextlib
  2  import io
  3  import os
  4  from typing import TYPE_CHECKING, Any, Callable
  5  
  6  import click
  7  from click.types import BOOL, FLOAT, INT, STRING, UUID
  8  
  9  import mlflow.deployments.cli as deployments_cli
 10  import mlflow.experiments
 11  import mlflow.models.cli as models_cli
 12  import mlflow.runs
 13  from mlflow.ai_commands.ai_command_utils import get_command_body, list_commands
 14  from mlflow.cli.scorers import commands as scorers_cli
 15  from mlflow.cli.traces import commands as traces_cli
 16  from mlflow.mcp.decorator import get_mcp_tool_name
 17  
 18  # Environment variable to control which tool categories are enabled
 19  # Supported values:
 20  #   - "genai": traces, scorers, experiments, and runs tools (default)
 21  #   - "ml": experiments, runs, models and deployments tools
 22  #   - "all": all available tools
 23  #   - Comma-separated list: "traces,scorers,experiments,runs,models,deployments"
 24  MLFLOW_MCP_TOOLS = os.environ.get("MLFLOW_MCP_TOOLS", "genai")
 25  
 26  # Tool category mappings
 27  _GENAI_TOOLS = {"traces", "scorers", "experiments", "runs"}
 28  _ML_TOOLS = {"models", "deployments", "experiments", "runs"}
 29  _ALL_TOOLS = _GENAI_TOOLS | _ML_TOOLS
 30  
 31  if TYPE_CHECKING:
 32      from fastmcp import FastMCP
 33      from fastmcp.tools import FunctionTool
 34  
 35  
 36  def param_type_to_json_schema_type(pt: click.ParamType) -> str:
 37      """
 38      Converts a Click ParamType to a JSON schema type.
 39      """
 40      if pt is STRING:
 41          return "string"
 42      if pt is BOOL:
 43          return "boolean"
 44      if pt is INT:
 45          return "integer"
 46      if pt is FLOAT:
 47          return "number"
 48      if pt is UUID:
 49          return "string"
 50      return "string"
 51  
 52  
 53  def get_input_schema(params: list[click.Parameter]) -> dict[str, Any]:
 54      """
 55      Converts click params to JSON schema
 56      """
 57      properties: dict[str, Any] = {}
 58      required: list[str] = []
 59      for p in params:
 60          is_array_param = p.multiple or p.nargs == -1
 61          item_schema = {"type": param_type_to_json_schema_type(p.type)}
 62          if isinstance(p.type, click.Choice):
 63              item_schema["enum"] = [str(choice) for choice in p.type.choices]
 64  
 65          schema = {"type": "array", "items": item_schema} if is_array_param else item_schema
 66          if (
 67              p.default is not None
 68              and (
 69                  # In click >= 8.3.0, the default value is set to `Sentinel.UNSET` when no default is
 70                  # provided. Skip setting the default in this case.
 71                  # See https://github.com/pallets/click/pull/3030 for more details.
 72                  not isinstance(p.default, str) and repr(p.default) != "Sentinel.UNSET"
 73              )
 74              and not (is_array_param and p.required)
 75          ):
 76              schema["default"] = list(p.default) if is_array_param else p.default
 77          if isinstance(p, click.Option):
 78              schema["description"] = (p.help or "").strip()
 79          if p.required:
 80              required.append(p.name)
 81              if is_array_param:
 82                  schema["minItems"] = 1
 83          properties[p.name] = schema
 84  
 85      return {
 86          "type": "object",
 87          "properties": properties,
 88          "required": required,
 89      }
 90  
 91  
 92  def fn_wrapper(command: click.Command) -> Callable[..., str]:
 93      def wrapper(**kwargs: Any) -> str:
 94          click_unset = getattr(click.core, "UNSET", object())
 95  
 96          # Capture stdout and stderr
 97          string_io = io.StringIO()
 98          with (
 99              contextlib.redirect_stdout(string_io),
100              contextlib.redirect_stderr(string_io),
101          ):
102              # Fill in defaults for missing optional arguments
103              for param in command.params:
104                  if param.name not in kwargs:
105                      if param.multiple or param.nargs == -1:
106                          if param.default in (None, click_unset):
107                              kwargs[param.name] = ()
108                          else:
109                              kwargs[param.name] = tuple(param.default)
110                      elif param.default is click_unset:
111                          kwargs[param.name] = None
112                      else:
113                          kwargs[param.name] = param.default
114  
115              # Convert array parameters to the types expected by each command's callback
116              for param in command.params:
117                  if (
118                      param.name in kwargs
119                      and (param.multiple or param.nargs == -1)
120                      and isinstance(kwargs[param.name], list)
121                  ):
122                      kwargs[param.name] = tuple(
123                          param.type.convert(value, param, None) for value in kwargs[param.name]
124                      )
125  
126              command.callback(**kwargs)  # type: ignore[misc]
127          return string_io.getvalue().strip()
128  
129      return wrapper
130  
131  
132  def cmd_to_function_tool(cmd: click.Command) -> "FunctionTool | None":
133      """
134      Converts a Click command to a FunctionTool.
135  
136      Args:
137          cmd: The Click command to convert.
138  
139      Returns:
140          FunctionTool if the command has been decorated with @mlflow_mcp,
141          None if the command should be skipped (not decorated for MCP exposure).
142      """
143      from fastmcp.tools import FunctionTool
144  
145      # Get the MCP tool name from the decorator
146      tool_name = get_mcp_tool_name(cmd)
147  
148      # Skip commands that don't have the @mlflow_mcp decorator
149      # This allows us to curate which commands are exposed as MCP tools
150      if tool_name is None:
151          return None
152  
153      return FunctionTool(
154          fn=fn_wrapper(cmd),
155          name=tool_name,
156          description=(cmd.help or "").strip(),
157          parameters=get_input_schema(cmd.params),
158      )
159  
160  
161  def register_prompts(mcp: "FastMCP") -> None:
162      """Register AI commands as MCP prompts."""
163      from mlflow.telemetry.events import AiCommandRunEvent
164      from mlflow.telemetry.track import _record_event
165  
166      for command in list_commands():
167          # Convert slash-separated keys to underscores for MCP names
168          mcp_name = command["key"].replace("/", "_")
169  
170          # Create a closure to capture the command key
171          def make_prompt(cmd_key: str):
172              @mcp.prompt(name=mcp_name, description=command["description"])
173              def ai_command_prompt() -> str:
174                  """Execute an MLflow AI command prompt."""
175                  _record_event(AiCommandRunEvent, {"command_key": cmd_key, "context": "mcp"})
176                  return get_command_body(cmd_key)
177  
178              return ai_command_prompt
179  
180          # Register the prompt
181          make_prompt(command["key"])
182  
183  
184  def _is_tool_enabled(category: str) -> bool:
185      """Check if a tool category is enabled based on MLFLOW_MCP_TOOLS env var."""
186      tools_config = MLFLOW_MCP_TOOLS.lower().strip()
187  
188      # Handle preset categories
189      if tools_config == "all":
190          return True
191      if tools_config == "genai":
192          return category.lower() in _GENAI_TOOLS
193      if tools_config == "ml":
194          return category.lower() in _ML_TOOLS
195  
196      # Handle comma-separated list of individual tools
197      enabled_tools = {t.strip().lower() for t in tools_config.split(",")}
198      return category.lower() in enabled_tools
199  
200  
201  def _collect_tools(commands: dict[str, click.Command]) -> list["FunctionTool"]:
202      """Collect MCP tools from commands, filtering out undecorated commands."""
203      tools = []
204      for cmd in commands.values():
205          tool = cmd_to_function_tool(cmd)
206          if tool is not None:
207              tools.append(tool)
208      return tools
209  
210  
211  def create_mcp() -> "FastMCP":
212      from fastmcp import FastMCP
213  
214      tools: list["FunctionTool"] = []
215  
216      # Traces CLI tools (genai)
217      if _is_tool_enabled("traces"):
218          tools.extend(_collect_tools(traces_cli.commands))
219  
220      # Scorers CLI tools (genai)
221      if _is_tool_enabled("scorers"):
222          tools.extend(_collect_tools(scorers_cli.commands))
223  
224      # Experiment tracking tools (genai)
225      if _is_tool_enabled("experiments"):
226          tools.extend(_collect_tools(mlflow.experiments.commands.commands))
227  
228      # Run management tools (genai)
229      if _is_tool_enabled("runs"):
230          tools.extend(_collect_tools(mlflow.runs.commands.commands))
231  
232      # Model serving tools (ml)
233      if _is_tool_enabled("models"):
234          tools.extend(_collect_tools(models_cli.commands.commands))
235  
236      # Deployment tools (ml)
237      if _is_tool_enabled("deployments"):
238          tools.extend(_collect_tools(deployments_cli.commands.commands))
239  
240      mcp = FastMCP(
241          name="Mlflow MCP",
242          tools=tools,
243      )
244  
245      register_prompts(mcp)
246      return mcp
247  
248  
249  def run_server() -> None:
250      mcp = create_mcp()
251      mcp.run(show_banner=False)
252  
253  
254  if __name__ == "__main__":
255      run_server()