model.py
1 """ 2 The ``mlflow.pyfunc.model`` module defines logic for saving and loading custom "python_function" 3 models with a user-defined ``PythonModel`` subclass. 4 """ 5 6 import bz2 7 import gzip 8 import inspect 9 import logging 10 import lzma 11 import os 12 import shutil 13 from abc import ABCMeta, abstractmethod 14 from collections.abc import Sequence 15 from pathlib import Path 16 from typing import Any, Generator, Iterator 17 18 import cloudpickle 19 import pandas as pd 20 import yaml 21 22 import mlflow.pyfunc 23 from mlflow.entities.span import SpanType 24 from mlflow.environment_variables import ( 25 MLFLOW_ALLOW_PICKLE_DESERIALIZATION, 26 MLFLOW_LOG_MODEL_COMPRESSION, 27 MLFLOW_UV_AUTO_DETECT, 28 ) 29 from mlflow.exceptions import MlflowException 30 from mlflow.models import Model 31 from mlflow.models.model import MLMODEL_FILE_NAME, MODEL_CODE_PATH 32 from mlflow.models.rag_signatures import ChatCompletionRequest, SplitChatMessagesRequest 33 from mlflow.models.signature import ( 34 _extract_type_hints, 35 _is_context_in_predict_function_signature, 36 _TypeHints, 37 ) 38 from mlflow.models.utils import _load_model_code_path 39 from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE 40 from mlflow.pyfunc.utils import pyfunc 41 from mlflow.pyfunc.utils.data_validation import ( 42 _check_func_signature, 43 _get_func_info_if_type_hint_supported, 44 _wrap_predict_with_pyfunc, 45 wrap_non_list_predict_pydantic, 46 ) 47 from mlflow.pyfunc.utils.input_converter import _hydrate_dataclass 48 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 49 from mlflow.types.agent import ( 50 ChatAgentChunk, 51 ChatAgentMessage, 52 ChatAgentRequest, 53 ChatAgentResponse, 54 ChatContext, 55 ) 56 from mlflow.types.llm import ( 57 ChatCompletionChunk, 58 ChatCompletionResponse, 59 ChatMessage, 60 ChatParams, 61 ) 62 from mlflow.types.responses import ( 63 Message, 64 OutputItem, 65 ResponsesAgentRequest, 66 ResponsesAgentResponse, 67 ResponsesAgentStreamEvent, 68 create_annotation_added, 69 create_function_call_item, 70 create_function_call_output_item, 71 create_reasoning_item, 72 create_text_delta, 73 create_text_output_item, 74 output_to_responses_items_stream, 75 responses_agent_output_reducer, 76 responses_to_cc, 77 to_chat_completions_input, 78 ) 79 from mlflow.types.utils import _is_list_dict_str, _is_list_str 80 from mlflow.utils.annotations import deprecated 81 from mlflow.utils.databricks_utils import ( 82 _get_databricks_serverless_env_vars, 83 is_in_databricks_model_serving_environment, 84 is_in_databricks_runtime, 85 is_in_databricks_serverless_runtime, 86 ) 87 from mlflow.utils.environment import ( 88 _CONDA_ENV_FILE_NAME, 89 _CONSTRAINTS_FILE_NAME, 90 _PYTHON_ENV_FILE_NAME, 91 _REQUIREMENTS_FILE_NAME, 92 _mlflow_conda_env, 93 _process_conda_env, 94 _process_pip_requirements, 95 _PythonEnv, 96 ) 97 from mlflow.utils.file_utils import TempDir, get_total_file_size, write_to 98 from mlflow.utils.model_utils import _get_flavor_configuration, _validate_infer_and_copy_code_paths 99 from mlflow.utils.requirements_utils import _get_pinned_requirement 100 from mlflow.utils.uv_utils import copy_uv_project_files 101 102 CONFIG_KEY_ARTIFACTS = "artifacts" 103 CONFIG_KEY_ARTIFACT_RELATIVE_PATH = "path" 104 CONFIG_KEY_ARTIFACT_URI = "uri" 105 CONFIG_KEY_PYTHON_MODEL = "python_model" 106 CONFIG_KEY_CLOUDPICKLE_VERSION = "cloudpickle_version" 107 CONFIG_KEY_COMPRESSION = "python_model_compression" 108 _SAVED_PYTHON_MODEL_SUBPATH = "python_model.pkl" 109 _DEFAULT_CHAT_MODEL_METADATA_TASK = "agent/v1/chat" 110 _DEFAULT_CHAT_AGENT_METADATA_TASK = "agent/v2/chat" 111 _COMPRESSION_INFO = { 112 "lzma": {"ext": ".xz", "open": lzma.open}, 113 "bzip2": {"ext": ".bz2", "open": bz2.open}, 114 "gzip": {"ext": ".gz", "open": gzip.open}, 115 } 116 _DEFAULT_RESPONSES_AGENT_METADATA_TASK = "agent/v1/responses" 117 118 _logger = logging.getLogger(__name__) 119 120 121 def get_default_pip_requirements(): 122 """ 123 Returns: 124 A list of default pip requirements for MLflow Models produced by this flavor. Calls to 125 :func:`save_model()` and :func:`log_model()` produce a pip environment that, at minimum, 126 contains these requirements. 127 """ 128 return [_get_pinned_requirement("cloudpickle")] 129 130 131 def get_default_conda_env(): 132 """ 133 Returns: 134 The default Conda environment for MLflow Models produced by calls to 135 :func:`save_model() <mlflow.pyfunc.save_model>` 136 and :func:`log_model() <mlflow.pyfunc.log_model>` when a user-defined subclass of 137 :class:`PythonModel` is provided. 138 """ 139 return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements()) 140 141 142 def _log_warning_if_params_not_in_predict_signature(logger, params): 143 if params: 144 logger.warning( 145 "The underlying model does not support passing additional parameters to the predict" 146 f" function. `params` {params} will be ignored." 147 ) 148 149 150 class PythonModel: 151 """ 152 Represents a generic Python model that evaluates inputs and produces API-compatible outputs. 153 By subclassing :class:`~PythonModel`, users can create customized MLflow models with the 154 "python_function" ("pyfunc") flavor, leveraging custom inference logic and artifact 155 dependencies. 156 """ 157 158 __metaclass__ = ABCMeta 159 160 def load_context(self, context): 161 """ 162 Loads artifacts from the specified :class:`~PythonModelContext` that can be used by 163 :func:`~PythonModel.predict` when evaluating inputs. When loading an MLflow model with 164 :func:`~load_model`, this method is called as soon as the :class:`~PythonModel` is 165 constructed. 166 167 The same :class:`~PythonModelContext` will also be available during calls to 168 :func:`~PythonModel.predict`, but it may be more efficient to override this method 169 and load artifacts from the context at model load time. 170 171 Args: 172 context: A :class:`~PythonModelContext` instance containing artifacts that the model 173 can use to perform inference. 174 """ 175 176 @deprecated("predict_type_hints", "2.20.0") 177 def _get_type_hints(self): 178 return self.predict_type_hints 179 180 @property 181 def predict_type_hints(self) -> _TypeHints: 182 """ 183 Internal method to get type hints from the predict function signature. 184 """ 185 if hasattr(self, "_predict_type_hints"): 186 return self._predict_type_hints 187 if _is_context_in_predict_function_signature(func=self.predict): 188 self._predict_type_hints = _extract_type_hints(self.predict, input_arg_index=1) 189 else: 190 self._predict_type_hints = _extract_type_hints(self.predict, input_arg_index=0) 191 return self._predict_type_hints 192 193 def __init_subclass__(cls, **kwargs) -> None: 194 super().__init_subclass__(**kwargs) 195 196 # automatically wrap the predict method with pyfunc to ensure data validation 197 # NB: skip wrapping for built-in classes defined in MLflow e.g. ChatModel 198 if not cls.__module__.startswith("mlflow."): 199 # TODO: ChatModel uses dataclass type hints which are not supported now, hence 200 # we need to skip type hint based validation for user-defined subclasses 201 # of ChatModel. Once we either (1) support dataclass type hints or (2) migrate 202 # ChatModel to pydantic, we can remove this exclusion. 203 # NB: issubclass(cls, ChatModel) does not work so we use a hacky attribute check 204 if getattr(cls, "_skip_type_hint_validation", False): 205 return 206 207 predict_attr = cls.__dict__.get("predict") 208 if predict_attr is not None and callable(predict_attr): 209 func_info = _get_func_info_if_type_hint_supported(predict_attr) 210 setattr(cls, "predict", _wrap_predict_with_pyfunc(predict_attr, func_info)) 211 predict_stream_attr = cls.__dict__.get("predict_stream") 212 if predict_stream_attr is not None and callable(predict_stream_attr): 213 _check_func_signature(predict_stream_attr, "predict_stream") 214 else: 215 cls.predict._is_pyfunc = True 216 217 @abstractmethod 218 def predict(self, context, model_input, params: dict[str, Any] | None = None): 219 """ 220 Evaluates a pyfunc-compatible input and produces a pyfunc-compatible output. 221 For more information about the pyfunc input/output API, see the :ref:`pyfunc-inference-api`. 222 223 Args: 224 context: A :class:`~PythonModelContext` instance containing artifacts that the model 225 can use to perform inference. 226 model_input: A pyfunc-compatible input for the model to evaluate. 227 params: Additional parameters to pass to the model for inference. 228 229 .. tip:: 230 Since MLflow 2.20.0, `context` parameter can be removed from `predict` function 231 signature if it's not used. `def predict(self, model_input, params=None)` is valid. 232 """ 233 234 def predict_stream(self, context, model_input, params: dict[str, Any] | None = None): 235 """ 236 Evaluates a pyfunc-compatible input and produces an iterator of output. 237 For more information about the pyfunc input API, see the :ref:`pyfunc-inference-api`. 238 239 Args: 240 context: A :class:`~PythonModelContext` instance containing artifacts that the model 241 can use to perform inference. 242 model_input: A pyfunc-compatible input for the model to evaluate. 243 params: Additional parameters to pass to the model for inference. 244 245 .. tip:: 246 Since MLflow 2.20.0, `context` parameter can be removed from `predict_stream` function 247 signature if it's not used. 248 `def predict_stream(self, model_input, params=None)` is valid. 249 """ 250 raise NotImplementedError() 251 252 253 class _FunctionPythonModel(PythonModel): 254 """ 255 When a user specifies a ``python_model`` argument that is a function, we wrap the function 256 in an instance of this class. 257 """ 258 259 def __init__(self, func, signature=None): 260 self.signature = signature 261 # only wrap `func` if @pyfunc is not already applied 262 if not getattr(func, "_is_pyfunc", False): 263 self.func = pyfunc(func) 264 else: 265 self.func = func 266 267 @property 268 def predict_type_hints(self): 269 if hasattr(self, "_predict_type_hints"): 270 return self._predict_type_hints 271 self._predict_type_hints = _extract_type_hints(self.func, input_arg_index=0) 272 return self._predict_type_hints 273 274 def predict( 275 self, 276 model_input, 277 params: dict[str, Any] | None = None, 278 ): 279 """ 280 Args: 281 model_input: A pyfunc-compatible input for the model to evaluate. 282 params: Additional parameters to pass to the model for inference. 283 284 Returns: 285 Model predictions. 286 """ 287 # callable only supports one input argument for now 288 return self.func(model_input) 289 290 291 class PythonModelContext: 292 """ 293 A collection of artifacts that a :class:`~PythonModel` can use when performing inference. 294 :class:`~PythonModelContext` objects are created *implicitly* by the 295 :func:`save_model() <mlflow.pyfunc.save_model>` and 296 :func:`log_model() <mlflow.pyfunc.log_model>` persistence methods, using the contents specified 297 by the ``artifacts`` parameter of these methods. 298 """ 299 300 def __init__(self, artifacts, model_config): 301 """ 302 Args: 303 artifacts: A dictionary of ``<name, artifact_path>`` entries, where ``artifact_path`` 304 is an absolute filesystem path to a given artifact. 305 model_config: The model configuration to make available to the model at 306 loading time. 307 """ 308 self._artifacts = artifacts 309 self._model_config = model_config 310 311 @property 312 def artifacts(self): 313 """ 314 A dictionary containing ``<name, artifact_path>`` entries, where ``artifact_path`` is an 315 absolute filesystem path to the artifact. 316 """ 317 return self._artifacts 318 319 @property 320 def model_config(self): 321 """ 322 A dictionary containing ``<config, value>`` entries, where ``config`` is the name 323 of the model configuration keys and ``value`` is the value of the given configuration. 324 """ 325 326 return self._model_config 327 328 329 @deprecated("ResponsesAgent", "3.0.0") 330 class ChatModel(PythonModel, metaclass=ABCMeta): 331 """ 332 .. tip:: 333 Since MLflow 3.0.0, we recommend using 334 :py:class:`ResponsesAgent <mlflow.pyfunc.ResponsesAgent>` 335 instead of :py:class:`ChatModel <mlflow.pyfunc.ChatModel>` unless you need strict 336 compatibility with the OpenAI ChatCompletion API. 337 338 A subclass of :class:`~PythonModel` that makes it more convenient to implement models 339 that are compatible with popular LLM chat APIs. By subclassing :class:`~ChatModel`, 340 users can create MLflow models with a ``predict()`` method that is more convenient 341 for chat tasks than the generic :class:`~PythonModel` API. ChatModels automatically 342 define input/output signatures and an input example, so manually specifying these values 343 when calling :func:`mlflow.pyfunc.save_model() <mlflow.pyfunc.save_model>` is not necessary. 344 345 See the documentation of the ``predict()`` method below for details on that parameters and 346 outputs that are expected by the ``ChatModel`` API. 347 348 .. list-table:: 349 :header-rows: 1 350 :widths: 20 40 40 351 352 * - 353 - ChatModel 354 - PythonModel 355 * - When to use 356 - Use when you want to develop and deploy a conversational model with **standard** chat 357 schema compatible with OpenAI spec. 358 - Use when you want **full control** over the model's interface or customize every aspect 359 of your model's behavior. 360 * - Interface 361 - **Fixed** to OpenAI's chat schema. 362 - **Full control** over the model's input and output schema. 363 * - Setup 364 - **Quick**. Works out of the box for conversational applications, with pre-defined 365 model signature and input example. 366 - **Custom**. You need to define model signature or input example yourself. 367 * - Complexity 368 - **Low**. Standardized interface simplified model deployment and integration. 369 - **High**. Deploying and integrating the custom PythonModel may not be straightforward. 370 E.g., The model needs to handle Pandas DataFrames as MLflow converts input data to 371 DataFrames before passing it to PythonModel. 372 373 """ 374 375 _skip_type_hint_validation = True 376 377 @abstractmethod 378 def predict( 379 self, context, messages: list[ChatMessage], params: ChatParams 380 ) -> ChatCompletionResponse: 381 """ 382 Evaluates a chat input and produces a chat output. 383 384 Args: 385 context: A :class:`~PythonModelContext` instance containing artifacts that the model 386 can use to perform inference. 387 messages (List[:py:class:`ChatMessage <mlflow.types.llm.ChatMessage>`]): 388 A list of :py:class:`ChatMessage <mlflow.types.llm.ChatMessage>` 389 objects representing chat history. 390 params (:py:class:`ChatParams <mlflow.types.llm.ChatParams>`): 391 A :py:class:`ChatParams <mlflow.types.llm.ChatParams>` object 392 containing various parameters used to modify model behavior during 393 inference. 394 395 .. tip:: 396 Since MLflow 2.20.0, `context` parameter can be removed from `predict` function 397 signature if it's not used. 398 `def predict(self, messages: list[ChatMessage], params: ChatParams)` is valid. 399 400 Returns: 401 A :py:class:`ChatCompletionResponse <mlflow.types.llm.ChatCompletionResponse>` 402 object containing the model's response(s), as well as other metadata. 403 """ 404 405 def predict_stream( 406 self, context, messages: list[ChatMessage], params: ChatParams 407 ) -> Generator[ChatCompletionChunk, None, None]: 408 """ 409 Evaluates a chat input and produces a chat output. 410 Override this function to implement a real stream prediction. 411 412 Args: 413 context: A :class:`~PythonModelContext` instance containing artifacts that the model 414 can use to perform inference. 415 messages (List[:py:class:`ChatMessage <mlflow.types.llm.ChatMessage>`]): 416 A list of :py:class:`ChatMessage <mlflow.types.llm.ChatMessage>` 417 objects representing chat history. 418 params (:py:class:`ChatParams <mlflow.types.llm.ChatParams>`): 419 A :py:class:`ChatParams <mlflow.types.llm.ChatParams>` object 420 containing various parameters used to modify model behavior during 421 inference. 422 423 .. tip:: 424 Since MLflow 2.20.0, `context` parameter can be removed from `predict_stream` function 425 signature if it's not used. 426 `def predict_stream(self, messages: list[ChatMessage], params: ChatParams)` is valid. 427 428 Returns: 429 A generator over :py:class:`ChatCompletionChunk <mlflow.types.llm.ChatCompletionChunk>` 430 object containing the model's response(s), as well as other metadata. 431 """ 432 raise NotImplementedError( 433 "Streaming implementation not provided. Please override the " 434 "`predict_stream` method on your model to generate streaming " 435 "predictions" 436 ) 437 438 439 class ChatAgent(PythonModel, metaclass=ABCMeta): 440 """ 441 .. tip:: 442 Since MLflow 3.0.0, we recommend using 443 :py:class:`ResponsesAgent <mlflow.pyfunc.ResponsesAgent>` 444 instead of :py:class:`ChatAgent <mlflow.pyfunc.ChatAgent>`. 445 446 **What is the ChatAgent Interface?** 447 448 The ChatAgent interface is a chat schema specification that has been designed for authoring 449 conversational agents. ChatAgent allows your agent to do the following: 450 451 - Return multiple messages 452 - Return intermediate steps for tool calling agents 453 - Confirm tool calls 454 - Support multi-agent scenarios 455 456 ``ChatAgent`` should always be used when authoring an agent. We also recommend using 457 ``ChatAgent`` instead of :py:class:`ChatModel <mlflow.pyfunc.ChatModel>` even for use cases 458 like simple chat models (e.g. prompt-engineered LLMs), to give you the flexibility to support 459 more agentic functionality in the future. 460 461 The :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` schema is similar to, 462 but not strictly compatible with the OpenAI ChatCompletion schema. ChatAgent adds additional 463 functionality and diverges from OpenAI 464 :py:class:`ChatCompletionRequest <mlflow.types.llm.ChatCompletionRequest>` in the following 465 ways: 466 467 - Adds an optional ``attachments`` attribute to every input/output message for tools and 468 internal agent calls so they can return additional outputs such as visualizations and progress 469 indicators 470 - Adds a ``context`` attribute with a ``conversation_id`` and ``user_id`` attributes to enable 471 modifying the behavior of the agent depending on the user querying the agent 472 - Adds the ``custom_inputs`` attribute, an arbitrary ``dict[str, Any]`` to pass in any 473 additional information to modify the agent's behavior 474 475 The :py:class:`ChatAgentResponse <mlflow.types.agent.ChatAgentResponse>` schema diverges from 476 :py:class:`ChatCompletionResponse <mlflow.types.llm.ChatCompletionResponse>` schema in the 477 following ways: 478 479 - Adds the ``custom_outputs`` key, an arbitrary ``dict[str, Any]`` to return any additional 480 information 481 - Allows multiple messages in the output, to improve the display and evaluation of internal 482 tool calls and inter-agent communication that led to the final answer. 483 484 Here's an example of a :py:class:`ChatAgentResponse <mlflow.types.agent.ChatAgentResponse>` 485 detailing a tool call: 486 487 .. code-block:: python 488 489 { 490 "messages": [ 491 { 492 "role": "assistant", 493 "content": "", 494 "id": "run-04b46401-c569-4a4a-933e-62e38d8f9647-0", 495 "tool_calls": [ 496 { 497 "id": "call_15ca4fcc-ffa1-419a-8748-3bea34b9c043", 498 "type": "function", 499 "function": { 500 "name": "generate_random_ints", 501 "arguments": '{"min": 1, "max": 100, "size": 5}', 502 }, 503 } 504 ], 505 }, 506 { 507 "role": "tool", 508 "content": '{"content": "Generated array of 2 random ints in [1, 100]."', 509 "name": "generate_random_ints", 510 "id": "call_15ca4fcc-ffa1-419a-8748-3bea34b9c043", 511 "tool_call_id": "call_15ca4fcc-ffa1-419a-8748-3bea34b9c043", 512 }, 513 { 514 "role": "assistant", 515 "content": "The new set of generated random numbers are: 93, 51, 12, 7, and 25", 516 "name": "llm", 517 "id": "run-70c7c738-739f-4ecd-ad18-0ae232df24e8-0", 518 }, 519 ], 520 "custom_outputs": {"random_nums": [93, 51, 12, 7, 25]}, 521 } 522 523 **Streaming Agent Output with ChatAgent** 524 525 Please read the docstring of 526 :py:func:`ChatAgent.predict_stream <mlflow.pyfunc.ChatAgent.predict_stream>` 527 for more details on how to stream the output of your agent. 528 529 530 **Authoring a ChatAgent** 531 532 Authoring an agent using the ChatAgent interface is a framework-agnostic way to create a model 533 with a standardized interface that is loggable with the MLflow pyfunc flavor, can be reused 534 across clients, and is ready for serving workloads. 535 536 To write your own agent, subclass ``ChatAgent``, implementing the ``predict`` and optionally 537 ``predict_stream`` methods to define the non-streaming and streaming behavior of your agent. You 538 can use any agent authoring framework - the only hard requirement is to implement the 539 ``predict`` interface. 540 541 .. code-block:: python 542 543 def predict( 544 self, 545 messages: list[ChatAgentMessage], 546 context: Optional[ChatContext] = None, 547 custom_inputs: Optional[dict[str, Any]] = None, 548 ) -> ChatAgentResponse: ... 549 550 In addition to calling predict and predict_stream methods with an input matching their type 551 hints, you can also pass a single input dict that matches the 552 :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` schema for ease of testing. 553 554 .. code-block:: python 555 556 chat_agent = MyChatAgent() 557 chat_agent.predict({ 558 "messages": [{"role": "user", "content": "What is 10 + 10?"}], 559 "context": {"conversation_id": "123", "user_id": "456"}, 560 }) 561 562 See an example implementation of ``predict`` and ``predict_stream`` for a LangGraph agent in 563 the :py:class:`ChatAgentState <mlflow.langchain.chat_agent_langgraph.ChatAgentState>` 564 docstring. 565 566 **Logging the ChatAgent** 567 568 Since the landscape of LLM frameworks is constantly evolving and not every flavor can be 569 natively supported by MLflow, we recommend the 570 `Models-from-Code <https://mlflow.org/docs/latest/ml/model/models-from-code.html>`_ logging 571 approach. 572 573 .. code-block:: python 574 575 with mlflow.start_run(): 576 logged_agent_info = mlflow.pyfunc.log_model( 577 name="agent", 578 python_model=os.path.join(os.getcwd(), "agent"), 579 # Add serving endpoints, tools, and vector search indexes here 580 resources=[], 581 ) 582 583 After logging the model, you can query the model with a single dictionary with the 584 :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` schema. Under the hood, it 585 will be converted into the python objects expected by your ``predict`` and ``predict_stream`` 586 methods. 587 588 .. code-block:: python 589 590 loaded_model = mlflow.pyfunc.load_model(tmp_path) 591 loaded_model.predict({ 592 "messages": [{"role": "user", "content": "What is 10 + 10?"}], 593 "context": {"conversation_id": "123", "user_id": "456"}, 594 }) 595 596 To make logging ChatAgent models as easy as possible, MLflow has built in the following 597 features: 598 599 - Automatic Model Signature Inference 600 - You do not need to set a signature when logging a ChatAgent 601 - An input and output signature will be automatically set that adheres to the 602 :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` and 603 :py:class:`ChatAgentResponse <mlflow.types.agent.ChatAgentResponse>` schemas 604 - Metadata 605 - ``{"task": "agent/v2/chat"}`` will be automatically appended to any metadata that you may 606 pass in when logging the model 607 - Input Example 608 - Providing an input example is optional, ``mlflow.types.agent.CHAT_AGENT_INPUT_EXAMPLE`` 609 will be provided by default 610 - If you do provide an input example, ensure it's a dict with the 611 :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` schema 612 613 - .. code-block:: python 614 615 input_example = { 616 "messages": [{"role": "user", "content": "What is MLflow?"}], 617 "context": {"conversation_id": "123", "user_id": "456"}, 618 } 619 620 **Migrating from ChatModel to ChatAgent** 621 622 To convert an existing ChatModel that takes in 623 :py:class:`List[ChatMessage] <mlflow.types.llm.ChatMessage>` and 624 :py:class:`ChatParams <mlflow.types.llm.ChatParams>` and outputs a 625 :py:class:`ChatCompletionResponse <mlflow.types.llm.ChatCompletionResponse>`, do the following: 626 627 - Subclass ``ChatAgent`` instead of ``ChatModel`` 628 - Move any functionality from your ``ChatModel``'s ``load_context`` implementation into the 629 ``__init__`` method of your new ``ChatAgent``. 630 - Use ``.model_dump()`` instead of ``.to_dict()`` when converting your model's inputs to 631 dictionaries. Ex. ``[msg.model_dump() for msg in messages]`` instead of 632 ``[msg.to_dict() for msg in messages]`` 633 - Return a :py:class:`ChatAgentResponse <mlflow.types.agent.ChatAgentResponse>` instead of a 634 :py:class:`ChatCompletionResponse <mlflow.types.llm.ChatCompletionResponse>` 635 636 For example, we can convert the ChatModel from the 637 `Chat Model Intro <https://mlflow.org/docs/latest/llms/chat-model-intro/index.html#building-your-first-chatmodel>`_ 638 to a ChatAgent: 639 640 .. code-block:: python 641 642 class SimpleOllamaModel(ChatModel): 643 def __init__(self): 644 self.model_name = "llama3.2:1b" 645 self.client = None 646 647 def load_context(self, context): 648 self.client = ollama.Client() 649 650 def predict( 651 self, context, messages: list[ChatMessage], params: ChatParams = None 652 ) -> ChatCompletionResponse: 653 ollama_messages = [msg.to_dict() for msg in messages] 654 response = self.client.chat(model=self.model_name, messages=ollama_messages) 655 return ChatCompletionResponse( 656 choices=[{"index": 0, "message": response["message"]}], 657 model=self.model_name, 658 ) 659 660 .. code-block:: python 661 662 class SimpleOllamaModel(ChatAgent): 663 def __init__(self): 664 self.model_name = "llama3.2:1b" 665 self.client = None 666 self.client = ollama.Client() 667 668 def predict( 669 self, 670 messages: list[ChatAgentMessage], 671 context: Optional[ChatContext] = None, 672 custom_inputs: Optional[dict[str, Any]] = None, 673 ) -> ChatAgentResponse: 674 ollama_messages = self._convert_messages_to_dict(messages) 675 response = self.client.chat(model=self.model_name, messages=ollama_messages) 676 return ChatAgentResponse(**{"messages": [response["message"]]}) 677 678 **ChatAgent Connectors** 679 680 MLflow provides convenience APIs for wrapping agents written in popular authoring frameworks 681 with ChatAgent. See examples for: 682 683 - LangGraph in the 684 :py:class:`ChatAgentState <mlflow.langchain.chat_agent_langgraph.ChatAgentState>` docstring 685 """ 686 687 _skip_type_hint_validation = True 688 689 def __init_subclass__(cls, **kwargs) -> None: 690 super().__init_subclass__(**kwargs) 691 for attr_name in ("predict", "predict_stream"): 692 attr = cls.__dict__.get(attr_name) 693 if callable(attr): 694 setattr( 695 cls, 696 attr_name, 697 wrap_non_list_predict_pydantic( 698 attr, 699 ChatAgentRequest, 700 "Invalid dictionary input for a ChatAgent. Expected a dictionary with the " 701 "ChatAgentRequest schema.", 702 unpack=True, 703 ), 704 ) 705 706 def _convert_messages_to_dict(self, messages: list[ChatAgentMessage]): 707 return [m.model_dump(exclude_none=True) for m in messages] 708 709 # nb: We use `messages` instead of `model_input` so that the trace generated by default is 710 # compatible with mlflow evaluate. We also want `custom_inputs` to be a top level key for 711 # ease of use. 712 @abstractmethod 713 def predict( 714 self, 715 messages: list[ChatAgentMessage], 716 context: ChatContext | None = None, 717 custom_inputs: dict[str, Any] | None = None, 718 ) -> ChatAgentResponse: 719 """ 720 Given a ChatAgent input, returns a ChatAgent output. In addition to calling ``predict`` 721 with an input matching the type hints, you can also pass a single input dict that matches 722 the :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` schema for ease 723 of testing. 724 725 .. code-block:: python 726 727 chat_agent = ChatAgent() 728 chat_agent.predict({ 729 "messages": [{"role": "user", "content": "What is 10 + 10?"}], 730 "context": {"conversation_id": "123", "user_id": "456"}, 731 }) 732 733 Args: 734 messages (List[:py:class:`ChatAgentMessage <mlflow.types.agent.ChatAgentMessage>`]): 735 A list of :py:class:`ChatAgentMessage <mlflow.types.agent.ChatAgentMessage>` 736 objects representing the chat history. 737 context (:py:class:`ChatContext <mlflow.types.agent.ChatContext>`): 738 A :py:class:`ChatContext <mlflow.types.agent.ChatContext>` object 739 containing conversation_id and user_id. **Optional** Defaults to None. 740 custom_inputs (Dict[str, Any]): 741 An optional param to provide arbitrary additional inputs 742 to the model. The dictionary values must be JSON-serializable. **Optional** 743 Defaults to None. 744 745 Returns: 746 A :py:class:`ChatAgentResponse <mlflow.types.agent.ChatAgentResponse>` object containing 747 the model's response, as well as other metadata. 748 """ 749 750 # nb: We use `messages` instead of `model_input` so that the trace generated by default is 751 # compatible with mlflow evaluate. We also want `custom_inputs` to be a top level key for 752 # ease of use. 753 def predict_stream( 754 self, 755 messages: list[ChatAgentMessage], 756 context: ChatContext | None = None, 757 custom_inputs: dict[str, Any] | None = None, 758 ) -> Generator[ChatAgentChunk, None, None]: 759 """ 760 Given a ChatAgent input, returns a generator containing streaming ChatAgent output chunks. 761 In addition to calling ``predict_stream`` with an input matching the type hints, you can 762 also pass a single input dict that matches the 763 :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` 764 schema for ease of testing. 765 766 .. code-block:: python 767 768 chat_agent = ChatAgent() 769 for event in chat_agent.predict_stream({ 770 "messages": [{"role": "user", "content": "What is 10 + 10?"}], 771 "context": {"conversation_id": "123", "user_id": "456"}, 772 }): 773 print(event) 774 775 To support streaming the output of your agent, override this method in your subclass of 776 ``ChatAgent``. When implementing ``predict_stream``, keep in mind the following 777 requirements: 778 779 - Ensure your implementation adheres to the ``predict_stream`` type signature. For example, 780 streamed messages must be of the type 781 :py:class:`ChatAgentChunk <mlflow.types.agent.ChatAgentChunk>`, where each chunk contains 782 partial output from a single response message. 783 - At most one chunk in a particular response can contain the ``custom_outputs`` key. 784 - Chunks containing partial content of a single response message must have the same ``id``. 785 The content field of the message and usage stats of the 786 :py:class:`ChatAgentChunk <mlflow.types.agent.ChatAgentChunk>` should be aggregated by 787 the consuming client. See the example below. 788 789 .. code-block:: python 790 791 {"delta": {"role": "assistant", "content": "Born", "id": "123"}} 792 {"delta": {"role": "assistant", "content": " in", "id": "123"}} 793 {"delta": {"role": "assistant", "content": " data", "id": "123"}} 794 795 796 Args: 797 messages (List[:py:class:`ChatAgentMessage <mlflow.types.agent.ChatAgentMessage>`]): 798 A list of :py:class:`ChatAgentMessage <mlflow.types.agent.ChatAgentMessage>` 799 objects representing the chat history. 800 context (:py:class:`ChatContext <mlflow.types.agent.ChatContext>`): 801 A :py:class:`ChatContext <mlflow.types.agent.ChatContext>` object 802 containing conversation_id and user_id. **Optional** Defaults to None. 803 custom_inputs (Dict[str, Any]): 804 An optional param to provide arbitrary additional inputs 805 to the model. The dictionary values must be JSON-serializable. **Optional** 806 Defaults to None. 807 808 Returns: 809 A generator over :py:class:`ChatAgentChunk <mlflow.types.agent.ChatAgentChunk>` 810 objects containing the model's response(s), as well as other metadata. 811 """ 812 raise NotImplementedError( 813 "Streaming implementation not provided. Please override the " 814 "`predict_stream` method on your model to generate streaming predictions" 815 ) 816 817 818 def _check_compression_supported(compression): 819 if compression in _COMPRESSION_INFO: 820 return True 821 if compression: 822 supported = ", ".join(sorted(_COMPRESSION_INFO)) 823 mlflow.pyfunc._logger.warning( 824 f"Unrecognized compression method '{compression}'" 825 f"Please select one of: {supported}. Falling back to uncompressed storage/loading." 826 ) 827 return False 828 829 830 def _maybe_compress_cloudpickle_dump(python_model, path, compression): 831 file_open = _COMPRESSION_INFO.get(compression, {}).get("open", open) 832 with file_open(path, "wb") as out: 833 cloudpickle.dump(python_model, out) 834 835 836 def _maybe_decompress_cloudpickle_load(path, compression): 837 _check_compression_supported(compression) 838 file_open = _COMPRESSION_INFO.get(compression, {}).get("open", open) 839 with file_open(path, "rb") as f: 840 return cloudpickle.load(f) 841 842 843 class ResponsesAgent(PythonModel, metaclass=ABCMeta): 844 """ 845 A base class for creating ResponsesAgent models. It can be used as a wrapper around any 846 agent framework to create an agent model that can be deployed to MLflow. Has a few helper 847 methods to help create output items that can be a part of a ResponsesAgentResponse or 848 ResponsesAgentStreamEvent. 849 850 See https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro for more details. 851 """ 852 853 _skip_type_hint_validation = True 854 855 @staticmethod 856 def responses_agent_output_reducer( 857 chunks: list[ResponsesAgentStreamEvent | dict[str, Any]], 858 ): 859 return responses_agent_output_reducer(chunks) 860 861 def __init_subclass__(cls, **kwargs) -> None: 862 super().__init_subclass__(**kwargs) 863 for attr_name in ("predict", "predict_stream"): 864 attr = cls.__dict__.get(attr_name) 865 if callable(attr): 866 # Only apply trace decorator if it is not already traced with mlflow.trace 867 if getattr(attr, "__mlflow_traced__", False): 868 mlflow.pyfunc._logger.warning( 869 f"You have manually traced {attr_name} with @mlflow.trace, but this is " 870 "unnecessary with ResponsesAgent subclasses. You can remove the " 871 "@mlflow.trace decorator and it will be automatically traced." 872 ) 873 traced_attr = attr 874 else: 875 # Apply trace decorator first 876 if attr_name == "predict_stream": 877 traced_attr = mlflow.trace( 878 span_type=SpanType.AGENT, 879 output_reducer=cls.responses_agent_output_reducer, 880 )(attr) 881 else: 882 traced_attr = mlflow.trace(span_type=SpanType.AGENT)(attr) 883 884 # Then wrap with pydantic wrapper 885 wrapped_attr = wrap_non_list_predict_pydantic( 886 traced_attr, 887 ResponsesAgentRequest, 888 "Invalid dictionary input for a ResponsesAgent. " 889 "Expected a dictionary with the ResponsesRequest schema.", 890 ) 891 setattr(cls, attr_name, wrapped_attr) 892 893 @abstractmethod 894 def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse: 895 """ 896 Given a ResponsesAgentRequest, returns a ResponsesAgentResponse. 897 898 You can see example implementations at 899 https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#simple-chat-example 900 and 901 https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#tool-calling-example. 902 """ 903 904 def predict_stream( 905 self, request: ResponsesAgentRequest 906 ) -> Generator[ResponsesAgentStreamEvent, None, None]: 907 """ 908 Given a ResponsesAgentRequest, returns a generator of ResponsesAgentStreamEvent objects. 909 910 See more details at 911 https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#streaming-agent-output. 912 913 You can see example implementations at 914 https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#simple-chat-example 915 and 916 https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#tool-calling-example. 917 """ 918 raise NotImplementedError( 919 "Streaming implementation not provided. Please override the " 920 "`predict_stream` method on your model to generate streaming predictions" 921 ) 922 923 @staticmethod 924 def create_text_delta(delta: str, item_id: str) -> dict[str, Any]: 925 """Helper method to create a dictionary conforming to the text delta schema for 926 streaming. 927 928 Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#streaming-agent-output. 929 """ 930 return create_text_delta(delta, item_id) 931 932 @staticmethod 933 def create_annotation_added( 934 item_id: str, annotation: dict[str, Any], annotation_index: int | None = 0 935 ) -> dict[str, Any]: 936 return create_annotation_added(item_id, annotation, annotation_index) 937 938 @staticmethod 939 def create_text_output_item( 940 text: str, id: str, annotations: list[dict[str, Any]] | None = None 941 ) -> dict[str, Any]: 942 """Helper method to create a dictionary conforming to the text output item schema. 943 944 Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output. 945 946 Args: 947 text (str): The text to be outputted. 948 id (str): The id of the output item. 949 annotations (Optional[list[dict]]): The annotations of the output item. 950 """ 951 return create_text_output_item(text, id, annotations) 952 953 @staticmethod 954 def create_reasoning_item(id: str, reasoning_text: str) -> dict[str, Any]: 955 """Helper method to create a dictionary conforming to the reasoning item schema. 956 957 Read more at https://www.mlflow.org/docs/latest/llms/responses-agent-intro/#creating-agent-output. 958 """ 959 return create_reasoning_item(id, reasoning_text) 960 961 @staticmethod 962 def create_function_call_item( 963 id: str, call_id: str, name: str, arguments: str 964 ) -> dict[str, Any]: 965 """Helper method to create a dictionary conforming to the function call item schema. 966 967 Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output. 968 969 Args: 970 id (str): The id of the output item. 971 call_id (str): The id of the function call. 972 name (str): The name of the function to be called. 973 arguments (str): The arguments to be passed to the function. 974 """ 975 return create_function_call_item(id, call_id, name, arguments) 976 977 @staticmethod 978 def create_function_call_output_item(call_id: str, output: str) -> dict[str, Any]: 979 """Helper method to create a dictionary conforming to the function call output item 980 schema. 981 982 Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output. 983 984 Args: 985 call_id (str): The id of the function call. 986 output (str): The output of the function call. 987 """ 988 return create_function_call_output_item(call_id, output) 989 990 @staticmethod 991 def _responses_to_cc(message: dict[str, Any]) -> list[dict[str, Any]]: 992 """Convert from a Responses API output item to a list of ChatCompletion messages.""" 993 return responses_to_cc(message) 994 995 @staticmethod 996 def prep_msgs_for_cc_llm( 997 responses_input: Sequence[dict[str, Any] | Message | OutputItem], 998 ) -> list[dict[str, Any]]: 999 "Convert from Responses input items to ChatCompletion dictionaries" 1000 return to_chat_completions_input(responses_input) 1001 1002 @staticmethod 1003 def output_to_responses_items_stream( 1004 chunks: Iterator[dict[str, Any]], aggregator: list[dict[str, Any]] | None = None 1005 ) -> Generator[ResponsesAgentStreamEvent, None, None]: 1006 """ 1007 For streaming, convert from various message format dicts to Responses output items, 1008 returning a generator of ResponsesAgentStreamEvent objects. 1009 1010 If `aggregator` is provided, it will be extended with the aggregated output item dicts. 1011 1012 For now, only handle a stream of Chat Completion chunks. 1013 """ 1014 yield from output_to_responses_items_stream(chunks, aggregator) 1015 1016 1017 def _save_model_with_class_artifacts_params( 1018 path, 1019 python_model, 1020 signature=None, 1021 artifacts=None, 1022 conda_env=None, 1023 code_paths=None, 1024 mlflow_model=None, 1025 pip_requirements=None, 1026 extra_pip_requirements=None, 1027 model_config=None, 1028 streamable=None, 1029 model_code_path=None, 1030 infer_code_paths=False, 1031 uv_project_path=None, 1032 uv_groups=None, 1033 uv_extras=None, 1034 ): 1035 """ 1036 Args: 1037 path: The path to which to save the Python model. 1038 python_model: An instance of a subclass of :class:`~PythonModel`. ``python_model`` 1039 defines how the model loads artifacts and how it performs inference. 1040 artifacts: A dictionary containing ``<name, artifact_uri>`` entries. Remote artifact URIs 1041 are resolved to absolute filesystem paths, producing a dictionary of 1042 ``<name, absolute_path>`` entries, (e.g. {"file": "absolute_path"}). 1043 ``python_model`` can reference these resolved entries as the ``artifacts`` property 1044 of the ``context`` attribute. If ``<artifact_name, 'hf:/repo_id'>``(e.g. 1045 {"bert-tiny-model": "hf:/prajjwal1/bert-tiny"}) is provided, then the model can be 1046 fetched from huggingface hub using repo_id `prajjwal1/bert-tiny` directly. If ``None``, 1047 no artifacts are added to the model. 1048 conda_env: Either a dictionary representation of a Conda environment or the path to a Conda 1049 environment yaml file. If provided, this describes the environment this model should be 1050 run in. At minimum, it should specify the dependencies contained in 1051 :func:`get_default_conda_env()`. If ``None``, the default 1052 :func:`get_default_conda_env()` environment is added to the model. 1053 code_paths: A list of local filesystem paths to Python file dependencies (or directories 1054 containing file dependencies). These files are *prepended* to the system path before the 1055 model is loaded. 1056 mlflow_model: The model to which to add the ``mlflow.pyfunc`` flavor. 1057 model_config: The model configuration for the flavor. Model configuration is available 1058 during model loading time. 1059 1060 .. Note:: Experimental: This parameter may change or be removed in a future release 1061 without warning. 1062 1063 model_code_path: The path to the code that is being logged as a PyFunc model. Can be used 1064 to load python_model when python_model is None. 1065 1066 .. Note:: Experimental: This parameter may change or be removed in a future release 1067 without warning. 1068 1069 streamable: A boolean value indicating if the model supports streaming prediction, 1070 If None, MLflow will try to inspect if the model supports streaming 1071 by checking if `predict_stream` method exists. Default None. 1072 """ 1073 # Capture original working directory for uv project detection 1074 # This must be done before any operations that might change cwd 1075 original_cwd = Path.cwd() 1076 1077 if mlflow_model is None: 1078 mlflow_model = Model() 1079 1080 custom_model_config_kwargs = { 1081 CONFIG_KEY_CLOUDPICKLE_VERSION: cloudpickle.__version__, 1082 } 1083 if callable(python_model): 1084 python_model = _FunctionPythonModel(func=python_model, signature=signature) 1085 1086 saved_python_model_subpath = _SAVED_PYTHON_MODEL_SUBPATH 1087 1088 compression = MLFLOW_LOG_MODEL_COMPRESSION.get() 1089 if compression: 1090 if _check_compression_supported(compression): 1091 custom_model_config_kwargs[CONFIG_KEY_COMPRESSION] = compression 1092 saved_python_model_subpath += _COMPRESSION_INFO[compression]["ext"] 1093 else: 1094 compression = None 1095 1096 # If model_code_path is defined, we load the model into python_model, but we don't want to 1097 # pickle/save the python_model since the module won't be able to be imported. 1098 if not model_code_path: 1099 try: 1100 _maybe_compress_cloudpickle_dump( 1101 python_model, os.path.join(path, saved_python_model_subpath), compression 1102 ) 1103 except Exception as e: 1104 # error_code is INVALID_PARAMETER_VALUE but this is a model serialization failure 1105 raise MlflowException( 1106 "Failed to serialize Python model. Please save the model into a python file " 1107 "and use code-based logging method instead. See" 1108 "https://mlflow.org/docs/latest/models.html#models-from-code for more information.", 1109 error_code=INVALID_PARAMETER_VALUE, 1110 error_class="MODEL_SERIALIZATION_FAILED", 1111 ) from e 1112 1113 custom_model_config_kwargs[CONFIG_KEY_PYTHON_MODEL] = saved_python_model_subpath 1114 1115 if artifacts: 1116 saved_artifacts_config = {} 1117 with TempDir() as tmp_artifacts_dir: 1118 saved_artifacts_dir_subpath = "artifacts" 1119 hf_prefix = "hf:/" 1120 for artifact_name, artifact_uri in artifacts.items(): 1121 if artifact_uri.startswith(hf_prefix): 1122 try: 1123 from huggingface_hub import snapshot_download 1124 except ImportError as e: 1125 raise MlflowException( 1126 "Failed to import huggingface_hub. Please install huggingface_hub " 1127 f"to log the model with artifact_uri {artifact_uri}. Error: {e}" 1128 ) 1129 1130 repo_id = artifact_uri[len(hf_prefix) :] 1131 try: 1132 snapshot_location = snapshot_download( 1133 repo_id=repo_id, 1134 local_dir=os.path.join( 1135 path, saved_artifacts_dir_subpath, artifact_name 1136 ), 1137 local_dir_use_symlinks=False, 1138 ) 1139 except Exception as e: 1140 raise MlflowException.invalid_parameter_value( 1141 "Failed to download snapshot from Hugging Face Hub with artifact_uri: " 1142 f"{artifact_uri}. Error: {e}" 1143 ) 1144 saved_artifact_subpath = ( 1145 Path(snapshot_location).relative_to(Path(os.path.realpath(path))).as_posix() 1146 ) 1147 else: 1148 tmp_artifact_path = _download_artifact_from_uri( 1149 artifact_uri=artifact_uri, output_path=tmp_artifacts_dir.path() 1150 ) 1151 1152 relative_path = ( 1153 Path(tmp_artifact_path) 1154 .relative_to(Path(tmp_artifacts_dir.path())) 1155 .as_posix() 1156 ) 1157 1158 saved_artifact_subpath = os.path.join( 1159 saved_artifacts_dir_subpath, relative_path 1160 ) 1161 1162 saved_artifacts_config[artifact_name] = { 1163 CONFIG_KEY_ARTIFACT_RELATIVE_PATH: saved_artifact_subpath, 1164 CONFIG_KEY_ARTIFACT_URI: artifact_uri, 1165 } 1166 1167 shutil.move(tmp_artifacts_dir.path(), os.path.join(path, saved_artifacts_dir_subpath)) 1168 custom_model_config_kwargs[CONFIG_KEY_ARTIFACTS] = saved_artifacts_config 1169 1170 if streamable is None: 1171 streamable = python_model.__class__.predict_stream != PythonModel.predict_stream 1172 1173 if model_code_path: 1174 loader_module = mlflow.pyfunc.loaders.code_model.__name__ 1175 elif python_model: 1176 loader_module = _get_pyfunc_loader_module(python_model) 1177 else: 1178 raise MlflowException( 1179 "Either `python_model` or `model_code_path` must be provided to save the model.", 1180 error_code=INVALID_PARAMETER_VALUE, 1181 ) 1182 1183 mlflow.pyfunc.add_to_model( 1184 model=mlflow_model, 1185 loader_module=loader_module, 1186 code=None, 1187 conda_env=_CONDA_ENV_FILE_NAME, 1188 python_env=_PYTHON_ENV_FILE_NAME, 1189 model_config=model_config, 1190 streamable=streamable, 1191 model_code_path=model_code_path, 1192 **custom_model_config_kwargs, 1193 ) 1194 if size := get_total_file_size(path): 1195 mlflow_model.model_size_bytes = size 1196 # `mlflow_model.save` must be called before _validate_infer_and_copy_code_paths as it 1197 # internally infers dependency, and MLmodel file is required to successfully load the model 1198 mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME)) 1199 1200 saved_code_subpath = _validate_infer_and_copy_code_paths( 1201 code_paths, 1202 path, 1203 infer_code_paths, 1204 mlflow.pyfunc.FLAVOR_NAME, 1205 ) 1206 mlflow_model.flavors[mlflow.pyfunc.FLAVOR_NAME][mlflow.pyfunc.CODE] = saved_code_subpath 1207 1208 # `mlflow_model.code` is updated, re-generate `MLmodel` file. 1209 mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME)) 1210 1211 if uv_project_path is not None: 1212 uv_source_dir = uv_project_path 1213 elif MLFLOW_UV_AUTO_DETECT.get(): 1214 uv_source_dir = original_cwd 1215 else: 1216 uv_source_dir = None 1217 1218 if conda_env is None: 1219 if pip_requirements is None: 1220 default_reqs = get_default_pip_requirements() 1221 extra_env_vars = ( 1222 _get_databricks_serverless_env_vars() 1223 if is_in_databricks_serverless_runtime() 1224 else None 1225 ) 1226 # To ensure `_load_pyfunc` can successfully load the model during the dependency 1227 # inference, `mlflow_model.save` must be called beforehand to save an MLmodel file. 1228 inferred_reqs = mlflow.models.infer_pip_requirements( 1229 path, 1230 mlflow.pyfunc.FLAVOR_NAME, 1231 fallback=default_reqs, 1232 extra_env_vars=extra_env_vars, 1233 uv_project_dir=uv_source_dir, 1234 uv_groups=uv_groups, 1235 uv_extras=uv_extras, 1236 ) 1237 default_reqs = sorted(set(inferred_reqs).union(default_reqs)) 1238 else: 1239 default_reqs = None 1240 conda_env, pip_requirements, pip_constraints = _process_pip_requirements( 1241 default_reqs, 1242 pip_requirements, 1243 extra_pip_requirements, 1244 ) 1245 else: 1246 conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env) 1247 1248 with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f: 1249 yaml.safe_dump(conda_env, stream=f, default_flow_style=False) 1250 1251 # Save `constraints.txt` if necessary 1252 if pip_constraints: 1253 write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints)) 1254 1255 # Save `requirements.txt` 1256 write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements)) 1257 1258 # Copy uv project files (uv.lock and pyproject.toml) if detected 1259 if uv_source_dir is not None: 1260 copy_uv_project_files(dest_dir=path, source_dir=uv_source_dir) 1261 1262 _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME)) 1263 1264 1265 def _load_context_model_and_signature(model_path: str, model_config: dict[str, Any] | None = None): 1266 pyfunc_config = _get_flavor_configuration( 1267 model_path=model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME 1268 ) 1269 signature = mlflow.models.Model.load(model_path).signature 1270 1271 if MODEL_CODE_PATH in pyfunc_config: 1272 conf_model_code_path = pyfunc_config.get(MODEL_CODE_PATH) 1273 model_code_path = os.path.join(model_path, os.path.basename(conf_model_code_path)) 1274 python_model = _load_model_code_path(model_code_path, model_config) 1275 1276 if callable(python_model): 1277 python_model = _FunctionPythonModel(python_model, signature=signature) 1278 else: 1279 if ( 1280 not MLFLOW_ALLOW_PICKLE_DESERIALIZATION.get() 1281 and not is_in_databricks_runtime() 1282 and not is_in_databricks_model_serving_environment() 1283 ): 1284 raise MlflowException( 1285 "Deserializing model using pickle is disallowed, but this model is saved " 1286 "in cloudpickle format. The recommended way is to save the model as " 1287 "models-from-code artifacts, see " 1288 "https://mlflow.org/docs/latest/ml/model/models-from-code/ for details. " 1289 "Another workaround is to set environment " 1290 "variable 'MLFLOW_ALLOW_PICKLE_DESERIALIZATION' to 'true' to allow " 1291 "deserializing model using pickle." 1292 ) 1293 python_model_cloudpickle_version = pyfunc_config.get(CONFIG_KEY_CLOUDPICKLE_VERSION, None) 1294 if python_model_cloudpickle_version is None: 1295 mlflow.pyfunc._logger.warning( 1296 "The version of CloudPickle used to save the model could not be found in the " 1297 "MLmodel configuration" 1298 ) 1299 elif python_model_cloudpickle_version != cloudpickle.__version__: 1300 # CloudPickle does not have a well-defined cross-version compatibility policy. Micro 1301 # version releases have been known to cause incompatibilities. Therefore, we match on 1302 # the full library version 1303 mlflow.pyfunc._logger.warning( 1304 "The version of CloudPickle that was used to save the model, `CloudPickle %s`, " 1305 "differs from the version of CloudPickle that is currently running, `CloudPickle " 1306 "%s`, and may be incompatible", 1307 python_model_cloudpickle_version, 1308 cloudpickle.__version__, 1309 ) 1310 python_model_compression = pyfunc_config.get(CONFIG_KEY_COMPRESSION, None) 1311 1312 python_model_subpath = pyfunc_config.get(CONFIG_KEY_PYTHON_MODEL, None) 1313 if python_model_subpath is None: 1314 raise MlflowException("Python model path was not specified in the model configuration") 1315 python_model = _maybe_decompress_cloudpickle_load( 1316 os.path.join(model_path, python_model_subpath), python_model_compression 1317 ) 1318 1319 artifacts = {} 1320 for saved_artifact_name, saved_artifact_info in pyfunc_config.get( 1321 CONFIG_KEY_ARTIFACTS, {} 1322 ).items(): 1323 artifacts[saved_artifact_name] = os.path.join( 1324 model_path, saved_artifact_info[CONFIG_KEY_ARTIFACT_RELATIVE_PATH] 1325 ) 1326 1327 context = PythonModelContext(artifacts=artifacts, model_config=model_config) 1328 python_model.load_context(context=context) 1329 1330 return context, python_model, signature 1331 1332 1333 def _load_pyfunc(model_path: str, model_config: dict[str, Any] | None = None): 1334 context, python_model, signature = _load_context_model_and_signature(model_path, model_config) 1335 return _PythonModelPyfuncWrapper( 1336 python_model=python_model, 1337 context=context, 1338 signature=signature, 1339 ) 1340 1341 1342 def _get_first_string_column(pdf): 1343 iter_string_columns = (col for col, val in pdf.iloc[0].items() if isinstance(val, str)) 1344 return next(iter_string_columns, None) 1345 1346 1347 class _PythonModelPyfuncWrapper: 1348 """ 1349 Wrapper class that creates a predict function such that 1350 predict(model_input: pd.DataFrame) -> model's output as pd.DataFrame (pandas DataFrame) 1351 """ 1352 1353 def __init__(self, python_model: PythonModel, context, signature): 1354 """ 1355 Args: 1356 python_model: An instance of a subclass of :class:`~PythonModel`. 1357 context: A :class:`~PythonModelContext` instance containing artifacts that 1358 ``python_model`` may use when performing inference. 1359 signature: :class:`~ModelSignature` instance describing model input and output. 1360 """ 1361 self.python_model = python_model 1362 self.context = context 1363 self.signature = signature 1364 1365 def _convert_input(self, model_input): 1366 hints = self.python_model.predict_type_hints 1367 # we still need this for backwards compatibility 1368 if isinstance(model_input, pd.DataFrame): 1369 if _is_list_str(hints.input): 1370 first_string_column = _get_first_string_column(model_input) 1371 if first_string_column is None: 1372 raise MlflowException.invalid_parameter_value( 1373 "Expected model input to contain at least one string column" 1374 ) 1375 return model_input[first_string_column].tolist() 1376 elif _is_list_dict_str(hints.input): 1377 if ( 1378 len(self.signature.inputs) == 1 1379 and next(iter(self.signature.inputs)).name is None 1380 ): 1381 if first_string_column := _get_first_string_column(model_input): 1382 return model_input[[first_string_column]].to_dict(orient="records") 1383 if len(model_input.columns) == 1: 1384 return model_input.to_dict("list")[0] 1385 return model_input.to_dict(orient="records") 1386 elif isinstance(hints.input, type) and ( 1387 issubclass(hints.input, ChatCompletionRequest) 1388 or issubclass(hints.input, SplitChatMessagesRequest) 1389 ): 1390 # If the type hint is a RAG dataclass, we hydrate it 1391 # If there are multiple rows, we should throw 1392 if len(model_input) > 1: 1393 raise MlflowException( 1394 "Expected a single input for dataclass type hint, but got multiple rows" 1395 ) 1396 # Since single input is expected, we take the first row 1397 return _hydrate_dataclass(hints.input, model_input.iloc[0]) 1398 return model_input 1399 1400 def predict(self, model_input, params: dict[str, Any] | None = None): 1401 """ 1402 Args: 1403 model_input: Model input data as one of dict, str, bool, bytes, float, int, str type. 1404 params: Additional parameters to pass to the model for inference. 1405 1406 Returns: 1407 Model predictions as an iterator of chunks. The chunks in the iterator must be type of 1408 dict or string. Chunk dict fields are determined by the model implementation. 1409 """ 1410 parameters = inspect.signature(self.python_model.predict).parameters 1411 kwargs = {} 1412 if "params" in parameters: 1413 kwargs["params"] = params 1414 else: 1415 _log_warning_if_params_not_in_predict_signature(_logger, params) 1416 if _is_context_in_predict_function_signature(parameters=parameters): 1417 return self.python_model.predict( 1418 self.context, self._convert_input(model_input), **kwargs 1419 ) 1420 else: 1421 return self.python_model.predict(self._convert_input(model_input), **kwargs) 1422 1423 def predict_stream(self, model_input, params: dict[str, Any] | None = None): 1424 """ 1425 Args: 1426 model_input: LLM Model single input. 1427 params: Additional parameters to pass to the model for inference. 1428 1429 Returns: 1430 Streaming predictions. 1431 """ 1432 parameters = inspect.signature(self.python_model.predict_stream).parameters 1433 kwargs = {} 1434 if "params" in parameters: 1435 kwargs["params"] = params 1436 else: 1437 _log_warning_if_params_not_in_predict_signature(_logger, params) 1438 if _is_context_in_predict_function_signature(parameters=parameters): 1439 return self.python_model.predict_stream( 1440 self.context, self._convert_input(model_input), **kwargs 1441 ) 1442 else: 1443 return self.python_model.predict_stream(self._convert_input(model_input), **kwargs) 1444 1445 1446 def _get_pyfunc_loader_module(python_model): 1447 if isinstance(python_model, ChatModel): 1448 return mlflow.pyfunc.loaders.chat_model.__name__ 1449 elif isinstance(python_model, ChatAgent): 1450 return mlflow.pyfunc.loaders.chat_agent.__name__ 1451 elif isinstance(python_model, ResponsesAgent): 1452 return mlflow.pyfunc.loaders.responses_agent.__name__ 1453 return __name__ 1454 1455 1456 class ModelFromDeploymentEndpoint(PythonModel): 1457 """ 1458 A PythonModel wrapper for invoking an MLflow Deployments endpoint. 1459 This class is particularly used for running evaluation against an MLflow Deployments endpoint. 1460 """ 1461 1462 def __init__(self, endpoint, params): 1463 self.endpoint = endpoint 1464 self.params = params 1465 1466 def predict(self, context, model_input: pd.DataFrame | dict[str, Any] | list[dict[str, Any]]): 1467 """ 1468 Run prediction on the input data. 1469 1470 Args: 1471 context: A :class:`~PythonModelContext` instance containing artifacts that the model 1472 can use to perform inference. 1473 model_input: The input data for prediction, either of the following: 1474 - Pandas DataFrame: If the default evaluator is used, input is a DF 1475 that contains the multiple request payloads in a single column. 1476 - A dictionary: If the model_type is "databricks-agents" and the 1477 Databricks RAG evaluator is used, this PythonModel can be invoked 1478 with a single dict corresponding to the ChatCompletionsRequest schema. 1479 - A list of dictionaries: Currently we don't have any evaluator that 1480 gives this input format, but we keep this for future use cases and 1481 compatibility with normal pyfunc models. 1482 1483 Return: 1484 The prediction result. The return type will be consistent with the model input type, 1485 e.g., if the input is a Pandas DataFrame, the return will be a Pandas Series. 1486 """ 1487 if isinstance(model_input, dict): 1488 return self._predict_single(model_input) 1489 elif isinstance(model_input, list) and all(isinstance(data, dict) for data in model_input): 1490 return [self._predict_single(data) for data in model_input] 1491 elif isinstance(model_input, pd.DataFrame): 1492 if len(model_input.columns) != 1: 1493 raise MlflowException( 1494 f"The number of input columns must be 1, but got {model_input.columns}. " 1495 "Multi-column input is not supported for evaluating an MLflow Deployments " 1496 "endpoint. Please include the input text or payload in a single column.", 1497 error_code=INVALID_PARAMETER_VALUE, 1498 ) 1499 input_column = model_input.columns[0] 1500 1501 predictions = [self._predict_single(data) for data in model_input[input_column]] 1502 return pd.Series(predictions) 1503 else: 1504 raise MlflowException( 1505 f"Invalid input data type: {type(model_input)}. The input data must be either " 1506 "a Pandas DataFrame, a dictionary, or a list of dictionaries containing the " 1507 "request payloads for evaluating an MLflow Deployments endpoint.", 1508 error_code=INVALID_PARAMETER_VALUE, 1509 ) 1510 1511 def _predict_single(self, data: str | dict[str, Any]) -> dict[str, Any]: 1512 """ 1513 Send a single prediction request to the MLflow Deployments endpoint. 1514 1515 Args: 1516 data: The single input data for prediction. If the input data is a string, we will 1517 construct the request payload from it. If the input data is a dictionary, we 1518 will directly use it as the request payload. 1519 1520 Returns: 1521 The prediction result from the MLflow Deployments endpoint as a dictionary. 1522 """ 1523 from mlflow.metrics.genai.model_utils import call_deployments_api, get_endpoint_type 1524 1525 endpoint_type = get_endpoint_type(f"endpoints:/{self.endpoint}") 1526 1527 if isinstance(data, str): 1528 # If the input payload is string, MLflow needs to construct the JSON 1529 # payload based on the endpoint type. If the endpoint type is not 1530 # set on the endpoint, we will default to chat format. 1531 endpoint_type = endpoint_type or "llm/v1/chat" 1532 prediction = call_deployments_api(self.endpoint, data, self.params, endpoint_type) 1533 elif isinstance(data, dict): 1534 # If the input is dictionary, we assume the input is already in the 1535 # compatible format for the endpoint. 1536 prediction = call_deployments_api(self.endpoint, data, self.params, endpoint_type) 1537 else: 1538 raise MlflowException( 1539 f"Invalid input data type: {type(data)}. The feature column of the evaluation " 1540 "dataset must contain only strings or dictionaries containing the request " 1541 "payload for evaluating an MLflow Deployments endpoint.", 1542 error_code=INVALID_PARAMETER_VALUE, 1543 ) 1544 return prediction