__init__.py
1 """ 2 The ``python_function`` model flavor serves as a default model interface for MLflow Python models. 3 Any MLflow Python model is expected to be loadable as a ``python_function`` model. 4 5 In addition, the ``mlflow.pyfunc`` module defines a generic :ref:`filesystem format 6 <pyfunc-filesystem-format>` for Python models and provides utilities for saving to and loading from 7 this format. The format is self contained in the sense that it includes all necessary information 8 for anyone to load it and use it. Dependencies are either stored directly with the model or 9 referenced via a Conda environment. 10 11 The ``mlflow.pyfunc`` module also defines utilities for creating custom ``pyfunc`` models 12 using frameworks and inference logic that may not be natively included in MLflow. See 13 :ref:`pyfunc-create-custom`. 14 15 .. _pyfunc-inference-api: 16 17 ************* 18 Inference API 19 ************* 20 21 Python function models are loaded as an instance of :py:class:`PyFuncModel 22 <mlflow.pyfunc.PyFuncModel>`, which is an MLflow wrapper around the model implementation and model 23 metadata (MLmodel file). You can score the model by calling the :py:func:`predict() 24 <mlflow.pyfunc.PyFuncModel.predict>` method, which has the following signature:: 25 26 predict( 27 model_input: [pandas.DataFrame, numpy.ndarray, scipy.sparse.(csc_matrix | csr_matrix), 28 List[Any], Dict[str, Any], pyspark.sql.DataFrame] 29 ) -> [numpy.ndarray | pandas.(Series | DataFrame) | List | Dict | pyspark.sql.DataFrame] 30 31 All PyFunc models will support `pandas.DataFrame` as input and PyFunc deep learning models will 32 also support tensor inputs in the form of Dict[str, numpy.ndarray] (named tensors) and 33 `numpy.ndarrays` (unnamed tensors). 34 35 Here are some examples of supported inference types, assuming we have the correct ``model`` object 36 loaded. 37 38 .. list-table:: 39 :widths: 30 70 40 :header-rows: 1 41 :class: wrap-table 42 43 * - Input Type 44 - Example 45 * - ``pandas.DataFrame`` 46 - 47 .. code-block:: python 48 49 import pandas as pd 50 51 x_new = pd.DataFrame(dict(x1=[1, 2, 3], x2=[4, 5, 6])) 52 model.predict(x_new) 53 54 * - ``numpy.ndarray`` 55 - 56 .. code-block:: python 57 58 import numpy as np 59 60 x_new = np.array([[1, 4][2, 5], [3, 6]]) 61 model.predict(x_new) 62 63 * - ``scipy.sparse.csc_matrix`` or ``scipy.sparse.csr_matrix`` 64 - 65 .. code-block:: python 66 67 import scipy 68 69 x_new = scipy.sparse.csc_matrix([[1, 2, 3], [4, 5, 6]]) 70 model.predict(x_new) 71 72 x_new = scipy.sparse.csr_matrix([[1, 2, 3], [4, 5, 6]]) 73 model.predict(x_new) 74 75 * - python ``List`` 76 - 77 .. code-block:: python 78 79 x_new = [[1, 4], [2, 5], [3, 6]] 80 model.predict(x_new) 81 82 * - python ``Dict`` 83 - 84 .. code-block:: python 85 86 x_new = dict(x1=[1, 2, 3], x2=[4, 5, 6]) 87 model.predict(x_new) 88 89 * - ``pyspark.sql.DataFrame`` 90 - 91 .. code-block:: python 92 93 from pyspark.sql import SparkSession 94 95 spark = SparkSession.builder.getOrCreate() 96 97 data = [(1, 4), (2, 5), (3, 6)] # List of tuples 98 x_new = spark.createDataFrame(data, ["x1", "x2"]) # Specify column name 99 model.predict(x_new) 100 101 .. _pyfunc-filesystem-format: 102 103 ***************** 104 Filesystem format 105 ***************** 106 107 The Pyfunc format is defined as a directory structure containing all required data, code, and 108 configuration:: 109 110 ./dst-path/ 111 ./MLmodel: configuration 112 <code>: code packaged with the model (specified in the MLmodel file) 113 <data>: data packaged with the model (specified in the MLmodel file) 114 <env>: Conda environment definition (specified in the MLmodel file) 115 116 The directory structure may contain additional contents that can be referenced by the ``MLmodel`` 117 configuration. 118 119 .. _pyfunc-model-config: 120 121 MLModel configuration 122 ##################### 123 124 A Python model contains an ``MLmodel`` file in **python_function** format in its root with the 125 following parameters: 126 127 - loader_module [required]: 128 Python module that can load the model. Expected as module identifier 129 e.g. ``mlflow.sklearn``, it will be imported using ``importlib.import_module``. 130 The imported module must contain a function with the following signature:: 131 132 _load_pyfunc(path: string) -> <pyfunc model implementation> 133 134 The path argument is specified by the ``data`` parameter and may refer to a file or 135 directory. The model implementation is expected to be an object with a 136 ``predict`` method with the following signature:: 137 138 predict( 139 model_input: [pandas.DataFrame, numpy.ndarray, 140 scipy.sparse.(csc_matrix | csr_matrix), List[Any], Dict[str, Any]], 141 pyspark.sql.DataFrame 142 ) -> [numpy.ndarray | pandas.(Series | DataFrame) | List | Dict | pyspark.sql.DataFrame] 143 144 - code [optional]: 145 Relative path to a directory containing the code packaged with this model. 146 All files and directories inside this directory are added to the Python path 147 prior to importing the model loader. 148 149 - data [optional]: 150 Relative path to a file or directory containing model data. 151 The path is passed to the model loader. 152 153 - env [optional]: 154 Relative path to an exported Conda environment. If present this environment 155 should be activated prior to running the model. 156 157 - Optionally, any additional parameters necessary for interpreting the serialized model in 158 ``pyfunc`` format. 159 160 .. rubric:: Example 161 162 :: 163 164 tree example/sklearn_iris/mlruns/run1/outputs/linear-lr 165 166 :: 167 168 ├── MLmodel 169 ├── code 170 │ ├── sklearn_iris.py 171 │ 172 ├── data 173 │ └── model.pkl 174 └── mlflow_env.yml 175 176 :: 177 178 cat example/sklearn_iris/mlruns/run1/outputs/linear-lr/MLmodel 179 180 :: 181 182 python_function: 183 code: code 184 data: data/model.pkl 185 loader_module: mlflow.sklearn 186 env: mlflow_env.yml 187 main: sklearn_iris 188 189 .. _pyfunc-create-custom: 190 191 ********************************** 192 Models From Code for Custom Models 193 ********************************** 194 195 .. tip:: 196 197 MLflow 2.12.2 introduced the feature "models from code", which greatly simplifies the process 198 of serializing and deploying custom models through the use of script serialization. It is 199 strongly recommended to migrate custom model implementations to this new paradigm to avoid the 200 limitations and complexity of serializing with cloudpickle. 201 You can learn more about models from code within the 202 `Models From Code Guide <../model/models-from-code.html>`_. 203 204 The section below illustrates the process of using the legacy serializer for custom Pyfunc models. 205 Models from code will provide a far simpler experience for logging of your models. 206 207 ****************************** 208 Creating custom Pyfunc models 209 ****************************** 210 211 MLflow's persistence modules provide convenience functions for creating models with the 212 ``pyfunc`` flavor in a variety of machine learning frameworks (scikit-learn, Keras, Pytorch, and 213 more); however, they do not cover every use case. For example, you may want to create an MLflow 214 model with the ``pyfunc`` flavor using a framework that MLflow does not natively support. 215 Alternatively, you may want to build an MLflow model that executes custom logic when evaluating 216 queries, such as preprocessing and postprocessing routines. Therefore, ``mlflow.pyfunc`` 217 provides utilities for creating ``pyfunc`` models from arbitrary code and model data. 218 219 The :meth:`save_model()` and :meth:`log_model()` methods are designed to support multiple workflows 220 for creating custom ``pyfunc`` models that incorporate custom inference logic and artifacts 221 that the logic may require. 222 223 An `artifact` is a file or directory, such as a serialized model or a CSV. For example, a 224 serialized TensorFlow graph is an artifact. An MLflow model directory is also an artifact. 225 226 .. _pyfunc-create-custom-workflows: 227 228 Workflows 229 ######### 230 231 :meth:`save_model()` and :meth:`log_model()` support the following workflows: 232 233 1. Programmatically defining a new MLflow model, including its attributes and artifacts. 234 235 Given a set of artifact URIs, :meth:`save_model()` and :meth:`log_model()` can 236 automatically download artifacts from their URIs and create an MLflow model directory. 237 238 In this case, you must define a Python class which inherits from :class:`~PythonModel`, 239 defining ``predict()`` and, optionally, ``load_context()``. An instance of this class is 240 specified via the ``python_model`` parameter; it is automatically serialized and deserialized 241 as a Python class, including all of its attributes. 242 243 2. Interpreting pre-existing data as an MLflow model. 244 245 If you already have a directory containing model data, :meth:`save_model()` and 246 :meth:`log_model()` can import the data as an MLflow model. The ``data_path`` parameter 247 specifies the local filesystem path to the directory containing model data. 248 249 In this case, you must provide a Python module, called a `loader module`. The 250 loader module defines a ``_load_pyfunc()`` method that performs the following tasks: 251 252 - Load data from the specified ``data_path``. For example, this process may include 253 deserializing pickled Python objects or models or parsing CSV files. 254 255 - Construct and return a pyfunc-compatible model wrapper. As in the first 256 use case, this wrapper must define a ``predict()`` method that is used to evaluate 257 queries. ``predict()`` must adhere to the :ref:`pyfunc-inference-api`. 258 259 The ``loader_module`` parameter specifies the name of your loader module. 260 261 For an example loader module implementation, refer to the `loader module 262 implementation in mlflow.sklearn <https://github.com/mlflow/mlflow/blob/ 263 74d75109aaf2975f5026104d6125bb30f4e3f744/mlflow/sklearn.py#L200-L205>`_. 264 265 .. _pyfunc-create-custom-selecting-workflow: 266 267 Which workflow is right for my use case? 268 ######################################## 269 270 We consider the first workflow to be more user-friendly and generally recommend it for the 271 following reasons: 272 273 - It automatically resolves and collects specified model artifacts. 274 275 - It automatically serializes and deserializes the ``python_model`` instance and all of 276 its attributes, reducing the amount of user logic that is required to load the model 277 278 - You can create Models using logic that is defined in the ``__main__`` scope. This allows 279 custom models to be constructed in interactive environments, such as notebooks and the Python 280 REPL. 281 282 You may prefer the second, lower-level workflow for the following reasons: 283 284 - Inference logic is always persisted as code, rather than a Python object. This makes logic 285 easier to inspect and modify later. 286 287 - If you have already collected all of your model data in a single location, the second 288 workflow allows it to be saved in MLflow format directly, without enumerating constituent 289 artifacts. 290 291 ****************************************** 292 Function-based Model vs Class-based Model 293 ****************************************** 294 295 When creating custom PyFunc models, you can choose between two different interfaces: 296 a function-based model and a class-based model. In short, a function-based model is simply a 297 python function that does not take additional params. The class-based model, on the other hand, 298 is subclass of ``PythonModel`` that supports several required and optional 299 methods. If your use case is simple and fits within a single predict function, a function-based 300 approach is recommended. If you need more power, such as custom serialization, custom data 301 processing, or to override additional methods, you should use the class-based implementation. 302 303 Before looking at code examples, it's important to note that both methods are serialized via 304 `cloudpickle <https://github.com/cloudpipe/cloudpickle>`_. cloudpickle can serialize Python 305 functions, lambda functions, and locally defined classes and functions inside other functions. This 306 makes cloudpickle especially useful for parallel and distributed computing where code objects need 307 to be sent over network to execute on remote workers, which is a common deployment paradigm for 308 MLflow. 309 310 That said, cloudpickle has some limitations. 311 312 - **Environment Dependency**: cloudpickle does not capture the full execution environment, so in 313 MLflow we must pass ``pip_requirements``, ``extra_pip_requirements``, or an ``input_example``, 314 the latter of which is used to infer environment dependencies. For more, refer to 315 `the model dependency docs <https://mlflow.org/docs/latest/model/dependencies.html>`_. 316 317 - **Object Support**: cloudpickle does not serialize objects outside of the Python data model. 318 Some relevant examples include raw files and database connections. If your program depends on 319 these, be sure to log ways to reference these objects along with your model. 320 321 Function-based Model 322 #################### 323 If you're looking to serialize a simple python function without additional dependent methods, you 324 can simply log a predict method via the keyword argument ``python_model``. 325 326 .. note:: 327 328 Function-based model only supports a function with a single input argument. If you would like 329 to pass more arguments or additional inference parameters, please use the class-based model 330 below. 331 332 .. code-block:: python 333 334 import mlflow 335 import pandas as pd 336 337 338 # Define a simple function to log 339 def predict(model_input): 340 return model_input.apply(lambda x: x * 2) 341 342 343 # Save the function as a model 344 with mlflow.start_run(): 345 mlflow.pyfunc.log_model(name="model", python_model=predict, pip_requirements=["pandas"]) 346 run_id = mlflow.active_run().info.run_id 347 348 # Load the model from the tracking server and perform inference 349 model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model") 350 x_new = pd.Series([1, 2, 3]) 351 352 prediction = model.predict(x_new) 353 print(prediction) 354 355 Class-based Model 356 ################# 357 If you're looking to serialize a more complex object, for instance a class that handles 358 preprocessing, complex prediction logic, or custom serialization, you should subclass the 359 ``PythonModel`` class. MLflow has tutorials on building custom PyFunc models, as shown 360 `here <https://mlflow.org/docs/latest/traditional-ml/creating-custom-pyfunc/index.html>`_, 361 so instead of duplicating that information, in this example we'll recreate the above functionality 362 to highlight the differences. Note that this PythonModel implementation is overly complex and 363 we would recommend using the functional-based Model instead for this simple case. 364 365 .. code-block:: python 366 367 import mlflow 368 import pandas as pd 369 370 371 class MyModel(mlflow.pyfunc.PythonModel): 372 def predict(self, context, model_input, params=None): 373 return [x * 2 for x in model_input] 374 375 376 # Save the function as a model 377 with mlflow.start_run(): 378 mlflow.pyfunc.log_model(name="model", python_model=MyModel(), pip_requirements=["pandas"]) 379 run_id = mlflow.active_run().info.run_id 380 381 # Load the model from the tracking server and perform inference 382 model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model") 383 x_new = pd.Series([1, 2, 3]) 384 385 print(f"Prediction:\n\t{model.predict(x_new)}") 386 387 The primary difference between the this implementation and the function-based implementation above 388 is that the predict method is wrapped with a class, has the ``self`` parameter, 389 and has the ``params`` parameter that defaults to None. Note that function-based models don't 390 support additional params. 391 392 In summary, use the function-based Model when you have a simple function to serialize. 393 If you need more power, use the class-based model. 394 """ 395 396 import collections 397 import functools 398 import hashlib 399 import importlib 400 import inspect 401 import json 402 import logging 403 import os 404 import shutil 405 import signal 406 import subprocess 407 import sys 408 import tempfile 409 import threading 410 import uuid 411 from copy import deepcopy 412 from pathlib import Path 413 from typing import Any, Iterator, Tuple, Union 414 from urllib.parse import urlparse 415 416 import numpy as np 417 import pandas 418 import pydantic 419 import yaml 420 from packaging.version import Version 421 422 import mlflow 423 import mlflow.models.signature 424 import mlflow.pyfunc.loaders 425 import mlflow.pyfunc.model 426 from mlflow.entities.model_registry.prompt import Prompt 427 from mlflow.environment_variables import ( 428 _MLFLOW_IN_CAPTURE_MODULE_PROCESS, 429 _MLFLOW_SPARK_UDF_SERVERLESS_SKIP_DBCONNECT_ARTIFACT, 430 _MLFLOW_TESTING, 431 MLFLOW_DISABLE_SCHEMA_DETAILS, 432 MLFLOW_ENFORCE_STDIN_SCORING_SERVER_FOR_SPARK_UDF, 433 MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR, 434 MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT, 435 MLFLOW_UV_AUTO_DETECT, 436 ) 437 from mlflow.exceptions import MlflowException 438 from mlflow.models import Model, ModelInputExample, ModelSignature 439 from mlflow.models.auth_policy import AuthPolicy 440 from mlflow.models.dependencies_schemas import ( 441 _clear_dependencies_schemas, 442 _get_dependencies_schema_from_model, 443 _get_dependencies_schemas, 444 ) 445 from mlflow.models.flavor_backend_registry import get_flavor_backend 446 from mlflow.models.model import ( 447 _DATABRICKS_FS_LOADER_MODULE, 448 MLMODEL_FILE_NAME, 449 MODEL_CODE_PATH, 450 MODEL_CONFIG, 451 ) 452 from mlflow.models.resources import Resource, _ResourceBuilder 453 from mlflow.models.signature import ( 454 _extract_type_hints, 455 _infer_signature_from_input_example, 456 _infer_signature_from_type_hints, 457 ) 458 from mlflow.models.utils import ( 459 PyFuncInput, 460 PyFuncLLMOutputChunk, 461 PyFuncLLMSingleInput, 462 PyFuncOutput, 463 _convert_llm_input_data, 464 _enforce_params_schema, 465 _enforce_schema, 466 _load_model_code_path, 467 _save_example, 468 _split_input_data_and_params, 469 _validate_and_get_model_code_path, 470 ) 471 from mlflow.protos.databricks_pb2 import ( 472 BAD_REQUEST, 473 INTERNAL_ERROR, 474 INVALID_PARAMETER_VALUE, 475 RESOURCE_DOES_NOT_EXIST, 476 ) 477 from mlflow.protos.databricks_uc_registry_messages_pb2 import ( 478 Entity, 479 Job, 480 LineageHeaderInfo, 481 Notebook, 482 ) 483 from mlflow.pyfunc.context import Context, set_prediction_context 484 from mlflow.pyfunc.dbconnect_artifact_cache import ( 485 DBConnectArtifactCache, 486 archive_directory, 487 extract_archive_to_dir, 488 ) 489 from mlflow.pyfunc.model import ( 490 _DEFAULT_CHAT_AGENT_METADATA_TASK, 491 _DEFAULT_CHAT_MODEL_METADATA_TASK, 492 _DEFAULT_RESPONSES_AGENT_METADATA_TASK, 493 ChatAgent, 494 ChatModel, 495 PythonModel, 496 PythonModelContext, 497 _FunctionPythonModel, 498 _log_warning_if_params_not_in_predict_signature, 499 _PythonModelPyfuncWrapper, 500 get_default_conda_env, # noqa: F401 501 get_default_pip_requirements, 502 ) 503 504 try: 505 from mlflow.pyfunc.model import ResponsesAgent 506 507 IS_RESPONSES_AGENT_AVAILABLE = True 508 except ImportError: 509 IS_RESPONSES_AGENT_AVAILABLE = False 510 from mlflow.tracing.provider import trace_disabled 511 from mlflow.tracing.utils import _try_get_prediction_context 512 from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS 513 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 514 from mlflow.types.agent import ( 515 CHAT_AGENT_INPUT_EXAMPLE, 516 CHAT_AGENT_INPUT_SCHEMA, 517 CHAT_AGENT_OUTPUT_SCHEMA, 518 ChatAgentRequest, 519 ChatAgentResponse, 520 ) 521 from mlflow.types.llm import ( 522 CHAT_MODEL_INPUT_EXAMPLE, 523 CHAT_MODEL_INPUT_SCHEMA, 524 CHAT_MODEL_OUTPUT_SCHEMA, 525 ChatCompletionResponse, 526 ChatMessage, 527 ChatParams, 528 ) 529 from mlflow.types.type_hints import ( 530 _convert_dataframe_to_example_format, 531 _is_example_valid_for_type_from_example, 532 _is_type_hint_from_example, 533 _signature_cannot_be_inferred_from_type_hint, 534 model_validate, 535 ) 536 from mlflow.utils import ( 537 PYTHON_VERSION, 538 _is_in_ipython_notebook, 539 check_port_connectivity, 540 databricks_utils, 541 find_free_port, 542 get_major_minor_py_version, 543 ) 544 from mlflow.utils import env_manager as _EnvManager 545 from mlflow.utils._spark_utils import modified_environ 546 from mlflow.utils.annotations import deprecated, developer_stable 547 from mlflow.utils.databricks_utils import ( 548 _get_databricks_serverless_env_vars, 549 get_dbconnect_udf_sandbox_info, 550 is_databricks_connect, 551 is_in_databricks_runtime, 552 is_in_databricks_serverless_runtime, 553 is_in_databricks_shared_cluster_runtime, 554 ) 555 from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring 556 from mlflow.utils.environment import ( 557 _CONDA_ENV_FILE_NAME, 558 _CONSTRAINTS_FILE_NAME, 559 _PYTHON_ENV_FILE_NAME, 560 _REQUIREMENTS_FILE_NAME, 561 _process_conda_env, 562 _process_pip_requirements, 563 _PythonEnv, 564 _validate_env_arguments, 565 ) 566 from mlflow.utils.file_utils import ( 567 _copy_file_or_tree, 568 get_or_create_nfs_tmp_dir, 569 get_or_create_tmp_dir, 570 get_total_file_size, 571 write_to, 572 ) 573 from mlflow.utils.mlflow_tags import MLFLOW_MODEL_IS_EXTERNAL 574 from mlflow.utils.model_utils import ( 575 _add_code_from_conf_to_system_path, 576 _get_flavor_configuration, 577 _get_flavor_configuration_from_ml_model_file, 578 _get_overridden_pyfunc_model_config, 579 _validate_and_copy_file_to_directory, 580 _validate_and_get_model_config_from_file, 581 _validate_and_prepare_target_save_path, 582 _validate_infer_and_copy_code_paths, 583 _validate_pyfunc_model_config, 584 ) 585 from mlflow.utils.nfs_on_spark import get_nfs_cache_root_dir 586 from mlflow.utils.requirements_utils import ( 587 _parse_requirements, 588 warn_dependency_requirement_mismatches, 589 ) 590 from mlflow.utils.spark_utils import is_spark_connect_mode 591 from mlflow.utils.uv_utils import copy_uv_project_files 592 from mlflow.utils.virtualenv import _get_python_env, _get_virtualenv_name 593 from mlflow.utils.warnings_utils import color_warning 594 595 try: 596 from pyspark.sql import DataFrame as SparkDataFrame 597 598 HAS_PYSPARK = True 599 except ImportError: 600 HAS_PYSPARK = False 601 FLAVOR_NAME = "python_function" 602 MAIN = "loader_module" 603 CODE = "code" 604 DATA = "data" 605 ENV = "env" 606 TASK = "task" 607 608 _MODEL_DATA_SUBPATH = "data" 609 _CHAT_PARAMS_WARNING_MESSAGE = ( 610 "Default values for temperature, n and stream in ChatParams will be removed in the " 611 "next release. Specify them in the input example explicitly if needed." 612 ) 613 _TYPE_FROM_EXAMPLE_ERROR_MESSAGE = ( 614 "Input example must be provided when using TypeFromExample as type hint. " 615 "Fix this by passing `input_example` when logging your model. Check " 616 "https://mlflow.org/docs/latest/model/python_model.html#typefromexample-type-hint-usage " 617 "for more details." 618 ) 619 620 621 class EnvType: 622 CONDA = "conda" 623 VIRTUALENV = "virtualenv" 624 625 def __init__(self): 626 raise NotImplementedError("This class is not meant to be instantiated.") 627 628 629 PY_VERSION = "python_version" 630 631 _logger = logging.getLogger(__name__) 632 633 634 def add_to_model( 635 model, 636 loader_module, 637 data=None, 638 code=None, 639 conda_env=None, 640 python_env=None, 641 model_config=None, 642 model_code_path=None, 643 **kwargs, 644 ): 645 """ 646 Add a ``pyfunc`` spec to the model configuration. 647 648 Defines ``pyfunc`` configuration schema. Caller can use this to create a valid ``pyfunc`` model 649 flavor out of an existing directory structure. For example, other model flavors can use this to 650 specify how to use their output as a ``pyfunc``. 651 652 NOTE: 653 654 All paths are relative to the exported model root directory. 655 656 Args: 657 model: Existing model. 658 loader_module: The module to be used to load the model. 659 data: Path to the model data. 660 code: Path to the code dependencies. 661 conda_env: Conda environment. 662 python_env: Python environment. 663 model_config: The model configuration to apply to the model. This configuration 664 is available during model loading. 665 666 .. Note:: Experimental: This parameter may change or be removed in a future 667 release without warning. 668 669 model_code_path: Path to the model code. 670 kwargs: Additional key-value pairs to include in the ``pyfunc`` flavor specification. 671 Values must be YAML-serializable. 672 673 Returns: 674 Updated model configuration. 675 """ 676 params = deepcopy(kwargs) 677 params[MAIN] = loader_module 678 params[PY_VERSION] = PYTHON_VERSION 679 if code: 680 params[CODE] = code 681 if data: 682 params[DATA] = data 683 if conda_env or python_env: 684 params[ENV] = {} 685 if conda_env: 686 params[ENV][EnvType.CONDA] = conda_env 687 if python_env: 688 params[ENV][EnvType.VIRTUALENV] = python_env 689 if model_config: 690 params[MODEL_CONFIG] = model_config 691 if model_code_path: 692 params[MODEL_CODE_PATH] = model_code_path 693 return model.add_flavor(FLAVOR_NAME, **params) 694 695 696 def _extract_conda_env(env): 697 # In MLflow < 2.0.0, the 'env' field in a pyfunc configuration is a string containing the path 698 # to a conda.yaml file. 699 return env if isinstance(env, str) else env[EnvType.CONDA] 700 701 702 def _load_model_env(path): 703 """ 704 Get ENV file string from a model configuration stored in Python Function format. 705 Returned value is a model-relative path to a Conda Environment file, 706 or None if none was specified at model save time 707 """ 708 return _get_flavor_configuration(model_path=path, flavor_name=FLAVOR_NAME).get(ENV, None) 709 710 711 def _validate_params(params, model_metadata): 712 if hasattr(model_metadata, "get_params_schema"): 713 params_schema = model_metadata.get_params_schema() 714 return _enforce_params_schema(params, params_schema) 715 if params: 716 raise MlflowException.invalid_parameter_value( 717 "This model was not logged with a params schema and does not support " 718 "providing the params argument." 719 "Please log the model with mlflow >= 2.6.0 and specify a params schema.", 720 ) 721 return 722 723 724 def _validate_prediction_input(data: PyFuncInput, params, input_schema, params_schema, flavor=None): 725 """ 726 Internal helper function to transform and validate input data and params for prediction. 727 Any additional transformation logics related to input data and params should be added here. 728 """ 729 if input_schema is not None: 730 try: 731 data = _enforce_schema(data, input_schema, flavor) 732 except Exception as e: 733 if MLFLOW_DISABLE_SCHEMA_DETAILS.get(): 734 message = "Failed to enforce model input schema. Please check your input data." 735 else: 736 # Include error in message for backwards compatibility 737 message = ( 738 f"Failed to enforce schema of data '{data}' " 739 f"with schema '{input_schema}'. " 740 f"Error: {e}" 741 ) 742 # error_code is INVALID_PARAMETER_VALUE but this is a schema enforcement failure 743 raise MlflowException.invalid_parameter_value( 744 message, error_class="SCHEMA_ENFORCEMENT_FAILED" 745 ) 746 params = _enforce_params_schema(params, params_schema) 747 if HAS_PYSPARK and isinstance(data, SparkDataFrame): 748 _logger.warning( 749 "Input data is a Spark DataFrame. Note that behaviour for " 750 "Spark DataFrames is model dependent." 751 ) 752 return data, params 753 754 755 class PyFuncModel: 756 """ 757 MLflow 'python function' model. 758 759 Wrapper around model implementation and metadata. This class is not meant to be constructed 760 directly. Instead, instances of this class are constructed and returned from 761 :py:func:`load_model() <mlflow.pyfunc.load_model>`. 762 763 ``model_impl`` can be any Python object that implements the `Pyfunc interface 764 <https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#pyfunc-inference-api>`_, and is 765 returned by invoking the model's ``loader_module``. 766 767 ``model_meta`` contains model metadata loaded from the MLmodel file. 768 """ 769 770 def __init__( 771 self, 772 model_meta: Model, 773 model_impl: Any, 774 predict_fn: str = "predict", 775 predict_stream_fn: str | None = None, 776 model_id: str | None = None, 777 ): 778 if not hasattr(model_impl, predict_fn): 779 raise MlflowException(f"Model implementation is missing required {predict_fn} method.") 780 if not model_meta: 781 raise MlflowException("Model is missing metadata.") 782 self._model_meta = model_meta 783 self.__model_impl = model_impl 784 self._predict_fn = getattr(model_impl, predict_fn) 785 if predict_stream_fn: 786 if not hasattr(model_impl, predict_stream_fn): 787 raise MlflowException( 788 f"Model implementation is missing required {predict_stream_fn} method." 789 ) 790 self._predict_stream_fn = getattr(model_impl, predict_stream_fn) 791 else: 792 self._predict_stream_fn = None 793 self._model_id = model_id 794 self._input_example = None 795 796 @property 797 @developer_stable 798 def _model_impl(self) -> Any: 799 """ 800 The underlying model implementation object. 801 802 NOTE: This is a stable developer API. 803 """ 804 return self.__model_impl 805 806 @property 807 def model_id(self) -> str | None: 808 """ 809 The model ID of the model. 810 811 Returns: 812 The model ID of the model. 813 """ 814 return self._model_id 815 816 def _update_dependencies_schemas_in_prediction_context(self, context: Context): 817 if self._model_meta and self._model_meta.metadata: 818 dependencies_schemas = self._model_meta.metadata.get("dependencies_schemas", {}) 819 context.update( 820 dependencies_schemas={ 821 dependency: json.dumps(schema) 822 for dependency, schema in dependencies_schemas.items() 823 } 824 ) 825 826 @property 827 def input_example(self) -> Any | None: 828 """ 829 The input example provided when the model was saved. 830 """ 831 return self._input_example 832 833 @input_example.setter 834 def input_example(self, value: Any) -> None: 835 self._input_example = value 836 837 def predict(self, data: PyFuncInput, params: dict[str, Any] | None = None) -> PyFuncOutput: 838 context = _try_get_prediction_context() or Context() 839 with set_prediction_context(context): 840 if schema := _get_dependencies_schema_from_model(self._model_meta): 841 context.update(**schema) 842 843 if self.model_id: 844 context.update(model_id=self.model_id) 845 return self._predict(data, params) 846 847 def _predict(self, data: PyFuncInput, params: dict[str, Any] | None = None) -> PyFuncOutput: 848 """ 849 Generates model predictions. 850 851 If the model contains signature, enforce the input schema first before calling the model 852 implementation with the sanitized input. If the pyfunc model does not include model schema, 853 the input is passed to the model implementation as is. See `Model Signature Enforcement 854 <https://www.mlflow.org/docs/latest/models.html#signature-enforcement>`_ for more details. 855 856 Args: 857 data: LLM Model single input as one of pandas.DataFrame, numpy.ndarray, 858 scipy.sparse.(csc_matrix | csr_matrix), List[Any], or 859 Dict[str, numpy.ndarray]. 860 For model signatures with tensor spec inputs 861 (e.g. the Tensorflow core / Keras model), the input data type must be one of 862 `numpy.ndarray`, `List[numpy.ndarray]`, `Dict[str, numpy.ndarray]` or 863 `pandas.DataFrame`. If data is of `pandas.DataFrame` type and the model 864 contains a signature with tensor spec inputs, the corresponding column values 865 in the pandas DataFrame will be reshaped to the required shape with 'C' order 866 (i.e. read / write the elements using C-like index order), and DataFrame 867 column values will be cast as the required tensor spec type. For Pyspark 868 DataFrame inputs, MLflow will only enforce the schema on a subset 869 of the data rows. 870 params: Additional parameters to pass to the model for inference. 871 872 Returns: 873 Model predictions as one of pandas.DataFrame, pandas.Series, numpy.ndarray or list. 874 """ 875 # fetch the schema from metadata to avoid signature change after model is loaded 876 self.input_schema = self.metadata.get_input_schema() 877 self.params_schema = self.metadata.get_params_schema() 878 # signature can only be inferred from type hints if the model is PythonModel 879 if self.metadata._is_signature_from_type_hint(): 880 # we don't need to validate on data as data validation 881 # will be done during PythonModel's predict call 882 params = _enforce_params_schema(params, self.params_schema) 883 else: 884 data, params = _validate_prediction_input( 885 data, params, self.input_schema, self.params_schema, self.loader_module 886 ) 887 if ( 888 isinstance(data, pandas.DataFrame) 889 and self.metadata._is_type_hint_from_example() 890 and self.input_example is not None 891 ): 892 data = _convert_dataframe_to_example_format(data, self.input_example) 893 params_arg = inspect.signature(self._predict_fn).parameters.get("params") 894 if params_arg and params_arg.kind != inspect.Parameter.VAR_KEYWORD: 895 return self._predict_fn(data, params=params) 896 897 _log_warning_if_params_not_in_predict_signature(_logger, params) 898 return self._predict_fn(data) 899 900 def predict_stream( 901 self, data: PyFuncLLMSingleInput, params: dict[str, Any] | None = None 902 ) -> Iterator[PyFuncLLMOutputChunk]: 903 context = _try_get_prediction_context() or Context() 904 905 if schema := _get_dependencies_schema_from_model(self._model_meta): 906 context.update(**schema) 907 908 if self.model_id: 909 context.update(model_id=self.model_id) 910 911 # NB: The prediction context must be applied during iterating over the stream, 912 # hence, simply wrapping the self._predict_stream call with the context manager 913 # is not sufficient. 914 def _gen_with_context(*args, **kwargs): 915 with set_prediction_context(context): 916 yield from self._predict_stream(*args, **kwargs) 917 918 return _gen_with_context(data, params) 919 920 def _predict_stream( 921 self, data: PyFuncLLMSingleInput, params: dict[str, Any] | None = None 922 ) -> Iterator[PyFuncLLMOutputChunk]: 923 """ 924 Generates streaming model predictions. Only LLM supports this method. 925 926 If the model contains signature, enforce the input schema first before calling the model 927 implementation with the sanitized input. If the pyfunc model does not include model schema, 928 the input is passed to the model implementation as is. See `Model Signature Enforcement 929 <https://www.mlflow.org/docs/latest/models.html#signature-enforcement>`_ for more details. 930 931 Args: 932 data: LLM Model single input as one of dict, str, bool, bytes, float, int, str type. 933 params: Additional parameters to pass to the model for inference. 934 935 Returns: 936 Model predictions as an iterator of chunks. The chunks in the iterator must be type of 937 dict or string. Chunk dict fields are determined by the model implementation. 938 """ 939 940 if self._predict_stream_fn is None: 941 raise MlflowException("This model does not support predict_stream method.") 942 943 self.input_schema = self.metadata.get_input_schema() 944 self.params_schema = self.metadata.get_params_schema() 945 data, params = _validate_prediction_input( 946 data, params, self.input_schema, self.params_schema, self.loader_module 947 ) 948 data = _convert_llm_input_data(data) 949 if isinstance(data, list): 950 # `predict_stream` only accepts single input. 951 # but `enforce_schema` might convert single input into a list like `[single_input]` 952 # so extract the first element in the list. 953 if len(data) != 1: 954 raise MlflowException( 955 f"'predict_stream' requires single input, but it got input data {data}" 956 ) 957 data = data[0] 958 959 if "params" in inspect.signature(self._predict_stream_fn).parameters: 960 return self._predict_stream_fn(data, params=params) 961 962 _log_warning_if_params_not_in_predict_signature(_logger, params) 963 return self._predict_stream_fn(data) 964 965 def unwrap_python_model(self): 966 """ 967 Unwrap the underlying Python model object. 968 969 This method is useful for accessing custom model functions, while still being able to 970 leverage the MLflow designed workflow through the `predict()` method. 971 972 Returns: 973 The underlying wrapped model object 974 975 .. code-block:: python 976 :test: 977 :caption: Example 978 979 import mlflow 980 981 982 # define a custom model 983 class MyModel(mlflow.pyfunc.PythonModel): 984 def predict(self, context, model_input, params=None): 985 return self.my_custom_function(model_input, params) 986 987 def my_custom_function(self, model_input, params=None): 988 # do something with the model input 989 return 0 990 991 992 some_input = 1 993 # save the model 994 with mlflow.start_run(): 995 model_info = mlflow.pyfunc.log_model(name="model", python_model=MyModel()) 996 997 # load the model 998 loaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri) 999 print(type(loaded_model)) # <class 'mlflow.pyfunc.model.PyFuncModel'> 1000 unwrapped_model = loaded_model.unwrap_python_model() 1001 print(type(unwrapped_model)) # <class '__main__.MyModel'> 1002 1003 # does not work, only predict() is exposed 1004 # print(loaded_model.my_custom_function(some_input)) 1005 print(unwrapped_model.my_custom_function(some_input)) # works 1006 print(loaded_model.predict(some_input)) # works 1007 1008 # works, but None is needed for context arg 1009 print(unwrapped_model.predict(None, some_input)) 1010 """ 1011 try: 1012 python_model = self._model_impl.python_model 1013 if python_model is None: 1014 raise AttributeError("Expected python_model attribute not to be None.") 1015 except AttributeError as e: 1016 raise MlflowException("Unable to retrieve base model object from pyfunc.") from e 1017 return python_model 1018 1019 def __eq__(self, other): 1020 if not isinstance(other, PyFuncModel): 1021 return False 1022 return self._model_meta == other._model_meta 1023 1024 @property 1025 def metadata(self) -> Model: 1026 """Model metadata.""" 1027 if self._model_meta is None: 1028 raise MlflowException("Model is missing metadata.") 1029 return self._model_meta 1030 1031 @property 1032 def model_config(self): 1033 """Model's flavor configuration""" 1034 return self._model_meta.flavors[FLAVOR_NAME].get(MODEL_CONFIG, {}) 1035 1036 @property 1037 def loader_module(self): 1038 """Model's flavor configuration""" 1039 if self._model_meta.flavors.get(FLAVOR_NAME) is None: 1040 return None 1041 return self._model_meta.flavors[FLAVOR_NAME].get(MAIN) 1042 1043 def __repr__(self): 1044 info = {} 1045 if self._model_meta is not None: 1046 if hasattr(self._model_meta, "run_id") and self._model_meta.run_id is not None: 1047 info["run_id"] = self._model_meta.run_id 1048 if ( 1049 hasattr(self._model_meta, "artifact_path") 1050 and self._model_meta.artifact_path is not None 1051 ): 1052 info["artifact_path"] = self._model_meta.artifact_path 1053 info["flavor"] = self._model_meta.flavors[FLAVOR_NAME]["loader_module"] 1054 return yaml.safe_dump({"mlflow.pyfunc.loaded_model": info}, default_flow_style=False) 1055 1056 def get_raw_model(self): 1057 """ 1058 Get the underlying raw model if the model wrapper implemented `get_raw_model` function. 1059 """ 1060 if hasattr(self._model_impl, "get_raw_model"): 1061 return self._model_impl.get_raw_model() 1062 raise NotImplementedError("`get_raw_model` is not implemented by the underlying model") 1063 1064 1065 def _get_pip_requirements_from_model_path(model_path: str): 1066 req_file_path = os.path.join(model_path, _REQUIREMENTS_FILE_NAME) 1067 if not os.path.exists(req_file_path): 1068 return [] 1069 1070 return [req.req_str for req in _parse_requirements(req_file_path, is_constraint=False)] 1071 1072 1073 @trace_disabled # Suppress traces while loading model 1074 def load_model( 1075 model_uri: str, 1076 suppress_warnings: bool = False, 1077 dst_path: str | None = None, 1078 model_config: str | Path | dict[str, Any] | None = None, 1079 ) -> PyFuncModel: 1080 """ 1081 Load a model stored in Python function format. 1082 1083 Args: 1084 model_uri: The location, in URI format, of the MLflow model. For example: 1085 1086 - ``/Users/me/path/to/local/model`` 1087 - ``relative/path/to/local/model`` 1088 - ``s3://my_bucket/path/to/model`` 1089 - ``runs:/<mlflow_run_id>/run-relative/path/to/model`` 1090 - ``models:/<model_name>/<model_version>`` 1091 - ``models:/<model_name>/<stage>`` 1092 - ``mlflow-artifacts:/path/to/model`` 1093 1094 For more information about supported URI schemes, see 1095 `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html# 1096 artifact-locations>`_. 1097 suppress_warnings: If ``True``, non-fatal warning messages associated with the model 1098 loading process will be suppressed. If ``False``, these warning messages will be 1099 emitted. 1100 dst_path: The local filesystem path to which to download the model artifact. 1101 This directory must already exist. If unspecified, a local output 1102 path will be created. 1103 model_config: The model configuration to apply to the model. The configuration will 1104 be available as the ``model_config`` property of the ``context`` parameter 1105 in :func:`PythonModel.load_context() <mlflow.pyfunc.PythonModel.load_context>` 1106 and :func:`PythonModel.predict() <mlflow.pyfunc.PythonModel.predict>`. 1107 The configuration can be passed as a file path, or a dict with string keys. 1108 1109 .. Note:: Experimental: This parameter may change or be removed in a future 1110 release without warning. 1111 """ 1112 1113 lineage_header_info = None 1114 if ( 1115 not _MLFLOW_IN_CAPTURE_MODULE_PROCESS.get() 1116 ) and databricks_utils.is_in_databricks_runtime(): 1117 entity_list = [] 1118 # Get notebook id and job id, pack them into lineage_header_info 1119 if databricks_utils.is_in_databricks_notebook() and ( 1120 notebook_id := databricks_utils.get_notebook_id() 1121 ): 1122 notebook_entity = Notebook(id=notebook_id) 1123 entity_list.append(Entity(notebook=notebook_entity)) 1124 1125 if databricks_utils.is_in_databricks_job() and (job_id := databricks_utils.get_job_id()): 1126 job_entity = Job(id=job_id) 1127 entity_list.append(Entity(job=job_entity)) 1128 1129 lineage_header_info = LineageHeaderInfo(entities=entity_list) if entity_list else None 1130 1131 local_path = _download_artifact_from_uri( 1132 artifact_uri=model_uri, output_path=dst_path, lineage_header_info=lineage_header_info 1133 ) 1134 1135 if not suppress_warnings: 1136 model_requirements = _get_pip_requirements_from_model_path(local_path) 1137 warn_dependency_requirement_mismatches(model_requirements) 1138 1139 model_meta = Model.load(os.path.join(local_path, MLMODEL_FILE_NAME)) 1140 1141 if model_meta.metadata and model_meta.metadata.get(MLFLOW_MODEL_IS_EXTERNAL, False) is True: 1142 raise MlflowException( 1143 "This model's artifacts are external and are not stored in the model directory." 1144 " This model cannot be loaded with MLflow.", 1145 BAD_REQUEST, 1146 ) 1147 1148 conf = model_meta.flavors.get(FLAVOR_NAME) 1149 if conf is None: 1150 raise MlflowException( 1151 f'Model does not have the "{FLAVOR_NAME}" flavor', 1152 RESOURCE_DOES_NOT_EXIST, 1153 ) 1154 model_py_version = conf.get(PY_VERSION) 1155 if not suppress_warnings: 1156 _warn_potentially_incompatible_py_version_if_necessary(model_py_version=model_py_version) 1157 1158 _add_code_from_conf_to_system_path(local_path, conf, code_key=CODE) 1159 data_path = os.path.join(local_path, conf[DATA]) if (DATA in conf) else local_path 1160 1161 if isinstance(model_config, str): 1162 model_config = _validate_and_get_model_config_from_file(model_config) 1163 1164 model_config = _get_overridden_pyfunc_model_config( 1165 conf.get(MODEL_CONFIG, None), model_config, _logger 1166 ) 1167 1168 try: 1169 if model_config: 1170 model_impl = importlib.import_module(conf[MAIN])._load_pyfunc(data_path, model_config) 1171 else: 1172 model_impl = importlib.import_module(conf[MAIN])._load_pyfunc(data_path) 1173 except ModuleNotFoundError as e: 1174 # This error message is particularly for the case when the error is caused by module 1175 # "databricks.feature_store.mlflow_model". But depending on the environment, the offending 1176 # module might be "databricks", "databricks.feature_store" or full package. So we will 1177 # raise the error with the following note if "databricks" presents in the error. All non- 1178 # databricks module errors will just be re-raised. 1179 if conf[MAIN] == _DATABRICKS_FS_LOADER_MODULE and e.name.startswith("databricks"): 1180 raise MlflowException( 1181 f"{e.msg}; " 1182 "Note: mlflow.pyfunc.load_model is not supported for Feature Store models. " 1183 "spark_udf() and predict() will not work as expected. Use " 1184 "score_batch for offline predictions.", 1185 BAD_REQUEST, 1186 ) from None 1187 raise e 1188 finally: 1189 # clean up the dependencies schema which is set to global state after loading the model. 1190 # This avoids the schema being used by other models loaded in the same process. 1191 _clear_dependencies_schemas() 1192 predict_fn = conf.get("predict_fn", "predict") 1193 streamable = conf.get("streamable", False) 1194 predict_stream_fn = conf.get("predict_stream_fn", "predict_stream") if streamable else None 1195 1196 pyfunc_model = PyFuncModel( 1197 model_meta=model_meta, 1198 model_impl=model_impl, 1199 predict_fn=predict_fn, 1200 predict_stream_fn=predict_stream_fn, 1201 model_id=model_meta.model_id, 1202 ) 1203 1204 try: 1205 model_input_example = model_meta.load_input_example(path=local_path) 1206 pyfunc_model.input_example = model_input_example 1207 except Exception as e: 1208 _logger.debug(f"Failed to load input example from model metadata: {e}.") 1209 1210 return pyfunc_model 1211 1212 1213 class _ServedPyFuncModel(PyFuncModel): 1214 def __init__(self, model_meta: Model, client: Any, server_pid: int, env_manager="local"): 1215 super().__init__(model_meta=model_meta, model_impl=client, predict_fn="invoke") 1216 self._client = client 1217 self._server_pid = server_pid 1218 # We need to set `env_manager` attribute because it is used by Databricks runtime 1219 # evaluate usage logging to log 'env_manager' tag in `_evaluate` function patching. 1220 self._env_manager = env_manager 1221 1222 def predict(self, data, params=None): 1223 """ 1224 Args: 1225 data: Model input data. 1226 params: Additional parameters to pass to the model for inference. 1227 1228 Returns: 1229 Model predictions. 1230 """ 1231 if "params" in inspect.signature(self._client.invoke).parameters: 1232 result = self._client.invoke(data, params=params).get_predictions() 1233 else: 1234 _log_warning_if_params_not_in_predict_signature(_logger, params) 1235 result = self._client.invoke(data).get_predictions() 1236 if isinstance(result, pandas.DataFrame): 1237 result = result[result.columns[0]] 1238 return result 1239 1240 @property 1241 def pid(self): 1242 if self._server_pid is None: 1243 raise MlflowException("Served PyFunc Model is missing server process ID.") 1244 return self._server_pid 1245 1246 @property 1247 def env_manager(self): 1248 return self._env_manager 1249 1250 @env_manager.setter 1251 def env_manager(self, value): 1252 self._env_manager = value 1253 1254 1255 def _load_model_or_server( 1256 model_uri: str, env_manager: str, model_config: dict[str, Any] | None = None 1257 ): 1258 """ 1259 Load a model with env restoration. If a non-local ``env_manager`` is specified, prepare an 1260 independent Python environment with the training time dependencies of the specified model 1261 installed and start a MLflow Model Scoring Server process with that model in that environment. 1262 Return a _ServedPyFuncModel that invokes the scoring server for prediction. Otherwise, load and 1263 return the model locally as a PyFuncModel using :py:func:`mlflow.pyfunc.load_model`. 1264 1265 Args: 1266 model_uri: The uri of the model. 1267 env_manager: The environment manager to load the model. 1268 model_config: The model configuration to use by the model, only if the model 1269 accepts it. 1270 1271 Returns: 1272 A _ServedPyFuncModel for non-local ``env_manager``s or a PyFuncModel otherwise. 1273 """ 1274 from mlflow.pyfunc.scoring_server.client import ( 1275 ScoringServerClient, 1276 StdinScoringServerClient, 1277 ) 1278 1279 if env_manager == _EnvManager.LOCAL: 1280 return load_model(model_uri, model_config=model_config) 1281 1282 _logger.info("Starting model server for model environment restoration.") 1283 1284 local_path = _download_artifact_from_uri(artifact_uri=model_uri) 1285 model_meta = Model.load(os.path.join(local_path, MLMODEL_FILE_NAME)) 1286 1287 is_port_connectable = check_port_connectivity() 1288 pyfunc_backend = get_flavor_backend( 1289 local_path, 1290 env_manager=env_manager, 1291 install_mlflow=os.environ.get("MLFLOW_HOME") is not None, 1292 create_env_root_dir=not is_port_connectable, 1293 ) 1294 _logger.info("Restoring model environment. This can take a few minutes.") 1295 # Set capture_output to True in Databricks so that when environment preparation fails, the 1296 # exception message of the notebook cell output will include child process command execution 1297 # stdout/stderr output. 1298 pyfunc_backend.prepare_env(model_uri=local_path, capture_output=is_in_databricks_runtime()) 1299 if is_port_connectable: 1300 server_port = find_free_port() 1301 scoring_server_proc = pyfunc_backend.serve( 1302 model_uri=local_path, 1303 port=server_port, 1304 host="127.0.0.1", 1305 timeout=MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT.get(), 1306 enable_mlserver=False, 1307 synchronous=False, 1308 stdout=subprocess.PIPE, 1309 stderr=subprocess.STDOUT, 1310 model_config=model_config, 1311 ) 1312 client = ScoringServerClient("127.0.0.1", server_port) 1313 else: 1314 scoring_server_proc = pyfunc_backend.serve_stdin(local_path, model_config=model_config) 1315 client = StdinScoringServerClient(scoring_server_proc) 1316 1317 _logger.info(f"Scoring server process started at PID: {scoring_server_proc.pid}") 1318 try: 1319 client.wait_server_ready(timeout=90, scoring_server_proc=scoring_server_proc) 1320 except Exception as e: 1321 if scoring_server_proc.poll() is None: 1322 # the scoring server is still running but client can't connect to it. 1323 # kill the server. 1324 scoring_server_proc.kill() 1325 server_output, _ = scoring_server_proc.communicate(timeout=15) 1326 if isinstance(server_output, bytes): 1327 server_output = server_output.decode("UTF-8") 1328 raise MlflowException( 1329 "MLflow model server failed to launch, server process stdout and stderr are:\n" 1330 + server_output 1331 ) from e 1332 1333 return _ServedPyFuncModel( 1334 model_meta=model_meta, 1335 client=client, 1336 server_pid=scoring_server_proc.pid, 1337 env_manager=env_manager, 1338 ) 1339 1340 1341 def _get_model_dependencies(model_uri, format="pip"): 1342 model_dir = _download_artifact_from_uri(model_uri) 1343 1344 def get_conda_yaml_path(): 1345 model_config = _get_flavor_configuration_from_ml_model_file( 1346 os.path.join(model_dir, MLMODEL_FILE_NAME), flavor_name=FLAVOR_NAME 1347 ) 1348 return os.path.join(model_dir, _extract_conda_env(model_config[ENV])) 1349 1350 if format == "pip": 1351 requirements_file = os.path.join(model_dir, _REQUIREMENTS_FILE_NAME) 1352 if os.path.exists(requirements_file): 1353 return requirements_file 1354 1355 _logger.info( 1356 f"{_REQUIREMENTS_FILE_NAME} is not found in the model directory. Falling back to" 1357 f" extracting pip requirements from the model's 'conda.yaml' file. Conda" 1358 " dependencies will be ignored." 1359 ) 1360 1361 with open(get_conda_yaml_path()) as yf: 1362 conda_yaml = yaml.safe_load(yf) 1363 1364 conda_deps = conda_yaml.get("dependencies", []) 1365 for index, dep in enumerate(conda_deps): 1366 if isinstance(dep, dict) and "pip" in dep: 1367 pip_deps_index = index 1368 break 1369 else: 1370 raise MlflowException( 1371 "No pip section found in conda.yaml file in the model directory.", 1372 error_code=RESOURCE_DOES_NOT_EXIST, 1373 ) 1374 1375 pip_deps = conda_deps.pop(pip_deps_index)["pip"] 1376 tmp_dir = tempfile.mkdtemp() 1377 pip_file_path = os.path.join(tmp_dir, _REQUIREMENTS_FILE_NAME) 1378 with open(pip_file_path, "w") as f: 1379 f.write("\n".join(pip_deps) + "\n") 1380 1381 if len(conda_deps) > 0: 1382 _logger.warning( 1383 "The following conda dependencies have been excluded from the environment file:" 1384 f" {', '.join(conda_deps)}." 1385 ) 1386 1387 return pip_file_path 1388 1389 elif format == "conda": 1390 return get_conda_yaml_path() 1391 else: 1392 raise MlflowException( 1393 f"Illegal format argument '{format}'.", error_code=INVALID_PARAMETER_VALUE 1394 ) 1395 1396 1397 def get_model_dependencies(model_uri, format="pip"): 1398 """ 1399 Downloads the model dependencies and returns the path to requirements.txt or conda.yaml file. 1400 1401 .. warning:: 1402 This API downloads all the model artifacts to the local filesystem. This may take 1403 a long time for large models. To avoid this overhead, use 1404 ``mlflow.artifacts.download_artifacts("<model_uri>/requirements.txt")`` or 1405 ``mlflow.artifacts.download_artifacts("<model_uri>/conda.yaml")`` instead. 1406 1407 Args: 1408 model_uri: The uri of the model to get dependencies from. 1409 format: The format of the returned dependency file. If the ``"pip"`` format is 1410 specified, the path to a pip ``requirements.txt`` file is returned. 1411 If the ``"conda"`` format is specified, the path to a ``"conda.yaml"`` 1412 file is returned . If the ``"pip"`` format is specified but the model 1413 was not saved with a ``requirements.txt`` file, the ``pip`` section 1414 of the model's ``conda.yaml`` file is extracted instead, and any 1415 additional conda dependencies are ignored. Default value is ``"pip"``. 1416 1417 Returns: 1418 The local filesystem path to either a pip ``requirements.txt`` file 1419 (if ``format="pip"``) or a ``conda.yaml`` file (if ``format="conda"``) 1420 specifying the model's dependencies. 1421 """ 1422 dep_file = _get_model_dependencies(model_uri, format) 1423 1424 if format == "pip": 1425 prefix = "%" if _is_in_ipython_notebook() else "" 1426 _logger.info( 1427 "To install the dependencies that were used to train the model, run the " 1428 f"following command: '{prefix}pip install -r {dep_file}'." 1429 ) 1430 return dep_file 1431 1432 1433 @deprecated("mlflow.pyfunc.load_model", 1.0) 1434 def load_pyfunc(model_uri, suppress_warnings=False): 1435 """ 1436 Load a model stored in Python function format. 1437 1438 Args: 1439 model_uri: The location, in URI format, of the MLflow model. For example: 1440 1441 - ``/Users/me/path/to/local/model`` 1442 - ``relative/path/to/local/model`` 1443 - ``s3://my_bucket/path/to/model`` 1444 - ``runs:/<mlflow_run_id>/run-relative/path/to/model`` 1445 - ``models:/<model_name>/<model_version>`` 1446 - ``models:/<model_name>/<stage>`` 1447 - ``mlflow-artifacts:/path/to/model`` 1448 1449 For more information about supported URI schemes, see 1450 `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html# 1451 artifact-locations>`_. 1452 1453 suppress_warnings: If ``True``, non-fatal warning messages associated with the model 1454 loading process will be suppressed. If ``False``, these warning messages will be 1455 emitted. 1456 """ 1457 return load_model(model_uri, suppress_warnings) 1458 1459 1460 def _warn_potentially_incompatible_py_version_if_necessary(model_py_version=None): 1461 """ 1462 Compares the version of Python that was used to save a given model with the version 1463 of Python that is currently running. If a major or minor version difference is detected, 1464 logs an appropriate warning. 1465 """ 1466 if model_py_version is None: 1467 _logger.warning( 1468 "The specified model does not have a specified Python version. It may be" 1469 " incompatible with the version of Python that is currently running: Python %s", 1470 PYTHON_VERSION, 1471 ) 1472 elif get_major_minor_py_version(model_py_version) != get_major_minor_py_version(PYTHON_VERSION): 1473 _logger.warning( 1474 "The version of Python that the model was saved in, `Python %s`, differs" 1475 " from the version of Python that is currently running, `Python %s`," 1476 " and may be incompatible", 1477 model_py_version, 1478 PYTHON_VERSION, 1479 ) 1480 1481 1482 def _create_model_downloading_tmp_dir(should_use_nfs): 1483 root_tmp_dir = get_or_create_nfs_tmp_dir() if should_use_nfs else get_or_create_tmp_dir() 1484 1485 root_model_cache_dir = os.path.join(root_tmp_dir, "models") 1486 os.makedirs(root_model_cache_dir, exist_ok=True) 1487 1488 tmp_model_dir = tempfile.mkdtemp(dir=root_model_cache_dir) 1489 # mkdtemp creates a directory with permission 0o700 1490 # For Spark UDFs, we need to make it accessible to other processes 1491 # Use 0o750 (owner: rwx, group: r-x, others: None) instead of 0o770 1492 os.chmod(tmp_model_dir, 0o750) 1493 return tmp_model_dir 1494 1495 1496 _MLFLOW_SERVER_OUTPUT_TAIL_LINES_TO_KEEP = 200 1497 1498 1499 def _is_variant_type(spark_type): 1500 try: 1501 from pyspark.sql.types import VariantType 1502 1503 return isinstance(spark_type, VariantType) 1504 except ImportError: 1505 return False 1506 1507 1508 def _convert_spec_type_to_spark_type(spec_type): 1509 from pyspark.sql.types import ArrayType, MapType, StringType, StructField, StructType 1510 1511 from mlflow.types.schema import AnyType, Array, DataType, Map, Object 1512 1513 if isinstance(spec_type, DataType): 1514 return spec_type.to_spark() 1515 1516 if isinstance(spec_type, AnyType): 1517 try: 1518 from pyspark.sql.types import VariantType 1519 1520 return VariantType() 1521 except ImportError: 1522 raise MlflowException.invalid_parameter_value( 1523 "`AnyType` is not supported in PySpark versions older than 4.0.0. " 1524 "Upgrade your PySpark version to use this feature.", 1525 ) 1526 1527 if isinstance(spec_type, Array): 1528 return ArrayType(_convert_spec_type_to_spark_type(spec_type.dtype)) 1529 1530 if isinstance(spec_type, Object): 1531 return StructType([ 1532 StructField( 1533 property.name, 1534 _convert_spec_type_to_spark_type(property.dtype), 1535 # we set nullable to True for all properties 1536 # to avoid some errors like java.lang.NullPointerException 1537 # when the signature is not inferred based on correct data. 1538 ) 1539 for property in spec_type.properties 1540 ]) 1541 1542 # Map only supports string as key 1543 if isinstance(spec_type, Map): 1544 return MapType( 1545 keyType=StringType(), valueType=_convert_spec_type_to_spark_type(spec_type.value_type) 1546 ) 1547 1548 raise MlflowException(f"Failed to convert schema type `{spec_type}` to spark type.") 1549 1550 1551 def _cast_output_spec_to_spark_type(spec): 1552 from pyspark.sql.types import ArrayType 1553 1554 from mlflow.types.schema import ColSpec, DataType, TensorSpec 1555 1556 # TODO: handle optional output columns. 1557 if isinstance(spec, ColSpec): 1558 return _convert_spec_type_to_spark_type(spec.type) 1559 elif isinstance(spec, TensorSpec): 1560 data_type = DataType.from_numpy_type(spec.type) 1561 if data_type is None: 1562 raise MlflowException( 1563 f"Model output tensor spec type {spec.type} is not supported in spark_udf.", 1564 error_code=INVALID_PARAMETER_VALUE, 1565 ) 1566 1567 if len(spec.shape) == 1: 1568 return ArrayType(data_type.to_spark()) 1569 elif len(spec.shape) == 2: 1570 return ArrayType(ArrayType(data_type.to_spark())) 1571 else: 1572 raise MlflowException( 1573 "Only 1D or 2D tensors are supported as spark_udf " 1574 f"return value, but model output '{spec.name}' has shape {spec.shape}.", 1575 error_code=INVALID_PARAMETER_VALUE, 1576 ) 1577 else: 1578 raise MlflowException( 1579 f"Unknown schema output spec {spec}.", error_code=INVALID_PARAMETER_VALUE 1580 ) 1581 1582 1583 def _infer_spark_udf_return_type(model_output_schema): 1584 from pyspark.sql.types import StructField, StructType 1585 1586 if len(model_output_schema.inputs) == 1: 1587 return _cast_output_spec_to_spark_type(model_output_schema.inputs[0]) 1588 1589 return StructType([ 1590 StructField(name=spec.name or str(i), dataType=_cast_output_spec_to_spark_type(spec)) 1591 for i, spec in enumerate(model_output_schema.inputs) 1592 ]) 1593 1594 1595 def _parse_spark_datatype(datatype: str): 1596 from pyspark.sql.functions import udf 1597 from pyspark.sql.session import SparkSession 1598 1599 return_type = "boolean" if datatype == "bool" else datatype 1600 parsed_datatype = udf(lambda x: x, returnType=return_type).returnType 1601 1602 if parsed_datatype.typeName() == "unparseddata": 1603 # For spark 3.5.x, `udf(lambda x: x, returnType=return_type).returnType` 1604 # returns UnparsedDataType, which is not compatible with signature inference. 1605 # Note: SparkSession.active only exists for spark >= 3.5.0 1606 schema = ( 1607 SparkSession 1608 .active() 1609 .range(0) 1610 .select(udf(lambda x: x, returnType=return_type)("id")) 1611 .schema 1612 ) 1613 return schema[0].dataType 1614 1615 return parsed_datatype 1616 1617 1618 def _is_none_or_nan(value): 1619 # The condition `isinstance(value, float)` is needed to avoid error 1620 # from `np.isnan(value)` if value is a non-numeric type. 1621 return value is None or isinstance(value, float) and np.isnan(value) 1622 1623 1624 def _convert_array_values(values, result_type): 1625 """ 1626 Convert list or numpy array values to spark dataframe column values. 1627 """ 1628 from pyspark.sql.types import ArrayType, StructType 1629 1630 if not isinstance(result_type, ArrayType): 1631 raise MlflowException.invalid_parameter_value( 1632 f"result_type must be ArrayType, got {result_type.simpleString()}", 1633 ) 1634 1635 spark_primitive_type_to_np_type = _get_spark_primitive_type_to_np_type() 1636 1637 if type(result_type.elementType) in spark_primitive_type_to_np_type: 1638 np_type = spark_primitive_type_to_np_type[type(result_type.elementType)] 1639 # For array type result values, if provided value is None or NaN, regard it as a null array. 1640 # see https://github.com/mlflow/mlflow/issues/8986 1641 return None if _is_none_or_nan(values) else np.array(values, dtype=np_type) 1642 if isinstance(result_type.elementType, ArrayType): 1643 return [_convert_array_values(v, result_type.elementType) for v in values] 1644 if isinstance(result_type.elementType, StructType): 1645 return [_convert_struct_values(v, result_type.elementType) for v in values] 1646 if _is_variant_type(result_type.elementType): 1647 return values 1648 1649 raise MlflowException.invalid_parameter_value( 1650 "Unsupported array type field with element type " 1651 f"{result_type.elementType.simpleString()} in Array type.", 1652 ) 1653 1654 1655 def _get_spark_primitive_types(): 1656 from pyspark.sql import types 1657 1658 return ( 1659 types.IntegerType, 1660 types.LongType, 1661 types.FloatType, 1662 types.DoubleType, 1663 types.StringType, 1664 types.BooleanType, 1665 ) 1666 1667 1668 def _get_spark_primitive_type_to_np_type(): 1669 from pyspark.sql import types 1670 1671 return { 1672 types.IntegerType: np.int32, 1673 types.LongType: np.int64, 1674 types.FloatType: np.float32, 1675 types.DoubleType: np.float64, 1676 types.BooleanType: np.bool_, 1677 types.StringType: np.str_, 1678 } 1679 1680 1681 def _get_spark_primitive_type_to_python_type(): 1682 from pyspark.sql import types 1683 1684 return { 1685 types.IntegerType: int, 1686 types.LongType: int, 1687 types.FloatType: float, 1688 types.DoubleType: float, 1689 types.BooleanType: bool, 1690 types.StringType: str, 1691 } 1692 1693 1694 def _check_udf_return_type(data_type): 1695 from pyspark.sql.types import ArrayType, MapType, StringType, StructType 1696 1697 primitive_types = _get_spark_primitive_types() 1698 if isinstance(data_type, primitive_types): 1699 return True 1700 1701 if isinstance(data_type, ArrayType): 1702 return _check_udf_return_type(data_type.elementType) 1703 1704 if isinstance(data_type, StructType): 1705 return all(_check_udf_return_type(field.dataType) for field in data_type.fields) 1706 1707 if isinstance(data_type, MapType): 1708 return isinstance(data_type.keyType, StringType) and _check_udf_return_type( 1709 data_type.valueType 1710 ) 1711 1712 return False 1713 1714 1715 def _convert_struct_values( 1716 result: pandas.DataFrame | dict[str, Any], 1717 result_type, 1718 ): 1719 """ 1720 Convert spark StructType values to spark dataframe column values. 1721 """ 1722 1723 from pyspark.sql.types import ArrayType, MapType, StructType 1724 1725 if not isinstance(result_type, StructType): 1726 raise MlflowException.invalid_parameter_value( 1727 f"result_type must be StructType, got {result_type.simpleString()}", 1728 ) 1729 1730 if not isinstance(result, (dict, pandas.DataFrame)): 1731 raise MlflowException.invalid_parameter_value( 1732 f"Unsupported result type {type(result)}, expected dict or pandas DataFrame", 1733 ) 1734 1735 spark_primitive_type_to_np_type = _get_spark_primitive_type_to_np_type() 1736 is_pandas_df = isinstance(result, pandas.DataFrame) 1737 result_dict = {} 1738 for field_name in result_type.fieldNames(): 1739 field_type = result_type[field_name].dataType 1740 field_values = result[field_name] 1741 1742 if type(field_type) in spark_primitive_type_to_np_type: 1743 np_type = spark_primitive_type_to_np_type[type(field_type)] 1744 if is_pandas_df: 1745 # it's possible that field_values contain only Nones 1746 # in this case, we don't need to cast the type 1747 if not all(_is_none_or_nan(field_value) for field_value in field_values): 1748 field_values = field_values.astype(np_type) 1749 else: 1750 field_values = ( 1751 None 1752 if _is_none_or_nan(field_values) 1753 else np.array(field_values, dtype=np_type).item() 1754 ) 1755 elif isinstance(field_type, ArrayType): 1756 if is_pandas_df: 1757 field_values = pandas.Series( 1758 _convert_array_values(field_value, field_type) for field_value in field_values 1759 ) 1760 else: 1761 field_values = _convert_array_values(field_values, field_type) 1762 elif isinstance(field_type, StructType): 1763 if is_pandas_df: 1764 field_values = pandas.Series([ 1765 _convert_struct_values(field_value, field_type) for field_value in field_values 1766 ]) 1767 else: 1768 if isinstance(field_values, pydantic.BaseModel): 1769 field_values = field_values.model_dump() 1770 field_values = _convert_struct_values(field_values, field_type) 1771 elif isinstance(field_type, MapType): 1772 if is_pandas_df: 1773 field_values = pandas.Series([ 1774 { 1775 key: _convert_value_based_on_spark_type(value, field_type.valueType) 1776 for key, value in field_value.items() 1777 } 1778 for field_value in field_values 1779 ]).astype(object) 1780 else: 1781 field_values = { 1782 key: _convert_value_based_on_spark_type(value, field_type.valueType) 1783 for key, value in field_values.items() 1784 } 1785 elif _is_variant_type(field_type): 1786 return field_values 1787 else: 1788 raise MlflowException.invalid_parameter_value( 1789 f"Unsupported field type {field_type.simpleString()} in struct type.", 1790 ) 1791 result_dict[field_name] = field_values 1792 1793 if is_pandas_df: 1794 return pandas.DataFrame(result_dict) 1795 return result_dict 1796 1797 1798 def _convert_value_based_on_spark_type(value, spark_type): 1799 """ 1800 Convert value to python types based on the given spark type. 1801 """ 1802 1803 from pyspark.sql.types import ArrayType, MapType, StructType 1804 1805 spark_primitive_type_to_python_type = _get_spark_primitive_type_to_python_type() 1806 1807 if type(spark_type) in spark_primitive_type_to_python_type: 1808 python_type = spark_primitive_type_to_python_type[type(spark_type)] 1809 return None if _is_none_or_nan(value) else python_type(value) 1810 if isinstance(spark_type, StructType): 1811 return _convert_struct_values(value, spark_type) 1812 if isinstance(spark_type, ArrayType): 1813 return [_convert_value_based_on_spark_type(v, spark_type.elementType) for v in value] 1814 if isinstance(spark_type, MapType): 1815 return { 1816 key: _convert_value_based_on_spark_type(value[key], spark_type.valueType) 1817 for key in value 1818 } 1819 if _is_variant_type(spark_type): 1820 return value 1821 raise MlflowException.invalid_parameter_value( 1822 f"Unsupported type {spark_type} for value {value}" 1823 ) 1824 1825 1826 # This location is used to prebuild python environment in Databricks runtime. 1827 # The location for prebuilding env should be located under /local_disk0 1828 # because the python env will be uploaded to NFS and mounted to Serverless UDF sandbox, 1829 # for serverless client image case, it doesn't have "/local_disk0" directory 1830 _PREBUILD_ENV_ROOT_LOCATION = "/tmp" 1831 1832 1833 def _gen_prebuilt_env_archive_name(spark, local_model_path): 1834 """ 1835 Generate prebuilt env archive file name. 1836 The format is: 1837 'mlflow-{sha of python env config and dependencies}-{runtime version}-{platform machine}' 1838 Note: The runtime version and platform machine information are included in the 1839 archive name because the prebuilt env might not be compatible across different 1840 runtime versions or platform machines. 1841 """ 1842 python_env = _get_python_env(Path(local_model_path)) 1843 env_name = _get_virtualenv_name(python_env, local_model_path) 1844 dbconnect_udf_sandbox_info = get_dbconnect_udf_sandbox_info(spark) 1845 return ( 1846 f"{env_name}-{dbconnect_udf_sandbox_info.image_version}-" 1847 f"{dbconnect_udf_sandbox_info.platform_machine}" 1848 ) 1849 1850 1851 def _verify_prebuilt_env(spark, local_model_path, env_archive_path): 1852 # Use `[:-7]` to truncate ".tar.gz" in the end 1853 archive_name = os.path.basename(env_archive_path)[:-7] 1854 prebuilt_env_sha, prebuilt_runtime_version, prebuilt_platform_machine = archive_name.split("-")[ 1855 -3: 1856 ] 1857 1858 python_env = _get_python_env(Path(local_model_path)) 1859 env_sha = _get_virtualenv_name(python_env, local_model_path).split("-")[-1] 1860 dbconnect_udf_sandbox_info = get_dbconnect_udf_sandbox_info(spark) 1861 runtime_version = dbconnect_udf_sandbox_info.image_version 1862 platform_machine = dbconnect_udf_sandbox_info.platform_machine 1863 1864 if prebuilt_env_sha != env_sha: 1865 raise MlflowException( 1866 f"The prebuilt env '{env_archive_path}' does not match the model required environment." 1867 ) 1868 if prebuilt_runtime_version != runtime_version: 1869 raise MlflowException( 1870 f"The prebuilt env '{env_archive_path}' runtime version '{prebuilt_runtime_version}' " 1871 f"does not match UDF sandbox runtime version {runtime_version}." 1872 ) 1873 if prebuilt_platform_machine != platform_machine: 1874 raise MlflowException( 1875 f"The prebuilt env '{env_archive_path}' platform machine '{prebuilt_platform_machine}' " 1876 f"does not match UDF sandbox platform machine {platform_machine}." 1877 ) 1878 1879 1880 def _prebuild_env_internal(local_model_path, archive_name, save_path, env_manager): 1881 env_root_dir = os.path.join(_PREBUILD_ENV_ROOT_LOCATION, archive_name) 1882 archive_path = os.path.join(save_path, archive_name + ".tar.gz") 1883 if os.path.exists(env_root_dir): 1884 shutil.rmtree(env_root_dir) 1885 if os.path.exists(archive_path): 1886 os.remove(archive_path) 1887 1888 try: 1889 pyfunc_backend = get_flavor_backend( 1890 local_model_path, 1891 env_manager=env_manager, 1892 install_mlflow=False, 1893 create_env_root_dir=False, 1894 env_root_dir=env_root_dir, 1895 ) 1896 1897 pyfunc_backend.prepare_env(model_uri=local_model_path, capture_output=False) 1898 # exclude pip cache from the archive file. 1899 cache_path = os.path.join(env_root_dir, "pip_cache_pkgs") 1900 if os.path.exists(cache_path): 1901 shutil.rmtree(cache_path) 1902 1903 return archive_directory(env_root_dir, archive_path) 1904 finally: 1905 shutil.rmtree(env_root_dir, ignore_errors=True) 1906 1907 1908 def _download_prebuilt_env_if_needed(prebuilt_env_uri): 1909 from mlflow.utils.file_utils import get_or_create_tmp_dir 1910 1911 parsed_url = urlparse(prebuilt_env_uri) 1912 if parsed_url.scheme in {"", "file"}: 1913 # local path 1914 return parsed_url.path 1915 if parsed_url.scheme == "dbfs": 1916 tmp_dir = MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR.get() or get_or_create_tmp_dir() 1917 model_env_uc_path = parsed_url.path 1918 1919 # download file from DBFS. 1920 local_model_env_path = os.path.join(tmp_dir, os.path.basename(model_env_uc_path)) 1921 if os.path.exists(local_model_env_path): 1922 # file is already downloaded. 1923 return local_model_env_path 1924 1925 try: 1926 from databricks.sdk import WorkspaceClient 1927 1928 ws = WorkspaceClient() 1929 # Download model env file from UC volume. 1930 with ( 1931 ws.files.download(model_env_uc_path).contents as rf, 1932 open(local_model_env_path, "wb") as wf, 1933 ): 1934 while chunk := rf.read(4096 * 1024): 1935 wf.write(chunk) 1936 return local_model_env_path 1937 except (Exception, KeyboardInterrupt): 1938 if os.path.exists(local_model_env_path): 1939 # clean the partially saved file if downloading fails. 1940 os.remove(local_model_env_path) 1941 raise 1942 1943 raise MlflowException( 1944 f"Unsupported prebuilt env file path '{prebuilt_env_uri}', " 1945 f"invalid scheme: '{parsed_url.scheme}'." 1946 ) 1947 1948 1949 def build_model_env(model_uri, save_path, env_manager=_EnvManager.VIRTUALENV): 1950 """ 1951 Prebuild model python environment and generate an archive file saved to provided 1952 `save_path`. 1953 1954 Typical usages: 1955 - Pre-build a model's environment in Databricks Runtime and then download the prebuilt 1956 python environment archive file. This pre-built environment archive can then be used 1957 in `mlflow.pyfunc.spark_udf` for remote inference execution when using Databricks Connect 1958 to remotely connect to a Databricks environment for code execution. 1959 1960 .. note:: 1961 The `build_model_env` API is intended to only work when executed within Databricks runtime, 1962 serving the purpose of capturing the required execution environment that is needed for 1963 remote code execution when using DBConnect. The environment archive is designed to be used 1964 when performing remote execution using `mlflow.pyfunc.spark_udf` in 1965 Databricks runtime or Databricks Connect client and has no other purpose. 1966 The prebuilt env archive file cannot be used across different Databricks runtime 1967 versions or different platform machines. As such, if you connect to a different cluster 1968 that is running a different runtime version on Databricks, you will need to execute this 1969 API in a notebook and retrieve the generated archive to your local machine. Each 1970 environment snapshot is unique to the the model, the runtime version of your remote 1971 Databricks cluster, and the specification of the udf execution environment. 1972 When using the prebuilt env in `mlflow.pyfunc.spark_udf`, MLflow will verify 1973 whether the spark UDF sandbox environment matches the prebuilt env requirements and will 1974 raise Exceptions if there are compatibility issues. If these occur, simply re-run this API 1975 in the cluster that you are attempting to attach to. 1976 1977 .. code-block:: python 1978 :caption: Example 1979 1980 from mlflow.pyfunc import build_model_env 1981 1982 # Create a python environment archive file at the path `prebuilt_env_uri` 1983 prebuilt_env_uri = build_model_env(f"runs:/{run_id}/model", "/path/to/save_directory") 1984 1985 Args: 1986 model_uri: URI to the model that is used to build the python environment. 1987 save_path: The directory path that is used to save the prebuilt model environment 1988 archive file path. 1989 The path can be either local directory path or 1990 mounted DBFS path such as '/dbfs/...' or 1991 mounted UC volume path such as '/Volumes/...'. 1992 env_manager: The environment manager to use in order to create the python environment 1993 for model inference, the value can be either 'virtualenv' or 'uv', the default 1994 value is 'virtualenv'. 1995 1996 Returns: 1997 Return the path of an archive file containing the python environment data. 1998 """ 1999 from mlflow.utils._spark_utils import _get_active_spark_session 2000 2001 if not is_in_databricks_runtime(): 2002 raise RuntimeError("'build_model_env' only support running in Databricks runtime.") 2003 2004 if os.path.isfile(save_path): 2005 raise RuntimeError(f"The saving path '{save_path}' must be a directory.") 2006 os.makedirs(save_path, exist_ok=True) 2007 2008 local_model_path = _download_artifact_from_uri( 2009 artifact_uri=model_uri, output_path=_create_model_downloading_tmp_dir(should_use_nfs=False) 2010 ) 2011 archive_name = _gen_prebuilt_env_archive_name(_get_active_spark_session(), local_model_path) 2012 dest_path = os.path.join(save_path, archive_name + ".tar.gz") 2013 if os.path.exists(dest_path): 2014 raise RuntimeError( 2015 "A pre-built model python environment already exists " 2016 f"in '{dest_path}'. To rebuild it, please remove " 2017 "the existing one first." 2018 ) 2019 2020 # Archive the environment directory as a `tar.gz` format archive file, 2021 # and then move the archive file to the destination directory. 2022 # Note: 2023 # - all symlink files in the input directory are kept as it is in the 2024 # archive file. 2025 # - the destination directory could be UC-volume fuse mounted directory 2026 # which only supports limited filesystem operations, so to ensure it works, 2027 # we generate the archive file under /tmp and then move it into the 2028 # destination directory. 2029 tmp_archive_path = None 2030 try: 2031 tmp_archive_path = _prebuild_env_internal( 2032 local_model_path, archive_name, _PREBUILD_ENV_ROOT_LOCATION, env_manager 2033 ) 2034 shutil.move(tmp_archive_path, save_path) 2035 return dest_path 2036 finally: 2037 shutil.rmtree(local_model_path, ignore_errors=True) 2038 if tmp_archive_path and os.path.exists(tmp_archive_path): 2039 os.remove(tmp_archive_path) 2040 2041 2042 def spark_udf( 2043 spark, 2044 model_uri, 2045 result_type=None, 2046 env_manager=None, 2047 params: dict[str, Any] | None = None, 2048 extra_env: dict[str, str] | None = None, 2049 prebuilt_env_uri: str | None = None, 2050 model_config: str | Path | dict[str, Any] | None = None, 2051 ): 2052 """ 2053 A Spark UDF that can be used to invoke the Python function formatted model. 2054 2055 Parameters passed to the UDF are forwarded to the model as a DataFrame where the column names 2056 are ordinals (0, 1, ...). On some versions of Spark (3.0 and above), it is also possible to 2057 wrap the input in a struct. In that case, the data will be passed as a DataFrame with column 2058 names given by the struct definition (e.g. when invoked as my_udf(struct('x', 'y')), the model 2059 will get the data as a pandas DataFrame with 2 columns 'x' and 'y'). 2060 2061 If a model contains a signature with tensor spec inputs, you will need to pass a column of 2062 array type as a corresponding UDF argument. The column values of which must be one dimensional 2063 arrays. The UDF will reshape the column values to the required shape with 'C' order 2064 (i.e. read / write the elements using C-like index order) and cast the values as the required 2065 tensor spec type. 2066 2067 If a model contains a signature, the UDF can be called without specifying column name 2068 arguments. In this case, the UDF will be called with column names from signature, so the 2069 evaluation dataframe's column names must match the model signature's column names. 2070 2071 The predictions are filtered to contain only the columns that can be represented as the 2072 ``result_type``. If the ``result_type`` is string or array of strings, all predictions are 2073 converted to string. If the result type is not an array type, the left most column with 2074 matching type is returned. 2075 2076 .. note:: 2077 Inputs of type ``pyspark.sql.types.DateType`` are not supported on earlier versions of 2078 Spark (2.4 and below). 2079 2080 .. note:: 2081 When using Databricks Connect to connect to a remote Databricks cluster, 2082 the Databricks cluster must use runtime version >= 16, and if the 'prebuilt_env_uri' 2083 parameter is set, 'env_manager' parameter should not be set. 2084 the Databricks cluster must use runtime version >= 15.4,and if the 'prebuilt_env_uri' 2085 parameter is set, 'env_manager' parameter should not be set, 2086 if the runtime version is 15.4 and the cluster is 2087 standard access mode, the cluster need to configure 2088 "spark.databricks.safespark.archive.artifact.unpack.disabled" to "false". 2089 2090 .. note:: 2091 Please be aware that when operating in Databricks Serverless, 2092 spark tasks run within the confines of the Databricks Serverless UDF sandbox. 2093 This environment has a total capacity limit of 1GB, combining both available 2094 memory and local disk capacity. Furthermore, there are no GPU devices available 2095 in this setup. Therefore, any deep-learning models that contain large weights 2096 or require a GPU are not suitable for deployment on Databricks Serverless. 2097 2098 .. code-block:: python 2099 :caption: Example 2100 2101 from pyspark.sql.functions import struct 2102 2103 predict = mlflow.pyfunc.spark_udf(spark, "/my/local/model") 2104 df.withColumn("prediction", predict(struct("name", "age"))).show() 2105 2106 Args: 2107 spark: A SparkSession object. 2108 model_uri: The location, in URI format, of the MLflow model with the 2109 :py:mod:`mlflow.pyfunc` flavor. For example: 2110 2111 - ``/Users/me/path/to/local/model`` 2112 - ``relative/path/to/local/model`` 2113 - ``s3://my_bucket/path/to/model`` 2114 - ``runs:/<mlflow_run_id>/run-relative/path/to/model`` 2115 - ``models:/<model_name>/<model_version>`` 2116 - ``models:/<model_name>/<stage>`` 2117 - ``mlflow-artifacts:/path/to/model`` 2118 2119 For more information about supported URI schemes, see 2120 `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html# 2121 artifact-locations>`_. 2122 2123 result_type: the return type of the user-defined function. The value can be either a 2124 ``pyspark.sql.types.DataType`` object or a DDL-formatted type string. Only a primitive 2125 type, an array ``pyspark.sql.types.ArrayType`` of primitive type, or a struct type 2126 containing fields of above 2 kinds of types are allowed. 2127 If unspecified, it tries to infer result type from model signature 2128 output schema, if model output schema is not available, it fallbacks to use ``double`` 2129 type. 2130 2131 The following classes of result type are supported: 2132 2133 - "int" or ``pyspark.sql.types.IntegerType``: The leftmost integer that can fit in an 2134 ``int32`` or an exception if there is none. 2135 2136 - "long" or ``pyspark.sql.types.LongType``: The leftmost long integer that can fit in an 2137 ``int64`` or an exception if there is none. 2138 2139 - ``ArrayType(IntegerType|LongType)``: All integer columns that can fit into the 2140 requested size. 2141 2142 - "float" or ``pyspark.sql.types.FloatType``: The leftmost numeric result cast to 2143 ``float32`` or an exception if there is none. 2144 2145 - "double" or ``pyspark.sql.types.DoubleType``: The leftmost numeric result cast to 2146 ``double`` or an exception if there is none. 2147 2148 - ``ArrayType(FloatType|DoubleType)``: All numeric columns cast to the requested type or 2149 an exception if there are no numeric columns. 2150 2151 - "string" or ``pyspark.sql.types.StringType``: The leftmost column converted to 2152 ``string``. 2153 2154 - "boolean" or "bool" or ``pyspark.sql.types.BooleanType``: The leftmost column 2155 converted to ``bool`` or an exception if there is none. 2156 2157 - ``ArrayType(StringType)``: All columns converted to ``string``. 2158 2159 - "field1 FIELD1_TYPE, field2 FIELD2_TYPE, ...": A struct type containing multiple 2160 fields separated by comma, each field type must be one of types listed above. 2161 2162 env_manager: The environment manager to use in order to create the python environment 2163 for model inference. Note that environment is only restored in the context 2164 of the PySpark UDF; the software environment outside of the UDF is 2165 unaffected. If `prebuilt_env_uri` parameter is not set, the default value 2166 is ``local``, and the following values are supported: 2167 2168 - ``virtualenv``: Use virtualenv to restore the python environment that 2169 was used to train the model. This is the default option if ``env_manager`` 2170 is not set. 2171 - ``uv`` : Use uv to restore the python environment that 2172 was used to train the model. 2173 - ``conda``: Use Conda to restore the software environment 2174 that was used to train the model. 2175 - ``local``: Use the current Python environment for model inference, which 2176 may differ from the environment used to train the model and may lead to 2177 errors or invalid predictions. 2178 2179 If the `prebuilt_env_uri` parameter is set, `env_manager` parameter should not 2180 be set. 2181 2182 params: Additional parameters to pass to the model for inference. 2183 2184 extra_env: Extra environment variables to pass to the UDF executors. 2185 For overrides that need to propagate to the Spark workers (i.e., 2186 overriding the scoring server timeout via `MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT`). 2187 2188 prebuilt_env_uri: The path of the prebuilt env archive file created by 2189 `mlflow.pyfunc.build_model_env` API. 2190 This parameter can only be used in Databricks Serverless notebook REPL, 2191 Databricks Shared cluster notebook REPL, and Databricks Connect client 2192 environment. 2193 The path can be either local file path or DBFS path such as 2194 'dbfs:/Volumes/...', in this case, MLflow automatically downloads it 2195 to local temporary directory, "MLFLOW_MODEL_ENV_DOWNLOADING_TEMP_DIR" 2196 environmental variable can be set to specify the temporary directory 2197 to use. 2198 2199 If this parameter is set, `env_manger` parameter must not be set. 2200 2201 model_config: The model configuration to set when loading the model. 2202 See 'model_config' argument in `mlflow.pyfunc.load_model` API for details. 2203 2204 Returns: 2205 Spark UDF that applies the model's ``predict`` method to the data and returns a 2206 type specified by ``result_type``, which by default is a double. 2207 """ 2208 2209 # Scope Spark import to this method so users don't need pyspark to use non-Spark-related 2210 # functionality. 2211 from pyspark.sql.functions import pandas_udf 2212 from pyspark.sql.types import ( 2213 ArrayType, 2214 BooleanType, 2215 DoubleType, 2216 FloatType, 2217 IntegerType, 2218 LongType, 2219 MapType, 2220 StringType, 2221 ) 2222 from pyspark.sql.types import StructType as SparkStructType 2223 2224 from mlflow.pyfunc.spark_model_cache import SparkModelCache 2225 from mlflow.utils._spark_utils import _SparkDirectoryDistributor 2226 2227 is_spark_connect = is_spark_connect_mode() 2228 # Used in test to force install local version of mlflow when starting a model server 2229 mlflow_home = os.environ.get("MLFLOW_HOME") 2230 openai_env_vars = mlflow.openai.model._OpenAIEnvVar.read_environ() 2231 mlflow_testing = _MLFLOW_TESTING.get_raw() 2232 2233 if prebuilt_env_uri: 2234 if env_manager not in (None, _EnvManager.VIRTUALENV, _EnvManager.UV): 2235 raise MlflowException( 2236 "If 'prebuilt_env_uri' parameter is set, 'env_manager' parameter must " 2237 "be either None, 'virtualenv', or 'uv'." 2238 ) 2239 env_manager = _EnvManager.VIRTUALENV 2240 else: 2241 env_manager = env_manager or _EnvManager.LOCAL 2242 2243 _EnvManager.validate(env_manager) 2244 2245 if is_spark_connect: 2246 is_spark_in_local_mode = False 2247 else: 2248 # Check whether spark is in local or local-cluster mode 2249 # this case all executors and driver share the same filesystem 2250 is_spark_in_local_mode = spark.conf.get("spark.master").startswith("local") 2251 2252 is_dbconnect_mode = is_databricks_connect(spark) 2253 if prebuilt_env_uri is not None and not is_dbconnect_mode: 2254 raise RuntimeError( 2255 "'prebuilt_env' parameter can only be used in Databricks Serverless " 2256 "notebook REPL, atabricks Shared cluster notebook REPL, and Databricks Connect client " 2257 "environment." 2258 ) 2259 2260 if prebuilt_env_uri is None and is_dbconnect_mode and not is_in_databricks_runtime(): 2261 raise RuntimeError( 2262 "'prebuilt_env_uri' param is required if using Databricks Connect to connect " 2263 "to Databricks cluster from your own machine." 2264 ) 2265 2266 # Databricks connect can use `spark.addArtifact` to upload artifact to NFS. 2267 # But for Databricks shared cluster runtime, it can directly write to NFS, so exclude it. 2268 # For Databricks Serverless runtime (notebook REPL), spark.addArtifact may not reliably 2269 # make archives available to UDF executor sandboxes. As a temporary workaround, set 2270 # _MLFLOW_SPARK_UDF_SERVERLESS_SKIP_DBCONNECT_ARTIFACT=true to skip the addArtifact path 2271 # and let each executor fetch the model directly from the MLflow artifact store instead. 2272 use_dbconnect_artifact = ( 2273 is_dbconnect_mode 2274 and not is_in_databricks_shared_cluster_runtime() 2275 and not ( 2276 is_in_databricks_serverless_runtime() 2277 and _MLFLOW_SPARK_UDF_SERVERLESS_SKIP_DBCONNECT_ARTIFACT.get() 2278 ) 2279 ) 2280 2281 if use_dbconnect_artifact: 2282 udf_sandbox_info = get_dbconnect_udf_sandbox_info(spark) 2283 if Version(udf_sandbox_info.mlflow_version) < Version("2.18.0"): 2284 raise MlflowException( 2285 "Using 'mlflow.pyfunc.spark_udf' in Databricks Serverless or in remote " 2286 "Databricks Connect requires UDF sandbox image installed with MLflow " 2287 "of version >= 2.18.0" 2288 ) 2289 # `udf_sandbox_info.runtime_version` format is like '<major_version>.<minor_version>'. 2290 # It's safe to apply `Version`. 2291 dbr_runtime_version = Version(udf_sandbox_info.runtime_version) 2292 if dbr_runtime_version < Version("15.4"): 2293 raise MlflowException( 2294 "Using 'mlflow.pyfunc.spark_udf' in Databricks Serverless or in remote " 2295 "Databricks Connect requires Databricks runtime version >= 15.4." 2296 ) 2297 if dbr_runtime_version == Version("15.4"): 2298 if spark.conf.get("spark.databricks.pyspark.udf.isolation.enabled").lower() == "true": 2299 # The connected cluster is standard (shared) mode. 2300 if ( 2301 spark.conf.get( 2302 "spark.databricks.safespark.archive.artifact.unpack.disabled" 2303 ).lower() 2304 != "false" 2305 ): 2306 raise MlflowException( 2307 "Using 'mlflow.pyfunc.spark_udf' in remote Databricks Connect requires " 2308 "Databricks cluster setting " 2309 "'spark.databricks.safespark.archive.artifact.unpack.disabled' to 'false' " 2310 "if Databricks runtime version is 15.4" 2311 ) 2312 2313 nfs_root_dir = get_nfs_cache_root_dir() 2314 should_use_nfs = nfs_root_dir is not None 2315 2316 should_use_spark_to_broadcast_file = not ( 2317 is_spark_in_local_mode or should_use_nfs or is_spark_connect or use_dbconnect_artifact 2318 ) 2319 2320 # For spark connect mode, 2321 # If client code is executed in databricks runtime and NFS is available, 2322 # we save model to NFS temp directory in the driver 2323 # and load the model in the executor. 2324 should_spark_connect_use_nfs = is_in_databricks_runtime() and should_use_nfs 2325 2326 if ( 2327 is_spark_connect 2328 and not is_dbconnect_mode 2329 and env_manager in (_EnvManager.VIRTUALENV, _EnvManager.CONDA, _EnvManager.UV) 2330 ): 2331 raise MlflowException.invalid_parameter_value( 2332 f"Environment manager {env_manager!r} is not supported in Spark Connect " 2333 "client environment if it connects to non-Databricks Spark cluster.", 2334 ) 2335 2336 local_model_path = _download_artifact_from_uri( 2337 artifact_uri=model_uri, 2338 output_path=_create_model_downloading_tmp_dir(should_use_nfs), 2339 ) 2340 2341 if prebuilt_env_uri: 2342 prebuilt_env_uri = _download_prebuilt_env_if_needed(prebuilt_env_uri) 2343 _verify_prebuilt_env(spark, local_model_path, prebuilt_env_uri) 2344 if use_dbconnect_artifact and env_manager == _EnvManager.CONDA: 2345 raise MlflowException( 2346 "Databricks connect mode or Databricks Serverless python REPL doesn't " 2347 "support env_manager 'conda'." 2348 ) 2349 2350 if env_manager == _EnvManager.LOCAL: 2351 # Assume spark executor python environment is the same with spark driver side. 2352 model_requirements = _get_pip_requirements_from_model_path(local_model_path) 2353 warn_dependency_requirement_mismatches(model_requirements) 2354 _logger.warning( 2355 'Calling `spark_udf()` with `env_manager="local"` does not recreate the same ' 2356 "environment that was used during training, which may lead to errors or inaccurate " 2357 'predictions. We recommend specifying `env_manager="conda"`, which automatically ' 2358 "recreates the environment that was used to train the model and performs inference " 2359 "in the recreated environment." 2360 ) 2361 else: 2362 _logger.info( 2363 f"This UDF will use {env_manager} to recreate the model's software environment for " 2364 "inference. This may take extra time during execution." 2365 ) 2366 if not sys.platform.startswith("linux"): 2367 # TODO: support killing mlflow server launched in UDF task when spark job canceled 2368 # for non-linux system. 2369 # https://stackoverflow.com/questions/53208/how-do-i-automatically-destroy-child-processes-in-windows 2370 _logger.warning( 2371 "In order to run inference code in restored python environment, PySpark UDF " 2372 "processes spawn MLflow Model servers as child processes. Due to system " 2373 "limitations with handling SIGKILL signals, these MLflow Model server child " 2374 "processes cannot be cleaned up if the Spark Job is canceled." 2375 ) 2376 2377 if prebuilt_env_uri: 2378 env_cache_key = os.path.basename(prebuilt_env_uri)[:-7] 2379 elif use_dbconnect_artifact: 2380 env_cache_key = _gen_prebuilt_env_archive_name(spark, local_model_path) 2381 else: 2382 env_cache_key = None 2383 2384 if use_dbconnect_artifact or prebuilt_env_uri is not None: 2385 prebuilt_env_root_dir = os.path.join(_PREBUILD_ENV_ROOT_LOCATION, env_cache_key) 2386 pyfunc_backend_env_root_config = { 2387 "create_env_root_dir": False, 2388 "env_root_dir": prebuilt_env_root_dir, 2389 } 2390 else: 2391 pyfunc_backend_env_root_config = {"create_env_root_dir": True} 2392 pyfunc_backend = get_flavor_backend( 2393 local_model_path, 2394 env_manager=env_manager, 2395 install_mlflow=os.environ.get("MLFLOW_HOME") is not None, 2396 **pyfunc_backend_env_root_config, 2397 ) 2398 dbconnect_artifact_cache = DBConnectArtifactCache.get_or_create(spark) 2399 2400 if use_dbconnect_artifact: 2401 # Upload model artifacts and python environment to NFS as DBConnect artifacts. 2402 if env_manager in (_EnvManager.VIRTUALENV, _EnvManager.UV): 2403 if not dbconnect_artifact_cache.has_cache_key(env_cache_key): 2404 if prebuilt_env_uri: 2405 env_archive_path = prebuilt_env_uri 2406 else: 2407 env_archive_path = _prebuild_env_internal( 2408 local_model_path, env_cache_key, get_or_create_tmp_dir(), env_manager 2409 ) 2410 dbconnect_artifact_cache.add_artifact_archive(env_cache_key, env_archive_path) 2411 2412 if not dbconnect_artifact_cache.has_cache_key(model_uri): 2413 model_archive_path = os.path.join( 2414 os.path.dirname(local_model_path), f"model-{uuid.uuid4()}.tar.gz" 2415 ) 2416 archive_directory(local_model_path, model_archive_path) 2417 dbconnect_artifact_cache.add_artifact_archive(model_uri, model_archive_path) 2418 2419 elif not should_use_spark_to_broadcast_file: 2420 if prebuilt_env_uri: 2421 # Extract prebuilt env archive file to NFS directory. 2422 prebuilt_env_nfs_dir = os.path.join( 2423 get_or_create_nfs_tmp_dir(), "prebuilt_env", env_cache_key 2424 ) 2425 if not os.path.exists(prebuilt_env_nfs_dir): 2426 extract_archive_to_dir(prebuilt_env_uri, prebuilt_env_nfs_dir) 2427 else: 2428 # Prepare restored environment in driver side if possible. 2429 # Note: In databricks runtime, because databricks notebook cell output cannot capture 2430 # child process output, so that set capture_output to be True so that when `conda 2431 # prepare env` command failed, the exception message will include command stdout/stderr 2432 # output. Otherwise user have to check cluster driver log to find command stdout/stderr 2433 # output. 2434 # In non-databricks runtime, set capture_output to be False, because the benefit of 2435 # "capture_output=False" is the output will be printed immediately, otherwise you have 2436 # to wait conda command fail and suddenly get all output printed (included in error 2437 # message). 2438 if env_manager != _EnvManager.LOCAL: 2439 pyfunc_backend.prepare_env( 2440 model_uri=local_model_path, capture_output=is_in_databricks_runtime() 2441 ) 2442 else: 2443 # Broadcast local model directory to remote worker if needed. 2444 archive_path = SparkModelCache.add_local_model(spark, local_model_path) 2445 2446 model_metadata = Model.load(os.path.join(local_model_path, MLMODEL_FILE_NAME)) 2447 2448 if result_type is None: 2449 if model_output_schema := model_metadata.get_output_schema(): 2450 result_type = _infer_spark_udf_return_type(model_output_schema) 2451 else: 2452 _logger.warning( 2453 "No 'result_type' provided for spark_udf and the model does not " 2454 "have an output schema. 'result_type' is set to 'double' type." 2455 ) 2456 result_type = DoubleType() 2457 else: 2458 if isinstance(result_type, str): 2459 result_type = _parse_spark_datatype(result_type) 2460 2461 # if result type is inferred by MLflow, we don't need to check it 2462 if not _check_udf_return_type(result_type): 2463 raise MlflowException.invalid_parameter_value( 2464 f"""Invalid 'spark_udf' result type: {result_type}. 2465 It must be one of the following types: 2466 Primitive types: 2467 - int 2468 - long 2469 - float 2470 - double 2471 - string 2472 - boolean 2473 Compound types: 2474 - ND array of primitives / structs. 2475 - struct<field: primitive | array<primitive> | array<array<primitive>>, ...>: 2476 A struct with primitive, ND array<primitive/structs>, 2477 e.g., struct<a:int, b:array<int>>. 2478 """ 2479 ) 2480 params = _validate_params(params, model_metadata) 2481 2482 def _predict_row_batch(predict_fn, args): 2483 input_schema = model_metadata.get_input_schema() 2484 args = list(args) 2485 if len(args) == 1 and isinstance(args[0], pandas.DataFrame): 2486 pdf = args[0] 2487 else: 2488 if input_schema is None: 2489 names = [str(i) for i in range(len(args))] 2490 else: 2491 names = input_schema.input_names() 2492 required_names = input_schema.required_input_names() 2493 if len(args) > len(names): 2494 args = args[: len(names)] 2495 if len(args) < len(required_names): 2496 raise MlflowException( 2497 f"Model input is missing required columns. Expected {len(names)} required" 2498 f" input columns {names}, but the model received only {len(args)} " 2499 "unnamed input columns (Since the columns were passed unnamed they are" 2500 " expected to be in the order specified by the schema)." 2501 ) 2502 pdf = pandas.DataFrame( 2503 data={ 2504 names[i]: arg 2505 if isinstance(arg, pandas.Series) 2506 # pandas_udf receives a StructType column as a pandas DataFrame. 2507 # We need to convert it back to a dict of pandas Series. 2508 else arg.apply(lambda row: row.to_dict(), axis=1) 2509 for i, arg in enumerate(args) 2510 }, 2511 columns=names, 2512 ) 2513 2514 result = predict_fn(pdf, params) 2515 2516 if isinstance(result, dict): 2517 result = {k: list(v) for k, v in result.items()} 2518 2519 if isinstance(result_type, ArrayType) and isinstance(result_type.elementType, ArrayType): 2520 result_values = _convert_array_values(result, result_type) 2521 return pandas.Series(result_values) 2522 2523 if isinstance(result_type, SparkStructType): 2524 if ( 2525 isinstance(result, list) 2526 and len(result) > 0 2527 and isinstance(result[0], pydantic.BaseModel) 2528 ): 2529 result = pandas.DataFrame([r.model_dump() for r in result]) 2530 else: 2531 result = pandas.DataFrame(result) 2532 return _convert_struct_values(result, result_type) 2533 2534 if not isinstance(result, pandas.DataFrame): 2535 if isinstance(result_type, MapType): 2536 # list of dicts should be converted into a single column 2537 result = pandas.DataFrame([result]) 2538 else: 2539 result = ( 2540 pandas.DataFrame([result]) if np.isscalar(result) else pandas.DataFrame(result) 2541 ) 2542 2543 elem_type = result_type.elementType if isinstance(result_type, ArrayType) else result_type 2544 if type(elem_type) == IntegerType: 2545 result = result.select_dtypes([ 2546 np.byte, 2547 np.ubyte, 2548 np.short, 2549 np.ushort, 2550 np.int32, 2551 ]).astype(np.int32) 2552 2553 elif type(elem_type) == LongType: 2554 result = result.select_dtypes([np.byte, np.ubyte, np.short, np.ushort, int]).astype( 2555 np.int64 2556 ) 2557 2558 elif type(elem_type) == FloatType: 2559 result = result.select_dtypes(include=(np.number,)).astype(np.float32) 2560 2561 elif type(elem_type) == DoubleType: 2562 result = result.select_dtypes(include=(np.number,)).astype(np.float64) 2563 2564 elif type(elem_type) == BooleanType: 2565 result = result.select_dtypes([bool, np.bool_]).astype(bool) 2566 2567 if len(result.columns) == 0: 2568 raise MlflowException( 2569 message="The model did not produce any values compatible with the requested " 2570 f"type '{elem_type}'. Consider requesting udf with StringType or " 2571 "Arraytype(StringType).", 2572 error_code=INVALID_PARAMETER_VALUE, 2573 ) 2574 2575 if type(elem_type) == StringType: 2576 if Version(pandas.__version__) >= Version("2.1.0"): 2577 result = result.map(str) 2578 else: 2579 result = result.applymap(str) 2580 2581 if type(result_type) == ArrayType: 2582 return pandas.Series(result.to_numpy().tolist()) 2583 else: 2584 return result[result.columns[0]] 2585 2586 result_type_hint = ( 2587 pandas.DataFrame if isinstance(result_type, SparkStructType) else pandas.Series 2588 ) 2589 2590 tracking_uri = mlflow.get_tracking_uri() 2591 2592 enforce_stdin_scoring_server = MLFLOW_ENFORCE_STDIN_SCORING_SERVER_FOR_SPARK_UDF.get() 2593 2594 @pandas_udf(result_type) 2595 def udf( 2596 # `pandas_udf` does not support modern type annotations 2597 iterator: Iterator[Tuple[Union[pandas.Series, pandas.DataFrame], ...]], # noqa: UP006,UP007 2598 ) -> Iterator[result_type_hint]: 2599 # importing here to prevent circular import 2600 from mlflow.pyfunc.scoring_server.client import ( 2601 ScoringServerClient, 2602 StdinScoringServerClient, 2603 ) 2604 2605 # Note: this is a pandas udf function in iteration style, which takes an iterator of 2606 # tuple of pandas.Series and outputs an iterator of pandas.Series. 2607 update_envs = {} 2608 if mlflow_home is not None: 2609 update_envs["MLFLOW_HOME"] = mlflow_home 2610 if openai_env_vars: 2611 update_envs.update(openai_env_vars) 2612 if mlflow_testing: 2613 update_envs[_MLFLOW_TESTING.name] = mlflow_testing 2614 if extra_env: 2615 update_envs.update(extra_env) 2616 2617 # use `modified_environ` to temporarily set the envs and restore them finally 2618 with modified_environ(update=update_envs): 2619 scoring_server_proc = None 2620 # set tracking_uri inside udf so that with spark_connect 2621 # we can load the model from correct path 2622 mlflow.set_tracking_uri(tracking_uri) 2623 2624 if env_manager != _EnvManager.LOCAL: 2625 if use_dbconnect_artifact: 2626 local_model_path_on_executor = ( 2627 dbconnect_artifact_cache.get_unpacked_artifact_dir(model_uri) 2628 ) 2629 env_src_dir = dbconnect_artifact_cache.get_unpacked_artifact_dir(env_cache_key) 2630 2631 # Create symlink if it does not exist 2632 if not os.path.exists(prebuilt_env_root_dir): 2633 os.symlink(env_src_dir, prebuilt_env_root_dir) 2634 elif prebuilt_env_uri is not None: 2635 # prebuilt env is extracted to `prebuilt_env_nfs_dir` directory, 2636 # and model is downloaded to `local_model_path` which points to an NFS 2637 # directory too. 2638 local_model_path_on_executor = None 2639 2640 # Create symlink if it does not exist 2641 if not os.path.exists(prebuilt_env_root_dir): 2642 os.symlink(prebuilt_env_nfs_dir, prebuilt_env_root_dir) 2643 elif should_use_spark_to_broadcast_file: 2644 local_model_path_on_executor = _SparkDirectoryDistributor.get_or_extract( 2645 archive_path 2646 ) 2647 # Call "prepare_env" in advance in order to reduce scoring server launch time. 2648 # So that we can use a shorter timeout when call `client.wait_server_ready`, 2649 # otherwise we have to set a long timeout for `client.wait_server_ready` time, 2650 # this prevents spark UDF task failing fast if other exception raised 2651 # when scoring server launching. 2652 # Set "capture_output" so that if "conda env create" command failed, the command 2653 # stdout/stderr output will be attached to the exception message and included in 2654 # driver side exception. 2655 pyfunc_backend.prepare_env( 2656 model_uri=local_model_path_on_executor, capture_output=True 2657 ) 2658 else: 2659 local_model_path_on_executor = None 2660 2661 if not enforce_stdin_scoring_server and check_port_connectivity(): 2662 # launch scoring server 2663 server_port = find_free_port() 2664 host = "127.0.0.1" 2665 scoring_server_proc = pyfunc_backend.serve( 2666 model_uri=local_model_path_on_executor or local_model_path, 2667 port=server_port, 2668 host=host, 2669 timeout=MLFLOW_SCORING_SERVER_REQUEST_TIMEOUT.get(), 2670 enable_mlserver=False, 2671 synchronous=False, 2672 stdout=subprocess.PIPE, 2673 stderr=subprocess.STDOUT, 2674 model_config=model_config, 2675 ) 2676 2677 client = ScoringServerClient(host, server_port) 2678 else: 2679 scoring_server_proc = pyfunc_backend.serve_stdin( 2680 model_uri=local_model_path_on_executor or local_model_path, 2681 stdout=subprocess.PIPE, 2682 stderr=subprocess.STDOUT, 2683 model_config=model_config, 2684 ) 2685 client = StdinScoringServerClient(scoring_server_proc) 2686 2687 _logger.info("Using %s", client.__class__.__name__) 2688 2689 server_tail_logs = collections.deque( 2690 maxlen=_MLFLOW_SERVER_OUTPUT_TAIL_LINES_TO_KEEP 2691 ) 2692 2693 def server_redirect_log_thread_func(child_stdout): 2694 for line in child_stdout: 2695 decoded = line.decode() if isinstance(line, bytes) else line 2696 server_tail_logs.append(decoded) 2697 sys.stdout.write("[model server] " + decoded) 2698 2699 server_redirect_log_thread = threading.Thread( 2700 target=server_redirect_log_thread_func, 2701 args=(scoring_server_proc.stdout,), 2702 daemon=True, 2703 name=f"mlflow_pyfunc_model_server_log_redirector_{uuid.uuid4().hex[:8]}", 2704 ) 2705 server_redirect_log_thread.start() 2706 2707 try: 2708 client.wait_server_ready(timeout=90, scoring_server_proc=scoring_server_proc) 2709 except Exception as e: 2710 err_msg = ( 2711 "During spark UDF task execution, mlflow model server failed to launch. " 2712 ) 2713 if len(server_tail_logs) == _MLFLOW_SERVER_OUTPUT_TAIL_LINES_TO_KEEP: 2714 err_msg += ( 2715 f"Last {_MLFLOW_SERVER_OUTPUT_TAIL_LINES_TO_KEEP} " 2716 "lines of MLflow model server output:\n" 2717 ) 2718 else: 2719 err_msg += "MLflow model server output:\n" 2720 err_msg += "".join(server_tail_logs) 2721 raise MlflowException(err_msg) from e 2722 2723 def batch_predict_fn(pdf, params=None): 2724 if "params" in inspect.signature(client.invoke).parameters: 2725 return client.invoke(pdf, params=params).get_predictions() 2726 _log_warning_if_params_not_in_predict_signature(_logger, params) 2727 return client.invoke(pdf).get_predictions() 2728 2729 elif env_manager == _EnvManager.LOCAL: 2730 if use_dbconnect_artifact: 2731 model_path = dbconnect_artifact_cache.get_unpacked_artifact_dir(model_uri) 2732 loaded_model = mlflow.pyfunc.load_model(model_path, model_config=model_config) 2733 elif is_spark_connect and not should_spark_connect_use_nfs: 2734 model_path = os.path.join( 2735 tempfile.gettempdir(), 2736 "mlflow", 2737 hashlib.sha1(model_uri.encode(), usedforsecurity=False).hexdigest(), 2738 # Use pid to avoid conflict when multiple spark UDF tasks 2739 str(os.getpid()), 2740 ) 2741 try: 2742 loaded_model = mlflow.pyfunc.load_model( 2743 model_path, model_config=model_config 2744 ) 2745 except Exception: 2746 os.makedirs(model_path, exist_ok=True) 2747 loaded_model = mlflow.pyfunc.load_model( 2748 model_uri, dst_path=model_path, model_config=model_config 2749 ) 2750 elif should_use_spark_to_broadcast_file: 2751 loaded_model, _ = SparkModelCache.get_or_load(archive_path) 2752 else: 2753 loaded_model = mlflow.pyfunc.load_model( 2754 local_model_path, model_config=model_config 2755 ) 2756 2757 def batch_predict_fn(pdf, params=None): 2758 if "params" in inspect.signature(loaded_model.predict).parameters: 2759 return loaded_model.predict(pdf, params=params) 2760 _log_warning_if_params_not_in_predict_signature(_logger, params) 2761 return loaded_model.predict(pdf) 2762 2763 try: 2764 for input_batch in iterator: 2765 # If the UDF is called with only multiple arguments, 2766 # the `input_batch` is a tuple which composes of several pd.Series/pd.DataFrame 2767 # objects. 2768 # If the UDF is called with only one argument, 2769 # the `input_batch` instance will be an instance of `pd.Series`/`pd.DataFrame`, 2770 if isinstance(input_batch, (pandas.Series, pandas.DataFrame)): 2771 # UDF is called with only one argument 2772 row_batch_args = (input_batch,) 2773 else: 2774 row_batch_args = input_batch 2775 2776 if len(row_batch_args[0]) > 0: 2777 yield _predict_row_batch(batch_predict_fn, row_batch_args) 2778 except SystemError as e: 2779 if "error return without exception set" in str(e): 2780 raise MlflowException( 2781 "A system error related to the Python C extension has occurred. " 2782 "This is usually caused by an incompatible Python library that uses the " 2783 "C extension. To address this, we recommend you to log the model " 2784 "with fixed version python libraries that use the C extension " 2785 "(such as 'numpy' library), and set spark_udf `env_manager` argument " 2786 "to 'virtualenv' or 'uv' so that spark_udf can restore the original " 2787 "python library version before running model inference." 2788 ) from e 2789 finally: 2790 if scoring_server_proc is not None: 2791 os.kill(scoring_server_proc.pid, signal.SIGTERM) 2792 2793 udf.metadata = model_metadata 2794 2795 @functools.wraps(udf) 2796 def udf_with_default_cols(*args): 2797 if len(args) == 0: 2798 input_schema = model_metadata.get_input_schema() 2799 if input_schema and len(input_schema.optional_input_names()) > 0: 2800 raise MlflowException( 2801 message="Cannot apply UDF without column names specified when" 2802 " model signature contains optional columns.", 2803 error_code=INVALID_PARAMETER_VALUE, 2804 ) 2805 if input_schema and len(input_schema.inputs) > 0: 2806 if input_schema.has_input_names(): 2807 input_names = input_schema.input_names() 2808 return udf(*input_names) 2809 else: 2810 raise MlflowException( 2811 message="Cannot apply udf because no column names specified. The udf " 2812 f"expects {len(input_schema.inputs)} columns with types: " 2813 "{input_schema.inputs}. Input column names could not be inferred from the" 2814 " model signature (column names not found).", 2815 error_code=INVALID_PARAMETER_VALUE, 2816 ) 2817 else: 2818 raise MlflowException( 2819 "Attempting to apply udf on zero columns because no column names were " 2820 "specified as arguments or inferred from the model signature.", 2821 error_code=INVALID_PARAMETER_VALUE, 2822 ) 2823 else: 2824 return udf(*args) 2825 2826 return udf_with_default_cols 2827 2828 2829 def _validate_function_python_model(python_model): 2830 if not (isinstance(python_model, PythonModel) or callable(python_model)): 2831 raise MlflowException( 2832 "`python_model` must be a PythonModel instance, callable object, or path to a script " 2833 "that uses set_model() to set a PythonModel instance or callable object.", 2834 error_code=INVALID_PARAMETER_VALUE, 2835 ) 2836 2837 if callable(python_model): 2838 num_args = len(inspect.signature(python_model).parameters) 2839 if num_args != 1: 2840 raise MlflowException( 2841 "When `python_model` is a callable object, it must accept exactly one argument. " 2842 f"Found {num_args} arguments.", 2843 error_code=INVALID_PARAMETER_VALUE, 2844 ) 2845 2846 2847 @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn")) 2848 @trace_disabled # Suppress traces for internal predict calls while saving model 2849 def save_model( 2850 path, 2851 loader_module=None, 2852 data_path=None, 2853 code_paths=None, 2854 infer_code_paths=False, 2855 conda_env=None, 2856 mlflow_model=None, 2857 python_model=None, 2858 artifacts=None, 2859 signature: ModelSignature = None, 2860 input_example: ModelInputExample = None, 2861 pip_requirements=None, 2862 extra_pip_requirements=None, 2863 metadata=None, 2864 model_config=None, 2865 streamable=None, 2866 resources: str | list[Resource] | None = None, 2867 auth_policy: AuthPolicy | None = None, 2868 uv_project_path: str | Path | None = None, 2869 uv_groups: list[str] | None = None, 2870 uv_extras: list[str] | None = None, 2871 **kwargs, 2872 ): 2873 """ 2874 Save a Pyfunc model with custom inference logic and optional data dependencies to a path on the 2875 local filesystem. 2876 2877 For information about the workflows that this method supports, please see :ref:`"workflows for 2878 creating custom pyfunc models" <pyfunc-create-custom-workflows>` and 2879 :ref:`"which workflow is right for my use case?" <pyfunc-create-custom-selecting-workflow>`. 2880 Note that the parameters for the second workflow: ``loader_module``, ``data_path`` and the 2881 parameters for the first workflow: ``python_model``, ``artifacts``, cannot be 2882 specified together. 2883 2884 Args: 2885 path: The path to which to save the Python model. 2886 loader_module: The name of the Python module that is used to load the model 2887 from ``data_path``. This module must define a method with the prototype 2888 ``_load_pyfunc(data_path)``. If not ``None``, this module and its 2889 dependencies must be included in one of the following locations: 2890 2891 - The MLflow library. 2892 - Package(s) listed in the model's Conda environment, specified by 2893 the ``conda_env`` parameter. 2894 - One or more of the files specified by the ``code_paths`` parameter. 2895 2896 data_path: Path to a file or directory containing model data. 2897 code_paths: {{ code_paths_pyfunc }} 2898 infer_code_paths: {{ infer_code_paths }} 2899 conda_env: {{ conda_env }} 2900 mlflow_model: :py:mod:`mlflow.models.Model` configuration to which to add the 2901 **python_function** flavor. 2902 python_model: 2903 A file path to the PythonModel 2904 which defines the model from code artifact, 2905 (recommended), see https://mlflow.org/docs/latest/ml/model/models-from-code/ 2906 for details; 2907 or an instance of a subclass of :class:`~PythonModel` or a callable object with a single 2908 argument (see the examples below), the passed-in object is serialized using the 2909 CloudPickle library, it requires exercising caution because these formats rely on 2910 Python's object serialization mechanism, which can execute arbitrary code during 2911 deserialization. 2912 Any dependencies of the class should be included in one of the 2913 following locations: 2914 2915 - The MLflow library. 2916 - Package(s) listed in the model's Conda environment, specified by the ``conda_env`` 2917 parameter. 2918 - One or more of the files specified by the ``code_paths`` parameter. 2919 2920 Note: If the class is imported from another module, as opposed to being defined in the 2921 ``__main__`` scope, the defining module should also be included in one of the listed 2922 locations. 2923 2924 **Examples** 2925 2926 Class model 2927 2928 .. code-block:: python 2929 2930 from typing import List, Dict 2931 import mlflow 2932 2933 2934 class MyModel(mlflow.pyfunc.PythonModel): 2935 def predict(self, context, model_input: List[str], params=None) -> List[str]: 2936 return [i.upper() for i in model_input] 2937 2938 2939 mlflow.pyfunc.save_model("model", python_model=MyModel(), input_example=["a"]) 2940 model = mlflow.pyfunc.load_model("model") 2941 print(model.predict(["a", "b", "c"])) # -> ["A", "B", "C"] 2942 2943 Functional model 2944 2945 .. note:: 2946 Experimental: Functional model support is experimental and may change or be removed 2947 in a future release without warning. 2948 2949 .. code-block:: python 2950 2951 from typing import List 2952 import mlflow 2953 2954 2955 def predict(model_input: List[str]) -> List[str]: 2956 return [i.upper() for i in model_input] 2957 2958 2959 mlflow.pyfunc.save_model("model", python_model=predict, input_example=["a"]) 2960 model = mlflow.pyfunc.load_model("model") 2961 print(model.predict(["a", "b", "c"])) # -> ["A", "B", "C"] 2962 2963 Model from code 2964 2965 .. note:: 2966 Experimental: Model from code model support is experimental and may change or 2967 be removed in a future release without warning. 2968 2969 .. code-block:: python 2970 2971 # code.py 2972 from typing import List 2973 import mlflow 2974 2975 2976 class MyModel(mlflow.pyfunc.PythonModel): 2977 def predict(self, context, model_input: List[str], params=None) -> List[str]: 2978 return [i.upper() for i in model_input] 2979 2980 2981 mlflow.models.set_model(MyModel()) 2982 2983 # log_model.py 2984 import mlflow 2985 2986 with mlflow.start_run(): 2987 model_info = mlflow.pyfunc.log_model( 2988 name="model", 2989 python_model="code.py", 2990 ) 2991 2992 If the `predict` method or function has type annotations, MLflow automatically 2993 constructs a model signature based on the type annotations (unless the ``signature`` 2994 argument is explicitly specified), and converts the input value to the specified type 2995 before passing it to the function. Currently, the following type annotations are 2996 supported: 2997 2998 - ``List[str]`` 2999 - ``List[Dict[str, str]]`` 3000 3001 artifacts: A dictionary containing ``<name, artifact_uri>`` entries. Remote artifact URIs 3002 are resolved to absolute filesystem paths, producing a dictionary of 3003 ``<name, absolute_path>`` entries. ``python_model`` can reference these 3004 resolved entries as the ``artifacts`` property of the ``context`` parameter 3005 in :func:`PythonModel.load_context() <mlflow.pyfunc.PythonModel.load_context>` 3006 and :func:`PythonModel.predict() <mlflow.pyfunc.PythonModel.predict>`. 3007 For example, consider the following ``artifacts`` dictionary:: 3008 3009 {"my_file": "s3://my-bucket/path/to/my/file"} 3010 3011 In this case, the ``"my_file"`` artifact is downloaded from S3. The 3012 ``python_model`` can then refer to ``"my_file"`` as an absolute filesystem 3013 path via ``context.artifacts["my_file"]``. 3014 3015 If ``None``, no artifacts are added to the model. 3016 3017 signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>` 3018 describes model input and output :py:class:`Schema <mlflow.types.Schema>`. 3019 The model signature can be :py:func:`inferred <mlflow.models.infer_signature>` 3020 from datasets with valid model input (e.g. the training dataset with target 3021 column omitted) and valid model output (e.g. model predictions generated on 3022 the training dataset), for example: 3023 3024 .. code-block:: python 3025 3026 from mlflow.models import infer_signature 3027 3028 train = df.drop_column("target_label") 3029 predictions = ... # compute model predictions 3030 signature = infer_signature(train, predictions) 3031 input_example: {{ input_example }} 3032 pip_requirements: {{ pip_requirements }} 3033 extra_pip_requirements: {{ extra_pip_requirements }} 3034 metadata: {{ metadata }} 3035 model_config: The model configuration to apply to the model. The configuration will 3036 be available as the ``model_config`` property of the ``context`` parameter 3037 in :func:`PythonModel.load_context() <mlflow.pyfunc.PythonModel.load_context>` 3038 and :func:`PythonModel.predict() <mlflow.pyfunc.PythonModel.predict>`. 3039 The configuration can be passed as a file path, or a dict with string keys. 3040 3041 .. Note:: Experimental: This parameter may change or be removed in a future 3042 release without warning. 3043 streamable: A boolean value indicating if the model supports streaming prediction, 3044 If None, MLflow will try to inspect if the model supports streaming 3045 by checking if `predict_stream` method exists. Default None. 3046 resources: A list of model resources or a resources.yaml file containing a list of 3047 resources required to serve the model. 3048 3049 .. Note:: Experimental: This parameter may change or be removed in a future 3050 release without warning. 3051 auth_policy: {{ auth_policy }} 3052 uv_project_path: Explicit path to the uv project directory containing uv.lock, 3053 pyproject.toml, and optionally .python-version. This is useful for monorepos 3054 or non-standard project layouts where the uv project is not in the current 3055 working directory. If ``None``, MLflow will auto-detect uv.lock, pyproject.toml, 3056 and .python-version files in the current working directory. 3057 3058 When a uv project is detected (either via this parameter or auto-detection), 3059 pip requirements are generated by running ``uv export`` against the lockfile 3060 instead of inferring dependencies by capturing imported packages during model 3061 inference. 3062 3063 Auto-detection can be disabled by setting the environment variable 3064 ``MLFLOW_UV_AUTO_DETECT=false``. 3065 3066 .. Note:: Experimental: This parameter may change or be removed in a future 3067 release without warning. 3068 uv_groups: Optional list of uv dependency groups to include when exporting 3069 requirements from the uv lockfile. Maps to ``uv export --group <name>``. 3070 These are additive with the project's default dependencies. 3071 3072 .. Note:: Experimental: This parameter may change or be removed in a future 3073 release without warning. 3074 uv_extras: Optional list of uv extras (optional dependency sets) to include 3075 when exporting requirements from the uv lockfile. Maps to 3076 ``uv export --extra <name>``. 3077 3078 .. Note:: Experimental: This parameter may change or be removed in a future 3079 release without warning. 3080 kwargs: Extra keyword arguments. 3081 """ 3082 if ( 3083 python_model is not None 3084 and not isinstance(python_model, (Path, str)) 3085 and not is_in_databricks_runtime() 3086 ): 3087 _logger.warning( 3088 "Passing a Python object as `python_model` causes it to be serialized " 3089 "using CloudPickle, " 3090 "it requires exercising caution as Python object serialization mechanisms may " 3091 "execute arbitrary code during deserialization." 3092 "Consider using a file path (str or Path) instead. See " 3093 "https://mlflow.org/docs/latest/ml/model/models-from-code/ for details." 3094 ) 3095 3096 _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements) 3097 _validate_pyfunc_model_config(model_config) 3098 _validate_and_prepare_target_save_path(path) 3099 3100 with tempfile.TemporaryDirectory() as temp_dir: 3101 model_code_path = None 3102 if python_model: 3103 if isinstance(model_config, Path): 3104 model_config = os.fspath(model_config) 3105 3106 if isinstance(model_config, str): 3107 model_config = _validate_and_get_model_config_from_file(model_config) 3108 3109 if isinstance(python_model, Path): 3110 python_model = os.fspath(python_model) 3111 3112 if isinstance(python_model, str): 3113 model_code_path = _validate_and_get_model_code_path(python_model, temp_dir) 3114 _validate_and_copy_file_to_directory(model_code_path, path, "code") 3115 python_model = _load_model_code_path(model_code_path, model_config) 3116 3117 _validate_function_python_model(python_model) 3118 if callable(python_model) and all( 3119 a is None for a in (input_example, pip_requirements, extra_pip_requirements) 3120 ): 3121 raise MlflowException( 3122 "If `python_model` is a callable object, at least one of `input_example`, " 3123 "`pip_requirements`, or `extra_pip_requirements` must be specified." 3124 ) 3125 3126 mlflow_model = kwargs.pop("model", mlflow_model) 3127 if len(kwargs) > 0: 3128 raise TypeError(f"save_model() got unexpected keyword arguments: {kwargs}") 3129 3130 if code_paths is not None: 3131 if not isinstance(code_paths, list): 3132 raise TypeError(f"Argument code_paths should be a list, not {type(code_paths)}") 3133 3134 first_argument_set = { 3135 "loader_module": loader_module, 3136 "data_path": data_path, 3137 } 3138 second_argument_set = { 3139 "artifacts": artifacts, 3140 "python_model": python_model, 3141 } 3142 first_argument_set_specified = any(item is not None for item in first_argument_set.values()) 3143 second_argument_set_specified = any(item is not None for item in second_argument_set.values()) 3144 if first_argument_set_specified and second_argument_set_specified: 3145 raise MlflowException( 3146 message=( 3147 f"The following sets of parameters cannot be specified together:" 3148 f" {first_argument_set.keys()} and {second_argument_set.keys()}." 3149 " All parameters in one set must be `None`. Instead, found" 3150 f" the following values: {first_argument_set} and {second_argument_set}" 3151 ), 3152 error_code=INVALID_PARAMETER_VALUE, 3153 ) 3154 elif (loader_module is None) and (python_model is None): 3155 msg = ( 3156 "Either `loader_module` or `python_model` must be specified. A `loader_module` " 3157 "should be a python module. A `python_model` should be a subclass of PythonModel" 3158 ) 3159 raise MlflowException(message=msg, error_code=INVALID_PARAMETER_VALUE) 3160 if mlflow_model is None: 3161 mlflow_model = Model() 3162 saved_example = None 3163 signature_from_type_hints = None 3164 type_hint_from_example = None 3165 if isinstance(python_model, ChatModel): 3166 if signature is not None: 3167 raise MlflowException( 3168 "ChatModel subclasses have a standard signature that is set " 3169 "automatically. Please remove the `signature` parameter from " 3170 "the call to log_model() or save_model().", 3171 error_code=INVALID_PARAMETER_VALUE, 3172 ) 3173 mlflow_model.signature = ModelSignature( 3174 CHAT_MODEL_INPUT_SCHEMA, 3175 CHAT_MODEL_OUTPUT_SCHEMA, 3176 ) 3177 # For ChatModel we set default metadata to indicate its task 3178 default_metadata = {TASK: _DEFAULT_CHAT_MODEL_METADATA_TASK} 3179 mlflow_model.metadata = default_metadata | (mlflow_model.metadata or {}) 3180 3181 if input_example: 3182 input_example, input_params = _split_input_data_and_params(input_example) 3183 valid_params = {} 3184 if isinstance(input_example, list): 3185 messages = [ 3186 message if isinstance(message, ChatMessage) else ChatMessage.from_dict(message) 3187 for message in input_example 3188 ] 3189 else: 3190 # If the input example is a dictionary, convert it to ChatMessage format 3191 messages = [ 3192 ChatMessage.from_dict(m) if isinstance(m, dict) else m 3193 for m in input_example["messages"] 3194 ] 3195 valid_params = { 3196 k: v 3197 for k, v in input_example.items() 3198 if k != "messages" and k in ChatParams.keys() 3199 } 3200 if valid_params or input_params: 3201 _logger.warning(_CHAT_PARAMS_WARNING_MESSAGE) 3202 input_example = { 3203 "messages": [m.to_dict() for m in messages], 3204 **valid_params, 3205 **(input_params or {}), 3206 } 3207 else: 3208 input_example = CHAT_MODEL_INPUT_EXAMPLE 3209 _logger.warning(_CHAT_PARAMS_WARNING_MESSAGE) 3210 messages = [ChatMessage.from_dict(m) for m in input_example["messages"]] 3211 # extra params introduced by ChatParams will not be included in the 3212 # logged input example file to avoid confusion 3213 _save_example(mlflow_model, input_example, path) 3214 params = ChatParams.from_dict(input_example) 3215 3216 # call load_context() first, as predict may depend on it 3217 _logger.info("Predicting on input example to validate output") 3218 context = PythonModelContext(artifacts, model_config) 3219 python_model.load_context(context) 3220 if "context" in inspect.signature(python_model.predict).parameters: 3221 output = python_model.predict(context, messages, params) 3222 else: 3223 output = python_model.predict(messages, params) 3224 if not isinstance(output, ChatCompletionResponse): 3225 raise MlflowException( 3226 "Failed to save ChatModel. Please ensure that the model's predict() method " 3227 "returns a ChatCompletionResponse object. If your predict() method currently " 3228 "returns a dict, you can instantiate a ChatCompletionResponse using " 3229 "`from_dict()`, e.g. `ChatCompletionResponse.from_dict(output)`", 3230 ) 3231 elif isinstance(python_model, ChatAgent): 3232 input_example = _save_model_chat_agent_helper( 3233 python_model, mlflow_model, signature, input_example, artifacts, model_config 3234 ) 3235 elif IS_RESPONSES_AGENT_AVAILABLE and isinstance(python_model, ResponsesAgent): 3236 input_example = _save_model_responses_agent_helper( 3237 python_model, mlflow_model, signature, input_example, artifacts, model_config 3238 ) 3239 elif callable(python_model) or isinstance(python_model, PythonModel): 3240 model_for_signature_inference = None 3241 if callable(python_model): 3242 # first argument is the model input 3243 type_hints = _extract_type_hints(python_model, input_arg_index=0) 3244 pyfunc_decorator_used = getattr(python_model, "_is_pyfunc", False) 3245 # only show the warning here if @pyfunc is not applied on the function 3246 # since @pyfunc will trigger the warning instead 3247 if type_hints.input is None and not pyfunc_decorator_used: 3248 color_warning( 3249 "Add type hints to the `predict` method to enable " 3250 "data validation and automatic signature inference. Check " 3251 "https://mlflow.org/docs/latest/model/python_model.html#type-hint-usage-in-pythonmodel" 3252 " for more details.", 3253 stacklevel=1, 3254 color="yellow", 3255 ) 3256 model_for_signature_inference = _FunctionPythonModel(python_model) 3257 elif isinstance(python_model, PythonModel): 3258 type_hints = python_model.predict_type_hints 3259 model_for_signature_inference = python_model 3260 context = PythonModelContext(artifacts, model_config) 3261 type_hint_from_example = _is_type_hint_from_example(type_hints.input) 3262 if type_hint_from_example: 3263 should_infer_signature_from_type_hints = False 3264 else: 3265 should_infer_signature_from_type_hints = ( 3266 not _signature_cannot_be_inferred_from_type_hint(type_hints.input) 3267 ) 3268 if should_infer_signature_from_type_hints: 3269 # context is only loaded when input_example exists 3270 signature_from_type_hints = _infer_signature_from_type_hints( 3271 python_model=python_model, 3272 context=context, 3273 type_hints=type_hints, 3274 input_example=input_example, 3275 ) 3276 # only infer signature based on input example when signature 3277 # and type hints are not provided 3278 if signature is None and signature_from_type_hints is None: 3279 saved_example = _save_example(mlflow_model, input_example, path) 3280 if saved_example is not None: 3281 _logger.info("Inferring model signature from input example") 3282 try: 3283 model_for_signature_inference.load_context(context) 3284 mlflow_model.signature = _infer_signature_from_input_example( 3285 saved_example, 3286 _PythonModelPyfuncWrapper(model_for_signature_inference, context, None), 3287 ) 3288 except Exception as e: 3289 _logger.warning( 3290 f"Failed to infer model signature from input example, error: {e}", 3291 ) 3292 else: 3293 if type_hint_from_example and mlflow_model.signature: 3294 update_signature_for_type_hint_from_example( 3295 input_example, mlflow_model.signature 3296 ) 3297 else: 3298 if type_hint_from_example: 3299 _logger.warning( 3300 _TYPE_FROM_EXAMPLE_ERROR_MESSAGE, 3301 extra={"color": "red"}, 3302 ) 3303 # if signature is inferred from type hints, warnings are emitted 3304 # in _infer_signature_from_type_hints 3305 elif not should_infer_signature_from_type_hints: 3306 _logger.warning( 3307 "Failed to infer model signature: " 3308 f"Type hint {type_hints} cannot be used to infer model signature and " 3309 "input example is not provided, model signature cannot be inferred." 3310 ) 3311 3312 if metadata is not None: 3313 mlflow_model.metadata = metadata 3314 if saved_example is None: 3315 saved_example = _save_example(mlflow_model, input_example, path) 3316 3317 if signature_from_type_hints: 3318 if signature and signature_from_type_hints != signature: 3319 # TODO: drop this support and raise exception in the next minor release since this 3320 # is a behavior change 3321 _logger.warning( 3322 "Provided signature does not match the signature inferred from the Python model's " 3323 "`predict` function type hint. Signature inferred from type hint will be used:\n" 3324 f"{signature_from_type_hints}\nRemove the `signature` parameter or ensure it " 3325 "matches the inferred signature. In a future release, this warning will become an " 3326 "exception, and the signature must align with the type hint.", 3327 extra={"color": "red"}, 3328 ) 3329 mlflow_model.signature = signature_from_type_hints 3330 elif signature: 3331 mlflow_model.signature = signature 3332 if type_hint_from_example: 3333 if saved_example is None: 3334 _logger.warning( 3335 _TYPE_FROM_EXAMPLE_ERROR_MESSAGE, 3336 extra={"color": "red"}, 3337 ) 3338 else: 3339 # TODO: validate input example against signature 3340 update_signature_for_type_hint_from_example(input_example, mlflow_model.signature) 3341 else: 3342 if saved_example is None: 3343 color_warning( 3344 message="An input example was not provided when logging the model. To ensure " 3345 "the model signature functions correctly, specify the `input_example` " 3346 "parameter. See " 3347 "https://mlflow.org/docs/latest/model/signatures.html#model-input-example " 3348 "for more details about the benefits of using input_example.", 3349 stacklevel=1, 3350 color="yellow_bold", 3351 ) 3352 else: 3353 _logger.info("Validating input example against model signature") 3354 try: 3355 _validate_prediction_input( 3356 data=saved_example.inference_data, 3357 params=saved_example.inference_params, 3358 input_schema=signature.inputs, 3359 params_schema=signature.params, 3360 ) 3361 except Exception as e: 3362 raise MlflowException.invalid_parameter_value( 3363 f"Input example does not match the model signature. {e}" 3364 ) 3365 3366 with _get_dependencies_schemas() as dependencies_schemas: 3367 schema = dependencies_schemas.to_dict() 3368 if schema is not None: 3369 if mlflow_model.metadata is None: 3370 mlflow_model.metadata = {} 3371 mlflow_model.metadata.update(schema) 3372 3373 if resources is not None: 3374 if isinstance(resources, (Path, str)): 3375 serialized_resource = _ResourceBuilder.from_yaml_file(resources) 3376 else: 3377 serialized_resource = _ResourceBuilder.from_resources(resources) 3378 3379 mlflow_model.resources = serialized_resource 3380 3381 if auth_policy is not None: 3382 mlflow_model.auth_policy = auth_policy 3383 3384 if first_argument_set_specified: 3385 return _save_model_with_loader_module_and_data_path( 3386 path=path, 3387 loader_module=loader_module, 3388 data_path=data_path, 3389 code_paths=code_paths, 3390 conda_env=conda_env, 3391 mlflow_model=mlflow_model, 3392 pip_requirements=pip_requirements, 3393 extra_pip_requirements=extra_pip_requirements, 3394 model_config=model_config, 3395 streamable=streamable, 3396 infer_code_paths=infer_code_paths, 3397 uv_project_path=uv_project_path, 3398 uv_groups=uv_groups, 3399 uv_extras=uv_extras, 3400 ) 3401 elif second_argument_set_specified: 3402 return mlflow.pyfunc.model._save_model_with_class_artifacts_params( 3403 path=path, 3404 signature=signature, 3405 python_model=python_model, 3406 artifacts=artifacts, 3407 conda_env=conda_env, 3408 code_paths=code_paths, 3409 mlflow_model=mlflow_model, 3410 pip_requirements=pip_requirements, 3411 extra_pip_requirements=extra_pip_requirements, 3412 model_config=model_config, 3413 streamable=streamable, 3414 model_code_path=model_code_path, 3415 infer_code_paths=infer_code_paths, 3416 uv_project_path=uv_project_path, 3417 uv_groups=uv_groups, 3418 uv_extras=uv_extras, 3419 ) 3420 3421 3422 def update_signature_for_type_hint_from_example(input_example: Any, signature: ModelSignature): 3423 if _is_example_valid_for_type_from_example(input_example): 3424 signature._is_type_hint_from_example = True 3425 else: 3426 _logger.warning( 3427 "Input example must be one of pandas.DataFrame, pandas.Series " 3428 f"or list when using TypeFromExample as type hint, got {type(input_example)}. " 3429 "Check https://mlflow.org/docs/latest/model/python_model.html#typefromexample-type-hint-usage" 3430 " for more details.", 3431 ) 3432 3433 3434 @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn")) 3435 @trace_disabled # Suppress traces for internal predict calls while logging model 3436 def log_model( 3437 artifact_path=None, 3438 loader_module=None, 3439 data_path=None, 3440 code_paths=None, 3441 infer_code_paths=False, 3442 conda_env=None, 3443 python_model=None, 3444 artifacts=None, 3445 registered_model_name=None, 3446 signature: ModelSignature = None, 3447 input_example: ModelInputExample = None, 3448 await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, 3449 pip_requirements=None, 3450 extra_pip_requirements=None, 3451 metadata=None, 3452 model_config=None, 3453 streamable=None, 3454 resources: str | list[Resource] | None = None, 3455 auth_policy: AuthPolicy | None = None, 3456 uv_project_path: str | Path | None = None, 3457 uv_groups: list[str] | None = None, 3458 uv_extras: list[str] | None = None, 3459 prompts: list[str | Prompt] | None = None, 3460 name=None, 3461 params: dict[str, Any] | None = None, 3462 tags: dict[str, Any] | None = None, 3463 model_type: str | None = None, 3464 step: int = 0, 3465 model_id: str | None = None, 3466 ): 3467 """ 3468 Log a Pyfunc model with custom inference logic and optional data dependencies as an MLflow 3469 artifact for the current run. 3470 3471 For information about the workflows that this method supports, see :ref:`Workflows for 3472 creating custom pyfunc models <pyfunc-create-custom-workflows>` and 3473 :ref:`Which workflow is right for my use case? <pyfunc-create-custom-selecting-workflow>`. 3474 You cannot specify the parameters for the second workflow: ``loader_module``, ``data_path`` 3475 and the parameters for the first workflow: ``python_model``, ``artifacts`` together. 3476 3477 Args: 3478 artifact_path: Deprecated. Use `name` instead. 3479 loader_module: The name of the Python module that is used to load the model 3480 from ``data_path``. This module must define a method with the prototype 3481 ``_load_pyfunc(data_path)``. If not ``None``, this module and its 3482 dependencies must be included in one of the following locations: 3483 3484 - The MLflow library. 3485 - Package(s) listed in the model's Conda environment, specified by 3486 the ``conda_env`` parameter. 3487 - One or more of the files specified by the ``code_paths`` parameter. 3488 3489 data_path: Path to a file or directory containing model data. 3490 code_paths: {{ code_paths_pyfunc }} 3491 infer_code_paths: {{ infer_code_paths }} 3492 conda_env: {{ conda_env }} 3493 python_model: 3494 A file path to the PythonModel 3495 which defines the model from code artifact, 3496 (recommended), see https://mlflow.org/docs/latest/ml/model/models-from-code/ 3497 for details; 3498 or an instance of a subclass of :class:`~PythonModel` or a callable object with a single 3499 argument (see the examples below), the passed-in object is serialized using the 3500 CloudPickle library, it requires exercising caution because these formats rely on 3501 Python's object serialization mechanism, which can execute arbitrary code during 3502 deserialization. 3503 Any dependencies of the class should be included in one of the 3504 following locations: 3505 3506 - The MLflow library. 3507 - Package(s) listed in the model's Conda environment, specified by the ``conda_env`` 3508 parameter. 3509 - One or more of the files specified by the ``code_paths`` parameter. 3510 3511 Note: If the class is imported from another module, as opposed to being defined in the 3512 ``__main__`` scope, the defining module should also be included in one of the listed 3513 locations. 3514 3515 **Examples** 3516 3517 Class model 3518 3519 .. code-block:: python 3520 3521 from typing import List 3522 import mlflow 3523 3524 3525 class MyModel(mlflow.pyfunc.PythonModel): 3526 def predict(self, context, model_input: List[str], params=None) -> List[str]: 3527 return [i.upper() for i in model_input] 3528 3529 3530 with mlflow.start_run(): 3531 model_info = mlflow.pyfunc.log_model( 3532 name="model", 3533 python_model=MyModel(), 3534 ) 3535 3536 loaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri) 3537 print(loaded_model.predict(["a", "b", "c"])) # -> ["A", "B", "C"] 3538 3539 Functional model 3540 3541 .. note:: 3542 Experimental: Functional model support is experimental and may change or be removed 3543 in a future release without warning. 3544 3545 .. code-block:: python 3546 3547 from typing import List 3548 import mlflow 3549 3550 3551 def predict(model_input: List[str]) -> List[str]: 3552 return [i.upper() for i in model_input] 3553 3554 3555 with mlflow.start_run(): 3556 model_info = mlflow.pyfunc.log_model( 3557 name="model", python_model=predict, input_example=["a"] 3558 ) 3559 3560 loaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri) 3561 print(loaded_model.predict(["a", "b", "c"])) # -> ["A", "B", "C"] 3562 3563 Model from code 3564 3565 .. note:: 3566 Experimental: Model from code model support is experimental and may change or 3567 be removed in a future release without warning. 3568 3569 .. code-block:: python 3570 3571 # code.py 3572 from typing import List 3573 import mlflow 3574 3575 3576 class MyModel(mlflow.pyfunc.PythonModel): 3577 def predict(self, context, model_input: List[str], params=None) -> List[str]: 3578 return [i.upper() for i in model_input] 3579 3580 3581 mlflow.models.set_model(MyModel()) 3582 3583 # log_model.py 3584 import mlflow 3585 3586 with mlflow.start_run(): 3587 model_info = mlflow.pyfunc.log_model( 3588 name="model", 3589 python_model="code.py", 3590 ) 3591 3592 If the `predict` method or function has type annotations, MLflow automatically 3593 constructs a model signature based on the type annotations (unless the ``signature`` 3594 argument is explicitly specified), and converts the input value to the specified type 3595 before passing it to the function. Currently, the following type annotations are 3596 supported: 3597 3598 - ``List[str]`` 3599 - ``List[Dict[str, str]]`` 3600 3601 artifacts: A dictionary containing ``<name, artifact_uri>`` entries. Remote artifact URIs 3602 are resolved to absolute filesystem paths, producing a dictionary of 3603 ``<name, absolute_path>`` entries. ``python_model`` can reference these 3604 resolved entries as the ``artifacts`` property of the ``context`` parameter 3605 in :func:`PythonModel.load_context() <mlflow.pyfunc.PythonModel.load_context>` 3606 and :func:`PythonModel.predict() <mlflow.pyfunc.PythonModel.predict>`. 3607 For example, consider the following ``artifacts`` dictionary:: 3608 3609 {"my_file": "s3://my-bucket/path/to/my/file"} 3610 3611 In this case, the ``"my_file"`` artifact is downloaded from S3. The 3612 ``python_model`` can then refer to ``"my_file"`` as an absolute filesystem 3613 path via ``context.artifacts["my_file"]``. 3614 3615 If ``None``, no artifacts are added to the model. 3616 registered_model_name: If given, create a model 3617 version under ``registered_model_name``, also creating a 3618 registered model if one with the given name does not exist. 3619 3620 signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>` 3621 describes model input and output :py:class:`Schema <mlflow.types.Schema>`. 3622 The model signature can be :py:func:`inferred <mlflow.models.infer_signature>` 3623 from datasets with valid model input (e.g. the training dataset with target 3624 column omitted) and valid model output (e.g. model predictions generated on 3625 the training dataset), for example: 3626 3627 .. code-block:: python 3628 3629 from mlflow.models import infer_signature 3630 3631 train = df.drop_column("target_label") 3632 predictions = ... # compute model predictions 3633 signature = infer_signature(train, predictions) 3634 3635 input_example: {{ input_example }} 3636 await_registration_for: Number of seconds to wait for the model version to finish 3637 being created and is in ``READY`` status. By default, the function 3638 waits for five minutes. Specify 0 or None to skip waiting. 3639 pip_requirements: {{ pip_requirements }} 3640 extra_pip_requirements: {{ extra_pip_requirements }} 3641 metadata: {{ metadata }} 3642 model_config: The model configuration to apply to the model. The configuration will 3643 be available as the ``model_config`` property of the ``context`` parameter 3644 in :func:`PythonModel.load_context() <mlflow.pyfunc.PythonModel.load_context>` 3645 and :func:`PythonModel.predict() <mlflow.pyfunc.PythonModel.predict>`. 3646 The configuration can be passed as a file path, or a dict with string keys. 3647 3648 .. Note:: Experimental: This parameter may change or be removed in a future 3649 release without warning. 3650 streamable: A boolean value indicating if the model supports streaming prediction, 3651 If None, MLflow will try to inspect if the model supports streaming 3652 by checking if `predict_stream` method exists. Default None. 3653 resources: A list of model resources or a resources.yaml file containing a list of 3654 resources required to serve the model. 3655 3656 .. Note:: Experimental: This parameter may change or be removed in a future 3657 release without warning. 3658 auth_policy: {{ auth_policy }} 3659 uv_project_path: Explicit path to the uv project directory containing uv.lock, 3660 pyproject.toml, and optionally .python-version. This is useful for monorepos 3661 or non-standard project layouts where the uv project is not in the current 3662 working directory. If ``None``, MLflow will auto-detect uv.lock, pyproject.toml, 3663 and .python-version files in the current working directory. 3664 3665 When a uv project is detected (either via this parameter or auto-detection), 3666 pip requirements are generated by running ``uv export`` against the lockfile 3667 instead of inferring dependencies by capturing imported packages during model 3668 inference. 3669 3670 Auto-detection can be disabled by setting the environment variable 3671 ``MLFLOW_UV_AUTO_DETECT=false``. 3672 3673 .. Note:: Experimental: This parameter may change or be removed in a future 3674 release without warning. 3675 uv_groups: Optional list of uv dependency groups to include when exporting 3676 requirements from the uv lockfile. Maps to ``uv export --group <name>``. 3677 These are additive with the project's default dependencies. 3678 3679 .. Note:: Experimental: This parameter may change or be removed in a future 3680 release without warning. 3681 uv_extras: Optional list of uv extras (optional dependency sets) to include 3682 when exporting requirements from the uv lockfile. Maps to 3683 ``uv export --extra <name>``. 3684 3685 .. Note:: Experimental: This parameter may change or be removed in a future 3686 release without warning. 3687 prompts: {{ prompts }} 3688 name: {{ name }} 3689 params: {{ params }} 3690 tags: {{ tags }} 3691 model_type: {{ model_type }} 3692 step: {{ step }} 3693 model_id: {{ model_id }} 3694 3695 Returns: 3696 A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the 3697 metadata of the logged model. 3698 """ 3699 flavor_name = _get_pyfunc_model_flavor_name(python_model) 3700 return Model.log( 3701 artifact_path=artifact_path, 3702 name=name, 3703 flavor=mlflow.pyfunc, 3704 loader_module=loader_module, 3705 data_path=data_path, 3706 code_paths=code_paths, 3707 python_model=python_model, 3708 artifacts=artifacts, 3709 conda_env=conda_env, 3710 registered_model_name=registered_model_name, 3711 signature=signature, 3712 input_example=input_example, 3713 await_registration_for=await_registration_for, 3714 pip_requirements=pip_requirements, 3715 extra_pip_requirements=extra_pip_requirements, 3716 metadata=metadata, 3717 prompts=prompts, 3718 model_config=model_config, 3719 streamable=streamable, 3720 resources=resources, 3721 infer_code_paths=infer_code_paths, 3722 auth_policy=auth_policy, 3723 uv_project_path=uv_project_path, 3724 uv_groups=uv_groups, 3725 uv_extras=uv_extras, 3726 params=params, 3727 tags=tags, 3728 model_type=model_type, 3729 step=step, 3730 model_id=model_id, 3731 # only used for checking python model type 3732 flavor_name=flavor_name, 3733 ) 3734 3735 3736 def _get_pyfunc_model_flavor_name(python_model: Any) -> str: 3737 if python_model is None: 3738 return "pyfunc" 3739 if isinstance(python_model, str): 3740 return "pyfunc.ModelFromCode" 3741 if IS_RESPONSES_AGENT_AVAILABLE and isinstance(python_model, ResponsesAgent): 3742 return "pyfunc.ResponsesAgent" 3743 if isinstance(python_model, ChatAgent): 3744 return "pyfunc.ChatAgent" 3745 if isinstance(python_model, ChatModel): 3746 return "pyfunc.ChatModel" 3747 if isinstance(python_model, PythonModel): 3748 return "pyfunc.CustomPythonModel" 3749 return "pyfunc" 3750 3751 3752 def _save_model_with_loader_module_and_data_path( 3753 path, 3754 loader_module, 3755 data_path=None, 3756 code_paths=None, 3757 conda_env=None, 3758 mlflow_model=None, 3759 pip_requirements=None, 3760 extra_pip_requirements=None, 3761 model_config=None, 3762 streamable=None, 3763 infer_code_paths=False, 3764 uv_project_path=None, 3765 uv_groups=None, 3766 uv_extras=None, 3767 ): 3768 """ 3769 Export model as a generic Python function model. 3770 3771 Args: 3772 path: The path to which to save the Python model. 3773 loader_module: The name of the Python module that is used to load the model 3774 from ``data_path``. This module must define a method with the prototype 3775 ``_load_pyfunc(data_path)``. 3776 data_path: Path to a file or directory containing model data. 3777 code_paths: A list of local filesystem paths to Python file dependencies (or directories 3778 containing file dependencies). These files are *prepended* to the system 3779 path before the model is loaded. 3780 conda_env: Either a dictionary representation of a Conda environment or the path to a 3781 Conda environment yaml file. If provided, this describes the environment 3782 this model should be run in. 3783 streamable: A boolean value indicating if the model supports streaming prediction, 3784 None value also means not streamable. 3785 3786 Returns: 3787 Model configuration containing model info. 3788 """ 3789 # Capture original working directory for uv project detection 3790 # This must be done before any operations that might change cwd 3791 original_cwd = Path.cwd() 3792 3793 data = None 3794 3795 if data_path is not None: 3796 model_file = _copy_file_or_tree(src=data_path, dst=path, dst_dir="data") 3797 data = model_file 3798 3799 if mlflow_model is None: 3800 mlflow_model = Model() 3801 3802 streamable = streamable or False 3803 mlflow.pyfunc.add_to_model( 3804 mlflow_model, 3805 loader_module=loader_module, 3806 code=None, 3807 data=data, 3808 conda_env=_CONDA_ENV_FILE_NAME, 3809 python_env=_PYTHON_ENV_FILE_NAME, 3810 model_config=model_config, 3811 streamable=streamable, 3812 ) 3813 if size := get_total_file_size(path): 3814 mlflow_model.model_size_bytes = size 3815 mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME)) 3816 3817 code_dir_subpath = _validate_infer_and_copy_code_paths( 3818 code_paths, path, infer_code_paths, FLAVOR_NAME 3819 ) 3820 mlflow_model.flavors[FLAVOR_NAME][CODE] = code_dir_subpath 3821 3822 # `mlflow_model.code` is updated, re-generate `MLmodel` file. 3823 mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME)) 3824 3825 if uv_project_path is not None: 3826 uv_source_dir = uv_project_path 3827 elif MLFLOW_UV_AUTO_DETECT.get(): 3828 uv_source_dir = original_cwd 3829 else: 3830 uv_source_dir = None 3831 3832 if conda_env is None: 3833 if pip_requirements is None: 3834 default_reqs = get_default_pip_requirements() 3835 extra_env_vars = ( 3836 _get_databricks_serverless_env_vars() 3837 if is_in_databricks_serverless_runtime() 3838 else None 3839 ) 3840 # To ensure `_load_pyfunc` can successfully load the model during the dependency 3841 # inference, `mlflow_model.save` must be called beforehand to save an MLmodel file. 3842 inferred_reqs = mlflow.models.infer_pip_requirements( 3843 path, 3844 FLAVOR_NAME, 3845 fallback=default_reqs, 3846 extra_env_vars=extra_env_vars, 3847 uv_project_dir=uv_source_dir, 3848 uv_groups=uv_groups, 3849 uv_extras=uv_extras, 3850 ) 3851 default_reqs = sorted(set(inferred_reqs).union(default_reqs)) 3852 else: 3853 default_reqs = None 3854 conda_env, pip_requirements, pip_constraints = _process_pip_requirements( 3855 default_reqs, 3856 pip_requirements, 3857 extra_pip_requirements, 3858 ) 3859 else: 3860 conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env) 3861 3862 with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f: 3863 yaml.safe_dump(conda_env, stream=f, default_flow_style=False) 3864 3865 # Save `constraints.txt` if necessary 3866 if pip_constraints: 3867 write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints)) 3868 3869 # Save `requirements.txt` 3870 write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements)) 3871 3872 # Copy uv project files (uv.lock and pyproject.toml) if detected 3873 if uv_source_dir is not None: 3874 copy_uv_project_files(dest_dir=path, source_dir=uv_source_dir) 3875 3876 _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME)) 3877 return mlflow_model 3878 3879 3880 def _save_model_chat_agent_helper( 3881 python_model, mlflow_model, signature, input_example, artifacts, model_config 3882 ): 3883 """Helper method for save_model for ChatAgent models 3884 3885 Returns: a dict input_example 3886 """ 3887 if signature is not None: 3888 raise MlflowException( 3889 "ChatAgent subclasses have a standard signature that is set " 3890 "automatically. Please remove the `signature` parameter from " 3891 "the call to log_model() or save_model().", 3892 error_code=INVALID_PARAMETER_VALUE, 3893 ) 3894 mlflow_model.signature = ModelSignature( 3895 inputs=CHAT_AGENT_INPUT_SCHEMA, 3896 outputs=CHAT_AGENT_OUTPUT_SCHEMA, 3897 ) 3898 # For ChatAgent we set default metadata to indicate its task 3899 default_metadata = {TASK: _DEFAULT_CHAT_AGENT_METADATA_TASK} 3900 mlflow_model.metadata = default_metadata | (mlflow_model.metadata or {}) 3901 3902 # We accept a dict with ChatAgentRequest schema 3903 if input_example: 3904 try: 3905 model_validate(ChatAgentRequest, input_example) 3906 except pydantic.ValidationError as e: 3907 raise MlflowException( 3908 message=( 3909 f"Invalid input example. Expected a ChatAgentRequest object or dictionary with" 3910 f" its schema. Pydantic validation error: {e}" 3911 ), 3912 error_code=INTERNAL_ERROR, 3913 ) from e 3914 if isinstance(input_example, ChatAgentRequest): 3915 input_example = input_example.model_dump(exclude_none=True) 3916 else: 3917 input_example = CHAT_AGENT_INPUT_EXAMPLE 3918 3919 _logger.info("Predicting on input example to validate output") 3920 context = PythonModelContext(artifacts, model_config) 3921 python_model.load_context(context) 3922 request = ChatAgentRequest(**input_example) 3923 output = python_model.predict(request.messages, request.context, request.custom_inputs) 3924 try: 3925 model_validate(ChatAgentResponse, output) 3926 except Exception as e: 3927 raise MlflowException( 3928 "Failed to save ChatAgent. Ensure your model's predict() method returns a " 3929 "ChatAgentResponse object or a dict with the same schema." 3930 f"Pydantic validation error: {e}" 3931 ) from e 3932 return input_example 3933 3934 3935 def _save_model_responses_agent_helper( 3936 python_model, mlflow_model, signature, input_example, artifacts, model_config 3937 ): 3938 """Helper method for save_model for ResponsesAgent models 3939 3940 Returns: a dictionary input example 3941 """ 3942 from mlflow.types.responses import ( 3943 RESPONSES_AGENT_INPUT_EXAMPLE, 3944 RESPONSES_AGENT_INPUT_SCHEMA, 3945 RESPONSES_AGENT_OUTPUT_SCHEMA, 3946 ResponsesAgentRequest, 3947 ResponsesAgentResponse, 3948 ) 3949 3950 if signature is not None: 3951 raise MlflowException( 3952 "ResponsesAgent subclasses have a standard signature that is set " 3953 "automatically. Please remove the `signature` parameter from " 3954 "the call to log_model() or save_model().", 3955 error_code=INVALID_PARAMETER_VALUE, 3956 ) 3957 mlflow_model.signature = ModelSignature( 3958 inputs=RESPONSES_AGENT_INPUT_SCHEMA, 3959 outputs=RESPONSES_AGENT_OUTPUT_SCHEMA, 3960 ) 3961 3962 # For ResponsesAgent we set default metadata to indicate its task 3963 default_metadata = {TASK: _DEFAULT_RESPONSES_AGENT_METADATA_TASK} 3964 mlflow_model.metadata = default_metadata | (mlflow_model.metadata or {}) 3965 3966 # We accept either a dict or a ResponsesRequest object as input 3967 if input_example: 3968 try: 3969 model_validate(ResponsesAgentRequest, input_example) 3970 except pydantic.ValidationError as e: 3971 raise MlflowException( 3972 message=( 3973 f"Invalid input example. Expected a ResponsesRequest object or dictionary with" 3974 f" its schema. Pydantic validation error: {e}" 3975 ), 3976 error_code=INTERNAL_ERROR, 3977 ) from e 3978 if isinstance(input_example, ResponsesAgentRequest): 3979 input_example = input_example.model_dump(exclude_none=True) 3980 else: 3981 input_example = RESPONSES_AGENT_INPUT_EXAMPLE 3982 _logger.info("Predicting on input example to validate output") 3983 context = PythonModelContext(artifacts, model_config) 3984 python_model.load_context(context) 3985 request = ResponsesAgentRequest(**input_example) 3986 output = python_model.predict(request) 3987 try: 3988 model_validate(ResponsesAgentResponse, output) 3989 except Exception as e: 3990 raise MlflowException( 3991 "Failed to save ResponsesAgent. Ensure your model's predict() method returns a " 3992 "ResponsesResponse object or a dict with the same schema." 3993 f"Pydantic validation error: {e}" 3994 ) from e 3995 return input_example