/ mlflow / pyfunc / model.py
model.py
   1  """
   2  The ``mlflow.pyfunc.model`` module defines logic for saving and loading custom "python_function"
   3  models with a user-defined ``PythonModel`` subclass.
   4  """
   5  
   6  import bz2
   7  import gzip
   8  import inspect
   9  import logging
  10  import lzma
  11  import os
  12  import shutil
  13  from abc import ABCMeta, abstractmethod
  14  from collections.abc import Sequence
  15  from pathlib import Path
  16  from typing import Any, Generator, Iterator
  17  
  18  import cloudpickle
  19  import pandas as pd
  20  import yaml
  21  
  22  import mlflow.pyfunc
  23  from mlflow.entities.span import SpanType
  24  from mlflow.environment_variables import (
  25      MLFLOW_ALLOW_PICKLE_DESERIALIZATION,
  26      MLFLOW_LOG_MODEL_COMPRESSION,
  27      MLFLOW_UV_AUTO_DETECT,
  28  )
  29  from mlflow.exceptions import MlflowException
  30  from mlflow.models import Model
  31  from mlflow.models.model import MLMODEL_FILE_NAME, MODEL_CODE_PATH
  32  from mlflow.models.rag_signatures import ChatCompletionRequest, SplitChatMessagesRequest
  33  from mlflow.models.signature import (
  34      _extract_type_hints,
  35      _is_context_in_predict_function_signature,
  36      _TypeHints,
  37  )
  38  from mlflow.models.utils import _load_model_code_path
  39  from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
  40  from mlflow.pyfunc.utils import pyfunc
  41  from mlflow.pyfunc.utils.data_validation import (
  42      _check_func_signature,
  43      _get_func_info_if_type_hint_supported,
  44      _wrap_predict_with_pyfunc,
  45      wrap_non_list_predict_pydantic,
  46  )
  47  from mlflow.pyfunc.utils.input_converter import _hydrate_dataclass
  48  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
  49  from mlflow.types.agent import (
  50      ChatAgentChunk,
  51      ChatAgentMessage,
  52      ChatAgentRequest,
  53      ChatAgentResponse,
  54      ChatContext,
  55  )
  56  from mlflow.types.llm import (
  57      ChatCompletionChunk,
  58      ChatCompletionResponse,
  59      ChatMessage,
  60      ChatParams,
  61  )
  62  from mlflow.types.responses import (
  63      Message,
  64      OutputItem,
  65      ResponsesAgentRequest,
  66      ResponsesAgentResponse,
  67      ResponsesAgentStreamEvent,
  68      create_annotation_added,
  69      create_function_call_item,
  70      create_function_call_output_item,
  71      create_reasoning_item,
  72      create_text_delta,
  73      create_text_output_item,
  74      output_to_responses_items_stream,
  75      responses_agent_output_reducer,
  76      responses_to_cc,
  77      to_chat_completions_input,
  78  )
  79  from mlflow.types.utils import _is_list_dict_str, _is_list_str
  80  from mlflow.utils.annotations import deprecated
  81  from mlflow.utils.databricks_utils import (
  82      _get_databricks_serverless_env_vars,
  83      is_in_databricks_model_serving_environment,
  84      is_in_databricks_runtime,
  85      is_in_databricks_serverless_runtime,
  86  )
  87  from mlflow.utils.environment import (
  88      _CONDA_ENV_FILE_NAME,
  89      _CONSTRAINTS_FILE_NAME,
  90      _PYTHON_ENV_FILE_NAME,
  91      _REQUIREMENTS_FILE_NAME,
  92      _mlflow_conda_env,
  93      _process_conda_env,
  94      _process_pip_requirements,
  95      _PythonEnv,
  96  )
  97  from mlflow.utils.file_utils import TempDir, get_total_file_size, write_to
  98  from mlflow.utils.model_utils import _get_flavor_configuration, _validate_infer_and_copy_code_paths
  99  from mlflow.utils.requirements_utils import _get_pinned_requirement
 100  from mlflow.utils.uv_utils import copy_uv_project_files
 101  
 102  CONFIG_KEY_ARTIFACTS = "artifacts"
 103  CONFIG_KEY_ARTIFACT_RELATIVE_PATH = "path"
 104  CONFIG_KEY_ARTIFACT_URI = "uri"
 105  CONFIG_KEY_PYTHON_MODEL = "python_model"
 106  CONFIG_KEY_CLOUDPICKLE_VERSION = "cloudpickle_version"
 107  CONFIG_KEY_COMPRESSION = "python_model_compression"
 108  _SAVED_PYTHON_MODEL_SUBPATH = "python_model.pkl"
 109  _DEFAULT_CHAT_MODEL_METADATA_TASK = "agent/v1/chat"
 110  _DEFAULT_CHAT_AGENT_METADATA_TASK = "agent/v2/chat"
 111  _COMPRESSION_INFO = {
 112      "lzma": {"ext": ".xz", "open": lzma.open},
 113      "bzip2": {"ext": ".bz2", "open": bz2.open},
 114      "gzip": {"ext": ".gz", "open": gzip.open},
 115  }
 116  _DEFAULT_RESPONSES_AGENT_METADATA_TASK = "agent/v1/responses"
 117  
 118  _logger = logging.getLogger(__name__)
 119  
 120  
 121  def get_default_pip_requirements():
 122      """
 123      Returns:
 124          A list of default pip requirements for MLflow Models produced by this flavor. Calls to
 125          :func:`save_model()` and :func:`log_model()` produce a pip environment that, at minimum,
 126          contains these requirements.
 127      """
 128      return [_get_pinned_requirement("cloudpickle")]
 129  
 130  
 131  def get_default_conda_env():
 132      """
 133      Returns:
 134          The default Conda environment for MLflow Models produced by calls to
 135          :func:`save_model() <mlflow.pyfunc.save_model>`
 136          and :func:`log_model() <mlflow.pyfunc.log_model>` when a user-defined subclass of
 137          :class:`PythonModel` is provided.
 138      """
 139      return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements())
 140  
 141  
 142  def _log_warning_if_params_not_in_predict_signature(logger, params):
 143      if params:
 144          logger.warning(
 145              "The underlying model does not support passing additional parameters to the predict"
 146              f" function. `params` {params} will be ignored."
 147          )
 148  
 149  
 150  class PythonModel:
 151      """
 152      Represents a generic Python model that evaluates inputs and produces API-compatible outputs.
 153      By subclassing :class:`~PythonModel`, users can create customized MLflow models with the
 154      "python_function" ("pyfunc") flavor, leveraging custom inference logic and artifact
 155      dependencies.
 156      """
 157  
 158      __metaclass__ = ABCMeta
 159  
 160      def load_context(self, context):
 161          """
 162          Loads artifacts from the specified :class:`~PythonModelContext` that can be used by
 163          :func:`~PythonModel.predict` when evaluating inputs. When loading an MLflow model with
 164          :func:`~load_model`, this method is called as soon as the :class:`~PythonModel` is
 165          constructed.
 166  
 167          The same :class:`~PythonModelContext` will also be available during calls to
 168          :func:`~PythonModel.predict`, but it may be more efficient to override this method
 169          and load artifacts from the context at model load time.
 170  
 171          Args:
 172              context: A :class:`~PythonModelContext` instance containing artifacts that the model
 173                       can use to perform inference.
 174          """
 175  
 176      @deprecated("predict_type_hints", "2.20.0")
 177      def _get_type_hints(self):
 178          return self.predict_type_hints
 179  
 180      @property
 181      def predict_type_hints(self) -> _TypeHints:
 182          """
 183          Internal method to get type hints from the predict function signature.
 184          """
 185          if hasattr(self, "_predict_type_hints"):
 186              return self._predict_type_hints
 187          if _is_context_in_predict_function_signature(func=self.predict):
 188              self._predict_type_hints = _extract_type_hints(self.predict, input_arg_index=1)
 189          else:
 190              self._predict_type_hints = _extract_type_hints(self.predict, input_arg_index=0)
 191          return self._predict_type_hints
 192  
 193      def __init_subclass__(cls, **kwargs) -> None:
 194          super().__init_subclass__(**kwargs)
 195  
 196          # automatically wrap the predict method with pyfunc to ensure data validation
 197          # NB: skip wrapping for built-in classes defined in MLflow e.g. ChatModel
 198          if not cls.__module__.startswith("mlflow."):
 199              #  TODO: ChatModel uses dataclass type hints which are not supported now, hence
 200              #    we need to skip type hint based validation for user-defined subclasses
 201              #    of ChatModel. Once we either (1) support dataclass type hints or (2) migrate
 202              #    ChatModel to pydantic, we can remove this exclusion.
 203              #    NB: issubclass(cls, ChatModel) does not work so we use a hacky attribute check
 204              if getattr(cls, "_skip_type_hint_validation", False):
 205                  return
 206  
 207              predict_attr = cls.__dict__.get("predict")
 208              if predict_attr is not None and callable(predict_attr):
 209                  func_info = _get_func_info_if_type_hint_supported(predict_attr)
 210                  setattr(cls, "predict", _wrap_predict_with_pyfunc(predict_attr, func_info))
 211              predict_stream_attr = cls.__dict__.get("predict_stream")
 212              if predict_stream_attr is not None and callable(predict_stream_attr):
 213                  _check_func_signature(predict_stream_attr, "predict_stream")
 214          else:
 215              cls.predict._is_pyfunc = True
 216  
 217      @abstractmethod
 218      def predict(self, context, model_input, params: dict[str, Any] | None = None):
 219          """
 220          Evaluates a pyfunc-compatible input and produces a pyfunc-compatible output.
 221          For more information about the pyfunc input/output API, see the :ref:`pyfunc-inference-api`.
 222  
 223          Args:
 224              context: A :class:`~PythonModelContext` instance containing artifacts that the model
 225                       can use to perform inference.
 226              model_input: A pyfunc-compatible input for the model to evaluate.
 227              params: Additional parameters to pass to the model for inference.
 228  
 229          .. tip::
 230              Since MLflow 2.20.0, `context` parameter can be removed from `predict` function
 231              signature if it's not used. `def predict(self, model_input, params=None)` is valid.
 232          """
 233  
 234      def predict_stream(self, context, model_input, params: dict[str, Any] | None = None):
 235          """
 236          Evaluates a pyfunc-compatible input and produces an iterator of output.
 237          For more information about the pyfunc input API, see the :ref:`pyfunc-inference-api`.
 238  
 239          Args:
 240              context: A :class:`~PythonModelContext` instance containing artifacts that the model
 241                       can use to perform inference.
 242              model_input: A pyfunc-compatible input for the model to evaluate.
 243              params: Additional parameters to pass to the model for inference.
 244  
 245          .. tip::
 246              Since MLflow 2.20.0, `context` parameter can be removed from `predict_stream` function
 247              signature if it's not used.
 248              `def predict_stream(self, model_input, params=None)` is valid.
 249          """
 250          raise NotImplementedError()
 251  
 252  
 253  class _FunctionPythonModel(PythonModel):
 254      """
 255      When a user specifies a ``python_model`` argument that is a function, we wrap the function
 256      in an instance of this class.
 257      """
 258  
 259      def __init__(self, func, signature=None):
 260          self.signature = signature
 261          # only wrap `func` if @pyfunc is not already applied
 262          if not getattr(func, "_is_pyfunc", False):
 263              self.func = pyfunc(func)
 264          else:
 265              self.func = func
 266  
 267      @property
 268      def predict_type_hints(self):
 269          if hasattr(self, "_predict_type_hints"):
 270              return self._predict_type_hints
 271          self._predict_type_hints = _extract_type_hints(self.func, input_arg_index=0)
 272          return self._predict_type_hints
 273  
 274      def predict(
 275          self,
 276          model_input,
 277          params: dict[str, Any] | None = None,
 278      ):
 279          """
 280          Args:
 281              model_input: A pyfunc-compatible input for the model to evaluate.
 282              params: Additional parameters to pass to the model for inference.
 283  
 284          Returns:
 285              Model predictions.
 286          """
 287          # callable only supports one input argument for now
 288          return self.func(model_input)
 289  
 290  
 291  class PythonModelContext:
 292      """
 293      A collection of artifacts that a :class:`~PythonModel` can use when performing inference.
 294      :class:`~PythonModelContext` objects are created *implicitly* by the
 295      :func:`save_model() <mlflow.pyfunc.save_model>` and
 296      :func:`log_model() <mlflow.pyfunc.log_model>` persistence methods, using the contents specified
 297      by the ``artifacts`` parameter of these methods.
 298      """
 299  
 300      def __init__(self, artifacts, model_config):
 301          """
 302          Args:
 303              artifacts: A dictionary of ``<name, artifact_path>`` entries, where ``artifact_path``
 304                  is an absolute filesystem path to a given artifact.
 305              model_config: The model configuration to make available to the model at
 306                  loading time.
 307          """
 308          self._artifacts = artifacts
 309          self._model_config = model_config
 310  
 311      @property
 312      def artifacts(self):
 313          """
 314          A dictionary containing ``<name, artifact_path>`` entries, where ``artifact_path`` is an
 315          absolute filesystem path to the artifact.
 316          """
 317          return self._artifacts
 318  
 319      @property
 320      def model_config(self):
 321          """
 322          A dictionary containing ``<config, value>`` entries, where ``config`` is the name
 323          of the model configuration keys and ``value`` is the value of the given configuration.
 324          """
 325  
 326          return self._model_config
 327  
 328  
 329  @deprecated("ResponsesAgent", "3.0.0")
 330  class ChatModel(PythonModel, metaclass=ABCMeta):
 331      """
 332      .. tip::
 333          Since MLflow 3.0.0, we recommend using
 334          :py:class:`ResponsesAgent <mlflow.pyfunc.ResponsesAgent>`
 335          instead of :py:class:`ChatModel <mlflow.pyfunc.ChatModel>` unless you need strict
 336          compatibility with the OpenAI ChatCompletion API.
 337  
 338      A subclass of :class:`~PythonModel` that makes it more convenient to implement models
 339      that are compatible with popular LLM chat APIs. By subclassing :class:`~ChatModel`,
 340      users can create MLflow models with a ``predict()`` method that is more convenient
 341      for chat tasks than the generic :class:`~PythonModel` API. ChatModels automatically
 342      define input/output signatures and an input example, so manually specifying these values
 343      when calling :func:`mlflow.pyfunc.save_model() <mlflow.pyfunc.save_model>` is not necessary.
 344  
 345      See the documentation of the ``predict()`` method below for details on that parameters and
 346      outputs that are expected by the ``ChatModel`` API.
 347  
 348      .. list-table::
 349          :header-rows: 1
 350          :widths: 20 40 40
 351  
 352          * -
 353            - ChatModel
 354            - PythonModel
 355          * - When to use
 356            - Use when you want to develop and deploy a conversational model with **standard** chat
 357              schema compatible with OpenAI spec.
 358            - Use when you want **full control** over the model's interface or customize every aspect
 359              of your model's behavior.
 360          * - Interface
 361            - **Fixed** to OpenAI's chat schema.
 362            - **Full control** over the model's input and output schema.
 363          * - Setup
 364            - **Quick**. Works out of the box for conversational applications, with pre-defined
 365                model signature and input example.
 366            - **Custom**. You need to define model signature or input example yourself.
 367          * - Complexity
 368            - **Low**. Standardized interface simplified model deployment and integration.
 369            - **High**. Deploying and integrating the custom PythonModel may not be straightforward.
 370                E.g., The model needs to handle Pandas DataFrames as MLflow converts input data to
 371                DataFrames before passing it to PythonModel.
 372  
 373      """
 374  
 375      _skip_type_hint_validation = True
 376  
 377      @abstractmethod
 378      def predict(
 379          self, context, messages: list[ChatMessage], params: ChatParams
 380      ) -> ChatCompletionResponse:
 381          """
 382          Evaluates a chat input and produces a chat output.
 383  
 384          Args:
 385              context: A :class:`~PythonModelContext` instance containing artifacts that the model
 386                  can use to perform inference.
 387              messages (List[:py:class:`ChatMessage <mlflow.types.llm.ChatMessage>`]):
 388                  A list of :py:class:`ChatMessage <mlflow.types.llm.ChatMessage>`
 389                  objects representing chat history.
 390              params (:py:class:`ChatParams <mlflow.types.llm.ChatParams>`):
 391                  A :py:class:`ChatParams <mlflow.types.llm.ChatParams>` object
 392                  containing various parameters used to modify model behavior during
 393                  inference.
 394  
 395          .. tip::
 396              Since MLflow 2.20.0, `context` parameter can be removed from `predict` function
 397              signature if it's not used.
 398              `def predict(self, messages: list[ChatMessage], params: ChatParams)` is valid.
 399  
 400          Returns:
 401              A :py:class:`ChatCompletionResponse <mlflow.types.llm.ChatCompletionResponse>`
 402              object containing the model's response(s), as well as other metadata.
 403          """
 404  
 405      def predict_stream(
 406          self, context, messages: list[ChatMessage], params: ChatParams
 407      ) -> Generator[ChatCompletionChunk, None, None]:
 408          """
 409          Evaluates a chat input and produces a chat output.
 410          Override this function to implement a real stream prediction.
 411  
 412          Args:
 413              context: A :class:`~PythonModelContext` instance containing artifacts that the model
 414                  can use to perform inference.
 415              messages (List[:py:class:`ChatMessage <mlflow.types.llm.ChatMessage>`]):
 416                  A list of :py:class:`ChatMessage <mlflow.types.llm.ChatMessage>`
 417                  objects representing chat history.
 418              params (:py:class:`ChatParams <mlflow.types.llm.ChatParams>`):
 419                  A :py:class:`ChatParams <mlflow.types.llm.ChatParams>` object
 420                  containing various parameters used to modify model behavior during
 421                  inference.
 422  
 423          .. tip::
 424              Since MLflow 2.20.0, `context` parameter can be removed from `predict_stream` function
 425              signature if it's not used.
 426              `def predict_stream(self, messages: list[ChatMessage], params: ChatParams)` is valid.
 427  
 428          Returns:
 429              A generator over :py:class:`ChatCompletionChunk <mlflow.types.llm.ChatCompletionChunk>`
 430              object containing the model's response(s), as well as other metadata.
 431          """
 432          raise NotImplementedError(
 433              "Streaming implementation not provided. Please override the "
 434              "`predict_stream` method on your model to generate streaming "
 435              "predictions"
 436          )
 437  
 438  
 439  class ChatAgent(PythonModel, metaclass=ABCMeta):
 440      """
 441      .. tip::
 442          Since MLflow 3.0.0, we recommend using
 443          :py:class:`ResponsesAgent <mlflow.pyfunc.ResponsesAgent>`
 444          instead of :py:class:`ChatAgent <mlflow.pyfunc.ChatAgent>`.
 445  
 446      **What is the ChatAgent Interface?**
 447  
 448      The ChatAgent interface is a chat schema specification that has been designed for authoring
 449      conversational agents. ChatAgent allows your agent to do the following:
 450  
 451      - Return multiple messages
 452      - Return intermediate steps for tool calling agents
 453      - Confirm tool calls
 454      - Support multi-agent scenarios
 455  
 456      ``ChatAgent`` should always be used when authoring an agent. We also recommend using
 457      ``ChatAgent`` instead of :py:class:`ChatModel <mlflow.pyfunc.ChatModel>` even for use cases
 458      like simple chat models (e.g. prompt-engineered LLMs), to give you the flexibility to support
 459      more agentic functionality in the future.
 460  
 461      The :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` schema is similar to,
 462      but not strictly compatible with the OpenAI ChatCompletion schema. ChatAgent adds additional
 463      functionality and diverges from OpenAI
 464      :py:class:`ChatCompletionRequest <mlflow.types.llm.ChatCompletionRequest>` in the following
 465      ways:
 466  
 467      - Adds an optional ``attachments`` attribute to every input/output message for tools and
 468        internal agent calls so they can return additional outputs such as visualizations and progress
 469        indicators
 470      - Adds a ``context`` attribute with a ``conversation_id`` and ``user_id`` attributes to enable
 471        modifying the behavior of the agent depending on the user querying the agent
 472      - Adds the ``custom_inputs`` attribute, an arbitrary ``dict[str, Any]`` to pass in any
 473        additional information to modify the agent's behavior
 474  
 475      The :py:class:`ChatAgentResponse <mlflow.types.agent.ChatAgentResponse>` schema diverges from
 476      :py:class:`ChatCompletionResponse <mlflow.types.llm.ChatCompletionResponse>` schema in the
 477      following ways:
 478  
 479      - Adds the ``custom_outputs`` key, an arbitrary ``dict[str, Any]`` to return any additional
 480        information
 481      - Allows multiple messages in the output, to improve the  display and evaluation of internal
 482        tool calls and inter-agent communication that led to the final answer.
 483  
 484      Here's an example of a :py:class:`ChatAgentResponse <mlflow.types.agent.ChatAgentResponse>`
 485      detailing a tool call:
 486  
 487      .. code-block:: python
 488  
 489          {
 490              "messages": [
 491                  {
 492                      "role": "assistant",
 493                      "content": "",
 494                      "id": "run-04b46401-c569-4a4a-933e-62e38d8f9647-0",
 495                      "tool_calls": [
 496                          {
 497                              "id": "call_15ca4fcc-ffa1-419a-8748-3bea34b9c043",
 498                              "type": "function",
 499                              "function": {
 500                                  "name": "generate_random_ints",
 501                                  "arguments": '{"min": 1, "max": 100, "size": 5}',
 502                              },
 503                          }
 504                      ],
 505                  },
 506                  {
 507                      "role": "tool",
 508                      "content": '{"content": "Generated array of 2 random ints in [1, 100]."',
 509                      "name": "generate_random_ints",
 510                      "id": "call_15ca4fcc-ffa1-419a-8748-3bea34b9c043",
 511                      "tool_call_id": "call_15ca4fcc-ffa1-419a-8748-3bea34b9c043",
 512                  },
 513                  {
 514                      "role": "assistant",
 515                      "content": "The new set of generated random numbers are: 93, 51, 12, 7, and 25",
 516                      "name": "llm",
 517                      "id": "run-70c7c738-739f-4ecd-ad18-0ae232df24e8-0",
 518                  },
 519              ],
 520              "custom_outputs": {"random_nums": [93, 51, 12, 7, 25]},
 521          }
 522  
 523      **Streaming Agent Output with ChatAgent**
 524  
 525      Please read the docstring of
 526      :py:func:`ChatAgent.predict_stream <mlflow.pyfunc.ChatAgent.predict_stream>`
 527      for more details on how to stream the output of your agent.
 528  
 529  
 530      **Authoring a ChatAgent**
 531  
 532      Authoring an agent using the ChatAgent  interface is a framework-agnostic way to create a model
 533      with a  standardized interface that is loggable with the MLflow pyfunc flavor, can be reused
 534      across clients, and is ready for serving workloads.
 535  
 536      To write your own agent, subclass ``ChatAgent``, implementing the ``predict`` and optionally
 537      ``predict_stream`` methods to define the non-streaming and streaming behavior of your agent. You
 538      can use any agent authoring framework - the only hard requirement is to implement the
 539      ``predict`` interface.
 540  
 541      .. code-block:: python
 542  
 543          def predict(
 544              self,
 545              messages: list[ChatAgentMessage],
 546              context: Optional[ChatContext] = None,
 547              custom_inputs: Optional[dict[str, Any]] = None,
 548          ) -> ChatAgentResponse: ...
 549  
 550      In addition to calling predict and predict_stream methods with an input matching their type
 551      hints, you can also pass a single input dict that matches the
 552      :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` schema for ease of testing.
 553  
 554      .. code-block:: python
 555  
 556          chat_agent = MyChatAgent()
 557          chat_agent.predict({
 558              "messages": [{"role": "user", "content": "What is 10 + 10?"}],
 559              "context": {"conversation_id": "123", "user_id": "456"},
 560          })
 561  
 562      See an example implementation of ``predict`` and ``predict_stream`` for a LangGraph agent in
 563      the :py:class:`ChatAgentState <mlflow.langchain.chat_agent_langgraph.ChatAgentState>`
 564      docstring.
 565  
 566      **Logging the ChatAgent**
 567  
 568      Since the landscape of LLM frameworks is constantly evolving and not every flavor can be
 569      natively supported by MLflow, we recommend the
 570      `Models-from-Code <https://mlflow.org/docs/latest/ml/model/models-from-code.html>`_ logging
 571      approach.
 572  
 573      .. code-block:: python
 574  
 575          with mlflow.start_run():
 576              logged_agent_info = mlflow.pyfunc.log_model(
 577                  name="agent",
 578                  python_model=os.path.join(os.getcwd(), "agent"),
 579                  # Add serving endpoints, tools, and vector search indexes here
 580                  resources=[],
 581              )
 582  
 583      After logging the model, you can query the model with a single dictionary with the
 584      :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` schema. Under the hood, it
 585      will be converted into the python objects expected by your ``predict`` and ``predict_stream``
 586      methods.
 587  
 588      .. code-block:: python
 589  
 590          loaded_model = mlflow.pyfunc.load_model(tmp_path)
 591          loaded_model.predict({
 592              "messages": [{"role": "user", "content": "What is 10 + 10?"}],
 593              "context": {"conversation_id": "123", "user_id": "456"},
 594          })
 595  
 596      To make logging ChatAgent models as easy as possible, MLflow has built in the following
 597      features:
 598  
 599      - Automatic Model Signature Inference
 600          - You do not need to set a signature when logging a ChatAgent
 601          - An input and output signature will be automatically set that adheres to the
 602            :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` and
 603            :py:class:`ChatAgentResponse <mlflow.types.agent.ChatAgentResponse>` schemas
 604      - Metadata
 605          - ``{"task": "agent/v2/chat"}`` will be automatically appended to any metadata that you may
 606            pass in when logging the model
 607      - Input Example
 608          - Providing an input example is optional, ``mlflow.types.agent.CHAT_AGENT_INPUT_EXAMPLE``
 609            will be provided by default
 610          - If you do provide an input example, ensure it's a dict with the
 611            :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` schema
 612  
 613          - .. code-block:: python
 614  
 615              input_example = {
 616                  "messages": [{"role": "user", "content": "What is MLflow?"}],
 617                  "context": {"conversation_id": "123", "user_id": "456"},
 618              }
 619  
 620      **Migrating from ChatModel to ChatAgent**
 621  
 622      To convert an existing ChatModel that takes in
 623      :py:class:`List[ChatMessage] <mlflow.types.llm.ChatMessage>` and
 624      :py:class:`ChatParams <mlflow.types.llm.ChatParams>` and outputs a
 625      :py:class:`ChatCompletionResponse <mlflow.types.llm.ChatCompletionResponse>`, do the following:
 626  
 627      - Subclass ``ChatAgent`` instead of ``ChatModel``
 628      - Move any functionality from your ``ChatModel``'s ``load_context`` implementation into the
 629        ``__init__`` method of your new ``ChatAgent``.
 630      - Use ``.model_dump()`` instead of ``.to_dict()`` when converting your model's inputs to
 631        dictionaries. Ex. ``[msg.model_dump() for msg in messages]`` instead of
 632        ``[msg.to_dict() for msg in messages]``
 633      - Return a :py:class:`ChatAgentResponse <mlflow.types.agent.ChatAgentResponse>` instead of a
 634        :py:class:`ChatCompletionResponse <mlflow.types.llm.ChatCompletionResponse>`
 635  
 636      For example, we can convert the ChatModel from the
 637      `Chat Model Intro <https://mlflow.org/docs/latest/llms/chat-model-intro/index.html#building-your-first-chatmodel>`_
 638      to a ChatAgent:
 639  
 640      .. code-block:: python
 641  
 642          class SimpleOllamaModel(ChatModel):
 643              def __init__(self):
 644                  self.model_name = "llama3.2:1b"
 645                  self.client = None
 646  
 647              def load_context(self, context):
 648                  self.client = ollama.Client()
 649  
 650              def predict(
 651                  self, context, messages: list[ChatMessage], params: ChatParams = None
 652              ) -> ChatCompletionResponse:
 653                  ollama_messages = [msg.to_dict() for msg in messages]
 654                  response = self.client.chat(model=self.model_name, messages=ollama_messages)
 655                  return ChatCompletionResponse(
 656                      choices=[{"index": 0, "message": response["message"]}],
 657                      model=self.model_name,
 658                  )
 659  
 660      .. code-block:: python
 661  
 662          class SimpleOllamaModel(ChatAgent):
 663              def __init__(self):
 664                  self.model_name = "llama3.2:1b"
 665                  self.client = None
 666                  self.client = ollama.Client()
 667  
 668              def predict(
 669                  self,
 670                  messages: list[ChatAgentMessage],
 671                  context: Optional[ChatContext] = None,
 672                  custom_inputs: Optional[dict[str, Any]] = None,
 673              ) -> ChatAgentResponse:
 674                  ollama_messages = self._convert_messages_to_dict(messages)
 675                  response = self.client.chat(model=self.model_name, messages=ollama_messages)
 676                  return ChatAgentResponse(**{"messages": [response["message"]]})
 677  
 678      **ChatAgent Connectors**
 679  
 680      MLflow provides convenience APIs for wrapping agents written in popular authoring frameworks
 681      with ChatAgent. See examples for:
 682  
 683      - LangGraph in the
 684        :py:class:`ChatAgentState <mlflow.langchain.chat_agent_langgraph.ChatAgentState>` docstring
 685      """
 686  
 687      _skip_type_hint_validation = True
 688  
 689      def __init_subclass__(cls, **kwargs) -> None:
 690          super().__init_subclass__(**kwargs)
 691          for attr_name in ("predict", "predict_stream"):
 692              attr = cls.__dict__.get(attr_name)
 693              if callable(attr):
 694                  setattr(
 695                      cls,
 696                      attr_name,
 697                      wrap_non_list_predict_pydantic(
 698                          attr,
 699                          ChatAgentRequest,
 700                          "Invalid dictionary input for a ChatAgent. Expected a dictionary with the "
 701                          "ChatAgentRequest schema.",
 702                          unpack=True,
 703                      ),
 704                  )
 705  
 706      def _convert_messages_to_dict(self, messages: list[ChatAgentMessage]):
 707          return [m.model_dump(exclude_none=True) for m in messages]
 708  
 709      # nb: We use `messages` instead of `model_input` so that the trace generated by default is
 710      # compatible with mlflow evaluate. We also want `custom_inputs` to be a top level key for
 711      # ease of use.
 712      @abstractmethod
 713      def predict(
 714          self,
 715          messages: list[ChatAgentMessage],
 716          context: ChatContext | None = None,
 717          custom_inputs: dict[str, Any] | None = None,
 718      ) -> ChatAgentResponse:
 719          """
 720          Given a ChatAgent input, returns a ChatAgent output. In addition to calling ``predict``
 721          with an input matching the type hints, you can also pass a single input dict that matches
 722          the :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>` schema for ease
 723          of testing.
 724  
 725          .. code-block:: python
 726  
 727              chat_agent = ChatAgent()
 728              chat_agent.predict({
 729                  "messages": [{"role": "user", "content": "What is 10 + 10?"}],
 730                  "context": {"conversation_id": "123", "user_id": "456"},
 731              })
 732  
 733          Args:
 734              messages (List[:py:class:`ChatAgentMessage <mlflow.types.agent.ChatAgentMessage>`]):
 735                  A list of :py:class:`ChatAgentMessage <mlflow.types.agent.ChatAgentMessage>`
 736                  objects representing the chat history.
 737              context (:py:class:`ChatContext <mlflow.types.agent.ChatContext>`):
 738                  A :py:class:`ChatContext <mlflow.types.agent.ChatContext>` object
 739                  containing conversation_id and user_id. **Optional** Defaults to None.
 740              custom_inputs (Dict[str, Any]):
 741                  An optional param to provide arbitrary additional inputs
 742                  to the model. The dictionary values must be JSON-serializable. **Optional**
 743                  Defaults to None.
 744  
 745          Returns:
 746              A :py:class:`ChatAgentResponse <mlflow.types.agent.ChatAgentResponse>` object containing
 747              the model's response, as well as other metadata.
 748          """
 749  
 750      # nb: We use `messages` instead of `model_input` so that the trace generated by default is
 751      # compatible with mlflow evaluate. We also want `custom_inputs` to be a top level key for
 752      # ease of use.
 753      def predict_stream(
 754          self,
 755          messages: list[ChatAgentMessage],
 756          context: ChatContext | None = None,
 757          custom_inputs: dict[str, Any] | None = None,
 758      ) -> Generator[ChatAgentChunk, None, None]:
 759          """
 760          Given a ChatAgent input, returns a generator containing streaming ChatAgent output chunks.
 761          In addition to calling ``predict_stream`` with an input matching the type hints, you can
 762          also pass a single input dict that matches the
 763          :py:class:`ChatAgentRequest <mlflow.types.agent.ChatAgentRequest>`
 764          schema for ease of testing.
 765  
 766          .. code-block:: python
 767  
 768              chat_agent = ChatAgent()
 769              for event in chat_agent.predict_stream({
 770                  "messages": [{"role": "user", "content": "What is 10 + 10?"}],
 771                  "context": {"conversation_id": "123", "user_id": "456"},
 772              }):
 773                  print(event)
 774  
 775          To support streaming the output of your agent, override this method in your subclass of
 776          ``ChatAgent``. When implementing ``predict_stream``, keep in mind the following
 777          requirements:
 778  
 779          - Ensure your implementation adheres to the ``predict_stream`` type signature. For example,
 780            streamed messages must be of the type
 781            :py:class:`ChatAgentChunk <mlflow.types.agent.ChatAgentChunk>`, where each chunk contains
 782            partial output from a single response message.
 783          - At most one chunk in a particular response can contain the ``custom_outputs`` key.
 784          - Chunks containing partial content of a single response message must have the same ``id``.
 785            The content field of the message and usage stats of the
 786            :py:class:`ChatAgentChunk <mlflow.types.agent.ChatAgentChunk>` should be aggregated by
 787            the consuming client. See the example below.
 788  
 789          .. code-block:: python
 790  
 791              {"delta": {"role": "assistant", "content": "Born", "id": "123"}}
 792              {"delta": {"role": "assistant", "content": " in", "id": "123"}}
 793              {"delta": {"role": "assistant", "content": " data", "id": "123"}}
 794  
 795  
 796          Args:
 797              messages (List[:py:class:`ChatAgentMessage <mlflow.types.agent.ChatAgentMessage>`]):
 798                  A list of :py:class:`ChatAgentMessage <mlflow.types.agent.ChatAgentMessage>`
 799                  objects representing the chat history.
 800              context (:py:class:`ChatContext <mlflow.types.agent.ChatContext>`):
 801                  A :py:class:`ChatContext <mlflow.types.agent.ChatContext>` object
 802                  containing conversation_id and user_id. **Optional** Defaults to None.
 803              custom_inputs (Dict[str, Any]):
 804                  An optional param to provide arbitrary additional inputs
 805                  to the model. The dictionary values must be JSON-serializable. **Optional**
 806                  Defaults to None.
 807  
 808          Returns:
 809              A generator over :py:class:`ChatAgentChunk <mlflow.types.agent.ChatAgentChunk>`
 810              objects containing the model's response(s), as well as other metadata.
 811          """
 812          raise NotImplementedError(
 813              "Streaming implementation not provided. Please override the "
 814              "`predict_stream` method on your model to generate streaming predictions"
 815          )
 816  
 817  
 818  def _check_compression_supported(compression):
 819      if compression in _COMPRESSION_INFO:
 820          return True
 821      if compression:
 822          supported = ", ".join(sorted(_COMPRESSION_INFO))
 823          mlflow.pyfunc._logger.warning(
 824              f"Unrecognized compression method '{compression}'"
 825              f"Please select one of: {supported}. Falling back to uncompressed storage/loading."
 826          )
 827      return False
 828  
 829  
 830  def _maybe_compress_cloudpickle_dump(python_model, path, compression):
 831      file_open = _COMPRESSION_INFO.get(compression, {}).get("open", open)
 832      with file_open(path, "wb") as out:
 833          cloudpickle.dump(python_model, out)
 834  
 835  
 836  def _maybe_decompress_cloudpickle_load(path, compression):
 837      _check_compression_supported(compression)
 838      file_open = _COMPRESSION_INFO.get(compression, {}).get("open", open)
 839      with file_open(path, "rb") as f:
 840          return cloudpickle.load(f)
 841  
 842  
 843  class ResponsesAgent(PythonModel, metaclass=ABCMeta):
 844      """
 845      A base class for creating ResponsesAgent models. It can be used as a wrapper around any
 846      agent framework to create an agent model that can be deployed to MLflow. Has a few helper
 847      methods to help create output items that can be a part of a ResponsesAgentResponse or
 848      ResponsesAgentStreamEvent.
 849  
 850      See https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro for more details.
 851      """
 852  
 853      _skip_type_hint_validation = True
 854  
 855      @staticmethod
 856      def responses_agent_output_reducer(
 857          chunks: list[ResponsesAgentStreamEvent | dict[str, Any]],
 858      ):
 859          return responses_agent_output_reducer(chunks)
 860  
 861      def __init_subclass__(cls, **kwargs) -> None:
 862          super().__init_subclass__(**kwargs)
 863          for attr_name in ("predict", "predict_stream"):
 864              attr = cls.__dict__.get(attr_name)
 865              if callable(attr):
 866                  # Only apply trace decorator if it is not already traced with mlflow.trace
 867                  if getattr(attr, "__mlflow_traced__", False):
 868                      mlflow.pyfunc._logger.warning(
 869                          f"You have manually traced {attr_name} with @mlflow.trace, but this is "
 870                          "unnecessary with ResponsesAgent subclasses. You can remove the "
 871                          "@mlflow.trace decorator and it will be automatically traced."
 872                      )
 873                      traced_attr = attr
 874                  else:
 875                      # Apply trace decorator first
 876                      if attr_name == "predict_stream":
 877                          traced_attr = mlflow.trace(
 878                              span_type=SpanType.AGENT,
 879                              output_reducer=cls.responses_agent_output_reducer,
 880                          )(attr)
 881                      else:
 882                          traced_attr = mlflow.trace(span_type=SpanType.AGENT)(attr)
 883  
 884                  # Then wrap with pydantic wrapper
 885                  wrapped_attr = wrap_non_list_predict_pydantic(
 886                      traced_attr,
 887                      ResponsesAgentRequest,
 888                      "Invalid dictionary input for a ResponsesAgent. "
 889                      "Expected a dictionary with the ResponsesRequest schema.",
 890                  )
 891                  setattr(cls, attr_name, wrapped_attr)
 892  
 893      @abstractmethod
 894      def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
 895          """
 896          Given a ResponsesAgentRequest, returns a ResponsesAgentResponse.
 897  
 898          You can see example implementations at
 899          https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#simple-chat-example
 900          and
 901          https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#tool-calling-example.
 902          """
 903  
 904      def predict_stream(
 905          self, request: ResponsesAgentRequest
 906      ) -> Generator[ResponsesAgentStreamEvent, None, None]:
 907          """
 908          Given a ResponsesAgentRequest, returns a generator of ResponsesAgentStreamEvent objects.
 909  
 910          See more details at
 911          https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#streaming-agent-output.
 912  
 913          You can see example implementations at
 914          https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#simple-chat-example
 915          and
 916          https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#tool-calling-example.
 917          """
 918          raise NotImplementedError(
 919              "Streaming implementation not provided. Please override the "
 920              "`predict_stream` method on your model to generate streaming predictions"
 921          )
 922  
 923      @staticmethod
 924      def create_text_delta(delta: str, item_id: str) -> dict[str, Any]:
 925          """Helper method to create a dictionary conforming to the text delta schema for
 926          streaming.
 927  
 928          Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#streaming-agent-output.
 929          """
 930          return create_text_delta(delta, item_id)
 931  
 932      @staticmethod
 933      def create_annotation_added(
 934          item_id: str, annotation: dict[str, Any], annotation_index: int | None = 0
 935      ) -> dict[str, Any]:
 936          return create_annotation_added(item_id, annotation, annotation_index)
 937  
 938      @staticmethod
 939      def create_text_output_item(
 940          text: str, id: str, annotations: list[dict[str, Any]] | None = None
 941      ) -> dict[str, Any]:
 942          """Helper method to create a dictionary conforming to the text output item schema.
 943  
 944          Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output.
 945  
 946          Args:
 947              text (str): The text to be outputted.
 948              id (str): The id of the output item.
 949              annotations (Optional[list[dict]]): The annotations of the output item.
 950          """
 951          return create_text_output_item(text, id, annotations)
 952  
 953      @staticmethod
 954      def create_reasoning_item(id: str, reasoning_text: str) -> dict[str, Any]:
 955          """Helper method to create a dictionary conforming to the reasoning item schema.
 956  
 957          Read more at https://www.mlflow.org/docs/latest/llms/responses-agent-intro/#creating-agent-output.
 958          """
 959          return create_reasoning_item(id, reasoning_text)
 960  
 961      @staticmethod
 962      def create_function_call_item(
 963          id: str, call_id: str, name: str, arguments: str
 964      ) -> dict[str, Any]:
 965          """Helper method to create a dictionary conforming to the function call item schema.
 966  
 967          Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output.
 968  
 969          Args:
 970              id (str): The id of the output item.
 971              call_id (str): The id of the function call.
 972              name (str): The name of the function to be called.
 973              arguments (str): The arguments to be passed to the function.
 974          """
 975          return create_function_call_item(id, call_id, name, arguments)
 976  
 977      @staticmethod
 978      def create_function_call_output_item(call_id: str, output: str) -> dict[str, Any]:
 979          """Helper method to create a dictionary conforming to the function call output item
 980          schema.
 981  
 982          Read more at https://mlflow.org/docs/latest/genai/flavors/responses-agent-intro#creating-agent-output.
 983  
 984          Args:
 985              call_id (str): The id of the function call.
 986              output (str): The output of the function call.
 987          """
 988          return create_function_call_output_item(call_id, output)
 989  
 990      @staticmethod
 991      def _responses_to_cc(message: dict[str, Any]) -> list[dict[str, Any]]:
 992          """Convert from a Responses API output item to  a list of ChatCompletion messages."""
 993          return responses_to_cc(message)
 994  
 995      @staticmethod
 996      def prep_msgs_for_cc_llm(
 997          responses_input: Sequence[dict[str, Any] | Message | OutputItem],
 998      ) -> list[dict[str, Any]]:
 999          "Convert from Responses input items to ChatCompletion dictionaries"
