/ mlflow / pyfunc / __init__.py
__init__.py
   1  """
   2  The ``python_function`` model flavor serves as a default model interface for MLflow Python models.
   3  Any MLflow Python model is expected to be loadable as a ``python_function`` model.
   4  
   5  In addition, the ``mlflow.pyfunc`` module defines a generic :ref:`filesystem format
   6  <pyfunc-filesystem-format>` for Python models and provides utilities for saving to and loading from
   7  this format. The format is self contained in the sense that it includes all necessary information
   8  for anyone to load it and use it. Dependencies are either stored directly with the model or
   9  referenced via a Conda environment.
  10  
  11  The ``mlflow.pyfunc`` module also defines utilities for creating custom ``pyfunc`` models
  12  using frameworks and inference logic that may not be natively included in MLflow. See
  13  :ref:`pyfunc-create-custom`.
  14  
  15  .. _pyfunc-inference-api:
  16  
  17  *************
  18  Inference API
  19  *************
  20  
  21  Python function models are loaded as an instance of :py:class:`PyFuncModel
  22  <mlflow.pyfunc.PyFuncModel>`, which is an MLflow wrapper around the model implementation and model
  23  metadata (MLmodel file). You can score the model by calling the :py:func:`predict()
  24  <mlflow.pyfunc.PyFuncModel.predict>` method, which has the following signature::
  25  
  26    predict(
  27      model_input: [pandas.DataFrame, numpy.ndarray, scipy.sparse.(csc_matrix | csr_matrix),
  28      List[Any], Dict[str, Any], pyspark.sql.DataFrame]
  29    ) -> [numpy.ndarray | pandas.(Series | DataFrame) | List | Dict | pyspark.sql.DataFrame]
  30  
  31  All PyFunc models will support `pandas.DataFrame` as input and PyFunc deep learning models will
  32  also support tensor inputs in the form of Dict[str, numpy.ndarray] (named tensors) and
  33  `numpy.ndarrays` (unnamed tensors).
  34  
  35  Here are some examples of supported inference types, assuming we have the correct ``model`` object
  36  loaded.
  37  
  38  .. list-table::
  39      :widths: 30 70
  40      :header-rows: 1
  41      :class: wrap-table
  42  
  43      * - Input Type
  44        - Example
  45      * - ``pandas.DataFrame``
  46        -
  47          .. code-block:: python
  48  
  49              import pandas as pd
  50  
  51              x_new = pd.DataFrame(dict(x1=[1, 2, 3], x2=[4, 5, 6]))
  52              model.predict(x_new)
  53  
  54      * - ``numpy.ndarray``
  55        -
  56          .. code-block:: python
  57  
  58              import numpy as np
  59  
  60              x_new = np.array([[1, 4][2, 5], [3, 6]])
  61              model.predict(x_new)
  62  
  63      * - ``scipy.sparse.csc_matrix`` or ``scipy.sparse.csr_matrix``
  64        -
  65          .. code-block:: python
  66  
  67              import scipy
  68  
  69              x_new = scipy.sparse.csc_matrix([[1, 2, 3], [4, 5, 6]])
  70              model.predict(x_new)
  71  
  72              x_new = scipy.sparse.csr_matrix([[1, 2, 3], [4, 5, 6]])
  73              model.predict(x_new)
  74  
  75      * - python ``List``
  76        -
  77          .. code-block:: python
  78  
  79              x_new = [[1, 4], [2, 5], [3, 6]]
  80              model.predict(x_new)
  81  
  82      * - python ``Dict``
  83        -
  84          .. code-block:: python
  85  
  86              x_new = dict(x1=[1, 2, 3], x2=[4, 5, 6])
  87              model.predict(x_new)
  88  
  89      * - ``pyspark.sql.DataFrame``
  90        -
  91          .. code-block:: python
  92  
  93              from pyspark.sql import SparkSession
  94  
  95              spark = SparkSession.builder.getOrCreate()
  96  
  97              data = [(1, 4), (2, 5), (3, 6)]  # List of tuples
  98              x_new = spark.createDataFrame(data, ["x1", "x2"])  # Specify column name
  99              model.predict(x_new)
 100  
 101  .. _pyfunc-filesystem-format:
 102  
 103  *****************
 104  Filesystem format
 105  *****************
 106  
 107  The Pyfunc format is defined as a directory structure containing all required data, code, and
 108  configuration::
 109  
 110      ./dst-path/
 111          ./MLmodel: configuration
 112          <code>: code packaged with the model (specified in the MLmodel file)
 113          <data>: data packaged with the model (specified in the MLmodel file)
 114          <env>: Conda environment definition (specified in the MLmodel file)
 115  
 116  The directory structure may contain additional contents that can be referenced by the ``MLmodel``
 117  configuration.
 118  
 119  .. _pyfunc-model-config:
 120  
 121  MLModel configuration
 122  #####################
 123  
 124  A Python model contains an ``MLmodel`` file in **python_function** format in its root with the
 125  following parameters:
 126  
 127  - loader_module [required]:
 128           Python module that can load the model. Expected as module identifier
 129           e.g. ``mlflow.sklearn``, it will be imported using ``importlib.import_module``.
 130           The imported module must contain a function with the following signature::
 131  
 132            _load_pyfunc(path: string) -> <pyfunc model implementation>
 133  
 134           The path argument is specified by the ``data`` parameter and may refer to a file or
 135           directory. The model implementation is expected to be an object with a
 136           ``predict`` method with the following signature::
 137  
 138            predict(
 139              model_input: [pandas.DataFrame, numpy.ndarray,
 140              scipy.sparse.(csc_matrix | csr_matrix), List[Any], Dict[str, Any]],
 141              pyspark.sql.DataFrame
 142            ) -> [numpy.ndarray | pandas.(Series | DataFrame) | List | Dict | pyspark.sql.DataFrame]
 143  
 144  - code [optional]:
 145          Relative path to a directory containing the code packaged with this model.
 146          All files and directories inside this directory are added to the Python path
 147          prior to importing the model loader.
 148  
 149  - data [optional]:
 150           Relative path to a file or directory containing model data.
 151           The path is passed to the model loader.
 152  
 153  - env [optional]:
 154           Relative path to an exported Conda environment. If present this environment
 155           should be activated prior to running the model.
 156  
 157  - Optionally, any additional parameters necessary for interpreting the serialized model in
 158    ``pyfunc`` format.
 159  
 160  .. rubric:: Example
 161  
 162  ::
 163  
 164      tree example/sklearn_iris/mlruns/run1/outputs/linear-lr
 165  
 166  ::
 167  
 168    ├── MLmodel
 169    ├── code
 170    │   ├── sklearn_iris.py
 171 172    ├── data
 173    │   └── model.pkl
 174    └── mlflow_env.yml
 175  
 176  ::
 177  
 178      cat example/sklearn_iris/mlruns/run1/outputs/linear-lr/MLmodel
 179  
 180  ::
 181  
 182    python_function:
 183      code: code
 184      data: data/model.pkl
 185      loader_module: mlflow.sklearn
 186      env: mlflow_env.yml
 187      main: sklearn_iris
 188  
 189  .. _pyfunc-create-custom:
 190  
 191  **********************************
 192  Models From Code for Custom Models
 193  **********************************
 194  
 195  .. tip::
 196  
 197      MLflow 2.12.2 introduced the feature "models from code", which greatly simplifies the process
 198      of serializing and deploying custom models through the use of script serialization. It is
 199      strongly recommended to migrate custom model implementations to this new paradigm to avoid the
 200      limitations and complexity of serializing with cloudpickle.
 201      You can learn more about models from code within the
 202      `Models From Code Guide <../model/models-from-code.html>`_.
 203  
 204  The section below illustrates the process of using the legacy serializer for custom Pyfunc models.
 205  Models from code will provide a far simpler experience for logging of your models.
 206  
 207  ******************************
 208  Creating custom Pyfunc models
 209  ******************************
 210  
 211  MLflow's persistence modules provide convenience functions for creating models with the
 212  ``pyfunc`` flavor in a variety of machine learning frameworks (scikit-learn, Keras, Pytorch, and
 213  more); however, they do not cover every use case. For example, you may want to create an MLflow
 214  model with the ``pyfunc`` flavor using a framework that MLflow does not natively support.
 215  Alternatively, you may want to build an MLflow model that executes custom logic when evaluating
 216  queries, such as preprocessing and postprocessing routines. Therefore, ``mlflow.pyfunc``
 217  provides utilities for creating ``pyfunc`` models from arbitrary code and model data.
 218  
 219  The :meth:`save_model()` and :meth:`log_model()` methods are designed to support multiple workflows
 220  for creating custom ``pyfunc`` models that incorporate custom inference logic and artifacts
 221  that the logic may require.
 222  
 223  An `artifact` is a file or directory, such as a serialized model or a CSV. For example, a
 224  serialized TensorFlow graph is an artifact. An MLflow model directory is also an artifact.
 225  
 226  .. _pyfunc-create-custom-workflows:
 227  
 228  Workflows
 229  #########
 230  
 231  :meth:`save_model()` and :meth:`log_model()` support the following workflows:
 232  
 233  1. Programmatically defining a new MLflow model, including its attributes and artifacts.
 234  
 235     Given a set of artifact URIs, :meth:`save_model()` and :meth:`log_model()` can
 236     automatically download artifacts from their URIs and create an MLflow model directory.
 237  
 238     In this case, you must define a Python class which inherits from :class:`~PythonModel`,
 239     defining ``predict()`` and, optionally, ``load_context()``. An instance of this class is
 240     specified via the ``python_model`` parameter; it is automatically serialized and deserialized
 241     as a Python class, including all of its attributes.
 242  
 243  2. Interpreting pre-existing data as an MLflow model.
 244  
 245     If you already have a directory containing model data, :meth:`save_model()` and
 246     :meth:`log_model()` can import the data as an MLflow model. The ``data_path`` parameter
 247     specifies the local filesystem path to the directory containing model data.
 248  
 249     In this case, you must provide a Python module, called a `loader module`. The
 250     loader module defines a ``_load_pyfunc()`` method that performs the following tasks:
 251  
 252     - Load data from the specified ``data_path``. For example, this process may include
 253       deserializing pickled Python objects or models or parsing CSV files.
 254  
 255     - Construct and return a pyfunc-compatible model wrapper. As in the first
 256       use case, this wrapper must define a ``predict()`` method that is used to evaluate
 257       queries. ``predict()`` must adhere to the :ref:`pyfunc-inference-api`.
 258  
 259     The ``loader_module`` parameter specifies the name of your loader module.
 260  
 261     For an example loader module implementation, refer to the `loader module
 262     implementation in mlflow.sklearn <https://github.com/mlflow/mlflow/blob/
 263     74d75109aaf2975f5026104d6125bb30f4e3f744/mlflow/sklearn.py#L200-L205>`_.
 264  
 265  .. _pyfunc-create-custom-selecting-workflow:
 266  
 267  Which workflow is right for my use case?
 268  ########################################
 269  
 270  We consider the first workflow to be more user-friendly and generally recommend it for the
 271  following reasons:
 272  
 273  - It automatically resolves and collects specified model artifacts.
 274  
 275  - It automatically serializes and deserializes the ``python_model`` instance and all of
 276    its attributes, reducing the amount of user logic that is required to load the model
 277  
 278  - You can create Models using logic that is defined in the ``__main__`` scope. This allows
 279    custom models to be constructed in interactive environments, such as notebooks and the Python
 280    REPL.
 281  
 282  You may prefer the second, lower-level workflow for the following reasons:
 283  
 284  - Inference logic is always persisted as code, rather than a Python object. This makes logic
 285    easier to inspect and modify later.
 286  
 287  - If you have already collected all of your model data in a single location, the second
 288    workflow allows it to be saved in MLflow format directly, without enumerating constituent
 289    artifacts.
 290  
 291  ******************************************
 292  Function-based Model vs Class-based Model
 293  ******************************************
 294  
 295  When creating custom PyFunc models, you can choose between two different interfaces:
 296  a function-based model and a class-based model. In short, a function-based model is simply a
 297  python function that does not take additional params. The class-based model, on the other hand,
 298  is subclass of ``PythonModel`` that supports several required and optional
 299  methods. If your use case is simple and fits within a single predict function, a function-based
 300  approach is recommended. If you need more power, such as custom serialization, custom data
 301  processing, or to override additional methods, you should use the class-based implementation.
 302  
 303  Before looking at code examples, it's important to note that both methods are serialized via
 304  `cloudpickle <https://github.com/cloudpipe/cloudpickle>`_. cloudpickle can serialize Python
 305  functions, lambda functions, and locally defined classes and functions inside other functions. This
 306  makes cloudpickle especially useful for parallel and distributed computing where code objects need
 307  to be sent over network to execute on remote workers, which is a common deployment paradigm for
 308  MLflow.
 309  
 310  That said, cloudpickle has some limitations.
 311  
 312  - **Environment Dependency**: cloudpickle does not capture the full execution environment, so in
 313    MLflow we must pass ``pip_requirements``, ``extra_pip_requirements``, or an ``input_example``,
 314    the latter of which is used to infer environment dependencies. For more, refer to
 315    `the model dependency docs <https://mlflow.org/docs/latest/model/dependencies.html>`_.
 316  
 317  - **Object Support**: cloudpickle does not serialize objects outside of the Python data model.
 318    Some relevant examples include raw files and database connections. If your program depends on
 319    these, be sure to log ways to reference these objects along with your model.
 320  
 321  Function-based Model
 322  ####################
 323  If you're looking to serialize a simple python function without additional dependent methods, you
 324  can simply log a predict method via the keyword argument ``python_model``.
 325  
 326  .. note::
 327  
 328      Function-based model only supports a function with a single input argument. If you would like
 329      to pass more arguments or additional inference parameters, please use the class-based model
 330      below.
 331  
 332  .. code-block:: python
 333  
 334      import mlflow
 335      import pandas as pd
 336  
 337  
 338      # Define a simple function to log
 339      def predict(model_input):
 340          return model_input.apply(lambda x: x * 2)
 341  
 342  
 343      # Save the function as a model
 344      with mlflow.start_run():
 345          mlflow.pyfunc.log_model(name="model", python_model=predict, pip_requirements=["pandas"])
 346          run_id = mlflow.active_run().info.run_id
 347  
 348      # Load the model from the tracking server and perform inference
 349      model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model")
 350      x_new = pd.Series([1, 2, 3])
 351  
 352      prediction = model.predict(x_new)
 353      print(prediction)
 354  
 355  Class-based Model
 356  #################
 357  If you're looking to serialize a more complex object, for instance a class that handles
 358  preprocessing, complex prediction logic, or custom serialization, you should subclass the
 359  ``PythonModel`` class. MLflow has tutorials on building custom PyFunc models, as shown
 360  `here <https://mlflow.org/docs/latest/traditional-ml/creating-custom-pyfunc/index.html>`_,
 361  so instead of duplicating that information, in this example we'll recreate the above functionality
 362  to highlight the differences. Note that this PythonModel implementation is overly complex and
 363  we would recommend using the functional-based Model instead for this simple case.
 364  
 365  .. code-block:: python
 366  
 367      import mlflow
 368      import pandas as pd
 369  
 370  
 371      class MyModel(mlflow.pyfunc.PythonModel):
 372          def predict(self, context, model_input, params=None):
 373              return [x * 2 for x in model_input]
 374  
 375  
 376      # Save the function as a model
 377      with mlflow.start_run():
 378          mlflow.pyfunc.log_model(name="model", python_model=MyModel(), pip_requirements=["pandas"])
 379          run_id = mlflow.active_run().info.run_id
 380  
 381      # Load the model from the tracking server and perform inference
 382      model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model")
 383      x_new = pd.Series([1, 2, 3])
 384  
 385      print(f"Prediction:\n\t{model.predict(x_new)}")
 386  
 387  The primary difference between the this implementation and the function-based implementation above
 388  is that the predict method is wrapped with a class, has the ``self`` parameter,
 389  and has the ``params`` parameter that defaults to None. Note that function-based models don't
 390  support additional params.
 391  
 392  In summary, use the function-based Model when you have a simple function to serialize.
 393  If you need more power, use  the class-based model.
 394  """
 395  
 396  import collections
 397  import functools
 398  import hashlib
 399  import importlib
 400  import inspect
 401  import json
 402  import logging
 403  import os
 404  import shutil
 405  import signal
 406  import subprocess
 407  import sys
 408  import tempfile
 409  import threading
 410  import uuid
 411  from copy import deepcopy
 412  from pathlib import Path
 413  from typing import Any, Iterator, Tuple, Union
 414  from urllib.parse import urlparse
 415  
 416  import numpy as np
 417  import pandas
 418  import pydantic
 419  import yaml
 420  from packaging.version import Version
 421  
 422  import mlflow
 423  import mlflow.models.signature
 424  import mlflow.pyfunc.loaders
 425  import mlflow.pyfunc.model
 426  from mlflow.entities.model_registry.prompt import Prompt
 427  from mlflow.environment_variables import (
 428      _MLFLOW_IN_CAPTURE_MODULE_PROCESS,
 429      _MLFLOW_SPARK_UDF_SERVERLESS_SKIP_DBCONNECT_ARTIFACT,
 430      _MLFLOW_TESTING,
 431      MLFLOW_DISABLE_SCHEMA_DETAILS,
 432      MLFLOW_ENFORCE_STDIN_SCORING_SERVER_FOR_SPARK_UDF,
 433      MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR,
 434      MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT,
 435      MLFLOW_UV_AUTO_DETECT,
 436  )
 437  from mlflow.exceptions import MlflowException
 438  from mlflow.models import Model, ModelInputExample, ModelSignature
 439  from mlflow.models.auth_policy import AuthPolicy
 440  from mlflow.models.dependencies_schemas import (
 441      _clear_dependencies_schemas,
 442      _get_dependencies_schema_from_model,
 443      _get_dependencies_schemas,
 444  )
 445  from mlflow.models.flavor_backend_registry import get_flavor_backend
 446  from mlflow.models.model import (
 447      _DATABRICKS_FS_LOADER_MODULE,
 448      MLMODEL_FILE_NAME,
 449      MODEL_CODE_PATH,
 450      MODEL_CONFIG,
 451  )
 452  from mlflow.models.resources import Resource, _ResourceBuilder
 453  from mlflow.models.signature import (
 454      _extract_type_hints,
 455      _infer_signature_from_input_example,
 456      _infer_signature_from_type_hints,
 457  )
 458  from mlflow.models.utils import (
 459      PyFuncInput,
 460      PyFuncLLMOutputChunk,
 461      PyFuncLLMSingleInput,
 462      PyFuncOutput,
 463      _convert_llm_input_data,
 464      _enforce_params_schema,
 465      _enforce_schema,
 466      _load_model_code_path,
 467      _save_example,
 468      _split_input_data_and_params,
 469      _validate_and_get_model_code_path,
 470  )
 471  from mlflow.protos.databricks_pb2 import (
 472      BAD_REQUEST,
 473      INTERNAL_ERROR,
 474      INVALID_PARAMETER_VALUE,
 475      RESOURCE_DOES_NOT_EXIST,
 476  )
 477  from mlflow.protos.databricks_uc_registry_messages_pb2 import (
 478      Entity,
 479      Job,
 480      LineageHeaderInfo,
 481      Notebook,
 482  )
 483  from mlflow.pyfunc.context import Context, set_prediction_context
 484  from mlflow.pyfunc.dbconnect_artifact_cache import (
 485      DBConnectArtifactCache,
 486      archive_directory,
 487      extract_archive_to_dir,
 488  )
 489  from mlflow.pyfunc.model import (
 490      _DEFAULT_CHAT_AGENT_METADATA_TASK,
 491      _DEFAULT_CHAT_MODEL_METADATA_TASK,
 492      _DEFAULT_RESPONSES_AGENT_METADATA_TASK,
 493      ChatAgent,
 494      ChatModel,
 495      PythonModel,
 496      PythonModelContext,
 497      _FunctionPythonModel,
 498      _log_warning_if_params_not_in_predict_signature,
 499      _PythonModelPyfuncWrapper,
 500      get_default_conda_env,  # noqa: F401
 501      get_default_pip_requirements,
 502  )
 503  
 504  try:
 505      from mlflow.pyfunc.model import ResponsesAgent
 506  
 507      IS_RESPONSES_AGENT_AVAILABLE = True
 508  except ImportError:
 509      IS_RESPONSES_AGENT_AVAILABLE = False
 510  from mlflow.tracing.provider import trace_disabled
 511  from mlflow.tracing.utils import _try_get_prediction_context
 512  from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
 513  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
 514  from mlflow.types.agent import (
 515      CHAT_AGENT_INPUT_EXAMPLE,
 516      CHAT_AGENT_INPUT_SCHEMA,
 517      CHAT_AGENT_OUTPUT_SCHEMA,
 518      ChatAgentRequest,
 519      ChatAgentResponse,
 520  )
 521  from mlflow.types.llm import (
 522      CHAT_MODEL_INPUT_EXAMPLE,
 523      CHAT_MODEL_INPUT_SCHEMA,
 524      CHAT_MODEL_OUTPUT_SCHEMA,
 525      ChatCompletionResponse,
 526      ChatMessage,
 527      ChatParams,
 528  )
 529  from mlflow.types.type_hints import (
 530      _convert_dataframe_to_example_format,
 531      _is_example_valid_for_type_from_example,
 532      _is_type_hint_from_example,
 533      _signature_cannot_be_inferred_from_type_hint,
 534      model_validate,
 535  )
 536  from mlflow.utils import (
 537      PYTHON_VERSION,
 538      _is_in_ipython_notebook,
 539      check_port_connectivity,
 540      databricks_utils,
 541      find_free_port,
 542      get_major_minor_py_version,
 543  )
 544  from mlflow.utils import env_manager as _EnvManager
 545  from mlflow.utils._spark_utils import modified_environ
 546  from mlflow.utils.annotations import deprecated, developer_stable
 547  from mlflow.utils.databricks_utils import (
 548      _get_databricks_serverless_env_vars,
 549      get_dbconnect_udf_sandbox_info,
 550      is_databricks_connect,
 551      is_in_databricks_runtime,
 552      is_in_databricks_serverless_runtime,
 553      is_in_databricks_shared_cluster_runtime,
 554  )
 555  from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring
 556  from mlflow.utils.environment import (
 557      _CONDA_ENV_FILE_NAME,
 558      _CONSTRAINTS_FILE_NAME,
 559      _PYTHON_ENV_FILE_NAME,
 560      _REQUIREMENTS_FILE_NAME,
 561      _process_conda_env,
 562      _process_pip_requirements,
 563      _PythonEnv,
 564      _validate_env_arguments,
 565  )
 566  from mlflow.utils.file_utils import (
 567      _copy_file_or_tree,
 568      get_or_create_nfs_tmp_dir,
 569      get_or_create_tmp_dir,
 570      get_total_file_size,
 571      write_to,
 572  )
 573  from mlflow.utils.mlflow_tags import MLFLOW_MODEL_IS_EXTERNAL
 574  from mlflow.utils.model_utils import (
 575      _add_code_from_conf_to_system_path,
 576      _get_flavor_configuration,
 577      _get_flavor_configuration_from_ml_model_file,
 578      _get_overridden_pyfunc_model_config,
 579      _validate_and_copy_file_to_directory,
 580      _validate_and_get_model_config_from_file,
 581      _validate_and_prepare_target_save_path,
 582      _validate_infer_and_copy_code_paths,
 583      _validate_pyfunc_model_config,
 584  )
 585  from mlflow.utils.nfs_on_spark import get_nfs_cache_root_dir
 586  from mlflow.utils.requirements_utils import (
 587      _parse_requirements,
 588      warn_dependency_requirement_mismatches,
 589  )
 590  from mlflow.utils.spark_utils import is_spark_connect_mode
 591  from mlflow.utils.uv_utils import copy_uv_project_files
 592  from mlflow.utils.virtualenv import _get_python_env, _get_virtualenv_name
 593  from mlflow.utils.warnings_utils import color_warning
 594  
 595  try:
 596      from pyspark.sql import DataFrame as SparkDataFrame
 597  
 598      HAS_PYSPARK = True
 599  except ImportError:
 600      HAS_PYSPARK = False
 601  FLAVOR_NAME = "python_function"
 602  MAIN = "loader_module"
 603  CODE = "code"
 604  DATA = "data"
 605  ENV = "env"
 606  TASK = "task"
 607  
 608  _MODEL_DATA_SUBPATH = "data"
 609  _CHAT_PARAMS_WARNING_MESSAGE = (
 610      "Default values for temperature, n and stream in ChatParams will be removed in the "
 611      "next release. Specify them in the input example explicitly if needed."
 612  )
 613  _TYPE_FROM_EXAMPLE_ERROR_MESSAGE = (
 614      "Input example must be provided when using TypeFromExample as type hint. "
 615      "Fix this by passing `input_example` when logging your model. Check "
 616      "https://mlflow.org/docs/latest/model/python_model.html#typefromexample-type-hint-usage "
 617      "for more details."
 618  )
 619  
 620  
 621  class EnvType:
 622      CONDA = "conda"
 623      VIRTUALENV = "virtualenv"
 624  
 625      def __init__(self):
 626          raise NotImplementedError("This class is not meant to be instantiated.")
 627  
 628  
 629  PY_VERSION = "python_version"
 630  
 631  _logger = logging.getLogger(__name__)
 632  
 633  
 634  def add_to_model(
 635      model,
 636      loader_module,
 637      data=None,
 638      code=None,
 639      conda_env=None,
 640      python_env=None,
 641      model_config=None,
 642      model_code_path=None,
 643      **kwargs,
 644  ):
 645      """
 646      Add a ``pyfunc`` spec to the model configuration.
 647  
 648      Defines ``pyfunc`` configuration schema. Caller can use this to create a valid ``pyfunc`` model
 649      flavor out of an existing directory structure. For example, other model flavors can use this to
 650      specify how to use their output as a ``pyfunc``.
 651  
 652      NOTE:
 653  
 654          All paths are relative to the exported model root directory.
 655  
 656      Args:
 657          model: Existing model.
 658          loader_module: The module to be used to load the model.
 659          data: Path to the model data.
 660          code: Path to the code dependencies.
 661          conda_env: Conda environment.
 662          python_env: Python environment.
 663          model_config: The model configuration to apply to the model. This configuration
 664              is available during model loading.
 665  
 666              .. Note:: Experimental: This parameter may change or be removed in a future
 667                  release without warning.
 668  
 669          model_code_path: Path to the model code.
 670          kwargs: Additional key-value pairs to include in the ``pyfunc`` flavor specification.
 671                  Values must be YAML-serializable.
 672  
 673      Returns:
 674          Updated model configuration.
 675      """
 676      params = deepcopy(kwargs)
 677      params[MAIN] = loader_module
 678      params[PY_VERSION] = PYTHON_VERSION
 679      if code:
 680          params[CODE] = code
 681      if data:
 682          params[DATA] = data
 683      if conda_env or python_env:
 684          params[ENV] = {}
 685          if conda_env:
 686              params[ENV][EnvType.CONDA] = conda_env
 687          if python_env:
 688              params[ENV][EnvType.VIRTUALENV] = python_env
 689      if model_config:
 690          params[MODEL_CONFIG] = model_config
 691      if model_code_path:
 692          params[MODEL_CODE_PATH] = model_code_path
 693      return model.add_flavor(FLAVOR_NAME, **params)
 694  
 695  
 696  def _extract_conda_env(env):
 697      # In MLflow < 2.0.0, the 'env' field in a pyfunc configuration is a string containing the path
 698      # to a conda.yaml file.
 699      return env if isinstance(env, str) else env[EnvType.CONDA]
 700  
 701  
 702  def _load_model_env(path):
 703      """
 704      Get ENV file string from a model configuration stored in Python Function format.
 705      Returned value is a model-relative path to a Conda Environment file,
 706      or None if none was specified at model save time
 707      """
 708      return _get_flavor_configuration(model_path=path, flavor_name=FLAVOR_NAME).get(ENV, None)
 709  
 710  
 711  def _validate_params(params, model_metadata):
 712      if hasattr(model_metadata, "get_params_schema"):
 713          params_schema = model_metadata.get_params_schema()
 714          return _enforce_params_schema(params, params_schema)
 715      if params:
 716          raise MlflowException.invalid_parameter_value(
 717              "This model was not logged with a params schema and does not support "
 718              "providing the params argument."
 719              "Please log the model with mlflow >= 2.6.0 and specify a params schema.",
 720          )
 721      return
 722  
 723  
 724  def _validate_prediction_input(data: PyFuncInput, params, input_schema, params_schema, flavor=None):
 725      """
 726      Internal helper function to transform and validate input data and params for prediction.
 727      Any additional transformation logics related to input data and params should be added here.
 728      """
 729      if input_schema is not None:
 730          try:
 731              data = _enforce_schema(data, input_schema, flavor)
 732          except Exception as e:
 733              if MLFLOW_DISABLE_SCHEMA_DETAILS.get():
 734                  message = "Failed to enforce model input schema. Please check your input data."
 735              else:
 736                  # Include error in message for backwards compatibility
 737                  message = (
 738                      f"Failed to enforce schema of data '{data}' "
 739                      f"with schema '{input_schema}'. "
 740                      f"Error: {e}"
 741                  )
 742              # error_code is INVALID_PARAMETER_VALUE but this is a schema enforcement failure
 743              raise MlflowException.invalid_parameter_value(
 744                  message, error_class="SCHEMA_ENFORCEMENT_FAILED"
 745              )
 746      params = _enforce_params_schema(params, params_schema)
 747      if HAS_PYSPARK and isinstance(data, SparkDataFrame):
 748          _logger.warning(
 749              "Input data is a Spark DataFrame. Note that behaviour for "
 750              "Spark DataFrames is model dependent."
 751          )
 752      return data, params
 753  
 754  
 755  class PyFuncModel:
 756      """
 757      MLflow 'python function' model.
 758  
 759      Wrapper around model implementation and metadata. This class is not meant to be constructed
 760      directly. Instead, instances of this class are constructed and returned from
 761      :py:func:`load_model() <mlflow.pyfunc.load_model>`.
 762  
 763      ``model_impl`` can be any Python object that implements the `Pyfunc interface
 764      <https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#pyfunc-inference-api>`_, and is
 765      returned by invoking the model's ``loader_module``.
 766  
 767      ``model_meta`` contains model metadata loaded from the MLmodel file.
 768      """
 769  
 770      def __init__(
 771          self,
 772          model_meta: Model,
 773          model_impl: Any,
 774          predict_fn: str = "predict",
 775          predict_stream_fn: str | None = None,
 776          model_id: str | None = None,
 777      ):
 778          if not hasattr(model_impl, predict_fn):
 779              raise MlflowException(f"Model implementation is missing required {predict_fn} method.")
 780          if not model_meta:
 781              raise MlflowException("Model is missing metadata.")
 782          self._model_meta = model_meta
 783          self.__model_impl = model_impl
 784          self._predict_fn = getattr(model_impl, predict_fn)
 785          if predict_stream_fn:
 786              if not hasattr(model_impl, predict_stream_fn):
 787                  raise MlflowException(
 788                      f"Model implementation is missing required {predict_stream_fn} method."
 789                  )
 790              self._predict_stream_fn = getattr(model_impl, predict_stream_fn)
 791          else:
 792              self._predict_stream_fn = None
 793          self._model_id = model_id
 794          self._input_example = None
 795  
 796      @property
 797      @developer_stable
 798      def _model_impl(self) -> Any:
 799          """
 800          The underlying model implementation object.
 801  
 802          NOTE: This is a stable developer API.
 803          """
 804          return self.__model_impl
 805  
 806      @property
 807      def model_id(self) -> str | None:
 808          """
 809          The model ID of the model.
 810  
 811          Returns:
 812              The model ID of the model.
 813          """
 814          return self._model_id
 815  
 816      def _update_dependencies_schemas_in_prediction_context(self, context: Context):
 817          if self._model_meta and self._model_meta.metadata:
 818              dependencies_schemas = self._model_meta.metadata.get("dependencies_schemas", {})
 819              context.update(
 820                  dependencies_schemas={
 821                      dependency: json.dumps(schema)
 822                      for dependency, schema in dependencies_schemas.items()
 823                  }
 824              )
 825  
 826      @property
 827      def input_example(self) -> Any | None:
 828          """
 829          The input example provided when the model was saved.
 830          """
 831          return self._input_example
 832  
 833      @input_example.setter
 834      def input_example(self, value: Any) -> None:
 835          self._input_example = value
 836  
 837      def predict(self, data: PyFuncInput, params: dict[str, Any] | None = None) -> PyFuncOutput:
 838          context = _try_get_prediction_context() or Context()
 839          with set_prediction_context(context):
 840              if schema := _get_dependencies_schema_from_model(self._model_meta):
 841                  context.update(**schema)
 842  
 843              if self.model_id:
 844                  context.update(model_id=self.model_id)
 845              return self._predict(data, params)
 846  
 847      def _predict(self, data: PyFuncInput, params: dict[str, Any] | None = None) -> PyFuncOutput:
 848          """
 849          Generates model predictions.
 850  
 851          If the model contains signature, enforce the input schema first before calling the model
 852          implementation with the sanitized input. If the pyfunc model does not include model schema,
 853          the input is passed to the model implementation as is. See `Model Signature Enforcement
 854          <https://www.mlflow.org/docs/latest/models.html#signature-enforcement>`_ for more details.
 855  
 856          Args:
 857              data: LLM Model single input as one of pandas.DataFrame, numpy.ndarray,
 858                  scipy.sparse.(csc_matrix | csr_matrix), List[Any], or
 859                  Dict[str, numpy.ndarray].
 860                  For model signatures with tensor spec inputs
 861                  (e.g. the Tensorflow core / Keras model), the input data type must be one of
 862                  `numpy.ndarray`, `List[numpy.ndarray]`, `Dict[str, numpy.ndarray]` or
 863                  `pandas.DataFrame`. If data is of `pandas.DataFrame` type and the model
 864                  contains a signature with tensor spec inputs, the corresponding column values
 865                  in the pandas DataFrame will be reshaped to the required shape with 'C' order
 866                  (i.e. read / write the elements using C-like index order), and DataFrame
 867                  column values will be cast as the required tensor spec type. For Pyspark
 868                  DataFrame inputs, MLflow will only enforce the schema on a subset
 869                  of the data rows.
 870              params: Additional parameters to pass to the model for inference.
 871  
 872          Returns:
 873              Model predictions as one of pandas.DataFrame, pandas.Series, numpy.ndarray or list.
 874          """
 875          # fetch the schema from metadata to avoid signature change after model is loaded
 876          self.input_schema = self.metadata.get_input_schema()
 877          self.params_schema = self.metadata.get_params_schema()
 878          # signature can only be inferred from type hints if the model is PythonModel
 879          if self.metadata._is_signature_from_type_hint():
 880              # we don't need to validate on data as data validation
 881              # will be done during PythonModel's predict call
 882              params = _enforce_params_schema(params, self.params_schema)
 883          else:
 884              data, params = _validate_prediction_input(
 885                  data, params, self.input_schema, self.params_schema, self.loader_module
 886              )
 887              if (
 888                  isinstance(data, pandas.DataFrame)
 889                  and self.metadata._is_type_hint_from_example()
 890                  and self.input_example is not None
 891              ):
 892                  data = _convert_dataframe_to_example_format(data, self.input_example)
 893          params_arg = inspect.signature(self._predict_fn).parameters.get("params")
 894          if params_arg and params_arg.kind != inspect.Parameter.VAR_KEYWORD:
 895              return self._predict_fn(data, params=params)
 896  
 897          _log_warning_if_params_not_in_predict_signature(_logger, params)
 898          return self._predict_fn(data)
 899  
 900      def predict_stream(
 901          self, data: PyFuncLLMSingleInput, params: dict[str, Any] | None = None
 902      ) -> Iterator[PyFuncLLMOutputChunk]:
 903          context = _try_get_prediction_context() or Context()
 904  
 905          if schema := _get_dependencies_schema_from_model(self._model_meta):
 906              context.update(**schema)
 907  
 908          if self.model_id:
 909              context.update(model_id=self.model_id)
 910  
 911          # NB: The prediction context must be applied during iterating over the stream,
 912          # hence, simply wrapping the self._predict_stream call with the context manager
 913          # is not sufficient.
 914          def _gen_with_context(*args, **kwargs):
 915              with set_prediction_context(context):
 916                  yield from self._predict_stream(*args, **kwargs)
 917  
 918          return _gen_with_context(data, params)
 919  
 920      def _predict_stream(
 921          self, data: PyFuncLLMSingleInput, params: dict[str, Any] | None = None
 922      ) -> Iterator[PyFuncLLMOutputChunk]:
 923          """
 924          Generates streaming model predictions. Only LLM supports this method.
 925  
 926          If the model contains signature, enforce the input schema first before calling the model
 927          implementation with the sanitized input. If the pyfunc model does not include model schema,
 928          the input is passed to the model implementation as is. See `Model Signature Enforcement
 929          <https://www.mlflow.org/docs/latest/models.html#signature-enforcement>`_ for more details.
 930  
 931          Args:
 932              data: LLM Model single input as one of dict, str, bool, bytes, float, int, str type.
 933              params: Additional parameters to pass to the model for inference.
 934  
 935          Returns:
 936              Model predictions as an iterator of chunks. The chunks in the iterator must be type of
 937              dict or string. Chunk dict fields are determined by the model implementation.
 938          """
 939  
 940          if self._predict_stream_fn is None:
 941              raise MlflowException("This model does not support predict_stream method.")
 942  
 943          self.input_schema = self.metadata.get_input_schema()
 944          self.params_schema = self.metadata.get_params_schema()
 945          data, params = _validate_prediction_input(
 946              data, params, self.input_schema, self.params_schema, self.loader_module
 947          )
 948          data = _convert_llm_input_data(data)
 949          if isinstance(data, list):
 950              # `predict_stream` only accepts single input.
 951              # but `enforce_schema` might convert single input into a list like `[single_input]`
 952              # so extract the first element in the list.
 953              if len(data) != 1:
 954                  raise MlflowException(
 955                      f"'predict_stream' requires single input, but it got input data {data}"
 956                  )
 957              data = data[0]
 958  
 959          if "params" in inspect.signature(self._predict_stream_fn).parameters:
 960              return self._predict_stream_fn(data, params=params)
 961  
 962          _log_warning_if_params_not_in_predict_signature(_logger, params)
 963          return self._predict_stream_fn(data)
 964  
 965      def unwrap_python_model(self):
 966          """
 967          Unwrap the underlying Python model object.
 968  
 969          This method is useful for accessing custom model functions, while still being able to
 970          leverage the MLflow designed workflow through the `predict()` method.
 971  
 972          Returns:
 973              The underlying wrapped model object
 974  
 975          .. code-block:: python
 976              :test:
 977              :caption: Example
 978  
 979              import mlflow
 980  
 981  
 982              # define a custom model
 983              class MyModel(mlflow.pyfunc.PythonModel):
 984                  def predict(self, context, model_input, params=None):
 985                      return self.my_custom_function(model_input, params)
 986  
 987                  def my_custom_function(self, model_input, params=None):
 988                      # do something with the model input
 989                      return 0
 990  
 991  
 992              some_input = 1
 993              # save the model
 994              with mlflow.start_run():
 995                  model_info = mlflow.pyfunc.log_model(name="model", python_model=MyModel())
 996  
 997              # load the model
 998              loaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
 999              print(type(loaded_model))  # <class 'mlflow.pyfunc.model.PyFuncModel'>
1000              unwrapped_model = loaded_model.unwrap_python_model()
1001              print(type(unwrapped_model))  # <class '__main__.MyModel'>
1002  
1003              # does not work, only predict() is exposed
1004              # print(loaded_model.my_custom_function(some_input))
1005              print(unwrapped_model.my_custom_function(some_input))  # works
1006              print(loaded_model.predict(some_input))  # works
1007  
1008              # works, but None is needed for context arg
1009              print(unwrapped_model.predict(None, some_input))
1010          """
1011          try:
1012              python_model = self._model_impl.python_model
1013              if python_model is None:
1014                  raise AttributeError("Expected python_model attribute not to be None.")
1015          except AttributeError as e:
1016              raise MlflowException("Unable to retrieve base model object from pyfunc.") from e
1017          return python_model
1018  
1019      def __eq__(self, other):
1020          if not isinstance(other, PyFuncModel):
1021              return False
1022          return self._model_meta == other._model_meta
1023  
1024      @property
1025      def metadata(self) -> Model:
1026          """Model metadata."""
1027          if self._model_meta is None:
1028              raise MlflowException("Model is missing metadata.")
1029          return self._model_meta
1030  
1031      @property
1032      def model_config(self):
1033          """Model's flavor configuration"""
1034          return self._model_meta.flavors[FLAVOR_NAME].get(MODEL_CONFIG, {})
1035  
1036      @property
1037      def loader_module(self):
1038          """Model's flavor configuration"""
1039          if self._model_meta.flavors.get(FLAVOR_NAME) is None:
1040              return None
1041          return self._model_meta.flavors[FLAVOR_NAME].get(MAIN)
1042  
1043      def __repr__(self):
1044          info = {}
1045          if self._model_meta is not None:
1046              if hasattr(self._model_meta, "run_id") and self._model_meta.run_id is not None:
1047                  info["run_id"] = self._model_meta.run_id
1048              if (
1049                  hasattr(self._model_meta, "artifact_path")
1050                  and self._model_meta.artifact_path is not None
1051              ):
1052                  info["artifact_path"] = self._model_meta.artifact_path
1053              info["flavor"] = self._model_meta.flavors[FLAVOR_NAME]["loader_module"]
1054          return yaml.safe_dump({"mlflow.pyfunc.loaded_model": info}, default_flow_style=False)
1055  
1056      def get_raw_model(self):
1057          """
1058          Get the underlying raw model if the model wrapper implemented `get_raw_model` function.
1059          """
1060          if hasattr(self._model_impl, "get_raw_model"):
1061              return self._model_impl.get_raw_model()
1062          raise NotImplementedError("`get_raw_model` is not implemented by the underlying model")
1063  
1064  
1065  def _get_pip_requirements_from_model_path(model_path: str):
1066      req_file_path = os.path.join(model_path, _REQUIREMENTS_FILE_NAME)
1067      if not os.path.exists(req_file_path):
1068          return []
1069  
1070      return [req.req_str for req in _parse_requirements(req_file_path, is_constraint=False)]
1071  
1072  
1073  @trace_disabled  # Suppress traces while loading model
1074  def load_model(
1075      model_uri: str,
1076      suppress_warnings: bool = False,
1077      dst_path: str | None = None,
1078      model_config: str | Path | dict[str, Any] | None = None,
1079  ) -> PyFuncModel:
1080      """
1081      Load a model stored in Python function format.
1082  
1083      Args:
1084          model_uri: The location, in URI format, of the MLflow model. For example:
1085  
1086              - ``/Users/me/path/to/local/model``
1087              - ``relative/path/to/local/model``
1088              - ``s3://my_bucket/path/to/model``
1089              - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
1090              - ``models:/<model_name>/<model_version>``
1091              - ``models:/<model_name>/<stage>``
1092              - ``mlflow-artifacts:/path/to/model``
1093  
1094              For more information about supported URI schemes, see
1095              `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
1096              artifact-locations>`_.
1097          suppress_warnings: If ``True``, non-fatal warning messages associated with the model
1098              loading process will be suppressed. If ``False``, these warning messages will be
1099              emitted.
1100          dst_path: The local filesystem path to which to download the model artifact.
1101              This directory must already exist. If unspecified, a local output
1102              path will be created.
1103          model_config: The model configuration to apply to the model. The configuration will
1104              be available as the ``model_config`` property of the ``context`` parameter
1105              in :func:`PythonModel.load_context() <mlflow.pyfunc.PythonModel.load_context>`
1106              and :func:`PythonModel.predict() <mlflow.pyfunc.PythonModel.predict>`.
1107              The configuration can be passed as a file path, or a dict with string keys.
1108  
1109              .. Note:: Experimental: This parameter may change or be removed in a future
1110                  release without warning.
1111      """
1112  
1113      lineage_header_info = None
1114      if (
1115          not _MLFLOW_IN_CAPTURE_MODULE_PROCESS.get()
1116      ) and databricks_utils.is_in_databricks_runtime():
1117          entity_list = []
1118          # Get notebook id and job id, pack them into lineage_header_info
1119          if databricks_utils.is_in_databricks_notebook() and (
1120              notebook_id := databricks_utils.get_notebook_id()
1121          ):
1122              notebook_entity = Notebook(id=notebook_id)
1123              entity_list.append(Entity(notebook=notebook_entity))
1124  
1125          if databricks_utils.is_in_databricks_job() and (job_id := databricks_utils.get_job_id()):
1126              job_entity = Job(id=job_id)
1127              entity_list.append(Entity(job=job_entity))
1128  
1129          lineage_header_info = LineageHeaderInfo(entities=entity_list) if entity_list else None
1130  
1131      local_path = _download_artifact_from_uri(
1132          artifact_uri=model_uri, output_path=dst_path, lineage_header_info=lineage_header_info
1133      )
1134  
1135      if not suppress_warnings:
1136          model_requirements = _get_pip_requirements_from_model_path(local_path)
1137          warn_dependency_requirement_mismatches(model_requirements)
1138  
1139      model_meta = Model.load(os.path.join(local_path, MLMODEL_FILE_NAME))
1140  
1141      if model_meta.metadata and model_meta.metadata.get(MLFLOW_MODEL_IS_EXTERNAL, False) is True:
1142          raise MlflowException(
1143              "This model's artifacts are external and are not stored in the model directory."
1144              " This model cannot be loaded with MLflow.",
1145              BAD_REQUEST,
1146          )
1147  
1148      conf = model_meta.flavors.get(FLAVOR_NAME)
1149      if conf is None:
1150          raise MlflowException(
1151              f'Model does not have the "{FLAVOR_NAME}" flavor',
1152              RESOURCE_DOES_NOT_EXIST,
1153          )
1154      model_py_version = conf.get(PY_VERSION)
1155      if not suppress_warnings:
1156          _warn_potentially_incompatible_py_version_if_necessary(model_py_version=model_py_version)
1157  
1158      _add_code_from_conf_to_system_path(local_path, conf, code_key=CODE)
1159      data_path = os.path.join(local_path, conf[DATA]) if (DATA in conf) else local_path
1160  
1161      if isinstance(model_config, str):
1162          model_config = _validate_and_get_model_config_from_file(model_config)
1163  
1164      model_config = _get_overridden_pyfunc_model_config(
1165          conf.get(MODEL_CONFIG, None), model_config, _logger
1166      )
1167  
1168      try:
1169          if model_config:
1170              model_impl = importlib.import_module(conf[MAIN])._load_pyfunc(data_path, model_config)
1171          else:
1172              model_impl = importlib.import_module(conf[MAIN])._load_pyfunc(data_path)
1173      except ModuleNotFoundError as e:
1174          # This error message is particularly for the case when the error is caused by module
1175          # "databricks.feature_store.mlflow_model". But depending on the environment, the offending
1176          # module might be "databricks", "databricks.feature_store" or full package. So we will
1177          # raise the error with the following note if "databricks" presents in the error. All non-
1178          # databricks module errors will just be re-raised.
1179          if conf[MAIN] == _DATABRICKS_FS_LOADER_MODULE and e.name.startswith("databricks"):
1180              raise MlflowException(
1181                  f"{e.msg}; "
1182                  "Note: mlflow.pyfunc.load_model is not supported for Feature Store models. "
1183                  "spark_udf() and predict() will not work as expected. Use "
1184                  "score_batch for offline predictions.",
1185                  BAD_REQUEST,
1186              ) from None
1187          raise e
1188      finally:
1189          # clean up the dependencies schema which is set to global state after loading the model.
1190          # This avoids the schema being used by other models loaded in the same process.
1191          _clear_dependencies_schemas()
1192      predict_fn = conf.get("predict_fn", "predict")
1193      streamable = conf.get("streamable", False)
1194      predict_stream_fn = conf.get("predict_stream_fn", "predict_stream") if streamable else None
1195  
1196      pyfunc_model = PyFuncModel(
1197          model_meta=model_meta,
1198          model_impl=model_impl,
1199          predict_fn=predict_fn,
1200          predict_stream_fn=predict_stream_fn,
1201          model_id=model_meta.model_id,
1202      )
1203  
1204      try:
1205          model_input_example = model_meta.load_input_example(path=local_path)
1206          pyfunc_model.input_example = model_input_example
1207      except Exception as e:
1208          _logger.debug(f"Failed to load input example from model metadata: {e}.")
1209  
1210      return pyfunc_model
1211  
1212  
1213  class _ServedPyFuncModel(PyFuncModel):
1214      def __init__(self, model_meta: Model, client: Any, server_pid: int, env_manager="local"):
1215          super().__init__(model_meta=model_meta, model_impl=client, predict_fn="invoke")
1216          self._client = client
1217          self._server_pid = server_pid
1218          # We need to set `env_manager` attribute because it is used by Databricks runtime
1219          # evaluate usage logging to log 'env_manager' tag in `_evaluate` function patching.
1220          self._env_manager = env_manager
1221  
1222      def predict(self, data, params=None):
1223          """
1224          Args:
1225              data: Model input data.
1226              params: Additional parameters to pass to the model for inference.
1227  
1228          Returns:
1229              Model predictions.
1230          """
1231          if "params" in inspect.signature(self._client.invoke).parameters:
1232              result = self._client.invoke(data, params=params).get_predictions()
1233          else:
1234              _log_warning_if_params_not_in_predict_signature(_logger, params)
1235              result = self._client.invoke(data).get_predictions()
1236          if isinstance(result, pandas.DataFrame):
1237              result = result[result.columns[0]]
1238          return result
1239  
1240      @property
1241      def pid(self):
1242          if self._server_pid is None:
1243              raise MlflowException("Served PyFunc Model is missing server process ID.")
1244          return self._server_pid
1245  
1246      @property
1247      def env_manager(self):
1248          return self._env_manager
1249  
1250      @env_manager.setter
1251      def env_manager(self, value):
1252          self._env_manager = value
1253  
1254  
1255  def _load_model_or_server(
1256      model_uri: str, env_manager: str, model_config: dict[str, Any] | None = None
1257  ):
1258      """
1259      Load a model with env restoration. If a non-local ``env_manager`` is specified, prepare an
1260      independent Python environment with the training time dependencies of the specified model
1261      installed and start a MLflow Model Scoring Server process with that model in that environment.
1262      Return a _ServedPyFuncModel that invokes the scoring server for prediction. Otherwise, load and
1263      return the model locally as a PyFuncModel using :py:func:`mlflow.pyfunc.load_model`.
1264  
1265      Args:
1266          model_uri: The uri of the model.
1267          env_manager: The environment manager to load the model.
1268          model_config: The model configuration to use by the model, only if the model
1269                        accepts it.
1270  
1271      Returns:
1272          A _ServedPyFuncModel for non-local ``env_manager``s or a PyFuncModel otherwise.
1273      """
1274      from mlflow.pyfunc.scoring_server.client import (
1275          ScoringServerClient,
1276          StdinScoringServerClient,
1277      )
1278  
1279      if env_manager == _EnvManager.LOCAL:
1280          return load_model(model_uri, model_config=model_config)
1281  
1282      _logger.info("Starting model server for model environment restoration.")
1283  
1284      local_path = _download_artifact_from_uri(artifact_uri=model_uri)
1285      model_meta = Model.load(os.path.join(local_path, MLMODEL_FILE_NAME))
1286  
1287      is_port_connectable = check_port_connectivity()
1288      pyfunc_backend = get_flavor_backend(
1289          local_path,
1290          env_manager=env_manager,
1291          install_mlflow=os.environ.get("MLFLOW_HOME") is not None,
1292          create_env_root_dir=not is_port_connectable,
1293      )
1294      _logger.info("Restoring model environment. This can take a few minutes.")
1295      # Set capture_output to True in Databricks so that when environment preparation fails, the
1296      # exception message of the notebook cell output will include child process command execution
1297      # stdout/stderr output.
1298      pyfunc_backend.prepare_env(model_uri=local_path, capture_output=is_in_databricks_runtime())
1299      if is_port_connectable:
1300          server_port = find_free_port()
1301          scoring_server_proc = pyfunc_backend.serve(
1302              model_uri=local_path,
1303              port=server_port,
1304              host="127.0.0.1",
1305              timeout=MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT.get(),
1306              enable_mlserver=False,
1307              synchronous=False,
1308              stdout=subprocess.PIPE,
1309              stderr=subprocess.STDOUT,
1310              model_config=model_config,
1311          )
1312          client = ScoringServerClient("127.0.0.1", server_port)
1313      else:
1314          scoring_server_proc = pyfunc_backend.serve_stdin(local_path, model_config=model_config)
1315          client = StdinScoringServerClient(scoring_server_proc)
1316  
1317      _logger.info(f"Scoring server process started at PID: {scoring_server_proc.pid}")
1318      try:
1319          client.wait_server_ready(timeout=90, scoring_server_proc=scoring_server_proc)
1320      except Exception as e:
1321          if scoring_server_proc.poll() is None:
1322              # the scoring server is still running but client can't connect to it.
1323              # kill the server.
1324              scoring_server_proc.kill()
1325          server_output, _ = scoring_server_proc.communicate(timeout=15)
1326          if isinstance(server_output, bytes):
1327              server_output = server_output.decode("UTF-8")
1328          raise MlflowException(
1329              "MLflow model server failed to launch, server process stdout and stderr are:\n"
1330              + server_output
1331          ) from e
1332  
1333      return _ServedPyFuncModel(
1334          model_meta=model_meta,
1335          client=client,
1336          server_pid=scoring_server_proc.pid,
1337          env_manager=env_manager,
1338      )
1339  
1340  
1341  def _get_model_dependencies(model_uri, format="pip"):
1342      model_dir = _download_artifact_from_uri(model_uri)
1343  
1344      def get_conda_yaml_path():
1345          model_config = _get_flavor_configuration_from_ml_model_file(
1346              os.path.join(model_dir, MLMODEL_FILE_NAME), flavor_name=FLAVOR_NAME
1347          )
1348          return os.path.join(model_dir, _extract_conda_env(model_config[ENV]))
1349  
1350      if format == "pip":
1351          requirements_file = os.path.join(model_dir, _REQUIREMENTS_FILE_NAME)
1352          if os.path.exists(requirements_file):
1353              return requirements_file
1354  
1355          _logger.info(
1356              f"{_REQUIREMENTS_FILE_NAME} is not found in the model directory. Falling back to"
1357              f" extracting pip requirements from the model's 'conda.yaml' file. Conda"
1358              " dependencies will be ignored."
1359          )
1360  
1361          with open(get_conda_yaml_path()) as yf:
1362              conda_yaml = yaml.safe_load(yf)
1363  
1364          conda_deps = conda_yaml.get("dependencies", [])
1365          for index, dep in enumerate(conda_deps):
1366              if isinstance(dep, dict) and "pip" in dep:
1367                  pip_deps_index = index
1368                  break
1369          else:
1370              raise MlflowException(
1371                  "No pip section found in conda.yaml file in the model directory.",
1372                  error_code=RESOURCE_DOES_NOT_EXIST,
1373              )
1374  
1375          pip_deps = conda_deps.pop(pip_deps_index)["pip"]
1376          tmp_dir = tempfile.mkdtemp()
1377          pip_file_path = os.path.join(tmp_dir, _REQUIREMENTS_FILE_NAME)
1378          with open(pip_file_path, "w") as f:
1379              f.write("\n".join(pip_deps) + "\n")
1380  
1381          if len(conda_deps) > 0:
1382              _logger.warning(
1383                  "The following conda dependencies have been excluded from the environment file:"
1384                  f" {', '.join(conda_deps)}."
1385              )
1386  
1387          return pip_file_path
1388  
1389      elif format == "conda":
1390          return get_conda_yaml_path()
1391      else:
1392          raise MlflowException(
1393              f"Illegal format argument '{format}'.", error_code=INVALID_PARAMETER_VALUE
1394          )
1395  
1396  
1397  def get_model_dependencies(model_uri, format="pip"):
1398      """
1399      Downloads the model dependencies and returns the path to requirements.txt or conda.yaml file.
1400  
1401      .. warning::
1402          This API downloads all the model artifacts to the local filesystem. This may take
1403          a long time for large models. To avoid this overhead, use
1404          ``mlflow.artifacts.download_artifacts("<model_uri>/requirements.txt")`` or
1405          ``mlflow.artifacts.download_artifacts("<model_uri>/conda.yaml")`` instead.
1406  
1407      Args:
1408          model_uri: The uri of the model to get dependencies from.
1409          format: The format of the returned dependency file. If the ``"pip"`` format is
1410              specified, the path to a pip ``requirements.txt`` file is returned.
1411              If the ``"conda"`` format is specified, the path to a ``"conda.yaml"``
1412              file is returned . If the ``"pip"`` format is specified but the model
1413              was not saved with a ``requirements.txt`` file, the ``pip`` section
1414              of the model's ``conda.yaml`` file is extracted instead, and any
1415              additional conda dependencies are ignored. Default value is ``"pip"``.
1416  
1417      Returns:
1418          The local filesystem path to either a pip ``requirements.txt`` file
1419          (if ``format="pip"``) or a ``conda.yaml`` file (if ``format="conda"``)
1420          specifying the model's dependencies.
1421      """
1422      dep_file = _get_model_dependencies(model_uri, format)
1423  
1424      if format == "pip":
1425          prefix = "%" if _is_in_ipython_notebook() else ""
1426          _logger.info(
1427              "To install the dependencies that were used to train the model, run the "
1428              f"following command: '{prefix}pip install -r {dep_file}'."
1429          )
1430      return dep_file
1431  
1432  
1433  @deprecated("mlflow.pyfunc.load_model", 1.0)
1434  def load_pyfunc(model_uri, suppress_warnings=False):
1435      """
1436      Load a model stored in Python function format.
1437  
1438      Args:
1439          model_uri: The location, in URI format, of the MLflow model. For example:
1440  
1441              - ``/Users/me/path/to/local/model``
1442              - ``relative/path/to/local/model``
1443              - ``s3://my_bucket/path/to/model``
1444              - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
1445              - ``models:/<model_name>/<model_version>``
1446              - ``models:/<model_name>/<stage>``
1447              - ``mlflow-artifacts:/path/to/model``
1448  
1449              For more information about supported URI schemes, see
1450              `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
1451              artifact-locations>`_.
1452  
1453          suppress_warnings: If ``True``, non-fatal warning messages associated with the model
1454              loading process will be suppressed. If ``False``, these warning messages will be
1455              emitted.
1456      """
1457      return load_model(model_uri, suppress_warnings)
1458  
1459  
1460  def _warn_potentially_incompatible_py_version_if_necessary(model_py_version=None):
1461      """
1462      Compares the version of Python that was used to save a given model with the version
1463      of Python that is currently running. If a major or minor version difference is detected,
1464      logs an appropriate warning.
1465      """
1466      if model_py_version is None:
1467          _logger.warning(
1468              "The specified model does not have a specified Python version. It may be"
1469              " incompatible with the version of Python that is currently running: Python %s",
1470              PYTHON_VERSION,
1471          )
1472      elif get_major_minor_py_version(model_py_version) != get_major_minor_py_version(PYTHON_VERSION):
1473          _logger.warning(
1474              "The version of Python that the model was saved in, `Python %s`, differs"
1475              " from the version of Python that is currently running, `Python %s`,"
1476              " and may be incompatible",
1477              model_py_version,
1478              PYTHON_VERSION,
1479          )
1480  
1481  
1482  def _create_model_downloading_tmp_dir(should_use_nfs):
1483      root_tmp_dir = get_or_create_nfs_tmp_dir() if should_use_nfs else get_or_create_tmp_dir()
1484  
1485      root_model_cache_dir = os.path.join(root_tmp_dir, "models")
1486      os.makedirs(root_model_cache_dir, exist_ok=True)
1487  
1488      tmp_model_dir = tempfile.mkdtemp(dir=root_model_cache_dir)
1489      # mkdtemp creates a directory with permission 0o700
1490      # For Spark UDFs, we need to make it accessible to other processes
1491      # Use 0o750 (owner: rwx, group: r-x, others: None) instead of 0o770
1492      os.chmod(tmp_model_dir, 0o750)
1493      return tmp_model_dir
1494  
1495  
1496  _MLFLOW_SERVER_OUTPUT_TAIL_LINES_TO_KEEP = 200
1497  
1498  
1499  def _is_variant_type(spark_type):
1500      try:
1501          from pyspark.sql.types import VariantType
1502  
1503          return isinstance(spark_type, VariantType)
1504      except ImportError:
1505          return False
1506  
1507  
1508  def _convert_spec_type_to_spark_type(spec_type):
1509      from pyspark.sql.types import ArrayType, MapType, StringType, StructField, StructType
1510  
1511      from mlflow.types.schema import AnyType, Array, DataType, Map, Object
1512  
1513      if isinstance(spec_type, DataType):
1514          return spec_type.to_spark()
1515  
1516      if isinstance(spec_type, AnyType):
1517          try:
1518              from pyspark.sql.types import VariantType
1519  
1520              return VariantType()
1521          except ImportError:
1522              raise MlflowException.invalid_parameter_value(
1523                  "`AnyType` is not supported in PySpark versions older than 4.0.0. "
1524                  "Upgrade your PySpark version to use this feature.",
1525              )
1526  
1527      if isinstance(spec_type, Array):
1528          return ArrayType(_convert_spec_type_to_spark_type(spec_type.dtype))
1529  
1530      if isinstance(spec_type, Object):
1531          return StructType([
1532              StructField(
1533                  property.name,
1534                  _convert_spec_type_to_spark_type(property.dtype),
1535                  # we set nullable to True for all properties
1536                  # to avoid some errors like java.lang.NullPointerException
1537                  # when the signature is not inferred based on correct data.
1538              )
1539              for property in spec_type.properties
1540          ])
1541  
1542      # Map only supports string as key
1543      if isinstance(spec_type, Map):
1544          return MapType(
1545              keyType=StringType(), valueType=_convert_spec_type_to_spark_type(spec_type.value_type)
1546          )
1547  
1548      raise MlflowException(f"Failed to convert schema type `{spec_type}` to spark type.")
1549  
1550  
1551  def _cast_output_spec_to_spark_type(spec):
1552      from pyspark.sql.types import ArrayType
1553  
1554      from mlflow.types.schema import ColSpec, DataType, TensorSpec
1555  
1556      # TODO: handle optional output columns.
1557      if isinstance(spec, ColSpec):
1558          return _convert_spec_type_to_spark_type(spec.type)
1559      elif isinstance(spec, TensorSpec):
1560          data_type = DataType.from_numpy_type(spec.type)
1561          if data_type is None:
1562              raise MlflowException(
1563                  f"Model output tensor spec type {spec.type} is not supported in spark_udf.",
1564                  error_code=INVALID_PARAMETER_VALUE,
1565              )
1566  
1567          if len(spec.shape) == 1:
1568              return ArrayType(data_type.to_spark())
1569          elif len(spec.shape) == 2:
1570              return ArrayType(ArrayType(data_type.to_spark()))
1571          else:
1572              raise MlflowException(
1573                  "Only 1D or 2D tensors are supported as spark_udf "
1574                  f"return value, but model output '{spec.name}' has shape {spec.shape}.",
1575                  error_code=INVALID_PARAMETER_VALUE,
1576              )
1577      else:
1578          raise MlflowException(
1579              f"Unknown schema output spec {spec}.", error_code=INVALID_PARAMETER_VALUE
1580          )
1581  
1582  
1583  def _infer_spark_udf_return_type(model_output_schema):
1584      from pyspark.sql.types import StructField, StructType
1585  
1586      if len(model_output_schema.inputs) == 1:
1587          return _cast_output_spec_to_spark_type(model_output_schema.inputs[0])
1588  
1589      return StructType([
1590          StructField(name=spec.name or str(i), dataType=_cast_output_spec_to_spark_type(spec))
1591          for i, spec in enumerate(model_output_schema.inputs)
1592      ])
1593  
1594  
1595  def _parse_spark_datatype(datatype: str):
1596      from pyspark.sql.functions import udf
1597      from pyspark.sql.session import SparkSession
1598  
1599      return_type = "boolean" if datatype == "bool" else datatype
1600      parsed_datatype = udf(lambda x: x, returnType=return_type).returnType
1601  
1602      if parsed_datatype.typeName() == "unparseddata":
1603          # For spark 3.5.x, `udf(lambda x: x, returnType=return_type).returnType`
1604          # returns UnparsedDataType, which is not compatible with signature inference.
1605          # Note: SparkSession.active only exists for spark >= 3.5.0
1606          schema = (
1607              SparkSession
1608              .active()
1609              .range(0)
1610              .select(udf(lambda x: x, returnType=return_type)("id"))
1611              .schema
1612          )
1613          return schema[0].dataType
1614  
1615      return parsed_datatype
1616  
1617  
1618  def _is_none_or_nan(value):
1619      # The condition `isinstance(value, float)` is needed to avoid error
1620      # from `np.isnan(value)` if value is a non-numeric type.
1621      return value is None or isinstance(value, float) and np.isnan(value)
1622  
1623  
1624  def _convert_array_values(values, result_type):
1625      """
1626      Convert list or numpy array values to spark dataframe column values.
1627      """
1628      from pyspark.sql.types import ArrayType, StructType
1629  
1630      if not isinstance(result_type, ArrayType):
1631          raise MlflowException.invalid_parameter_value(
1632              f"result_type must be ArrayType, got {result_type.simpleString()}",
1633          )
1634  
1635      spark_primitive_type_to_np_type = _get_spark_primitive_type_to_np_type()
1636  
1637      if type(result_type.elementType) in spark_primitive_type_to_np_type:
1638          np_type = spark_primitive_type_to_np_type[type(result_type.elementType)]
1639          # For array type result values, if provided value is None or NaN, regard it as a null array.
1640          # see https://github.com/mlflow/mlflow/issues/8986
1641          return None if _is_none_or_nan(values) else np.array(values, dtype=np_type)
1642      if isinstance(result_type.elementType, ArrayType):
1643          return [_convert_array_values(v, result_type.elementType) for v in values]
1644      if isinstance(result_type.elementType, StructType):
1645          return [_convert_struct_values(v, result_type.elementType) for v in values]
1646      if _is_variant_type(result_type.elementType):
1647          return values
1648  
1649      raise MlflowException.invalid_parameter_value(
1650          "Unsupported array type field with element type "
1651          f"{result_type.elementType.simpleString()} in Array type.",
1652      )
1653  
1654  
1655  def _get_spark_primitive_types():
1656      from pyspark.sql import types
1657  
1658      return (
1659          types.IntegerType,
1660          types.LongType,
1661          types.FloatType,
1662          types.DoubleType,
1663          types.StringType,
1664          types.BooleanType,
1665      )
1666  
1667  
1668  def _get_spark_primitive_type_to_np_type():
1669      from pyspark.sql import types
1670  
1671      return {
1672          types.IntegerType: np.int32,
1673          types.LongType: np.int64,
1674          types.FloatType: np.float32,
1675          types.DoubleType: np.float64,
1676          types.BooleanType: np.bool_,
1677          types.StringType: np.str_,
1678      }
1679  
1680  
1681  def _get_spark_primitive_type_to_python_type():
1682      from pyspark.sql import types
1683  
1684      return {
1685          types.IntegerType: int,
1686          types.LongType: int,
1687          types.FloatType: float,
1688          types.DoubleType: float,
1689          types.BooleanType: bool,
1690          types.StringType: str,
1691      }
1692  
1693  
1694  def _check_udf_return_type(data_type):
1695      from pyspark.sql.types import ArrayType, MapType, StringType, StructType
1696  
1697      primitive_types = _get_spark_primitive_types()
1698      if isinstance(data_type, primitive_types):
1699          return True
1700  
1701      if isinstance(data_type, ArrayType):
1702          return _check_udf_return_type(data_type.elementType)
1703  
1704      if isinstance(data_type, StructType):
1705          return all(_check_udf_return_type(field.dataType) for field in data_type.fields)
1706  
1707      if isinstance(data_type, MapType):
1708          return isinstance(data_type.keyType, StringType) and _check_udf_return_type(
1709              data_type.valueType
1710          )
1711  
1712      return False
1713  
1714  
1715  def _convert_struct_values(
1716      result: pandas.DataFrame | dict[str, Any],
1717      result_type,
1718  ):
1719      """
1720      Convert spark StructType values to spark dataframe column values.
1721      """
1722  
1723      from pyspark.sql.types import ArrayType, MapType, StructType
1724  
1725      if not isinstance(result_type, StructType):
1726          raise MlflowException.invalid_parameter_value(
1727              f"result_type must be StructType, got {result_type.simpleString()}",
1728          )
1729  
1730      if not isinstance(result, (dict, pandas.DataFrame)):
1731          raise MlflowException.invalid_parameter_value(
1732              f"Unsupported result type {type(result)}, expected dict or pandas DataFrame",
1733          )
1734  
1735      spark_primitive_type_to_np_type = _get_spark_primitive_type_to_np_type()
1736      is_pandas_df = isinstance(result, pandas.DataFrame)
1737      result_dict = {}
1738      for field_name in result_type.fieldNames():
1739          field_type = result_type[field_name].dataType
1740          field_values = result[field_name]
1741  
1742          if type(field_type) in spark_primitive_type_to_np_type:
1743              np_type = spark_primitive_type_to_np_type[type(field_type)]
1744              if is_pandas_df:
1745                  # it's possible that field_values contain only Nones
1746                  # in this case, we don't need to cast the type
1747                  if not all(_is_none_or_nan(field_value) for field_value in field_values):
1748                      field_values = field_values.astype(np_type)
1749              else:
1750                  field_values = (
1751                      None
1752                      if _is_none_or_nan(field_values)
1753                      else np.array(field_values, dtype=np_type).item()
1754                  )
1755          elif isinstance(field_type, ArrayType):
1756              if is_pandas_df:
1757                  field_values = pandas.Series(
1758                      _convert_array_values(field_value, field_type) for field_value in field_values
1759                  )
1760              else:
1761                  field_values = _convert_array_values(field_values, field_type)
1762          elif isinstance(field_type, StructType):
1763              if is_pandas_df:
1764                  field_values = pandas.Series([
1765                      _convert_struct_values(field_value, field_type) for field_value in field_values
1766                  ])
1767              else:
1768                  if isinstance(field_values, pydantic.BaseModel):
1769                      field_values = field_values.model_dump()
1770                  field_values = _convert_struct_values(field_values, field_type)
1771          elif isinstance(field_type, MapType):
1772              if is_pandas_df:
1773                  field_values = pandas.Series([
1774                      {
1775                          key: _convert_value_based_on_spark_type(value, field_type.valueType)
1776                          for key, value in field_value.items()
1777                      }
1778                      for field_value in field_values
1779                  ]).astype(object)
1780              else:
1781                  field_values = {
1782                      key: _convert_value_based_on_spark_type(value, field_type.valueType)
1783                      for key, value in field_values.items()
1784                  }
1785          elif _is_variant_type(field_type):
1786              return field_values
1787          else:
1788              raise MlflowException.invalid_parameter_value(
1789                  f"Unsupported field type {field_type.simpleString()} in struct type.",
1790              )
1791          result_dict[field_name] = field_values
1792  
1793      if is_pandas_df:
1794          return pandas.DataFrame(result_dict)
1795      return result_dict
1796  
1797  
1798  def _convert_value_based_on_spark_type(value, spark_type):
1799      """
1800      Convert value to python types based on the given spark type.
1801      """
1802  
1803      from pyspark.sql.types import ArrayType, MapType, StructType
1804  
1805      spark_primitive_type_to_python_type = _get_spark_primitive_type_to_python_type()
1806  
1807      if type(spark_type) in spark_primitive_type_to_python_type:
1808          python_type = spark_primitive_type_to_python_type[type(spark_type)]
1809          return None if _is_none_or_nan(value) else python_type(value)
1810      if isinstance(spark_type, StructType):
1811          return _convert_struct_values(value, spark_type)
1812      if isinstance(spark_type, ArrayType):
1813          return [_convert_value_based_on_spark_type(v, spark_type.elementType) for v in value]
1814      if isinstance(spark_type, MapType):
1815          return {
1816              key: _convert_value_based_on_spark_type(value[key], spark_type.valueType)
1817              for key in value
1818          }
1819      if _is_variant_type(spark_type):
1820          return value
1821      raise MlflowException.invalid_parameter_value(
1822          f"Unsupported type {spark_type} for value {value}"
1823      )
1824  
1825  
1826  # This location is used to prebuild python environment in Databricks runtime.
1827  # The location for prebuilding env should be located under /local_disk0
1828  # because the python env will be uploaded to NFS and mounted to Serverless UDF sandbox,
1829  # for serverless client image case, it doesn't have "/local_disk0" directory
1830  _PREBUILD_ENV_ROOT_LOCATION = "/tmp"
1831  
1832  
1833  def _gen_prebuilt_env_archive_name(spark, local_model_path):
1834      """
1835      Generate prebuilt env archive file name.
1836      The format is:
1837      'mlflow-{sha of python env config and dependencies}-{runtime version}-{platform machine}'
1838      Note: The runtime version and platform machine information are included in the
1839       archive name because the prebuilt env might not be compatible across different
1840       runtime versions or platform machines.
1841      """
1842      python_env = _get_python_env(Path(local_model_path))
1843      env_name = _get_virtualenv_name(python_env, local_model_path)
1844      dbconnect_udf_sandbox_info = get_dbconnect_udf_sandbox_info(spark)
1845      return (
1846          f"{env_name}-{dbconnect_udf_sandbox_info.image_version}-"
1847          f"{dbconnect_udf_sandbox_info.platform_machine}"
1848      )
1849  
1850  
1851  def _verify_prebuilt_env(spark, local_model_path, env_archive_path):
1852      # Use `[:-7]` to truncate ".tar.gz" in the end
1853      archive_name = os.path.basename(env_archive_path)[:-7]
1854      prebuilt_env_sha, prebuilt_runtime_version, prebuilt_platform_machine = archive_name.split("-")[
1855          -3:
1856      ]
1857  
1858      python_env = _get_python_env(Path(local_model_path))
1859      env_sha = _get_virtualenv_name(python_env, local_model_path).split("-")[-1]
1860      dbconnect_udf_sandbox_info = get_dbconnect_udf_sandbox_info(spark)
1861      runtime_version = dbconnect_udf_sandbox_info.image_version
1862      platform_machine = dbconnect_udf_sandbox_info.platform_machine
1863  
1864      if prebuilt_env_sha != env_sha:
1865          raise MlflowException(
1866              f"The prebuilt env '{env_archive_path}' does not match the model required environment."
1867          )
1868      if prebuilt_runtime_version != runtime_version:
1869          raise MlflowException(
1870              f"The prebuilt env '{env_archive_path}' runtime version '{prebuilt_runtime_version}' "
1871              f"does not match UDF sandbox runtime version {runtime_version}."
1872          )
1873      if prebuilt_platform_machine != platform_machine:
1874          raise MlflowException(
1875              f"The prebuilt env '{env_archive_path}' platform machine '{prebuilt_platform_machine}' "
1876              f"does not match UDF sandbox platform machine {platform_machine}."
1877          )
1878  
1879  
1880  def _prebuild_env_internal(local_model_path, archive_name, save_path, env_manager):
1881      env_root_dir = os.path.join(_PREBUILD_ENV_ROOT_LOCATION, archive_name)
1882      archive_path = os.path.join(save_path, archive_name + ".tar.gz")
1883      if os.path.exists(env_root_dir):
1884          shutil.rmtree(env_root_dir)
1885      if os.path.exists(archive_path):
1886          os.remove(archive_path)
1887  
1888      try:
1889          pyfunc_backend = get_flavor_backend(
1890              local_model_path,
1891              env_manager=env_manager,
1892              install_mlflow=False,
1893              create_env_root_dir=False,
1894              env_root_dir=env_root_dir,
1895          )
1896  
1897          pyfunc_backend.prepare_env(model_uri=local_model_path, capture_output=False)
1898          # exclude pip cache from the archive file.
1899          cache_path = os.path.join(env_root_dir, "pip_cache_pkgs")
1900          if os.path.exists(cache_path):
1901              shutil.rmtree(cache_path)
1902  
1903          return archive_directory(env_root_dir, archive_path)
1904      finally:
1905          shutil.rmtree(env_root_dir, ignore_errors=True)
1906  
1907  
1908  def _download_prebuilt_env_if_needed(prebuilt_env_uri):
1909      from mlflow.utils.file_utils import get_or_create_tmp_dir
1910  
1911      parsed_url = urlparse(prebuilt_env_uri)
1912      if parsed_url.scheme in {"", "file"}:
1913          # local path
1914          return parsed_url.path
1915      if parsed_url.scheme == "dbfs":
1916          tmp_dir = MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR.get() or get_or_create_tmp_dir()
1917          model_env_uc_path = parsed_url.path
1918  
1919          # download file from DBFS.
1920          local_model_env_path = os.path.join(tmp_dir, os.path.basename(model_env_uc_path))
1921          if os.path.exists(local_model_env_path):
1922              # file is already downloaded.
1923              return local_model_env_path
1924  
1925          try:
1926              from databricks.sdk import WorkspaceClient
1927  
1928              ws = WorkspaceClient()
1929              # Download model env file from UC volume.
1930              with (
1931                  ws.files.download(model_env_uc_path).contents as rf,
1932                  open(local_model_env_path, "wb") as wf,
1933              ):
1934                  while chunk := rf.read(4096 * 1024):
1935                      wf.write(chunk)
1936              return local_model_env_path
1937          except (Exception, KeyboardInterrupt):
1938              if os.path.exists(local_model_env_path):
1939                  # clean the partially saved file if downloading fails.
1940                  os.remove(local_model_env_path)
1941              raise
1942  
1943      raise MlflowException(
1944          f"Unsupported prebuilt env file path '{prebuilt_env_uri}', "
1945          f"invalid scheme: '{parsed_url.scheme}'."
1946      )
1947  
1948  
1949  def build_model_env(model_uri, save_path, env_manager=_EnvManager.VIRTUALENV):
1950      """
1951      Prebuild model python environment and generate an archive file saved to provided
1952      `save_path`.
1953  
1954      Typical usages:
1955       - Pre-build a model's environment in Databricks Runtime and then download the prebuilt
1956         python environment archive file. This pre-built environment archive can then be used
1957         in `mlflow.pyfunc.spark_udf` for remote inference execution when using Databricks Connect
1958         to remotely connect to a Databricks environment for code execution.
1959  
1960      .. note::
1961          The `build_model_env` API is intended to only work when executed within Databricks runtime,
1962          serving the purpose of capturing the required execution environment that is needed for
1963          remote code execution when using DBConnect. The environment archive is designed to be used
1964          when performing remote execution using `mlflow.pyfunc.spark_udf` in
1965          Databricks runtime or Databricks Connect client and has no other purpose.
1966          The prebuilt env archive file cannot be used across different Databricks runtime
1967          versions or different platform machines. As such, if you connect to a different cluster
1968          that is running a different runtime version on Databricks, you will need to execute this
1969          API in a notebook and retrieve the generated archive to your local machine. Each
1970          environment snapshot is unique to the the model, the runtime version of your remote
1971          Databricks cluster, and the specification of the udf execution environment.
1972          When using the prebuilt env in `mlflow.pyfunc.spark_udf`, MLflow will verify
1973          whether the spark UDF sandbox environment matches the prebuilt env requirements and will
1974          raise Exceptions if there are compatibility issues. If these occur, simply re-run this API
1975          in the cluster that you are attempting to attach to.
1976  
1977      .. code-block:: python
1978          :caption: Example
1979  
1980          from mlflow.pyfunc import build_model_env
1981  
1982          # Create a python environment archive file at the path `prebuilt_env_uri`
1983          prebuilt_env_uri = build_model_env(f"runs:/{run_id}/model", "/path/to/save_directory")
1984  
1985      Args:
1986          model_uri: URI to the model that is used to build the python environment.
1987          save_path: The directory path that is used to save the prebuilt model environment
1988              archive file path.
1989              The path can be either local directory path or
1990              mounted DBFS path such as '/dbfs/...' or
1991              mounted UC volume path such as '/Volumes/...'.
1992          env_manager: The environment manager to use in order to create the python environment
1993              for model inference, the value can be either 'virtualenv' or 'uv', the default
1994              value is 'virtualenv'.
1995  
1996      Returns:
1997          Return the path of an archive file containing the python environment data.
1998      """
1999      from mlflow.utils._spark_utils import _get_active_spark_session
2000  
2001      if not is_in_databricks_runtime():
2002          raise RuntimeError("'build_model_env' only support running in Databricks runtime.")
2003  
2004      if os.path.isfile(save_path):
2005          raise RuntimeError(f"The saving path '{save_path}' must be a directory.")
2006      os.makedirs(save_path, exist_ok=True)
2007  
2008      local_model_path = _download_artifact_from_uri(
2009          artifact_uri=model_uri, output_path=_create_model_downloading_tmp_dir(should_use_nfs=False)
2010      )
2011      archive_name = _gen_prebuilt_env_archive_name(_get_active_spark_session(), local_model_path)
2012      dest_path = os.path.join(save_path, archive_name + ".tar.gz")
2013      if os.path.exists(dest_path):
2014          raise RuntimeError(
2015              "A pre-built model python environment already exists "
2016              f"in '{dest_path}'. To rebuild it, please remove "
2017              "the existing one first."
2018          )
2019  
2020      # Archive the environment directory as a `tar.gz` format archive file,
2021      # and then move the archive file to the destination directory.
2022      # Note:
2023      # - all symlink files in the input directory are kept as it is in the
2024      #  archive file.
2025      # - the destination directory could be UC-volume fuse mounted directory
2026      #  which only supports limited filesystem operations, so to ensure it works,
2027      #  we generate the archive file under /tmp and then move it into the
2028      #  destination directory.
2029      tmp_archive_path = None
2030      try:
2031          tmp_archive_path = _prebuild_env_internal(
2032              local_model_path, archive_name, _PREBUILD_ENV_ROOT_LOCATION, env_manager
2033          )
2034          shutil.move(tmp_archive_path, save_path)
2035          return dest_path
2036      finally:
2037          shutil.rmtree(local_model_path, ignore_errors=True)
2038          if tmp_archive_path and os.path.exists(tmp_archive_path):
2039              os.remove(tmp_archive_path)
2040  
2041  
2042  def spark_udf(
2043      spark,
2044      model_uri,
2045      result_type=None,
2046      env_manager=None,
2047      params: dict[str, Any] | None = None,
2048      extra_env: dict[str, str] | None = None,
2049      prebuilt_env_uri: str | None = None,
2050      model_config: str | Path | dict[str, Any] | None = None,
2051  ):
2052      """
2053      A Spark UDF that can be used to invoke the Python function formatted model.
2054  
2055      Parameters passed to the UDF are forwarded to the model as a DataFrame where the column names
2056      are ordinals (0, 1, ...). On some versions of Spark (3.0 and above), it is also possible to
2057      wrap the input in a struct. In that case, the data will be passed as a DataFrame with column
2058      names given by the struct definition (e.g. when invoked as my_udf(struct('x', 'y')), the model
2059      will get the data as a pandas DataFrame with 2 columns 'x' and 'y').
2060  
2061      If a model contains a signature with tensor spec inputs, you will need to pass a column of
2062      array type as a corresponding UDF argument. The column values of which must be one dimensional
2063      arrays. The UDF will reshape the column values to the required shape with 'C' order
2064      (i.e. read / write the elements using C-like index order) and cast the values as the required
2065      tensor spec type.
2066  
2067      If a model contains a signature, the UDF can be called without specifying column name
2068      arguments. In this case, the UDF will be called with column names from signature, so the
2069      evaluation dataframe's column names must match the model signature's column names.
2070  
2071      The predictions are filtered to contain only the columns that can be represented as the
2072      ``result_type``. If the ``result_type`` is string or array of strings, all predictions are
2073      converted to string. If the result type is not an array type, the left most column with
2074      matching type is returned.
2075  
2076      .. note::
2077          Inputs of type ``pyspark.sql.types.DateType`` are not supported on earlier versions of
2078          Spark (2.4 and below).
2079  
2080      .. note::
2081          When using Databricks Connect to connect to a remote Databricks cluster,
2082          the Databricks cluster must use runtime version >= 16, and if the 'prebuilt_env_uri'
2083          parameter is set, 'env_manager' parameter should not be set.
2084          the Databricks cluster must use runtime version >= 15.4,and if the 'prebuilt_env_uri'
2085          parameter is set, 'env_manager' parameter should not be set,
2086          if the runtime version is 15.4 and the cluster is
2087          standard access mode, the cluster need to configure
2088          "spark.databricks.safespark.archive.artifact.unpack.disabled" to "false".
2089  
2090      .. note::
2091          Please be aware that when operating in Databricks Serverless,
2092          spark tasks run within the confines of the Databricks Serverless UDF sandbox.
2093          This environment has a total capacity limit of 1GB, combining both available
2094          memory and local disk capacity. Furthermore, there are no GPU devices available
2095          in this setup. Therefore, any deep-learning models that contain large weights
2096          or require a GPU are not suitable for deployment on Databricks Serverless.
2097  
2098      .. code-block:: python
2099          :caption: Example
2100  
2101          from pyspark.sql.functions import struct
2102  
2103          predict = mlflow.pyfunc.spark_udf(spark, "/my/local/model")
2104          df.withColumn("prediction", predict(struct("name", "age"))).show()
2105  
2106      Args:
2107          spark: A SparkSession object.
2108          model_uri: The location, in URI format, of the MLflow model with the
2109              :py:mod:`mlflow.pyfunc` flavor. For example:
2110  
2111              - ``/Users/me/path/to/local/model``
2112              - ``relative/path/to/local/model``
2113              - ``s3://my_bucket/path/to/model``
2114              - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
2115              - ``models:/<model_name>/<model_version>``
2116              - ``models:/<model_name>/<stage>``
2117              - ``mlflow-artifacts:/path/to/model``
2118  
2119              For more information about supported URI schemes, see
2120              `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
2121              artifact-locations>`_.
2122  
2123          result_type: the return type of the user-defined function. The value can be either a
2124              ``pyspark.sql.types.DataType`` object or a DDL-formatted type string. Only a primitive
2125              type, an array ``pyspark.sql.types.ArrayType`` of primitive type, or a struct type
2126              containing fields of above 2 kinds of types are allowed.
2127              If unspecified, it tries to infer result type from model signature
2128              output schema, if model output schema is not available, it fallbacks to use ``double``
2129              type.
2130  
2131              The following classes of result type are supported:
2132  
2133              - "int" or ``pyspark.sql.types.IntegerType``: The leftmost integer that can fit in an
2134                ``int32`` or an exception if there is none.
2135  
2136              - "long" or ``pyspark.sql.types.LongType``: The leftmost long integer that can fit in an
2137                ``int64`` or an exception if there is none.
2138  
2139              - ``ArrayType(IntegerType|LongType)``: All integer columns that can fit into the
2140                requested size.
2141  
2142              - "float" or ``pyspark.sql.types.FloatType``: The leftmost numeric result cast to
2143                ``float32`` or an exception if there is none.
2144  
2145              - "double" or ``pyspark.sql.types.DoubleType``: The leftmost numeric result cast to
2146                ``double`` or an exception if there is none.
2147  
2148              - ``ArrayType(FloatType|DoubleType)``: All numeric columns cast to the requested type or
2149                an exception if there are no numeric columns.
2150  
2151              - "string" or ``pyspark.sql.types.StringType``: The leftmost column converted to
2152                ``string``.
2153  
2154              - "boolean" or "bool" or ``pyspark.sql.types.BooleanType``: The leftmost column
2155                converted to ``bool`` or an exception if there is none.
2156  
2157              - ``ArrayType(StringType)``: All columns converted to ``string``.
2158  
2159              - "field1 FIELD1_TYPE, field2 FIELD2_TYPE, ...": A struct type containing multiple
2160                fields separated by comma, each field type must be one of types listed above.
2161  
2162          env_manager: The environment manager to use in order to create the python environment
2163              for model inference. Note that environment is only restored in the context
2164              of the PySpark UDF; the software environment outside of the UDF is
2165              unaffected. If `prebuilt_env_uri` parameter is not set, the default value
2166              is ``local``, and the following values are supported:
2167  
2168              - ``virtualenv``: Use virtualenv to restore the python environment that
2169                was used to train the model. This is the default option if ``env_manager``
2170                is not set.
2171              - ``uv`` : Use uv to restore the python environment that
2172                was used to train the model.
2173              - ``conda``: Use Conda to restore the software environment
2174                that was used to train the model.
2175              - ``local``: Use the current Python environment for model inference, which
2176                may differ from the environment used to train the model and may lead to
2177                errors or invalid predictions.
2178  
2179              If the `prebuilt_env_uri` parameter is set, `env_manager` parameter should not
2180              be set.
2181  
2182          params: Additional parameters to pass to the model for inference.
2183  
2184          extra_env: Extra environment variables to pass to the UDF executors.
2185              For overrides that need to propagate to the Spark workers (i.e.,
2186              overriding the scoring server timeout via `MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT`).
2187  
2188          prebuilt_env_uri: The path of the prebuilt env archive file created by
2189              `mlflow.pyfunc.build_model_env` API.
2190              This parameter can only be used in Databricks Serverless notebook REPL,
2191              Databricks Shared cluster notebook REPL, and Databricks Connect client
2192              environment.
2193              The path can be either local file path or DBFS path such as
2194              'dbfs:/Volumes/...', in this case, MLflow automatically downloads it
2195              to local temporary directory, "MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR"
2196              environmental variable can be set to specify the temporary directory
2197              to use.
2198  
2199              If this parameter is set, `env_manger` parameter must not be set.
2200  
2201          model_config: The model configuration to set when loading the model.
2202              See 'model_config' argument in `mlflow.pyfunc.load_model` API for details.
2203  
2204      Returns:
2205          Spark UDF that applies the model's ``predict`` method to the data and returns a
2206          type specified by ``result_type``, which by default is a double.
2207      """
2208  
2209      # Scope Spark import to this method so users don't need pyspark to use non-Spark-related
2210      # functionality.
2211      from pyspark.sql.functions import pandas_udf
2212      from pyspark.sql.types import (
2213          ArrayType,
2214          BooleanType,
2215          DoubleType,
2216          FloatType,
2217          IntegerType,
2218          LongType,
2219          MapType,
2220          StringType,
2221      )
2222      from pyspark.sql.types import StructType as SparkStructType
2223  
2224      from mlflow.pyfunc.spark_model_cache import SparkModelCache
2225      from mlflow.utils._spark_utils import _SparkDirectoryDistributor
2226  
2227      is_spark_connect = is_spark_connect_mode()
2228      # Used in test to force install local version of mlflow when starting a model server
2229      mlflow_home = os.environ.get("MLFLOW_HOME")
2230      openai_env_vars = mlflow.openai.model._OpenAIEnvVar.read_environ()
2231      mlflow_testing = _MLFLOW_TESTING.get_raw()
2232  
2233      if prebuilt_env_uri:
2234          if env_manager not in (None, _EnvManager.VIRTUALENV, _EnvManager.UV):
2235              raise MlflowException(
2236                  "If 'prebuilt_env_uri' parameter is set, 'env_manager' parameter must "
2237                  "be either None, 'virtualenv', or 'uv'."
2238              )
2239          env_manager = _EnvManager.VIRTUALENV
2240      else:
2241          env_manager = env_manager or _EnvManager.LOCAL
2242  
2243      _EnvManager.validate(env_manager)
2244  
2245      if is_spark_connect:
2246          is_spark_in_local_mode = False
2247      else:
2248          # Check whether spark is in local or local-cluster mode
2249          # this case all executors and driver share the same filesystem
2250          is_spark_in_local_mode = spark.conf.get("spark.master").startswith("local")
2251  
2252      is_dbconnect_mode = is_databricks_connect(spark)
2253      if prebuilt_env_uri is not None and not is_dbconnect_mode:
2254          raise RuntimeError(
2255              "'prebuilt_env' parameter can only be used in Databricks Serverless "
2256              "notebook REPL, atabricks Shared cluster notebook REPL, and Databricks Connect client "
2257              "environment."
2258          )
2259  
2260      if prebuilt_env_uri is None and is_dbconnect_mode and not is_in_databricks_runtime():
2261          raise RuntimeError(
2262              "'prebuilt_env_uri' param is required if using Databricks Connect to connect "
2263              "to Databricks cluster from your own machine."
2264          )
2265  
2266      # Databricks connect can use `spark.addArtifact` to upload artifact to NFS.
2267      # But for Databricks shared cluster runtime, it can directly write to NFS, so exclude it.
2268      # For Databricks Serverless runtime (notebook REPL), spark.addArtifact may not reliably
2269      # make archives available to UDF executor sandboxes. As a temporary workaround, set
2270      # _MLFLOW_SPARK_UDF_SERVERLESS_SKIP_DBCONNECT_ARTIFACT=true to skip the addArtifact path
2271      # and let each executor fetch the model directly from the MLflow artifact store instead.
2272      use_dbconnect_artifact = (
2273          is_dbconnect_mode
2274          and not is_in_databricks_shared_cluster_runtime()
2275          and not (
2276              is_in_databricks_serverless_runtime()
2277              and _MLFLOW_SPARK_UDF_SERVERLESS_SKIP_DBCONNECT_ARTIFACT.get()
2278          )
2279      )
2280  
2281      if use_dbconnect_artifact:
2282          udf_sandbox_info = get_dbconnect_udf_sandbox_info(spark)
2283          if Version(udf_sandbox_info.mlflow_version) < Version("2.18.0"):
2284              raise MlflowException(
2285                  "Using 'mlflow.pyfunc.spark_udf' in Databricks Serverless or in remote "
2286                  "Databricks Connect requires UDF sandbox image installed with MLflow "
2287                  "of version >= 2.18.0"
2288              )
2289          # `udf_sandbox_info.runtime_version` format is like '<major_version>.<minor_version>'.
2290          # It's safe to apply `Version`.
2291          dbr_runtime_version = Version(udf_sandbox_info.runtime_version)
2292          if dbr_runtime_version < Version("15.4"):
2293              raise MlflowException(
2294                  "Using 'mlflow.pyfunc.spark_udf' in Databricks Serverless or in remote "
2295                  "Databricks Connect requires Databricks runtime version >= 15.4."
2296              )
2297          if dbr_runtime_version == Version("15.4"):
2298              if spark.conf.get("spark.databricks.pyspark.udf.isolation.enabled").lower() == "true":
2299                  # The connected cluster is standard (shared) mode.
2300                  if (
2301                      spark.conf.get(
2302                          "spark.databricks.safespark.archive.artifact.unpack.disabled"
2303                      ).lower()
2304                      != "false"
2305                  ):
2306                      raise MlflowException(
2307                          "Using 'mlflow.pyfunc.spark_udf' in remote Databricks Connect requires "
2308                          "Databricks cluster setting "
2309                          "'spark.databricks.safespark.archive.artifact.unpack.disabled' to 'false' "
2310                          "if Databricks runtime version is 15.4"
2311                      )
2312  
2313      nfs_root_dir = get_nfs_cache_root_dir()
2314      should_use_nfs = nfs_root_dir is not None
2315  
2316      should_use_spark_to_broadcast_file = not (
2317          is_spark_in_local_mode or should_use_nfs or is_spark_connect or use_dbconnect_artifact
2318      )
2319  
2320      # For spark connect mode,
2321      # If client code is executed in databricks runtime and NFS is available,
2322      # we save model to NFS temp directory in the driver
2323      # and load the model in the executor.
2324      should_spark_connect_use_nfs = is_in_databricks_runtime() and should_use_nfs
2325  
2326      if (
2327          is_spark_connect
2328          and not is_dbconnect_mode
2329          and env_manager in (_EnvManager.VIRTUALENV, _EnvManager.CONDA, _EnvManager.UV)
2330      ):
2331          raise MlflowException.invalid_parameter_value(
2332              f"Environment manager {env_manager!r} is not supported in Spark Connect "
2333              "client environment if it connects to non-Databricks Spark cluster.",
2334          )
2335  
2336      local_model_path = _download_artifact_from_uri(
2337          artifact_uri=model_uri,
2338          output_path=_create_model_downloading_tmp_dir(should_use_nfs),
2339      )
2340  
2341      if prebuilt_env_uri:
2342          prebuilt_env_uri = _download_prebuilt_env_if_needed(prebuilt_env_uri)
2343          _verify_prebuilt_env(spark, local_model_path, prebuilt_env_uri)
2344      if use_dbconnect_artifact and env_manager == _EnvManager.CONDA:
2345          raise MlflowException(
2346              "Databricks connect mode or Databricks Serverless python REPL doesn't "
2347              "support env_manager 'conda'."
2348          )
2349  
2350      if env_manager == _EnvManager.LOCAL:
2351          # Assume spark executor python environment is the same with spark driver side.
2352          model_requirements = _get_pip_requirements_from_model_path(local_model_path)
2353          warn_dependency_requirement_mismatches(model_requirements)
2354          _logger.warning(
2355              'Calling `spark_udf()` with `env_manager="local"` does not recreate the same '
2356              "environment that was used during training, which may lead to errors or inaccurate "
2357              'predictions. We recommend specifying `env_manager="conda"`, which automatically '
2358              "recreates the environment that was used to train the model and performs inference "
2359              "in the recreated environment."
2360          )
2361      else:
2362          _logger.info(
2363              f"This UDF will use {env_manager} to recreate the model's software environment for "
2364              "inference. This may take extra time during execution."
2365          )
2366          if not sys.platform.startswith("linux"):
2367              # TODO: support killing mlflow server launched in UDF task when spark job canceled
2368              #  for non-linux system.
2369              #  https://stackoverflow.com/questions/53208/how-do-i-automatically-destroy-child-processes-in-windows
2370              _logger.warning(
2371                  "In order to run inference code in restored python environment, PySpark UDF "
2372                  "processes spawn MLflow Model servers as child processes. Due to system "
2373                  "limitations with handling SIGKILL signals, these MLflow Model server child "
2374                  "processes cannot be cleaned up if the Spark Job is canceled."
2375              )
2376  
2377      if prebuilt_env_uri:
2378          env_cache_key = os.path.basename(prebuilt_env_uri)[:-7]
2379      elif use_dbconnect_artifact:
2380          env_cache_key = _gen_prebuilt_env_archive_name(spark, local_model_path)
2381      else:
2382          env_cache_key = None
2383  
2384      if use_dbconnect_artifact or prebuilt_env_uri is not None:
2385          prebuilt_env_root_dir = os.path.join(_PREBUILD_ENV_ROOT_LOCATION, env_cache_key)
2386          pyfunc_backend_env_root_config = {
2387              "create_env_root_dir": False,
2388              "env_root_dir": prebuilt_env_root_dir,
2389          }
2390      else:
2391          pyfunc_backend_env_root_config = {"create_env_root_dir": True}
2392      pyfunc_backend = get_flavor_backend(
2393          local_model_path,
2394          env_manager=env_manager,
2395          install_mlflow=os.environ.get("MLFLOW_HOME") is not None,
2396          **pyfunc_backend_env_root_config,
2397      )
2398      dbconnect_artifact_cache = DBConnectArtifactCache.get_or_create(spark)
2399  
2400      if use_dbconnect_artifact:
2401          # Upload model artifacts and python environment to NFS as DBConnect artifacts.
2402          if env_manager in (_EnvManager.VIRTUALENV, _EnvManager.UV):
2403              if not dbconnect_artifact_cache.has_cache_key(env_cache_key):
2404                  if prebuilt_env_uri:
2405                      env_archive_path = prebuilt_env_uri
2406                  else:
2407                      env_archive_path = _prebuild_env_internal(
2408                          local_model_path, env_cache_key, get_or_create_tmp_dir(), env_manager
2409                      )
2410                  dbconnect_artifact_cache.add_artifact_archive(env_cache_key, env_archive_path)
2411  
2412          if not dbconnect_artifact_cache.has_cache_key(model_uri):
2413              model_archive_path = os.path.join(
2414                  os.path.dirname(local_model_path), f"model-{uuid.uuid4()}.tar.gz"
2415              )
2416              archive_directory(local_model_path, model_archive_path)
2417              dbconnect_artifact_cache.add_artifact_archive(model_uri, model_archive_path)
2418  
2419      elif not should_use_spark_to_broadcast_file:
2420          if prebuilt_env_uri:
2421              # Extract prebuilt env archive file to NFS directory.
2422              prebuilt_env_nfs_dir = os.path.join(
2423                  get_or_create_nfs_tmp_dir(), "prebuilt_env", env_cache_key
2424              )
2425              if not os.path.exists(prebuilt_env_nfs_dir):
2426                  extract_archive_to_dir(prebuilt_env_uri, prebuilt_env_nfs_dir)
2427          else:
2428              # Prepare restored environment in driver side if possible.
2429              # Note: In databricks runtime, because databricks notebook cell output cannot capture
2430              # child process output, so that set capture_output to be True so that when `conda
2431              # prepare env` command failed, the exception message will include command stdout/stderr
2432              # output. Otherwise user have to check cluster driver log to find command stdout/stderr
2433              # output.
2434              # In non-databricks runtime, set capture_output to be False, because the benefit of
2435              # "capture_output=False" is the output will be printed immediately, otherwise you have
2436              # to wait conda command fail and suddenly get all output printed (included in error
2437              # message).
2438              if env_manager != _EnvManager.LOCAL:
2439                  pyfunc_backend.prepare_env(
2440                      model_uri=local_model_path, capture_output=is_in_databricks_runtime()
2441                  )
2442      else:
2443          # Broadcast local model directory to remote worker if needed.
2444          archive_path = SparkModelCache.add_local_model(spark, local_model_path)
2445  
2446      model_metadata = Model.load(os.path.join(local_model_path, MLMODEL_FILE_NAME))
2447  
2448      if result_type is None:
2449          if model_output_schema := model_metadata.get_output_schema():
2450              result_type = _infer_spark_udf_return_type(model_output_schema)
2451          else:
2452              _logger.warning(
2453                  "No 'result_type' provided for spark_udf and the model does not "
2454                  "have an output schema. 'result_type' is set to 'double' type."
2455              )
2456              result_type = DoubleType()
2457      else:
2458          if isinstance(result_type, str):
2459              result_type = _parse_spark_datatype(result_type)
2460  
2461          # if result type is inferred by MLflow, we don't need to check it
2462          if not _check_udf_return_type(result_type):
2463              raise MlflowException.invalid_parameter_value(
2464                  f"""Invalid 'spark_udf' result type: {result_type}.
2465  It must be one of the following types:
2466  Primitive types:
2467  - int
2468  - long
2469  - float
2470  - double
2471  - string
2472  - boolean
2473  Compound types:
2474  - ND array of primitives / structs.
2475  - struct<field: primitive | array<primitive> | array<array<primitive>>, ...>:
2476  A struct with primitive, ND array<primitive/structs>,
2477  e.g., struct<a:int, b:array<int>>.
2478  """
2479              )
2480      params = _validate_params(params, model_metadata)
2481  
2482      def _predict_row_batch(predict_fn, args):
2483          input_schema = model_metadata.get_input_schema()
2484          args = list(args)
2485          if len(args) == 1 and isinstance(args[0], pandas.DataFrame):
2486              pdf = args[0]
2487          else:
2488              if input_schema is None:
2489                  names = [str(i) for i in range(len(args))]
2490              else:
2491                  names = input_schema.input_names()
2492                  required_names = input_schema.required_input_names()
2493                  if len(args) > len(names):
2494                      args = args[: len(names)]
2495                  if len(args) < len(required_names):
2496                      raise MlflowException(
2497                          f"Model input is missing required columns. Expected {len(names)} required"
2498                          f" input columns {names}, but the model received only {len(args)} "
2499                          "unnamed input columns (Since the columns were passed unnamed they are"
2500                          " expected to be in the order specified by the schema)."
2501                      )
2502              pdf = pandas.DataFrame(
2503                  data={
2504                      names[i]: arg
2505                      if isinstance(arg, pandas.Series)
2506                      # pandas_udf receives a StructType column as a pandas DataFrame.
2507                      # We need to convert it back to a dict of pandas Series.
2508                      else arg.apply(lambda row: row.to_dict(), axis=1)
2509                      for i, arg in enumerate(args)
2510                  },
2511                  columns=names,
2512              )
2513  
2514          result = predict_fn(pdf, params)
2515  
2516          if isinstance(result, dict):
2517              result = {k: list(v) for k, v in result.items()}
2518  
2519          if isinstance(result_type, ArrayType) and isinstance(result_type.elementType, ArrayType):
2520              result_values = _convert_array_values(result, result_type)
2521              return pandas.Series(result_values)
2522  
2523          if isinstance(result_type, SparkStructType):
2524              if (
2525                  isinstance(result, list)
2526                  and len(result) > 0
2527                  and isinstance(result[0], pydantic.BaseModel)
2528              ):
2529                  result = pandas.DataFrame([r.model_dump() for r in result])
2530              else:
2531                  result = pandas.DataFrame(result)
2532              return _convert_struct_values(result, result_type)
2533  
2534          if not isinstance(result, pandas.DataFrame):
2535              if isinstance(result_type, MapType):
2536                  # list of dicts should be converted into a single column
2537                  result = pandas.DataFrame([result])
2538              else:
2539                  result = (
2540                      pandas.DataFrame([result]) if np.isscalar(result) else pandas.DataFrame(result)
2541                  )
2542  
2543          elem_type = result_type.elementType if isinstance(result_type, ArrayType) else result_type
2544          if type(elem_type) == IntegerType:
2545              result = result.select_dtypes([
2546                  np.byte,
2547                  np.ubyte,
2548                  np.short,
2549                  np.ushort,
2550                  np.int32,
2551              ]).astype(np.int32)
2552  
2553          elif type(elem_type) == LongType:
2554              result = result.select_dtypes([np.byte, np.ubyte, np.short, np.ushort, int]).astype(
2555                  np.int64
2556              )
2557  
2558          elif type(elem_type) == FloatType:
2559              result = result.select_dtypes(include=(np.number,)).astype(np.float32)
2560  
2561          elif type(elem_type) == DoubleType:
2562              result = result.select_dtypes(include=(np.number,)).astype(np.float64)
2563  
2564          elif type(elem_type) == BooleanType:
2565              result = result.select_dtypes([bool, np.bool_]).astype(bool)
2566  
2567          if len(result.columns) == 0:
2568              raise MlflowException(
2569                  message="The model did not produce any values compatible with the requested "
2570                  f"type '{elem_type}'. Consider requesting udf with StringType or "
2571                  "Arraytype(StringType).",
2572                  error_code=INVALID_PARAMETER_VALUE,
2573              )
2574  
2575          if type(elem_type) == StringType:
2576              if Version(pandas.__version__) >= Version("2.1.0"):
2577                  result = result.map(str)
2578              else:
2579                  result = result.applymap(str)
2580  
2581          if type(result_type) == ArrayType:
2582              return pandas.Series(result.to_numpy().tolist())
2583          else:
2584              return result[result.columns[0]]
2585  
2586      result_type_hint = (
2587          pandas.DataFrame if isinstance(result_type, SparkStructType) else pandas.Series
2588      )
2589  
2590      tracking_uri = mlflow.get_tracking_uri()
2591  
2592      enforce_stdin_scoring_server = MLFLOW_ENFORCE_STDIN_SCORING_SERVER_FOR_SPARK_UDF.get()
2593  
2594      @pandas_udf(result_type)
2595      def udf(
2596          # `pandas_udf` does not support modern type annotations
2597          iterator: Iterator[Tuple[Union[pandas.Series, pandas.DataFrame], ...]],  # noqa: UP006,UP007
2598      ) -> Iterator[result_type_hint]:
2599          # importing here to prevent circular import
2600          from mlflow.pyfunc.scoring_server.client import (
2601              ScoringServerClient,
2602              StdinScoringServerClient,
2603          )
2604  
2605          # Note: this is a pandas udf function in iteration style, which takes an iterator of
2606          # tuple of pandas.Series and outputs an iterator of pandas.Series.
2607          update_envs = {}
2608          if mlflow_home is not None:
2609              update_envs["MLFLOW_HOME"] = mlflow_home
2610          if openai_env_vars:
2611              update_envs.update(openai_env_vars)
2612          if mlflow_testing:
2613              update_envs[_MLFLOW_TESTING.name] = mlflow_testing
2614          if extra_env:
2615              update_envs.update(extra_env)
2616  
2617          #  use `modified_environ` to temporarily set the envs and restore them finally
2618          with modified_environ(update=update_envs):
2619              scoring_server_proc = None
2620              # set tracking_uri inside udf so that with spark_connect
2621              # we can load the model from correct path
2622              mlflow.set_tracking_uri(tracking_uri)
2623  
2624              if env_manager != _EnvManager.LOCAL:
2625                  if use_dbconnect_artifact:
2626                      local_model_path_on_executor = (
2627                          dbconnect_artifact_cache.get_unpacked_artifact_dir(model_uri)
2628                      )
2629                      env_src_dir = dbconnect_artifact_cache.get_unpacked_artifact_dir(env_cache_key)
2630  
2631                      # Create symlink if it does not exist
2632                      if not os.path.exists(prebuilt_env_root_dir):
2633                          os.symlink(env_src_dir, prebuilt_env_root_dir)
2634                  elif prebuilt_env_uri is not None:
2635                      # prebuilt env is extracted to `prebuilt_env_nfs_dir` directory,
2636                      # and model is downloaded to `local_model_path` which points to an NFS
2637                      # directory too.
2638                      local_model_path_on_executor = None
2639  
2640                      # Create symlink if it does not exist
2641                      if not os.path.exists(prebuilt_env_root_dir):
2642                          os.symlink(prebuilt_env_nfs_dir, prebuilt_env_root_dir)
2643                  elif should_use_spark_to_broadcast_file:
2644                      local_model_path_on_executor = _SparkDirectoryDistributor.get_or_extract(
2645                          archive_path
2646                      )
2647                      # Call "prepare_env" in advance in order to reduce scoring server launch time.
2648                      # So that we can use a shorter timeout when call `client.wait_server_ready`,
2649                      # otherwise we have to set a long timeout for `client.wait_server_ready` time,
2650                      # this prevents spark UDF task failing fast if other exception raised
2651                      # when scoring server launching.
2652                      # Set "capture_output" so that if "conda env create" command failed, the command
2653                      # stdout/stderr output will be attached to the exception message and included in
2654                      # driver side exception.
2655                      pyfunc_backend.prepare_env(
2656                          model_uri=local_model_path_on_executor, capture_output=True
2657                      )
2658                  else:
2659                      local_model_path_on_executor = None
2660  
2661                  if not enforce_stdin_scoring_server and check_port_connectivity():
2662                      # launch scoring server
2663                      server_port = find_free_port()
2664                      host = "127.0.0.1"
2665                      scoring_server_proc = pyfunc_backend.serve(
2666                          model_uri=local_model_path_on_executor or local_model_path,
2667                          port=server_port,
2668                          host=host,
2669                          timeout=MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT.get(),
2670                          enable_mlserver=False,
2671                          synchronous=False,
2672                          stdout=subprocess.PIPE,
2673                          stderr=subprocess.STDOUT,
2674                          model_config=model_config,
2675                      )
2676  
2677                      client = ScoringServerClient(host, server_port)
2678                  else:
2679                      scoring_server_proc = pyfunc_backend.serve_stdin(
2680                          model_uri=local_model_path_on_executor or local_model_path,
2681                          stdout=subprocess.PIPE,
2682                          stderr=subprocess.STDOUT,
2683                          model_config=model_config,
2684                      )
2685                      client = StdinScoringServerClient(scoring_server_proc)
2686  
2687                  _logger.info("Using %s", client.__class__.__name__)
2688  
2689                  server_tail_logs = collections.deque(
2690                      maxlen=_MLFLOW_SERVER_OUTPUT_TAIL_LINES_TO_KEEP
2691                  )
2692  
2693                  def server_redirect_log_thread_func(child_stdout):
2694                      for line in child_stdout:
2695                          decoded = line.decode() if isinstance(line, bytes) else line
2696                          server_tail_logs.append(decoded)
2697                          sys.stdout.write("[model server] " + decoded)
2698  
2699                  server_redirect_log_thread = threading.Thread(
2700                      target=server_redirect_log_thread_func,
2701                      args=(scoring_server_proc.stdout,),
2702                      daemon=True,
2703                      name=f"mlflow_pyfunc_model_server_log_redirector_{uuid.uuid4().hex[:8]}",
2704                  )
2705                  server_redirect_log_thread.start()
2706  
2707                  try:
2708                      client.wait_server_ready(timeout=90, scoring_server_proc=scoring_server_proc)
2709                  except Exception as e:
2710                      err_msg = (
2711                          "During spark UDF task execution, mlflow model server failed to launch. "
2712                      )
2713                      if len(server_tail_logs) == _MLFLOW_SERVER_OUTPUT_TAIL_LINES_TO_KEEP:
2714                          err_msg += (
2715                              f"Last {_MLFLOW_SERVER_OUTPUT_TAIL_LINES_TO_KEEP} "
2716                              "lines of MLflow model server output:\n"
2717                          )
2718                      else:
2719                          err_msg += "MLflow model server output:\n"
2720                      err_msg += "".join(server_tail_logs)
2721                      raise MlflowException(err_msg) from e
2722  
2723                  def batch_predict_fn(pdf, params=None):
2724                      if "params" in inspect.signature(client.invoke).parameters:
2725                          return client.invoke(pdf, params=params).get_predictions()
2726                      _log_warning_if_params_not_in_predict_signature(_logger, params)
2727                      return client.invoke(pdf).get_predictions()
2728  
2729              elif env_manager == _EnvManager.LOCAL:
2730                  if use_dbconnect_artifact:
2731                      model_path = dbconnect_artifact_cache.get_unpacked_artifact_dir(model_uri)
2732                      loaded_model = mlflow.pyfunc.load_model(model_path, model_config=model_config)
2733                  elif is_spark_connect and not should_spark_connect_use_nfs:
2734                      model_path = os.path.join(
2735                          tempfile.gettempdir(),
2736                          "mlflow",
2737                          hashlib.sha1(model_uri.encode(), usedforsecurity=False).hexdigest(),
2738                          # Use pid to avoid conflict when multiple spark UDF tasks
2739                          str(os.getpid()),
2740                      )
2741                      try:
2742                          loaded_model = mlflow.pyfunc.load_model(
2743                              model_path, model_config=model_config
2744                          )
2745                      except Exception:
2746                          os.makedirs(model_path, exist_ok=True)
2747                          loaded_model = mlflow.pyfunc.load_model(
2748                              model_uri, dst_path=model_path, model_config=model_config
2749                          )
2750                  elif should_use_spark_to_broadcast_file:
2751                      loaded_model, _ = SparkModelCache.get_or_load(archive_path)
2752                  else:
2753                      loaded_model = mlflow.pyfunc.load_model(
2754                          local_model_path, model_config=model_config
2755                      )
2756  
2757                  def batch_predict_fn(pdf, params=None):
2758                      if "params" in inspect.signature(loaded_model.predict).parameters:
2759                          return loaded_model.predict(pdf, params=params)
2760                      _log_warning_if_params_not_in_predict_signature(_logger, params)
2761                      return loaded_model.predict(pdf)
2762  
2763              try:
2764                  for input_batch in iterator:
2765                      # If the UDF is called with only multiple arguments,
2766                      # the `input_batch` is a tuple which composes of several pd.Series/pd.DataFrame
2767                      # objects.
2768                      # If the UDF is called with only one argument,
2769                      # the `input_batch` instance will be an instance of `pd.Series`/`pd.DataFrame`,
2770                      if isinstance(input_batch, (pandas.Series, pandas.DataFrame)):
2771                          # UDF is called with only one argument
2772                          row_batch_args = (input_batch,)
2773                      else:
2774                          row_batch_args = input_batch
2775  
2776                      if len(row_batch_args[0]) > 0:
2777                          yield _predict_row_batch(batch_predict_fn, row_batch_args)
2778              except SystemError as e:
2779                  if "error return without exception set" in str(e):
2780                      raise MlflowException(
2781                          "A system error related to the Python C extension has occurred. "
2782                          "This is usually caused by an incompatible Python library that uses the "
2783                          "C extension. To address this, we recommend you to log the model "
2784                          "with fixed version python libraries that use the C extension "
2785                          "(such as 'numpy' library), and set spark_udf `env_manager` argument "
2786                          "to 'virtualenv' or 'uv' so that spark_udf can restore the original "
2787                          "python library version before running model inference."
2788                      ) from e
2789              finally:
2790                  if scoring_server_proc is not None:
2791                      os.kill(scoring_server_proc.pid, signal.SIGTERM)
2792  
2793      udf.metadata = model_metadata
2794  
2795      @functools.wraps(udf)
2796      def udf_with_default_cols(*args):
2797          if len(args) == 0:
2798              input_schema = model_metadata.get_input_schema()
2799              if input_schema and len(input_schema.optional_input_names()) > 0:
2800                  raise MlflowException(
2801                      message="Cannot apply UDF without column names specified when"
2802                      " model signature contains optional columns.",
2803                      error_code=INVALID_PARAMETER_VALUE,
2804                  )
2805              if input_schema and len(input_schema.inputs) > 0:
2806                  if input_schema.has_input_names():
2807                      input_names = input_schema.input_names()
2808                      return udf(*input_names)
2809                  else:
2810                      raise MlflowException(
2811                          message="Cannot apply udf because no column names specified. The udf "
2812                          f"expects {len(input_schema.inputs)} columns with types: "
2813                          "{input_schema.inputs}. Input column names could not be inferred from the"
2814                          " model signature (column names not found).",
2815                          error_code=INVALID_PARAMETER_VALUE,
2816                      )
2817              else:
2818                  raise MlflowException(
2819                      "Attempting to apply udf on zero columns because no column names were "
2820                      "specified as arguments or inferred from the model signature.",
2821                      error_code=INVALID_PARAMETER_VALUE,
2822                  )
2823          else:
2824              return udf(*args)
2825  
2826      return udf_with_default_cols
2827  
2828  
2829  def _validate_function_python_model(python_model):
2830      if not (isinstance(python_model, PythonModel) or callable(python_model)):
2831          raise MlflowException(
2832              "`python_model` must be a PythonModel instance, callable object, or path to a script "
2833              "that uses set_model() to set a PythonModel instance or callable object.",
2834              error_code=INVALID_PARAMETER_VALUE,
2835          )
2836  
2837      if callable(python_model):
2838          num_args = len(inspect.signature(python_model).parameters)
2839          if num_args != 1:
2840              raise MlflowException(
2841                  "When `python_model` is a callable object, it must accept exactly one argument. "
2842                  f"Found {num_args} arguments.",
2843                  error_code=INVALID_PARAMETER_VALUE,
2844              )
2845  
2846  
2847  @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn"))
2848  @trace_disabled  # Suppress traces for internal predict calls while saving model
2849  def save_model(
2850      path,
2851      loader_module=None,
2852      data_path=None,
2853      code_paths=None,
2854      infer_code_paths=False,
2855      conda_env=None,
2856      mlflow_model=None,
2857      python_model=None,
2858      artifacts=None,
2859      signature: ModelSignature = None,
2860      input_example: ModelInputExample = None,
2861      pip_requirements=None,
2862      extra_pip_requirements=None,
2863      metadata=None,
2864      model_config=None,
2865      streamable=None,
2866      resources: str | list[Resource] | None = None,
2867      auth_policy: AuthPolicy | None = None,
2868      uv_project_path: str | Path | None = None,
2869      uv_groups: list[str] | None = None,
2870      uv_extras: list[str] | None = None,
2871      **kwargs,
2872  ):
2873      """
2874      Save a Pyfunc model with custom inference logic and optional data dependencies to a path on the
2875      local filesystem.
2876  
2877      For information about the workflows that this method supports, please see :ref:`"workflows for
2878      creating custom pyfunc models" <pyfunc-create-custom-workflows>` and
2879      :ref:`"which workflow is right for my use case?" <pyfunc-create-custom-selecting-workflow>`.
2880      Note that the parameters for the second workflow: ``loader_module``, ``data_path`` and the
2881      parameters for the first workflow: ``python_model``, ``artifacts``, cannot be
2882      specified together.
2883  
2884      Args:
2885          path: The path to which to save the Python model.
2886          loader_module: The name of the Python module that is used to load the model
2887              from ``data_path``. This module must define a method with the prototype
2888              ``_load_pyfunc(data_path)``. If not ``None``, this module and its
2889              dependencies must be included in one of the following locations:
2890  
2891              - The MLflow library.
2892              - Package(s) listed in the model's Conda environment, specified by
2893                the ``conda_env`` parameter.
2894              - One or more of the files specified by the ``code_paths`` parameter.
2895  
2896          data_path: Path to a file or directory containing model data.
2897          code_paths: {{ code_paths_pyfunc }}
2898          infer_code_paths: {{ infer_code_paths }}
2899          conda_env: {{ conda_env }}
2900          mlflow_model: :py:mod:`mlflow.models.Model` configuration to which to add the
2901              **python_function** flavor.
2902          python_model:
2903              A file path to the PythonModel
2904              which defines the model from code artifact,
2905              (recommended), see https://mlflow.org/docs/latest/ml/model/models-from-code/
2906              for details;
2907              or an instance of a subclass of :class:`~PythonModel` or a callable object with a single
2908              argument (see the examples below), the passed-in object is serialized using the
2909              CloudPickle library, it requires exercising caution because these formats rely on
2910              Python's object serialization mechanism, which can execute arbitrary code during
2911              deserialization.
2912              Any dependencies of the class should be included in one of the
2913              following locations:
2914  
2915              - The MLflow library.
2916              - Package(s) listed in the model's Conda environment, specified by the ``conda_env``
2917                parameter.
2918              - One or more of the files specified by the ``code_paths`` parameter.
2919  
2920              Note: If the class is imported from another module, as opposed to being defined in the
2921              ``__main__`` scope, the defining module should also be included in one of the listed
2922              locations.
2923  
2924              **Examples**
2925  
2926              Class model
2927  
2928              .. code-block:: python
2929  
2930                  from typing import List, Dict
2931                  import mlflow
2932  
2933  
2934                  class MyModel(mlflow.pyfunc.PythonModel):
2935                      def predict(self, context, model_input: List[str], params=None) -> List[str]:
2936                          return [i.upper() for i in model_input]
2937  
2938  
2939                  mlflow.pyfunc.save_model("model", python_model=MyModel(), input_example=["a"])
2940                  model = mlflow.pyfunc.load_model("model")
2941                  print(model.predict(["a", "b", "c"]))  # -> ["A", "B", "C"]
2942  
2943              Functional model
2944  
2945              .. note::
2946                  Experimental: Functional model support is experimental and may change or be removed
2947                  in a future release without warning.
2948  
2949              .. code-block:: python
2950  
2951                  from typing import List
2952                  import mlflow
2953  
2954  
2955                  def predict(model_input: List[str]) -> List[str]:
2956                      return [i.upper() for i in model_input]
2957  
2958  
2959                  mlflow.pyfunc.save_model("model", python_model=predict, input_example=["a"])
2960                  model = mlflow.pyfunc.load_model("model")
2961                  print(model.predict(["a", "b", "c"]))  # -> ["A", "B", "C"]
2962  
2963              Model from code
2964  
2965              .. note::
2966                  Experimental: Model from code model support is experimental and may change or
2967                  be removed in a future release without warning.
2968  
2969              .. code-block:: python
2970  
2971                  # code.py
2972                  from typing import List
2973                  import mlflow
2974  
2975  
2976                  class MyModel(mlflow.pyfunc.PythonModel):
2977                      def predict(self, context, model_input: List[str], params=None) -> List[str]:
2978                          return [i.upper() for i in model_input]
2979  
2980  
2981                  mlflow.models.set_model(MyModel())
2982  
2983                  # log_model.py
2984                  import mlflow
2985  
2986                  with mlflow.start_run():
2987                      model_info = mlflow.pyfunc.log_model(
2988                          name="model",
2989                          python_model="code.py",
2990                      )
2991  
2992              If the `predict` method or function has type annotations, MLflow automatically
2993              constructs a model signature based on the type annotations (unless the ``signature``
2994              argument is explicitly specified), and converts the input value to the specified type
2995              before passing it to the function. Currently, the following type annotations are
2996              supported:
2997  
2998                  - ``List[str]``
2999                  - ``List[Dict[str, str]]``
3000  
3001          artifacts: A dictionary containing ``<name, artifact_uri>`` entries. Remote artifact URIs
3002              are resolved to absolute filesystem paths, producing a dictionary of
3003              ``<name, absolute_path>`` entries. ``python_model`` can reference these
3004              resolved entries as the ``artifacts`` property of the ``context`` parameter
3005              in :func:`PythonModel.load_context() <mlflow.pyfunc.PythonModel.load_context>`
3006              and :func:`PythonModel.predict() <mlflow.pyfunc.PythonModel.predict>`.
3007              For example, consider the following ``artifacts`` dictionary::
3008  
3009                  {"my_file": "s3://my-bucket/path/to/my/file"}
3010  
3011              In this case, the ``"my_file"`` artifact is downloaded from S3. The
3012              ``python_model`` can then refer to ``"my_file"`` as an absolute filesystem
3013              path via ``context.artifacts["my_file"]``.
3014  
3015              If ``None``, no artifacts are added to the model.
3016  
3017          signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
3018              describes model input and output :py:class:`Schema <mlflow.types.Schema>`.
3019              The model signature can be :py:func:`inferred <mlflow.models.infer_signature>`
3020              from datasets with valid model input (e.g. the training dataset with target
3021              column omitted) and valid model output (e.g. model predictions generated on
3022              the training dataset), for example:
3023  
3024              .. code-block:: python
3025  
3026                  from mlflow.models import infer_signature
3027  
3028                  train = df.drop_column("target_label")
3029                  predictions = ...  # compute model predictions
3030                  signature = infer_signature(train, predictions)
3031          input_example: {{ input_example }}
3032          pip_requirements: {{ pip_requirements }}
3033          extra_pip_requirements: {{ extra_pip_requirements }}
3034          metadata: {{ metadata }}
3035          model_config: The model configuration to apply to the model. The configuration will
3036              be available as the ``model_config`` property of the ``context`` parameter
3037              in :func:`PythonModel.load_context() <mlflow.pyfunc.PythonModel.load_context>`
3038              and :func:`PythonModel.predict() <mlflow.pyfunc.PythonModel.predict>`.
3039              The configuration can be passed as a file path, or a dict with string keys.
3040  
3041              .. Note:: Experimental: This parameter may change or be removed in a future
3042                                      release without warning.
3043          streamable: A boolean value indicating if the model supports streaming prediction,
3044                      If None, MLflow will try to inspect if the model supports streaming
3045                      by checking if `predict_stream` method exists. Default None.
3046          resources: A list of model resources or a resources.yaml file containing a list of
3047                      resources required to serve the model.
3048  
3049              .. Note:: Experimental: This parameter may change or be removed in a future
3050                                      release without warning.
3051          auth_policy: {{ auth_policy }}
3052          uv_project_path: Explicit path to the uv project directory containing uv.lock,
3053              pyproject.toml, and optionally .python-version. This is useful for monorepos
3054              or non-standard project layouts where the uv project is not in the current
3055              working directory. If ``None``, MLflow will auto-detect uv.lock, pyproject.toml,
3056              and .python-version files in the current working directory.
3057  
3058              When a uv project is detected (either via this parameter or auto-detection),
3059              pip requirements are generated by running ``uv export`` against the lockfile
3060              instead of inferring dependencies by capturing imported packages during model
3061              inference.
3062  
3063              Auto-detection can be disabled by setting the environment variable
3064              ``MLFLOW_UV_AUTO_DETECT=false``.
3065  
3066              .. Note:: Experimental: This parameter may change or be removed in a future
3067                                      release without warning.
3068          uv_groups: Optional list of uv dependency groups to include when exporting
3069              requirements from the uv lockfile. Maps to ``uv export --group <name>``.
3070              These are additive with the project's default dependencies.
3071  
3072              .. Note:: Experimental: This parameter may change or be removed in a future
3073                                      release without warning.
3074          uv_extras: Optional list of uv extras (optional dependency sets) to include
3075              when exporting requirements from the uv lockfile. Maps to
3076              ``uv export --extra <name>``.
3077  
3078              .. Note:: Experimental: This parameter may change or be removed in a future
3079                                      release without warning.
3080          kwargs: Extra keyword arguments.
3081      """
3082      if (
3083          python_model is not None
3084          and not isinstance(python_model, (Path, str))
3085          and not is_in_databricks_runtime()
3086      ):
3087          _logger.warning(
3088              "Passing a Python object as `python_model` causes it to be serialized "
3089              "using CloudPickle, "
3090              "it requires exercising caution as Python object serialization mechanisms may "
3091              "execute arbitrary code during deserialization."
3092              "Consider using a file path (str or Path) instead. See "
3093              "https://mlflow.org/docs/latest/ml/model/models-from-code/ for details."
3094          )
3095  
3096      _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements)
3097      _validate_pyfunc_model_config(model_config)
3098      _validate_and_prepare_target_save_path(path)
3099  
3100      with tempfile.TemporaryDirectory() as temp_dir:
3101          model_code_path = None
3102          if python_model:
3103              if isinstance(model_config, Path):
3104                  model_config = os.fspath(model_config)
3105  
3106              if isinstance(model_config, str):
3107                  model_config = _validate_and_get_model_config_from_file(model_config)
3108  
3109              if isinstance(python_model, Path):
3110                  python_model = os.fspath(python_model)
3111  
3112              if isinstance(python_model, str):
3113                  model_code_path = _validate_and_get_model_code_path(python_model, temp_dir)
3114                  _validate_and_copy_file_to_directory(model_code_path, path, "code")
3115                  python_model = _load_model_code_path(model_code_path, model_config)
3116  
3117              _validate_function_python_model(python_model)
3118              if callable(python_model) and all(
3119                  a is None for a in (input_example, pip_requirements, extra_pip_requirements)
3120              ):
3121                  raise MlflowException(
3122                      "If `python_model` is a callable object, at least one of `input_example`, "
3123                      "`pip_requirements`, or `extra_pip_requirements` must be specified."
3124                  )
3125  
3126      mlflow_model = kwargs.pop("model", mlflow_model)
3127      if len(kwargs) > 0:
3128          raise TypeError(f"save_model() got unexpected keyword arguments: {kwargs}")
3129  
3130      if code_paths is not None:
3131          if not isinstance(code_paths, list):
3132              raise TypeError(f"Argument code_paths should be a list, not {type(code_paths)}")
3133  
3134      first_argument_set = {
3135          "loader_module": loader_module,
3136          "data_path": data_path,
3137      }
3138      second_argument_set = {
3139          "artifacts": artifacts,
3140          "python_model": python_model,
3141      }
3142      first_argument_set_specified = any(item is not None for item in first_argument_set.values())
3143      second_argument_set_specified = any(item is not None for item in second_argument_set.values())
3144      if first_argument_set_specified and second_argument_set_specified:
3145          raise MlflowException(
3146              message=(
3147                  f"The following sets of parameters cannot be specified together:"
3148                  f" {first_argument_set.keys()}  and {second_argument_set.keys()}."
3149                  " All parameters in one set must be `None`. Instead, found"
3150                  f" the following values: {first_argument_set} and {second_argument_set}"
3151              ),
3152              error_code=INVALID_PARAMETER_VALUE,
3153          )
3154      elif (loader_module is None) and (python_model is None):
3155          msg = (
3156              "Either `loader_module` or `python_model` must be specified. A `loader_module` "
3157              "should be a python module. A `python_model` should be a subclass of PythonModel"
3158          )
3159          raise MlflowException(message=msg, error_code=INVALID_PARAMETER_VALUE)
3160      if mlflow_model is None:
3161          mlflow_model = Model()
3162      saved_example = None
3163      signature_from_type_hints = None
3164      type_hint_from_example = None
3165      if isinstance(python_model, ChatModel):
3166          if signature is not None:
3167              raise MlflowException(
3168                  "ChatModel subclasses have a standard signature that is set "
3169                  "automatically. Please remove the `signature` parameter from "
3170                  "the call to log_model() or save_model().",
3171                  error_code=INVALID_PARAMETER_VALUE,
3172              )
3173          mlflow_model.signature = ModelSignature(
3174              CHAT_MODEL_INPUT_SCHEMA,
3175              CHAT_MODEL_OUTPUT_SCHEMA,
3176          )
3177          # For ChatModel we set default metadata to indicate its task
3178          default_metadata = {TASK: _DEFAULT_CHAT_MODEL_METADATA_TASK}
3179          mlflow_model.metadata = default_metadata | (mlflow_model.metadata or {})
3180  
3181          if input_example:
3182              input_example, input_params = _split_input_data_and_params(input_example)
3183              valid_params = {}
3184              if isinstance(input_example, list):
3185                  messages = [
3186                      message if isinstance(message, ChatMessage) else ChatMessage.from_dict(message)
3187                      for message in input_example
3188                  ]
3189              else:
3190                  # If the input example is a dictionary, convert it to ChatMessage format
3191                  messages = [
3192                      ChatMessage.from_dict(m) if isinstance(m, dict) else m
3193                      for m in input_example["messages"]
3194                  ]
3195                  valid_params = {
3196                      k: v
3197                      for k, v in input_example.items()
3198                      if k != "messages" and k in ChatParams.keys()
3199                  }
3200              if valid_params or input_params:
3201                  _logger.warning(_CHAT_PARAMS_WARNING_MESSAGE)
3202              input_example = {
3203                  "messages": [m.to_dict() for m in messages],
3204                  **valid_params,
3205                  **(input_params or {}),
3206              }
3207          else:
3208              input_example = CHAT_MODEL_INPUT_EXAMPLE
3209              _logger.warning(_CHAT_PARAMS_WARNING_MESSAGE)
3210              messages = [ChatMessage.from_dict(m) for m in input_example["messages"]]
3211          # extra params introduced by ChatParams will not be included in the
3212          # logged input example file to avoid confusion
3213          _save_example(mlflow_model, input_example, path)
3214          params = ChatParams.from_dict(input_example)
3215  
3216          # call load_context() first, as predict may depend on it
3217          _logger.info("Predicting on input example to validate output")
3218          context = PythonModelContext(artifacts, model_config)
3219          python_model.load_context(context)
3220          if "context" in inspect.signature(python_model.predict).parameters:
3221              output = python_model.predict(context, messages, params)
3222          else:
3223              output = python_model.predict(messages, params)
3224          if not isinstance(output, ChatCompletionResponse):
3225              raise MlflowException(
3226                  "Failed to save ChatModel. Please ensure that the model's predict() method "
3227                  "returns a ChatCompletionResponse object. If your predict() method currently "
3228                  "returns a dict, you can instantiate a ChatCompletionResponse using "
3229                  "`from_dict()`, e.g. `ChatCompletionResponse.from_dict(output)`",
3230              )
3231      elif isinstance(python_model, ChatAgent):
3232          input_example = _save_model_chat_agent_helper(
3233              python_model, mlflow_model, signature, input_example, artifacts, model_config
3234          )
3235      elif IS_RESPONSES_AGENT_AVAILABLE and isinstance(python_model, ResponsesAgent):
3236          input_example = _save_model_responses_agent_helper(
3237              python_model, mlflow_model, signature, input_example, artifacts, model_config
3238          )
3239      elif callable(python_model) or isinstance(python_model, PythonModel):
3240          model_for_signature_inference = None
3241          if callable(python_model):
3242              # first argument is the model input
3243              type_hints = _extract_type_hints(python_model, input_arg_index=0)
3244              pyfunc_decorator_used = getattr(python_model, "_is_pyfunc", False)
3245              # only show the warning here if @pyfunc is not applied on the function
3246              # since @pyfunc will trigger the warning instead
3247              if type_hints.input is None and not pyfunc_decorator_used:
3248                  color_warning(
3249                      "Add type hints to the `predict` method to enable "
3250                      "data validation and automatic signature inference. Check "
3251                      "https://mlflow.org/docs/latest/model/python_model.html#type-hint-usage-in-pythonmodel"
3252                      " for more details.",
3253                      stacklevel=1,
3254                      color="yellow",
3255                  )
3256              model_for_signature_inference = _FunctionPythonModel(python_model)
3257          elif isinstance(python_model, PythonModel):
3258              type_hints = python_model.predict_type_hints
3259              model_for_signature_inference = python_model
3260          context = PythonModelContext(artifacts, model_config)
3261          type_hint_from_example = _is_type_hint_from_example(type_hints.input)
3262          if type_hint_from_example:
3263              should_infer_signature_from_type_hints = False
3264          else:
3265              should_infer_signature_from_type_hints = (
3266                  not _signature_cannot_be_inferred_from_type_hint(type_hints.input)
3267              )
3268              if should_infer_signature_from_type_hints:
3269                  # context is only loaded when input_example exists
3270                  signature_from_type_hints = _infer_signature_from_type_hints(
3271                      python_model=python_model,
3272                      context=context,
3273                      type_hints=type_hints,
3274                      input_example=input_example,
3275                  )
3276          # only infer signature based on input example when signature
3277          # and type hints are not provided
3278          if signature is None and signature_from_type_hints is None:
3279              saved_example = _save_example(mlflow_model, input_example, path)
3280              if saved_example is not None:
3281                  _logger.info("Inferring model signature from input example")
3282                  try:
3283                      model_for_signature_inference.load_context(context)
3284                      mlflow_model.signature = _infer_signature_from_input_example(
3285                          saved_example,
3286                          _PythonModelPyfuncWrapper(model_for_signature_inference, context, None),
3287                      )
3288                  except Exception as e:
3289                      _logger.warning(
3290                          f"Failed to infer model signature from input example, error: {e}",
3291                      )
3292                  else:
3293                      if type_hint_from_example and mlflow_model.signature:
3294                          update_signature_for_type_hint_from_example(
3295                              input_example, mlflow_model.signature
3296                          )
3297              else:
3298                  if type_hint_from_example:
3299                      _logger.warning(
3300                          _TYPE_FROM_EXAMPLE_ERROR_MESSAGE,
3301                          extra={"color": "red"},
3302                      )
3303                  # if signature is inferred from type hints, warnings are emitted
3304                  # in _infer_signature_from_type_hints
3305                  elif not should_infer_signature_from_type_hints:
3306                      _logger.warning(
3307                          "Failed to infer model signature: "
3308                          f"Type hint {type_hints} cannot be used to infer model signature and "
3309                          "input example is not provided, model signature cannot be inferred."
3310                      )
3311  
3312      if metadata is not None:
3313          mlflow_model.metadata = metadata
3314      if saved_example is None:
3315          saved_example = _save_example(mlflow_model, input_example, path)
3316  
3317      if signature_from_type_hints:
3318          if signature and signature_from_type_hints != signature:
3319              # TODO: drop this support and raise exception in the next minor release since this
3320              # is a behavior change
3321              _logger.warning(
3322                  "Provided signature does not match the signature inferred from the Python model's "
3323                  "`predict` function type hint. Signature inferred from type hint will be used:\n"
3324                  f"{signature_from_type_hints}\nRemove the `signature` parameter or ensure it "
3325                  "matches the inferred signature. In a future release, this warning will become an "
3326                  "exception, and the signature must align with the type hint.",
3327                  extra={"color": "red"},
3328              )
3329          mlflow_model.signature = signature_from_type_hints
3330      elif signature:
3331          mlflow_model.signature = signature
3332          if type_hint_from_example:
3333              if saved_example is None:
3334                  _logger.warning(
3335                      _TYPE_FROM_EXAMPLE_ERROR_MESSAGE,
3336                      extra={"color": "red"},
3337                  )
3338              else:
3339                  # TODO: validate input example against signature
3340                  update_signature_for_type_hint_from_example(input_example, mlflow_model.signature)
3341          else:
3342              if saved_example is None:
3343                  color_warning(
3344                      message="An input example was not provided when logging the model. To ensure "
3345                      "the model signature functions correctly, specify the `input_example` "
3346                      "parameter. See "
3347                      "https://mlflow.org/docs/latest/model/signatures.html#model-input-example "
3348                      "for more details about the benefits of using input_example.",
3349                      stacklevel=1,
3350                      color="yellow_bold",
3351                  )
3352              else:
3353                  _logger.info("Validating input example against model signature")
3354                  try:
3355                      _validate_prediction_input(
3356                          data=saved_example.inference_data,
3357                          params=saved_example.inference_params,
3358                          input_schema=signature.inputs,
3359                          params_schema=signature.params,
3360                      )
3361                  except Exception as e:
3362                      raise MlflowException.invalid_parameter_value(
3363                          f"Input example does not match the model signature. {e}"
3364                      )
3365  
3366      with _get_dependencies_schemas() as dependencies_schemas:
3367          schema = dependencies_schemas.to_dict()
3368          if schema is not None:
3369              if mlflow_model.metadata is None:
3370                  mlflow_model.metadata = {}
3371              mlflow_model.metadata.update(schema)
3372  
3373      if resources is not None:
3374          if isinstance(resources, (Path, str)):
3375              serialized_resource = _ResourceBuilder.from_yaml_file(resources)
3376          else:
3377              serialized_resource = _ResourceBuilder.from_resources(resources)
3378  
3379          mlflow_model.resources = serialized_resource
3380  
3381      if auth_policy is not None:
3382          mlflow_model.auth_policy = auth_policy
3383  
3384      if first_argument_set_specified:
3385          return _save_model_with_loader_module_and_data_path(
3386              path=path,
3387              loader_module=loader_module,
3388              data_path=data_path,
3389              code_paths=code_paths,
3390              conda_env=conda_env,
3391              mlflow_model=mlflow_model,
3392              pip_requirements=pip_requirements,
3393              extra_pip_requirements=extra_pip_requirements,
3394              model_config=model_config,
3395              streamable=streamable,
3396              infer_code_paths=infer_code_paths,
3397              uv_project_path=uv_project_path,
3398              uv_groups=uv_groups,
3399              uv_extras=uv_extras,
3400          )
3401      elif second_argument_set_specified:
3402          return mlflow.pyfunc.model._save_model_with_class_artifacts_params(
3403              path=path,
3404              signature=signature,
3405              python_model=python_model,
3406              artifacts=artifacts,
3407              conda_env=conda_env,
3408              code_paths=code_paths,
3409              mlflow_model=mlflow_model,
3410              pip_requirements=pip_requirements,
3411              extra_pip_requirements=extra_pip_requirements,
3412              model_config=model_config,
3413              streamable=streamable,
3414              model_code_path=model_code_path,
3415              infer_code_paths=infer_code_paths,
3416              uv_project_path=uv_project_path,
3417              uv_groups=uv_groups,
3418              uv_extras=uv_extras,
3419          )
3420  
3421  
3422  def update_signature_for_type_hint_from_example(input_example: Any, signature: ModelSignature):
3423      if _is_example_valid_for_type_from_example(input_example):
3424          signature._is_type_hint_from_example = True
3425      else:
3426          _logger.warning(
3427              "Input example must be one of pandas.DataFrame, pandas.Series "
3428              f"or list when using TypeFromExample as type hint, got {type(input_example)}. "
3429              "Check https://mlflow.org/docs/latest/model/python_model.html#typefromexample-type-hint-usage"
3430              " for more details.",
3431          )
3432  
3433  
3434  @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn"))
3435  @trace_disabled  # Suppress traces for internal predict calls while logging model
3436  def log_model(
3437      artifact_path=None,
3438      loader_module=None,
3439      data_path=None,
3440      code_paths=None,
3441      infer_code_paths=False,
3442      conda_env=None,
3443      python_model=None,
3444      artifacts=None,
3445      registered_model_name=None,
3446      signature: ModelSignature = None,
3447      input_example: ModelInputExample = None,
3448      await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
3449      pip_requirements=None,
3450      extra_pip_requirements=None,
3451      metadata=None,
3452      model_config=None,
3453      streamable=None,
3454      resources: str | list[Resource] | None = None,
3455      auth_policy: AuthPolicy | None = None,
3456      uv_project_path: str | Path | None = None,
3457      uv_groups: list[str] | None = None,
3458      uv_extras: list[str] | None = None,
3459      prompts: list[str | Prompt] | None = None,
3460      name=None,
3461      params: dict[str, Any] | None = None,
3462      tags: dict[str, Any] | None = None,
3463      model_type: str | None = None,
3464      step: int = 0,
3465      model_id: str | None = None,
3466  ):
3467      """
3468      Log a Pyfunc model with custom inference logic and optional data dependencies as an MLflow
3469      artifact for the current run.
3470  
3471      For information about the workflows that this method supports, see :ref:`Workflows for
3472      creating custom pyfunc models <pyfunc-create-custom-workflows>` and
3473      :ref:`Which workflow is right for my use case? <pyfunc-create-custom-selecting-workflow>`.
3474      You cannot specify the parameters for the second workflow: ``loader_module``, ``data_path``
3475      and the parameters for the first workflow: ``python_model``, ``artifacts`` together.
3476  
3477      Args:
3478          artifact_path: Deprecated. Use `name` instead.
3479          loader_module: The name of the Python module that is used to load the model
3480              from ``data_path``. This module must define a method with the prototype
3481              ``_load_pyfunc(data_path)``. If not ``None``, this module and its
3482              dependencies must be included in one of the following locations:
3483  
3484              - The MLflow library.
3485              - Package(s) listed in the model's Conda environment, specified by
3486                the ``conda_env`` parameter.
3487              - One or more of the files specified by the ``code_paths`` parameter.
3488  
3489          data_path: Path to a file or directory containing model data.
3490          code_paths: {{ code_paths_pyfunc }}
3491          infer_code_paths: {{ infer_code_paths }}
3492          conda_env: {{ conda_env }}
3493          python_model:
3494              A file path to the PythonModel
3495              which defines the model from code artifact,
3496              (recommended), see https://mlflow.org/docs/latest/ml/model/models-from-code/
3497              for details;
3498              or an instance of a subclass of :class:`~PythonModel` or a callable object with a single
3499              argument (see the examples below), the passed-in object is serialized using the
3500              CloudPickle library, it requires exercising caution because these formats rely on
3501              Python's object serialization mechanism, which can execute arbitrary code during
3502              deserialization.
3503              Any dependencies of the class should be included in one of the
3504              following locations:
3505  
3506              - The MLflow library.
3507              - Package(s) listed in the model's Conda environment, specified by the ``conda_env``
3508                parameter.
3509              - One or more of the files specified by the ``code_paths`` parameter.
3510  
3511              Note: If the class is imported from another module, as opposed to being defined in the
3512              ``__main__`` scope, the defining module should also be included in one of the listed
3513              locations.
3514  
3515              **Examples**
3516  
3517              Class model
3518  
3519              .. code-block:: python
3520  
3521                  from typing import List
3522                  import mlflow
3523  
3524  
3525                  class MyModel(mlflow.pyfunc.PythonModel):
3526                      def predict(self, context, model_input: List[str], params=None) -> List[str]:
3527                          return [i.upper() for i in model_input]
3528  
3529  
3530                  with mlflow.start_run():
3531                      model_info = mlflow.pyfunc.log_model(
3532                          name="model",
3533                          python_model=MyModel(),
3534                      )
3535  
3536                  loaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
3537                  print(loaded_model.predict(["a", "b", "c"]))  # -> ["A", "B", "C"]
3538  
3539              Functional model
3540  
3541              .. note::
3542                  Experimental: Functional model support is experimental and may change or be removed
3543                  in a future release without warning.
3544  
3545              .. code-block:: python
3546  
3547                  from typing import List
3548                  import mlflow
3549  
3550  
3551                  def predict(model_input: List[str]) -> List[str]:
3552                      return [i.upper() for i in model_input]
3553  
3554  
3555                  with mlflow.start_run():
3556                      model_info = mlflow.pyfunc.log_model(
3557                          name="model", python_model=predict, input_example=["a"]
3558                      )
3559  
3560                  loaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
3561                  print(loaded_model.predict(["a", "b", "c"]))  # -> ["A", "B", "C"]
3562  
3563              Model from code
3564  
3565              .. note::
3566                  Experimental: Model from code model support is experimental and may change or
3567                  be removed in a future release without warning.
3568  
3569              .. code-block:: python
3570  
3571                  # code.py
3572                  from typing import List
3573                  import mlflow
3574  
3575  
3576                  class MyModel(mlflow.pyfunc.PythonModel):
3577                      def predict(self, context, model_input: List[str], params=None) -> List[str]:
3578                          return [i.upper() for i in model_input]
3579  
3580  
3581                  mlflow.models.set_model(MyModel())
3582  
3583                  # log_model.py
3584                  import mlflow
3585  
3586                  with mlflow.start_run():
3587                      model_info = mlflow.pyfunc.log_model(
3588                          name="model",
3589                          python_model="code.py",
3590                      )
3591  
3592              If the `predict` method or function has type annotations, MLflow automatically
3593              constructs a model signature based on the type annotations (unless the ``signature``
3594              argument is explicitly specified), and converts the input value to the specified type
3595              before passing it to the function. Currently, the following type annotations are
3596              supported:
3597  
3598                  - ``List[str]``
3599                  - ``List[Dict[str, str]]``
3600  
3601          artifacts: A dictionary containing ``<name, artifact_uri>`` entries. Remote artifact URIs
3602              are resolved to absolute filesystem paths, producing a dictionary of
3603              ``<name, absolute_path>`` entries. ``python_model`` can reference these
3604              resolved entries as the ``artifacts`` property of the ``context`` parameter
3605              in :func:`PythonModel.load_context() <mlflow.pyfunc.PythonModel.load_context>`
3606              and :func:`PythonModel.predict() <mlflow.pyfunc.PythonModel.predict>`.
3607              For example, consider the following ``artifacts`` dictionary::
3608  
3609                  {"my_file": "s3://my-bucket/path/to/my/file"}
3610  
3611              In this case, the ``"my_file"`` artifact is downloaded from S3. The
3612              ``python_model`` can then refer to ``"my_file"`` as an absolute filesystem
3613              path via ``context.artifacts["my_file"]``.
3614  
3615              If ``None``, no artifacts are added to the model.
3616          registered_model_name: If given, create a model
3617              version under ``registered_model_name``, also creating a
3618              registered model if one with the given name does not exist.
3619  
3620          signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
3621              describes model input and output :py:class:`Schema <mlflow.types.Schema>`.
3622              The model signature can be :py:func:`inferred <mlflow.models.infer_signature>`
3623              from datasets with valid model input (e.g. the training dataset with target
3624              column omitted) and valid model output (e.g. model predictions generated on
3625              the training dataset), for example:
3626  
3627              .. code-block:: python
3628  
3629                  from mlflow.models import infer_signature
3630  
3631                  train = df.drop_column("target_label")
3632                  predictions = ...  # compute model predictions
3633                  signature = infer_signature(train, predictions)
3634  
3635          input_example: {{ input_example }}
3636          await_registration_for: Number of seconds to wait for the model version to finish
3637              being created and is in ``READY`` status. By default, the function
3638              waits for five minutes. Specify 0 or None to skip waiting.
3639          pip_requirements: {{ pip_requirements }}
3640          extra_pip_requirements: {{ extra_pip_requirements }}
3641          metadata: {{ metadata }}
3642          model_config: The model configuration to apply to the model. The configuration will
3643              be available as the ``model_config`` property of the ``context`` parameter
3644              in :func:`PythonModel.load_context() <mlflow.pyfunc.PythonModel.load_context>`
3645              and :func:`PythonModel.predict() <mlflow.pyfunc.PythonModel.predict>`.
3646              The configuration can be passed as a file path, or a dict with string keys.
3647  
3648              .. Note:: Experimental: This parameter may change or be removed in a future
3649                                      release without warning.
3650          streamable: A boolean value indicating if the model supports streaming prediction,
3651                      If None, MLflow will try to inspect if the model supports streaming
3652                      by checking if `predict_stream` method exists. Default None.
3653          resources: A list of model resources or a resources.yaml file containing a list of
3654                      resources required to serve the model.
3655  
3656              .. Note:: Experimental: This parameter may change or be removed in a future
3657                                      release without warning.
3658          auth_policy: {{ auth_policy }}
3659          uv_project_path: Explicit path to the uv project directory containing uv.lock,
3660              pyproject.toml, and optionally .python-version. This is useful for monorepos
3661              or non-standard project layouts where the uv project is not in the current
3662              working directory. If ``None``, MLflow will auto-detect uv.lock, pyproject.toml,
3663              and .python-version files in the current working directory.
3664  
3665              When a uv project is detected (either via this parameter or auto-detection),
3666              pip requirements are generated by running ``uv export`` against the lockfile
3667              instead of inferring dependencies by capturing imported packages during model
3668              inference.
3669  
3670              Auto-detection can be disabled by setting the environment variable
3671              ``MLFLOW_UV_AUTO_DETECT=false``.
3672  
3673              .. Note:: Experimental: This parameter may change or be removed in a future
3674                                      release without warning.
3675          uv_groups: Optional list of uv dependency groups to include when exporting
3676              requirements from the uv lockfile. Maps to ``uv export --group <name>``.
3677              These are additive with the project's default dependencies.
3678  
3679              .. Note:: Experimental: This parameter may change or be removed in a future
3680                                      release without warning.
3681          uv_extras: Optional list of uv extras (optional dependency sets) to include
3682              when exporting requirements from the uv lockfile. Maps to
3683              ``uv export --extra <name>``.
3684  
3685              .. Note:: Experimental: This parameter may change or be removed in a future
3686                                      release without warning.
3687          prompts: {{ prompts }}
3688          name: {{ name }}
3689          params: {{ params }}
3690          tags: {{ tags }}
3691          model_type: {{ model_type }}
3692          step: {{ step }}
3693          model_id: {{ model_id }}
3694  
3695      Returns:
3696          A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
3697          metadata of the logged model.
3698      """
3699      flavor_name = _get_pyfunc_model_flavor_name(python_model)
3700      return Model.log(
3701          artifact_path=artifact_path,
3702          name=name,
3703          flavor=mlflow.pyfunc,
3704          loader_module=loader_module,
3705          data_path=data_path,
3706          code_paths=code_paths,
3707          python_model=python_model,
3708          artifacts=artifacts,
3709          conda_env=conda_env,
3710          registered_model_name=registered_model_name,
3711          signature=signature,
3712          input_example=input_example,
3713          await_registration_for=await_registration_for,
3714          pip_requirements=pip_requirements,
3715          extra_pip_requirements=extra_pip_requirements,
3716          metadata=metadata,
3717          prompts=prompts,
3718          model_config=model_config,
3719          streamable=streamable,
3720          resources=resources,
3721          infer_code_paths=infer_code_paths,
3722          auth_policy=auth_policy,
3723          uv_project_path=uv_project_path,
3724          uv_groups=uv_groups,
3725          uv_extras=uv_extras,
3726          params=params,
3727          tags=tags,
3728          model_type=model_type,
3729          step=step,
3730          model_id=model_id,
3731          # only used for checking python model type
3732          flavor_name=flavor_name,
3733      )
3734  
3735  
3736  def _get_pyfunc_model_flavor_name(python_model: Any) -> str:
3737      if python_model is None:
3738          return "pyfunc"
3739      if isinstance(python_model, str):
3740          return "pyfunc.ModelFromCode"
3741      if IS_RESPONSES_AGENT_AVAILABLE and isinstance(python_model, ResponsesAgent):
3742          return "pyfunc.ResponsesAgent"
3743      if isinstance(python_model, ChatAgent):
3744          return "pyfunc.ChatAgent"
3745      if isinstance(python_model, ChatModel):
3746          return "pyfunc.ChatModel"
3747      if isinstance(python_model, PythonModel):
3748          return "pyfunc.CustomPythonModel"
3749      return "pyfunc"
3750  
3751  
3752  def _save_model_with_loader_module_and_data_path(
3753      path,
3754      loader_module,
3755      data_path=None,
3756      code_paths=None,
3757      conda_env=None,
3758      mlflow_model=None,
3759      pip_requirements=None,
3760      extra_pip_requirements=None,
3761      model_config=None,
3762      streamable=None,
3763      infer_code_paths=False,
3764      uv_project_path=None,
3765      uv_groups=None,
3766      uv_extras=None,
3767  ):
3768      """
3769      Export model as a generic Python function model.
3770  
3771      Args:
3772          path: The path to which to save the Python model.
3773          loader_module: The name of the Python module that is used to load the model
3774              from ``data_path``. This module must define a method with the prototype
3775              ``_load_pyfunc(data_path)``.
3776          data_path: Path to a file or directory containing model data.
3777          code_paths: A list of local filesystem paths to Python file dependencies (or directories
3778              containing file dependencies). These files are *prepended* to the system
3779              path before the model is loaded.
3780          conda_env: Either a dictionary representation of a Conda environment or the path to a
3781              Conda environment yaml file. If provided, this describes the environment
3782              this model should be run in.
3783          streamable: A boolean value indicating if the model supports streaming prediction,
3784                      None value also means not streamable.
3785  
3786      Returns:
3787          Model configuration containing model info.
3788      """
3789      # Capture original working directory for uv project detection
3790      # This must be done before any operations that might change cwd
3791      original_cwd = Path.cwd()
3792  
3793      data = None
3794  
3795      if data_path is not None:
3796          model_file = _copy_file_or_tree(src=data_path, dst=path, dst_dir="data")
3797          data = model_file
3798  
3799      if mlflow_model is None:
3800          mlflow_model = Model()
3801  
3802      streamable = streamable or False
3803      mlflow.pyfunc.add_to_model(
3804          mlflow_model,
3805          loader_module=loader_module,
3806          code=None,
3807          data=data,
3808          conda_env=_CONDA_ENV_FILE_NAME,
3809          python_env=_PYTHON_ENV_FILE_NAME,
3810          model_config=model_config,
3811          streamable=streamable,
3812      )
3813      if size := get_total_file_size(path):
3814          mlflow_model.model_size_bytes = size
3815      mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
3816  
3817      code_dir_subpath = _validate_infer_and_copy_code_paths(
3818          code_paths, path, infer_code_paths, FLAVOR_NAME
3819      )
3820      mlflow_model.flavors[FLAVOR_NAME][CODE] = code_dir_subpath
3821  
3822      # `mlflow_model.code` is updated, re-generate `MLmodel` file.
3823      mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
3824  
3825      if uv_project_path is not None:
3826          uv_source_dir = uv_project_path
3827      elif MLFLOW_UV_AUTO_DETECT.get():
3828          uv_source_dir = original_cwd
3829      else:
3830          uv_source_dir = None
3831  
3832      if conda_env is None:
3833          if pip_requirements is None:
3834              default_reqs = get_default_pip_requirements()
3835              extra_env_vars = (
3836                  _get_databricks_serverless_env_vars()
3837                  if is_in_databricks_serverless_runtime()
3838                  else None
3839              )
3840              # To ensure `_load_pyfunc` can successfully load the model during the dependency
3841              # inference, `mlflow_model.save` must be called beforehand to save an MLmodel file.
3842              inferred_reqs = mlflow.models.infer_pip_requirements(
3843                  path,
3844                  FLAVOR_NAME,
3845                  fallback=default_reqs,
3846                  extra_env_vars=extra_env_vars,
3847                  uv_project_dir=uv_source_dir,
3848                  uv_groups=uv_groups,
3849                  uv_extras=uv_extras,
3850              )
3851              default_reqs = sorted(set(inferred_reqs).union(default_reqs))
3852          else:
3853              default_reqs = None
3854          conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
3855              default_reqs,
3856              pip_requirements,
3857              extra_pip_requirements,
3858          )
3859      else:
3860          conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
3861  
3862      with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f:
3863          yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
3864  
3865      # Save `constraints.txt` if necessary
3866      if pip_constraints:
3867          write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
3868  
3869      # Save `requirements.txt`
3870      write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
3871  
3872      # Copy uv project files (uv.lock and pyproject.toml) if detected
3873      if uv_source_dir is not None:
3874          copy_uv_project_files(dest_dir=path, source_dir=uv_source_dir)
3875  
3876      _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))
3877      return mlflow_model
3878  
3879  
3880  def _save_model_chat_agent_helper(
3881      python_model, mlflow_model, signature, input_example, artifacts, model_config
3882  ):
3883      """Helper method for save_model for ChatAgent models
3884  
3885      Returns: a dict input_example
3886      """
3887      if signature is not None:
3888          raise MlflowException(
3889              "ChatAgent subclasses have a standard signature that is set "
3890              "automatically. Please remove the `signature` parameter from "
3891              "the call to log_model() or save_model().",
3892              error_code=INVALID_PARAMETER_VALUE,
3893          )
3894      mlflow_model.signature = ModelSignature(
3895          inputs=CHAT_AGENT_INPUT_SCHEMA,
3896          outputs=CHAT_AGENT_OUTPUT_SCHEMA,
3897      )
3898      # For ChatAgent we set default metadata to indicate its task
3899      default_metadata = {TASK: _DEFAULT_CHAT_AGENT_METADATA_TASK}
3900      mlflow_model.metadata = default_metadata | (mlflow_model.metadata or {})
3901  
3902      # We accept a dict with ChatAgentRequest schema
3903      if input_example:
3904          try:
3905              model_validate(ChatAgentRequest, input_example)
3906          except pydantic.ValidationError as e:
3907              raise MlflowException(
3908                  message=(
3909                      f"Invalid input example. Expected a ChatAgentRequest object or dictionary with"
3910                      f" its schema. Pydantic validation error: {e}"
3911                  ),
3912                  error_code=INTERNAL_ERROR,
3913              ) from e
3914          if isinstance(input_example, ChatAgentRequest):
3915              input_example = input_example.model_dump(exclude_none=True)
3916      else:
3917          input_example = CHAT_AGENT_INPUT_EXAMPLE
3918  
3919      _logger.info("Predicting on input example to validate output")
3920      context = PythonModelContext(artifacts, model_config)
3921      python_model.load_context(context)
3922      request = ChatAgentRequest(**input_example)
3923      output = python_model.predict(request.messages, request.context, request.custom_inputs)
3924      try:
3925          model_validate(ChatAgentResponse, output)
3926      except Exception as e:
3927          raise MlflowException(
3928              "Failed to save ChatAgent. Ensure your model's predict() method returns a "
3929              "ChatAgentResponse object or a dict with the same schema."
3930              f"Pydantic validation error: {e}"
3931          ) from e
3932      return input_example
3933  
3934  
3935  def _save_model_responses_agent_helper(
3936      python_model, mlflow_model, signature, input_example, artifacts, model_config
3937  ):
3938      """Helper method for save_model for ResponsesAgent models
3939  
3940      Returns: a dictionary input example
3941      """
3942      from mlflow.types.responses import (
3943          RESPONSES_AGENT_INPUT_EXAMPLE,
3944          RESPONSES_AGENT_INPUT_SCHEMA,
3945          RESPONSES_AGENT_OUTPUT_SCHEMA,
3946          ResponsesAgentRequest,
3947          ResponsesAgentResponse,
3948      )
3949  
3950      if signature is not None:
3951          raise MlflowException(
3952              "ResponsesAgent subclasses have a standard signature that is set "
3953              "automatically. Please remove the `signature` parameter from "
3954              "the call to log_model() or save_model().",
3955              error_code=INVALID_PARAMETER_VALUE,
3956          )
3957      mlflow_model.signature = ModelSignature(
3958          inputs=RESPONSES_AGENT_INPUT_SCHEMA,
3959          outputs=RESPONSES_AGENT_OUTPUT_SCHEMA,
3960      )
3961  
3962      # For ResponsesAgent we set default metadata to indicate its task
3963      default_metadata = {TASK: _DEFAULT_RESPONSES_AGENT_METADATA_TASK}
3964      mlflow_model.metadata = default_metadata | (mlflow_model.metadata or {})
3965  
3966      # We accept either a dict or a ResponsesRequest object as input
3967      if input_example:
3968          try:
3969              model_validate(ResponsesAgentRequest, input_example)
3970          except pydantic.ValidationError as e:
3971              raise MlflowException(
3972                  message=(
3973                      f"Invalid input example. Expected a ResponsesRequest object or dictionary with"
3974                      f" its schema. Pydantic validation error: {e}"
3975                  ),
3976                  error_code=INTERNAL_ERROR,
3977              ) from e
3978          if isinstance(input_example, ResponsesAgentRequest):
3979              input_example = input_example.model_dump(exclude_none=True)
3980      else:
3981          input_example = RESPONSES_AGENT_INPUT_EXAMPLE
3982      _logger.info("Predicting on input example to validate output")
3983      context = PythonModelContext(artifacts, model_config)
3984      python_model.load_context(context)
3985      request = ResponsesAgentRequest(**input_example)
3986      output = python_model.predict(request)
3987      try:
3988          model_validate(ResponsesAgentResponse, output)
3989      except Exception as e:
3990          raise MlflowException(
3991              "Failed to save ResponsesAgent. Ensure your model's predict() method returns a "
3992              "ResponsesResponse object or a dict with the same schema."
3993              f"Pydantic validation error: {e}"
3994          ) from e
3995      return input_example