1000          return to_chat_completions_input(responses_input)
1001  
1002      @staticmethod
1003      def output_to_responses_items_stream(
1004          chunks: Iterator[dict[str, Any]], aggregator: list[dict[str, Any]] | None = None
1005      ) -> Generator[ResponsesAgentStreamEvent, None, None]:
1006          """
1007          For streaming, convert from various message format dicts to Responses output items,
1008          returning a generator of ResponsesAgentStreamEvent objects.
1009  
1010          If `aggregator` is provided, it will be extended with the aggregated output item dicts.
1011  
1012          For now, only handle a stream of Chat Completion chunks.
1013          """
1014          yield from output_to_responses_items_stream(chunks, aggregator)
1015  
1016  
1017  def _save_model_with_class_artifacts_params(
1018      path,
1019      python_model,
1020      signature=None,
1021      artifacts=None,
1022      conda_env=None,
1023      code_paths=None,
1024      mlflow_model=None,
1025      pip_requirements=None,
1026      extra_pip_requirements=None,
1027      model_config=None,
1028      streamable=None,
1029      model_code_path=None,
1030      infer_code_paths=False,
1031      uv_project_path=None,
1032      uv_groups=None,
1033      uv_extras=None,
1034  ):
1035      """
1036      Args:
1037          path: The path to which to save the Python model.
1038          python_model: An instance of a subclass of :class:`~PythonModel`. ``python_model``
1039              defines how the model loads artifacts and how it performs inference.
1040          artifacts: A dictionary containing ``<name, artifact_uri>`` entries. Remote artifact URIs
1041              are resolved to absolute filesystem paths, producing a dictionary of
1042              ``<name, absolute_path>`` entries, (e.g. {"file": "absolute_path"}).
1043              ``python_model`` can reference these resolved entries as the ``artifacts`` property
1044              of the ``context`` attribute. If ``<artifact_name, 'hf:/repo_id'>``(e.g.
1045              {"bert-tiny-model": "hf:/prajjwal1/bert-tiny"}) is provided, then the model can be
1046              fetched from huggingface hub using repo_id `prajjwal1/bert-tiny` directly. If ``None``,
1047              no artifacts are added to the model.
1048          conda_env: Either a dictionary representation of a Conda environment or the path to a Conda
1049              environment yaml file. If provided, this describes the environment this model should be
1050              run in. At minimum, it should specify the dependencies contained in
1051              :func:`get_default_conda_env()`. If ``None``, the default
1052              :func:`get_default_conda_env()` environment is added to the model.
1053          code_paths: A list of local filesystem paths to Python file dependencies (or directories
1054              containing file dependencies). These files are *prepended* to the system path before the
1055              model is loaded.
1056          mlflow_model: The model to which to add the ``mlflow.pyfunc`` flavor.
1057          model_config: The model configuration for the flavor. Model configuration is available
1058              during model loading time.
1059  
1060              .. Note:: Experimental: This parameter may change or be removed in a future release
1061                  without warning.
1062  
1063          model_code_path: The path to the code that is being logged as a PyFunc model. Can be used
1064              to load python_model when python_model is None.
1065  
1066              .. Note:: Experimental: This parameter may change or be removed in a future release
1067                  without warning.
1068  
1069          streamable: A boolean value indicating if the model supports streaming prediction,
1070                      If None, MLflow will try to inspect if the model supports streaming
1071                      by checking if `predict_stream` method exists. Default None.
1072      """
1073      # Capture original working directory for uv project detection
1074      # This must be done before any operations that might change cwd
1075      original_cwd = Path.cwd()
1076  
1077      if mlflow_model is None:
1078          mlflow_model = Model()
1079  
1080      custom_model_config_kwargs = {
1081          CONFIG_KEY_CLOUDPICKLE_VERSION: cloudpickle.__version__,
1082      }
1083      if callable(python_model):
1084          python_model = _FunctionPythonModel(func=python_model, signature=signature)
1085  
1086      saved_python_model_subpath = _SAVED_PYTHON_MODEL_SUBPATH
1087  
1088      compression = MLFLOW_LOG_MODEL_COMPRESSION.get()
1089      if compression:
1090          if _check_compression_supported(compression):
1091              custom_model_config_kwargs[CONFIG_KEY_COMPRESSION] = compression
1092              saved_python_model_subpath += _COMPRESSION_INFO[compression]["ext"]
1093          else:
1094              compression = None
1095  
1096      # If model_code_path is defined, we load the model into python_model, but we don't want to
1097      # pickle/save the python_model since the module won't be able to be imported.
1098      if not model_code_path:
1099          try:
1100              _maybe_compress_cloudpickle_dump(
1101                  python_model, os.path.join(path, saved_python_model_subpath), compression
1102              )
1103          except Exception as e:
1104              # error_code is INVALID_PARAMETER_VALUE but this is a model serialization failure
1105              raise MlflowException(
1106                  "Failed to serialize Python model. Please save the model into a python file "
1107                  "and use code-based logging method instead. See"
1108                  "https://mlflow.org/docs/latest/models.html#models-from-code for more information.",
1109                  error_code=INVALID_PARAMETER_VALUE,
1110                  error_class="MODEL_SERIALIZATION_FAILED",
1111              ) from e
1112  
1113          custom_model_config_kwargs[CONFIG_KEY_PYTHON_MODEL] = saved_python_model_subpath
1114  
1115      if artifacts:
1116          saved_artifacts_config = {}
1117          with TempDir() as tmp_artifacts_dir:
1118              saved_artifacts_dir_subpath = "artifacts"
1119              hf_prefix = "hf:/"
1120              for artifact_name, artifact_uri in artifacts.items():
1121                  if artifact_uri.startswith(hf_prefix):
1122                      try:
1123                          from huggingface_hub import snapshot_download
1124                      except ImportError as e:
1125                          raise MlflowException(
1126                              "Failed to import huggingface_hub. Please install huggingface_hub "
1127                              f"to log the model with artifact_uri {artifact_uri}. Error: {e}"
1128                          )
1129  
1130                      repo_id = artifact_uri[len(hf_prefix) :]
1131                      try:
1132                          snapshot_location = snapshot_download(
1133                              repo_id=repo_id,
1134                              local_dir=os.path.join(
1135                                  path, saved_artifacts_dir_subpath, artifact_name
1136                              ),
1137                              local_dir_use_symlinks=False,
1138                          )
1139                      except Exception as e:
1140                          raise MlflowException.invalid_parameter_value(
1141                              "Failed to download snapshot from Hugging Face Hub with artifact_uri: "
1142                              f"{artifact_uri}. Error: {e}"
1143                          )
1144                      saved_artifact_subpath = (
1145                          Path(snapshot_location).relative_to(Path(os.path.realpath(path))).as_posix()
1146                      )
1147                  else:
1148                      tmp_artifact_path = _download_artifact_from_uri(
1149                          artifact_uri=artifact_uri, output_path=tmp_artifacts_dir.path()
1150                      )
1151  
1152                      relative_path = (
1153                          Path(tmp_artifact_path)
1154                          .relative_to(Path(tmp_artifacts_dir.path()))
1155                          .as_posix()
1156                      )
1157  
1158                      saved_artifact_subpath = os.path.join(
1159                          saved_artifacts_dir_subpath, relative_path
1160                      )
1161  
1162                  saved_artifacts_config[artifact_name] = {
1163                      CONFIG_KEY_ARTIFACT_RELATIVE_PATH: saved_artifact_subpath,
1164                      CONFIG_KEY_ARTIFACT_URI: artifact_uri,
1165                  }
1166  
1167              shutil.move(tmp_artifacts_dir.path(), os.path.join(path, saved_artifacts_dir_subpath))
1168          custom_model_config_kwargs[CONFIG_KEY_ARTIFACTS] = saved_artifacts_config
1169  
1170      if streamable is None:
1171          streamable = python_model.__class__.predict_stream != PythonModel.predict_stream
1172  
1173      if model_code_path:
1174          loader_module = mlflow.pyfunc.loaders.code_model.__name__
1175      elif python_model:
1176          loader_module = _get_pyfunc_loader_module(python_model)
1177      else:
1178          raise MlflowException(
1179              "Either `python_model` or `model_code_path` must be provided to save the model.",
1180              error_code=INVALID_PARAMETER_VALUE,
1181          )
1182  
1183      mlflow.pyfunc.add_to_model(
1184          model=mlflow_model,
1185          loader_module=loader_module,
1186          code=None,
1187          conda_env=_CONDA_ENV_FILE_NAME,
1188          python_env=_PYTHON_ENV_FILE_NAME,
1189          model_config=model_config,
1190          streamable=streamable,
1191          model_code_path=model_code_path,
1192          **custom_model_config_kwargs,
1193      )
1194      if size := get_total_file_size(path):
1195          mlflow_model.model_size_bytes = size
1196      # `mlflow_model.save` must be called before _validate_infer_and_copy_code_paths as it
1197      # internally infers dependency, and MLmodel file is required to successfully load the model
1198      mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
1199  
1200      saved_code_subpath = _validate_infer_and_copy_code_paths(
1201          code_paths,
1202          path,
1203          infer_code_paths,
1204          mlflow.pyfunc.FLAVOR_NAME,
1205      )
1206      mlflow_model.flavors[mlflow.pyfunc.FLAVOR_NAME][mlflow.pyfunc.CODE] = saved_code_subpath
1207  
1208      # `mlflow_model.code` is updated, re-generate `MLmodel` file.
1209      mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
1210  
1211      if uv_project_path is not None:
1212          uv_source_dir = uv_project_path
1213      elif MLFLOW_UV_AUTO_DETECT.get():
1214          uv_source_dir = original_cwd
1215      else:
1216          uv_source_dir = None
1217  
1218      if conda_env is None:
1219          if pip_requirements is None:
1220              default_reqs = get_default_pip_requirements()
1221              extra_env_vars = (
1222                  _get_databricks_serverless_env_vars()
1223                  if is_in_databricks_serverless_runtime()
1224                  else None
1225              )
1226              # To ensure `_load_pyfunc` can successfully load the model during the dependency
1227              # inference, `mlflow_model.save` must be called beforehand to save an MLmodel file.
1228              inferred_reqs = mlflow.models.infer_pip_requirements(
1229                  path,
1230                  mlflow.pyfunc.FLAVOR_NAME,
1231                  fallback=default_reqs,
1232                  extra_env_vars=extra_env_vars,
1233                  uv_project_dir=uv_source_dir,
1234                  uv_groups=uv_groups,
1235                  uv_extras=uv_extras,
1236              )
1237              default_reqs = sorted(set(inferred_reqs).union(default_reqs))
1238          else:
1239              default_reqs = None
1240          conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
1241              default_reqs,
1242              pip_requirements,
1243              extra_pip_requirements,
1244          )
1245      else:
1246          conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
1247  
1248      with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f:
1249          yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
1250  
1251      # Save `constraints.txt` if necessary
1252      if pip_constraints:
1253          write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
1254  
1255      # Save `requirements.txt`
1256      write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
1257  
1258      # Copy uv project files (uv.lock and pyproject.toml) if detected
1259      if uv_source_dir is not None:
1260          copy_uv_project_files(dest_dir=path, source_dir=uv_source_dir)
1261  
1262      _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))
1263  
1264  
1265  def _load_context_model_and_signature(model_path: str, model_config: dict[str, Any] | None = None):
1266      pyfunc_config = _get_flavor_configuration(
1267          model_path=model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME
1268      )
1269      signature = mlflow.models.Model.load(model_path).signature
1270  
1271      if MODEL_CODE_PATH in pyfunc_config:
1272          conf_model_code_path = pyfunc_config.get(MODEL_CODE_PATH)
1273          model_code_path = os.path.join(model_path, os.path.basename(conf_model_code_path))
1274          python_model = _load_model_code_path(model_code_path, model_config)
1275  
1276          if callable(python_model):
1277              python_model = _FunctionPythonModel(python_model, signature=signature)
1278      else:
1279          if (
1280              not MLFLOW_ALLOW_PICKLE_DESERIALIZATION.get()
1281              and not is_in_databricks_runtime()
1282              and not is_in_databricks_model_serving_environment()
1283          ):
1284              raise MlflowException(
1285                  "Deserializing model using pickle is disallowed, but this model is saved "
1286                  "in cloudpickle format. The recommended way is to save the model as "
1287                  "models-from-code artifacts, see "
1288                  "https://mlflow.org/docs/latest/ml/model/models-from-code/ for details. "
1289                  "Another workaround is to set environment "
1290                  "variable 'MLFLOW_ALLOW_PICKLE_DESERIALIZATION' to 'true' to allow "
1291                  "deserializing model using pickle."
1292              )
1293          python_model_cloudpickle_version = pyfunc_config.get(CONFIG_KEY_CLOUDPICKLE_VERSION, None)
1294          if python_model_cloudpickle_version is None:
1295              mlflow.pyfunc._logger.warning(
1296                  "The version of CloudPickle used to save the model could not be found in the "
1297                  "MLmodel configuration"
1298              )
1299          elif python_model_cloudpickle_version != cloudpickle.__version__:
1300              # CloudPickle does not have a well-defined cross-version compatibility policy. Micro
1301              # version releases have been known to cause incompatibilities. Therefore, we match on
1302              # the full library version
1303              mlflow.pyfunc._logger.warning(
1304                  "The version of CloudPickle that was used to save the model, `CloudPickle %s`, "
1305                  "differs from the version of CloudPickle that is currently running, `CloudPickle "
1306                  "%s`, and may be incompatible",
1307                  python_model_cloudpickle_version,
1308                  cloudpickle.__version__,
1309              )
1310          python_model_compression = pyfunc_config.get(CONFIG_KEY_COMPRESSION, None)
1311  
1312          python_model_subpath = pyfunc_config.get(CONFIG_KEY_PYTHON_MODEL, None)
1313          if python_model_subpath is None:
1314              raise MlflowException("Python model path was not specified in the model configuration")
1315          python_model = _maybe_decompress_cloudpickle_load(
1316              os.path.join(model_path, python_model_subpath), python_model_compression
1317          )
1318  
1319      artifacts = {}
1320      for saved_artifact_name, saved_artifact_info in pyfunc_config.get(
1321          CONFIG_KEY_ARTIFACTS, {}
1322      ).items():
1323          artifacts[saved_artifact_name] = os.path.join(
1324              model_path, saved_artifact_info[CONFIG_KEY_ARTIFACT_RELATIVE_PATH]
1325          )
1326  
1327      context = PythonModelContext(artifacts=artifacts, model_config=model_config)
1328      python_model.load_context(context=context)
1329  
1330      return context, python_model, signature
1331  
1332  
1333  def _load_pyfunc(model_path: str, model_config: dict[str, Any] | None = None):
1334      context, python_model, signature = _load_context_model_and_signature(model_path, model_config)
1335      return _PythonModelPyfuncWrapper(
1336          python_model=python_model,
1337          context=context,
1338          signature=signature,
1339      )
1340  
1341  
1342  def _get_first_string_column(pdf):
1343      iter_string_columns = (col for col, val in pdf.iloc[0].items() if isinstance(val, str))
1344      return next(iter_string_columns, None)
1345  
1346  
1347  class _PythonModelPyfuncWrapper:
1348      """
1349      Wrapper class that creates a predict function such that
1350      predict(model_input: pd.DataFrame) -> model's output as pd.DataFrame (pandas DataFrame)
1351      """
1352  
1353      def __init__(self, python_model: PythonModel, context, signature):
1354          """
1355          Args:
1356              python_model: An instance of a subclass of :class:`~PythonModel`.
1357              context: A :class:`~PythonModelContext` instance containing artifacts that
1358                       ``python_model`` may use when performing inference.
1359              signature: :class:`~ModelSignature` instance describing model input and output.
1360          """
1361          self.python_model = python_model
1362          self.context = context
1363          self.signature = signature
1364  
1365      def _convert_input(self, model_input):
1366          hints = self.python_model.predict_type_hints
1367          # we still need this for backwards compatibility
1368          if isinstance(model_input, pd.DataFrame):
1369              if _is_list_str(hints.input):
1370                  first_string_column = _get_first_string_column(model_input)
1371                  if first_string_column is None:
1372                      raise MlflowException.invalid_parameter_value(
1373                          "Expected model input to contain at least one string column"
1374                      )
1375                  return model_input[first_string_column].tolist()
1376              elif _is_list_dict_str(hints.input):
1377                  if (
1378                      len(self.signature.inputs) == 1
1379                      and next(iter(self.signature.inputs)).name is None
1380                  ):
1381                      if first_string_column := _get_first_string_column(model_input):
1382                          return model_input[[first_string_column]].to_dict(orient="records")
1383                      if len(model_input.columns) == 1:
1384                          return model_input.to_dict("list")[0]
1385                  return model_input.to_dict(orient="records")
1386              elif isinstance(hints.input, type) and (
1387                  issubclass(hints.input, ChatCompletionRequest)
1388                  or issubclass(hints.input, SplitChatMessagesRequest)
1389              ):
1390                  # If the type hint is a RAG dataclass, we hydrate it
1391                  # If there are multiple rows, we should throw
1392                  if len(model_input) > 1:
1393                      raise MlflowException(
1394                          "Expected a single input for dataclass type hint, but got multiple rows"
1395                      )
1396                  # Since single input is expected, we take the first row
1397                  return _hydrate_dataclass(hints.input, model_input.iloc[0])
1398          return model_input
1399  
1400      def predict(self, model_input, params: dict[str, Any] | None = None):
1401          """
1402          Args:
1403              model_input: Model input data as one of dict, str, bool, bytes, float, int, str type.
1404              params: Additional parameters to pass to the model for inference.
1405  
1406          Returns:
1407              Model predictions as an iterator of chunks. The chunks in the iterator must be type of
1408              dict or string. Chunk dict fields are determined by the model implementation.
1409          """
1410          parameters = inspect.signature(self.python_model.predict).parameters
1411          kwargs = {}
1412          if "params" in parameters:
1413              kwargs["params"] = params
1414          else:
1415              _log_warning_if_params_not_in_predict_signature(_logger, params)
1416          if _is_context_in_predict_function_signature(parameters=parameters):
1417              return self.python_model.predict(
1418                  self.context, self._convert_input(model_input), **kwargs
1419              )
1420          else:
1421              return self.python_model.predict(self._convert_input(model_input), **kwargs)
1422  
1423      def predict_stream(self, model_input, params: dict[str, Any] | None = None):
1424          """
1425          Args:
1426              model_input: LLM Model single input.
1427              params: Additional parameters to pass to the model for inference.
1428  
1429          Returns:
1430              Streaming predictions.
1431          """
1432          parameters = inspect.signature(self.python_model.predict_stream).parameters
1433          kwargs = {}
1434          if "params" in parameters:
1435              kwargs["params"] = params
1436          else:
1437              _log_warning_if_params_not_in_predict_signature(_logger, params)
1438          if _is_context_in_predict_function_signature(parameters=parameters):
1439              return self.python_model.predict_stream(
1440                  self.context, self._convert_input(model_input), **kwargs
1441              )
1442          else:
1443              return self.python_model.predict_stream(self._convert_input(model_input), **kwargs)
1444  
1445  
1446  def _get_pyfunc_loader_module(python_model):
1447      if isinstance(python_model, ChatModel):
1448          return mlflow.pyfunc.loaders.chat_model.__name__
1449      elif isinstance(python_model, ChatAgent):
1450          return mlflow.pyfunc.loaders.chat_agent.__name__
1451      elif isinstance(python_model, ResponsesAgent):
1452          return mlflow.pyfunc.loaders.responses_agent.__name__
1453      return __name__
1454  
1455  
1456  class ModelFromDeploymentEndpoint(PythonModel):
1457      """
1458      A PythonModel wrapper for invoking an MLflow Deployments endpoint.
1459      This class is particularly used for running evaluation against an MLflow Deployments endpoint.
1460      """
1461  
1462      def __init__(self, endpoint, params):
1463          self.endpoint = endpoint
1464          self.params = params
1465  
1466      def predict(self, context, model_input: pd.DataFrame | dict[str, Any] | list[dict[str, Any]]):
1467          """
1468          Run prediction on the input data.
1469  
1470          Args:
1471              context: A :class:`~PythonModelContext` instance containing artifacts that the model
1472                  can use to perform inference.
1473              model_input: The input data for prediction, either of the following:
1474                  - Pandas DataFrame: If the default evaluator is used, input is a DF
1475                      that contains the multiple request payloads in a single column.
1476                  - A dictionary: If the model_type is "databricks-agents" and the
1477                      Databricks RAG evaluator is used, this PythonModel can be invoked
1478                      with a single dict corresponding to the ChatCompletionsRequest schema.
1479                  - A list of dictionaries: Currently we don't have any evaluator that
1480                      gives this input format, but we keep this for future use cases and
1481                      compatibility with normal pyfunc models.
1482  
1483          Return:
1484              The prediction result. The return type will be consistent with the model input type,
1485              e.g., if the input is a Pandas DataFrame, the return will be a Pandas Series.
1486          """
1487          if isinstance(model_input, dict):
1488              return self._predict_single(model_input)
1489          elif isinstance(model_input, list) and all(isinstance(data, dict) for data in model_input):
1490              return [self._predict_single(data) for data in model_input]
1491          elif isinstance(model_input, pd.DataFrame):
1492              if len(model_input.columns) != 1:
1493                  raise MlflowException(
1494                      f"The number of input columns must be 1, but got {model_input.columns}. "
1495                      "Multi-column input is not supported for evaluating an MLflow Deployments "
1496                      "endpoint. Please include the input text or payload in a single column.",
1497                      error_code=INVALID_PARAMETER_VALUE,
1498                  )
1499              input_column = model_input.columns[0]
1500  
1501              predictions = [self._predict_single(data) for data in model_input[input_column]]
1502              return pd.Series(predictions)
1503          else:
1504              raise MlflowException(
1505                  f"Invalid input data type: {type(model_input)}. The input data must be either "
1506                  "a Pandas DataFrame, a dictionary, or a list of dictionaries containing the "
1507                  "request payloads for evaluating an MLflow Deployments endpoint.",
1508                  error_code=INVALID_PARAMETER_VALUE,
1509              )
1510  
1511      def _predict_single(self, data: str | dict[str, Any]) -> dict[str, Any]:
1512          """
1513          Send a single prediction request to the MLflow Deployments endpoint.
1514  
1515          Args:
1516              data: The single input data for prediction. If the input data is a string, we will
1517                  construct the request payload from it. If the input data is a dictionary, we
1518                  will directly use it as the request payload.
1519  
1520          Returns:
1521              The prediction result from the MLflow Deployments endpoint as a dictionary.
1522          """
1523          from mlflow.metrics.genai.model_utils import call_deployments_api, get_endpoint_type
1524  
1525          endpoint_type = get_endpoint_type(f"endpoints:/{self.endpoint}")
1526  
1527          if isinstance(data, str):
1528              # If the input payload is string, MLflow needs to construct the JSON
1529              # payload based on the endpoint type. If the endpoint type is not
1530              # set on the endpoint, we will default to chat format.
1531              endpoint_type = endpoint_type or "llm/v1/chat"
1532              prediction = call_deployments_api(self.endpoint, data, self.params, endpoint_type)
1533          elif isinstance(data, dict):
1534              # If the input is dictionary, we assume the input is already in the
1535              # compatible format for the endpoint.
1536              prediction = call_deployments_api(self.endpoint, data, self.params, endpoint_type)
1537          else:
1538              raise MlflowException(
1539                  f"Invalid input data type: {type(data)}. The feature column of the evaluation "
1540                  "dataset must contain only strings or dictionaries containing the request "
1541                  "payload for evaluating an MLflow Deployments endpoint.",
1542                  error_code=INVALID_PARAMETER_VALUE,
1543              )
1544          return prediction