/ tests / pyfunc / test_model_export_with_class_and_artifacts.py
test_model_export_with_class_and_artifacts.py
   1  from __future__ import annotations
   2  
   3  import importlib.metadata
   4  import json
   5  import os
   6  import subprocess
   7  import sys
   8  import types
   9  import uuid
  10  from pathlib import Path
  11  from subprocess import PIPE, Popen
  12  from typing import Any, Dict, List
  13  from unittest import mock
  14  
  15  import cloudpickle
  16  import numpy as np
  17  import pandas as pd
  18  import pandas.testing
  19  import pytest
  20  import sklearn
  21  import sklearn.datasets
  22  import sklearn.linear_model
  23  import sklearn.neighbors
  24  import yaml
  25  
  26  import mlflow
  27  import mlflow.pyfunc
  28  import mlflow.pyfunc.model
  29  import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
  30  import mlflow.sklearn
  31  from mlflow.entities import Trace
  32  from mlflow.environment_variables import (
  33      MLFLOW_ALLOW_PICKLE_DESERIALIZATION,
  34      MLFLOW_LOG_MODEL_COMPRESSION,
  35      MLFLOW_RECORD_ENV_VARS_IN_MODEL_LOGGING,
  36  )
  37  from mlflow.exceptions import MlflowException
  38  from mlflow.models import Model, infer_signature
  39  from mlflow.models.auth_policy import AuthPolicy, SystemAuthPolicy, UserAuthPolicy
  40  from mlflow.models.dependencies_schemas import DependenciesSchemasType
  41  from mlflow.models.model import _DATABRICKS_FS_LOADER_MODULE
  42  from mlflow.models.resources import (
  43      DatabricksApp,
  44      DatabricksFunction,
  45      DatabricksGenieSpace,
  46      DatabricksLakebase,
  47      DatabricksServingEndpoint,
  48      DatabricksSQLWarehouse,
  49      DatabricksTable,
  50      DatabricksUCConnection,
  51      DatabricksVectorSearchIndex,
  52  )
  53  from mlflow.models.utils import _read_example
  54  from mlflow.pyfunc.context import Context, set_prediction_context
  55  from mlflow.pyfunc.model import _load_pyfunc
  56  from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
  57  from mlflow.tracing.constant import TraceMetadataKey
  58  from mlflow.tracing.export.inference_table import pop_trace
  59  from mlflow.tracking.artifact_utils import (
  60      _download_artifact_from_uri,
  61  )
  62  from mlflow.types.schema import ColSpec, Map, Schema
  63  from mlflow.types.type_hints import _infer_schema_from_list_type_hint
  64  from mlflow.utils.environment import _mlflow_conda_env
  65  from mlflow.utils.file_utils import TempDir
  66  from mlflow.utils.model_utils import _get_flavor_configuration
  67  from mlflow.utils.requirements_utils import _get_installed_version
  68  
  69  import tests
  70  from tests.helper_functions import (
  71      _assert_pip_requirements,
  72      _compare_conda_env_requirements,
  73      _mlflow_major_version_string,
  74      assert_register_model_called_with_local_model_path,
  75      pyfunc_serve_and_score_model,
  76  )
  77  from tests.tracing.helper import get_traces
  78  
  79  
  80  def get_model_class():
  81      """
  82      Defines a custom Python model class that wraps a scikit-learn estimator.
  83      This can be invoked within a pytest fixture to define the class in the ``__main__`` scope.
  84      Alternatively, it can be invoked within a module to define the class in the module's scope.
  85      """
  86  
  87      class CustomSklearnModel(mlflow.pyfunc.PythonModel):
  88          def __init__(self, predict_fn):
  89              self.predict_fn = predict_fn
  90  
  91          def load_context(self, context):
  92              super().load_context(context)
  93  
  94              self.model = (
  95                  mlflow.sklearn.load_model(model_uri=context.artifacts["sk_model"])
  96                  if context.artifacts and "sk_model" in context.artifacts
  97                  else None
  98              )
  99  
 100          def predict(self, context, model_input, params=None):
 101              return self.predict_fn(self.model, model_input)
 102  
 103      return CustomSklearnModel
 104  
 105  
 106  class ModuleScopedSklearnModel(get_model_class()):
 107      """
 108      A custom Python model class defined in the test module scope.
 109      """
 110  
 111  
 112  @pytest.fixture(scope="module")
 113  def main_scoped_model_class():
 114      """
 115      A custom Python model class defined in the ``__main__`` scope.
 116      """
 117      return get_model_class()
 118  
 119  
 120  @pytest.fixture(scope="module")
 121  def iris_data():
 122      iris = sklearn.datasets.load_iris()
 123      x = iris.data[:, :2]
 124      y = iris.target
 125      return x, y
 126  
 127  
 128  @pytest.fixture(scope="module")
 129  def sklearn_knn_model(iris_data):
 130      x, y = iris_data
 131      knn_model = sklearn.neighbors.KNeighborsClassifier()
 132      knn_model.fit(x, y)
 133      return knn_model
 134  
 135  
 136  @pytest.fixture(scope="module")
 137  def sklearn_logreg_model(iris_data):
 138      x, y = iris_data
 139      linear_lr = sklearn.linear_model.LogisticRegression()
 140      linear_lr.fit(x, y)
 141      return linear_lr
 142  
 143  
 144  @pytest.fixture
 145  def model_path(tmp_path):
 146      return os.path.join(tmp_path, "model")
 147  
 148  
 149  @pytest.fixture
 150  def pyfunc_custom_env(tmp_path):
 151      conda_env = os.path.join(tmp_path, "conda_env.yml")
 152      _mlflow_conda_env(
 153          conda_env,
 154          additional_pip_deps=["scikit-learn", "pytest", "cloudpickle"],
 155      )
 156      return conda_env
 157  
 158  
 159  def _conda_env():
 160      # NB: We need mlflow as a dependency in the environment.
 161      return _mlflow_conda_env(
 162          additional_pip_deps=[
 163              f"cloudpickle=={cloudpickle.__version__}",
 164              f"scikit-learn=={sklearn.__version__}",
 165          ],
 166      )
 167  
 168  
 169  def test_model_save_load(sklearn_knn_model, main_scoped_model_class, iris_data, tmp_path):
 170      sklearn_model_path = os.path.join(tmp_path, "sklearn_model")
 171      mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path)
 172  
 173      def test_predict(sk_model, model_input):
 174          return sk_model.predict(model_input) * 2
 175  
 176      pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model")
 177  
 178      mlflow.pyfunc.save_model(
 179          path=pyfunc_model_path,
 180          artifacts={"sk_model": sklearn_model_path},
 181          conda_env=_conda_env(),
 182          python_model=main_scoped_model_class(test_predict),
 183      )
 184  
 185      loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path)
 186      np.testing.assert_array_equal(
 187          loaded_pyfunc_model.predict(iris_data[0]),
 188          test_predict(sk_model=sklearn_knn_model, model_input=iris_data[0]),
 189      )
 190  
 191  
 192  @pytest.mark.skip(
 193      reason="In MLflow 3.0, `log_model` does not start a run. Consider removing this test."
 194  )
 195  def test_pyfunc_model_log_load_no_active_run(sklearn_knn_model, main_scoped_model_class, iris_data):
 196      sklearn_artifact_path = "sk_model_no_run"
 197      with mlflow.start_run():
 198          mlflow.sklearn.log_model(sklearn_knn_model, name=sklearn_artifact_path)
 199          sklearn_model_uri = f"runs:/{mlflow.active_run().info.run_id}/{sklearn_artifact_path}"
 200  
 201      def test_predict(sk_model, model_input):
 202          return sk_model.predict(model_input) * 2
 203  
 204      pyfunc_artifact_path = "pyfunc_model"
 205      assert mlflow.active_run() is None
 206      mlflow.pyfunc.log_model(
 207          name=pyfunc_artifact_path,
 208          artifacts={"sk_model": sklearn_model_uri},
 209          python_model=main_scoped_model_class(test_predict),
 210      )
 211      pyfunc_model_uri = f"runs:/{mlflow.active_run().info.run_id}/{pyfunc_artifact_path}"
 212      loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_uri)
 213      np.testing.assert_array_equal(
 214          loaded_pyfunc_model.predict(iris_data[0]),
 215          test_predict(sk_model=sklearn_knn_model, model_input=iris_data[0]),
 216      )
 217      mlflow.end_run()
 218  
 219  
 220  def test_model_log_load(sklearn_knn_model, main_scoped_model_class, iris_data):
 221      sklearn_artifact_path = "sk_model"
 222      with mlflow.start_run():
 223          sklearn_model_info = mlflow.sklearn.log_model(sklearn_knn_model, name=sklearn_artifact_path)
 224  
 225      def test_predict(sk_model, model_input):
 226          return sk_model.predict(model_input) * 2
 227  
 228      pyfunc_artifact_path = "pyfunc_model"
 229      with (
 230          mlflow.start_run(),
 231          mock.patch("mlflow.pyfunc._logger.warning") as mock_warning,
 232      ):
 233          pyfunc_model_info = mlflow.pyfunc.log_model(
 234              name=pyfunc_artifact_path,
 235              artifacts={"sk_model": sklearn_model_info.model_uri},
 236              python_model=main_scoped_model_class(test_predict),
 237          )
 238          pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_info.model_uri)
 239          model_config = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
 240          assert "Consider using a file path (str or Path) instead" in mock_warning.call_args[0][0]
 241  
 242      loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_info.model_uri)
 243      assert model_config.to_yaml() == loaded_pyfunc_model.metadata.to_yaml()
 244      np.testing.assert_array_equal(
 245          loaded_pyfunc_model.predict(iris_data[0]),
 246          test_predict(sk_model=sklearn_knn_model, model_input=iris_data[0]),
 247      )
 248  
 249  
 250  def test_python_model_predict_compatible_without_params(sklearn_knn_model, iris_data):
 251      class CustomSklearnModelWithoutParams(mlflow.pyfunc.PythonModel):
 252          def __init__(self, predict_fn):
 253              self.predict_fn = predict_fn
 254  
 255          def load_context(self, context):
 256              super().load_context(context)
 257  
 258              self.model = mlflow.sklearn.load_model(model_uri=context.artifacts["sk_model"])
 259  
 260          def predict(self, context, model_input):
 261              return self.predict_fn(self.model, model_input)
 262  
 263      sklearn_artifact_path = "sk_model"
 264      with mlflow.start_run():
 265          model_info = mlflow.sklearn.log_model(sklearn_knn_model, name=sklearn_artifact_path)
 266          sklearn_model_uri = model_info.model_uri
 267  
 268      def test_predict(sk_model, model_input):
 269          return sk_model.predict(model_input) * 2
 270  
 271      pyfunc_artifact_path = "pyfunc_model"
 272      with mlflow.start_run():
 273          model_info = mlflow.pyfunc.log_model(
 274              name=pyfunc_artifact_path,
 275              artifacts={"sk_model": sklearn_model_uri},
 276              python_model=CustomSklearnModelWithoutParams(test_predict),
 277          )
 278          pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri)
 279          model_config = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
 280  
 281      loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
 282      assert model_config.to_yaml() == loaded_pyfunc_model.metadata.to_yaml()
 283      np.testing.assert_array_equal(
 284          loaded_pyfunc_model.predict(iris_data[0]),
 285          test_predict(sk_model=sklearn_knn_model, model_input=iris_data[0]),
 286      )
 287  
 288  
 289  def test_signature_and_examples_are_saved_correctly(iris_data, main_scoped_model_class, tmp_path):
 290      sklearn_model_path = str(tmp_path.joinpath("sklearn_model"))
 291      mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path)
 292  
 293      def test_predict(sk_model, model_input):
 294          return sk_model.predict(model_input) * 2
 295  
 296      data = iris_data
 297      signature_ = infer_signature(*data)
 298      example_ = data[0][:3]
 299      for signature in (None, signature_):
 300          for example in (None, example_):
 301              with TempDir() as tmp:
 302                  path = tmp.path("model")
 303                  mlflow.pyfunc.save_model(
 304                      path=path,
 305                      artifacts={"sk_model": sklearn_model_path},
 306                      python_model=main_scoped_model_class(test_predict),
 307                      signature=signature,
 308                      input_example=example,
 309                  )
 310                  mlflow_model = Model.load(path)
 311                  assert signature == mlflow_model.signature
 312                  if example is None:
 313                      assert mlflow_model.saved_input_example_info is None
 314                  else:
 315                      np.testing.assert_array_equal(_read_example(mlflow_model, path), example)
 316  
 317  
 318  class DummyModel(mlflow.pyfunc.PythonModel):
 319      def predict(self, context, model_input, params=None):
 320          return model_input
 321  
 322  
 323  def test_log_model_calls_register_model(sklearn_knn_model, main_scoped_model_class):
 324      with mlflow.start_run():
 325          with mock.patch(
 326              "mlflow.tracking._model_registry.fluent._register_model"
 327          ) as register_model_mock:
 328              registered_model_name = "AdsModel1"
 329              pyfunc_model_info = mlflow.pyfunc.log_model(
 330                  name="pyfunc_model",
 331                  python_model=DummyModel(),
 332                  registered_model_name=registered_model_name,
 333              )
 334          assert_register_model_called_with_local_model_path(
 335              register_model_mock, pyfunc_model_info.model_uri, registered_model_name
 336          )
 337  
 338  
 339  def test_log_model_no_registered_model_name(sklearn_knn_model, main_scoped_model_class):
 340      with mlflow.start_run():
 341          with mock.patch(
 342              "mlflow.tracking._model_registry.fluent._register_model"
 343          ) as register_model_mock:
 344              mlflow.pyfunc.log_model(
 345                  name="pyfunc_model",
 346                  python_model=DummyModel(),
 347              )
 348          register_model_mock.assert_not_called()
 349  
 350  
 351  def test_model_load_from_remote_uri_succeeds(
 352      sklearn_knn_model, main_scoped_model_class, tmp_path, mock_s3_bucket, iris_data
 353  ):
 354      artifact_root = f"s3://{mock_s3_bucket}"
 355      artifact_repo = S3ArtifactRepository(artifact_root)
 356  
 357      sklearn_model_path = os.path.join(tmp_path, "sklearn_model")
 358      mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path)
 359      sklearn_artifact_path = "sk_model"
 360      artifact_repo.log_artifacts(sklearn_model_path, artifact_path=sklearn_artifact_path)
 361  
 362      def test_predict(sk_model, model_input):
 363          return sk_model.predict(model_input) * 2
 364  
 365      pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model")
 366      mlflow.pyfunc.save_model(
 367          path=pyfunc_model_path,
 368          artifacts={"sk_model": sklearn_model_path},
 369          python_model=main_scoped_model_class(test_predict),
 370          conda_env=_conda_env(),
 371      )
 372  
 373      pyfunc_artifact_path = "pyfunc_model"
 374      artifact_repo.log_artifacts(pyfunc_model_path, artifact_path=pyfunc_artifact_path)
 375  
 376      model_uri = artifact_root + "/" + pyfunc_artifact_path
 377      loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=model_uri)
 378      np.testing.assert_array_equal(
 379          loaded_pyfunc_model.predict(iris_data[0]),
 380          test_predict(sk_model=sklearn_knn_model, model_input=iris_data[0]),
 381      )
 382  
 383  
 384  def test_add_to_model_adds_specified_kwargs_to_mlmodel_configuration():
 385      custom_kwargs = {
 386          "key1": "value1",
 387          "key2": 20,
 388          "key3": range(10),
 389      }
 390      model_config = Model()
 391      mlflow.pyfunc.add_to_model(
 392          model=model_config,
 393          loader_module=os.path.basename(__file__)[:-3],
 394          data="data",
 395          code="code",
 396          env=None,
 397          **custom_kwargs,
 398      )
 399  
 400      assert mlflow.pyfunc.FLAVOR_NAME in model_config.flavors
 401      assert all(item in model_config.flavors[mlflow.pyfunc.FLAVOR_NAME] for item in custom_kwargs)
 402  
 403  
 404  def test_pyfunc_model_serving_without_conda_env_activation_succeeds_with_main_scoped_class(
 405      sklearn_knn_model, main_scoped_model_class, iris_data, tmp_path
 406  ):
 407      sklearn_model_path = os.path.join(tmp_path, "sklearn_model")
 408      mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path)
 409  
 410      def test_predict(sk_model, model_input):
 411          return sk_model.predict(model_input) * 2
 412  
 413      pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model")
 414      mlflow.pyfunc.save_model(
 415          path=pyfunc_model_path,
 416          artifacts={"sk_model": sklearn_model_path},
 417          python_model=main_scoped_model_class(test_predict),
 418          conda_env=_conda_env(),
 419      )
 420      loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path)
 421  
 422      sample_input = pd.DataFrame(iris_data[0])
 423      scoring_response = pyfunc_serve_and_score_model(
 424          model_uri=pyfunc_model_path,
 425          data=sample_input,
 426          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
 427          extra_args=["--env-manager", "local"],
 428      )
 429      assert scoring_response.status_code == 200
 430      np.testing.assert_array_equal(
 431          np.array(json.loads(scoring_response.text)["predictions"]),
 432          loaded_pyfunc_model.predict(sample_input),
 433      )
 434  
 435  
 436  def test_pyfunc_model_serving_with_conda_env_activation_succeeds_with_main_scoped_class(
 437      sklearn_knn_model, main_scoped_model_class, iris_data, tmp_path
 438  ):
 439      sklearn_model_path = os.path.join(tmp_path, "sklearn_model")
 440      mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path)
 441  
 442      def test_predict(sk_model, model_input):
 443          return sk_model.predict(model_input) * 2
 444  
 445      pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model")
 446      mlflow.pyfunc.save_model(
 447          path=pyfunc_model_path,
 448          artifacts={"sk_model": sklearn_model_path},
 449          python_model=main_scoped_model_class(test_predict),
 450          conda_env=_conda_env(),
 451      )
 452      loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path)
 453  
 454      sample_input = pd.DataFrame(iris_data[0])
 455      scoring_response = pyfunc_serve_and_score_model(
 456          model_uri=pyfunc_model_path,
 457          data=sample_input,
 458          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
 459      )
 460      assert scoring_response.status_code == 200
 461      np.testing.assert_array_equal(
 462          np.array(json.loads(scoring_response.text)["predictions"]),
 463          loaded_pyfunc_model.predict(sample_input),
 464      )
 465  
 466  
 467  def test_pyfunc_model_serving_without_conda_env_activation_succeeds_with_module_scoped_class(
 468      sklearn_knn_model, iris_data, tmp_path
 469  ):
 470      sklearn_model_path = os.path.join(tmp_path, "sklearn_model")
 471      mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path)
 472  
 473      def test_predict(sk_model, model_input):
 474          return sk_model.predict(model_input) * 2
 475  
 476      pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model")
 477      mlflow.pyfunc.save_model(
 478          path=pyfunc_model_path,
 479          artifacts={"sk_model": sklearn_model_path},
 480          python_model=ModuleScopedSklearnModel(test_predict),
 481          code_paths=[os.path.dirname(tests.__file__)],
 482          conda_env=_conda_env(),
 483      )
 484      loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path)
 485  
 486      sample_input = pd.DataFrame(iris_data[0])
 487      scoring_response = pyfunc_serve_and_score_model(
 488          model_uri=pyfunc_model_path,
 489          data=sample_input,
 490          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
 491          extra_args=["--env-manager", "local"],
 492      )
 493      assert scoring_response.status_code == 200
 494      np.testing.assert_array_equal(
 495          np.array(json.loads(scoring_response.text)["predictions"]),
 496          loaded_pyfunc_model.predict(sample_input),
 497      )
 498  
 499  
 500  def test_pyfunc_cli_predict_command_without_conda_env_activation_succeeds(
 501      sklearn_knn_model, main_scoped_model_class, iris_data, tmp_path
 502  ):
 503      sklearn_model_path = os.path.join(tmp_path, "sklearn_model")
 504      mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path)
 505  
 506      def test_predict(sk_model, model_input):
 507          return sk_model.predict(model_input) * 2
 508  
 509      pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model")
 510      mlflow.pyfunc.save_model(
 511          path=pyfunc_model_path,
 512          artifacts={"sk_model": sklearn_model_path},
 513          python_model=main_scoped_model_class(test_predict),
 514          conda_env=_conda_env(),
 515      )
 516      loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path)
 517  
 518      sample_input = pd.DataFrame(iris_data[0])
 519      input_csv_path = os.path.join(tmp_path, "input with spaces.csv")
 520      sample_input.to_csv(input_csv_path, header=True, index=False)
 521      output_json_path = os.path.join(tmp_path, "output.json")
 522      process = Popen(
 523          [
 524              sys.executable,
 525              "-m",
 526              "mlflow",
 527              "models",
 528              "predict",
 529              "--model-uri",
 530              pyfunc_model_path,
 531              "-i",
 532              input_csv_path,
 533              "--content-type",
 534              "csv",
 535              "-o",
 536              output_json_path,
 537              "--env-manager",
 538              "local",
 539          ],
 540          stdout=PIPE,
 541          stderr=PIPE,
 542          preexec_fn=os.setsid,
 543      )
 544      _, stderr = process.communicate()
 545      assert process.wait() == 0, f"stderr = \n\n{stderr}\n\n"
 546      with open(output_json_path) as f:
 547          result_df = pd.DataFrame(data=json.load(f)["predictions"])
 548      np.testing.assert_array_equal(
 549          result_df.values.transpose()[0], loaded_pyfunc_model.predict(sample_input)
 550      )
 551  
 552  
 553  def test_pyfunc_cli_predict_command_with_conda_env_activation_succeeds(
 554      sklearn_knn_model, main_scoped_model_class, iris_data, tmp_path
 555  ):
 556      sklearn_model_path = os.path.join(tmp_path, "sklearn_model")
 557      mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path)
 558  
 559      def test_predict(sk_model, model_input):
 560          return sk_model.predict(model_input) * 2
 561  
 562      pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model")
 563      mlflow.pyfunc.save_model(
 564          path=pyfunc_model_path,
 565          artifacts={"sk_model": sklearn_model_path},
 566          python_model=main_scoped_model_class(test_predict),
 567          conda_env=_conda_env(),
 568      )
 569      loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path)
 570  
 571      sample_input = pd.DataFrame(iris_data[0])
 572      input_csv_path = os.path.join(tmp_path, "input with spaces.csv")
 573      sample_input.to_csv(input_csv_path, header=True, index=False)
 574      output_json_path = os.path.join(tmp_path, "output.json")
 575      process = Popen(
 576          [
 577              sys.executable,
 578              "-m",
 579              "mlflow",
 580              "models",
 581              "predict",
 582              "--model-uri",
 583              pyfunc_model_path,
 584              "-i",
 585              input_csv_path,
 586              "--content-type",
 587              "csv",
 588              "-o",
 589              output_json_path,
 590          ],
 591          stderr=PIPE,
 592          stdout=PIPE,
 593          preexec_fn=os.setsid,
 594      )
 595      stdout, stderr = process.communicate()
 596      assert process.wait() == 0, f"stdout = \n\n{stdout}\n\n stderr = \n\n{stderr}\n\n"
 597      with open(output_json_path) as f:
 598          result_df = pandas.DataFrame(json.load(f)["predictions"])
 599      np.testing.assert_array_equal(
 600          result_df.values.transpose()[0], loaded_pyfunc_model.predict(sample_input)
 601      )
 602  
 603  
 604  def test_save_model_persists_specified_conda_env_in_mlflow_model_directory(
 605      sklearn_knn_model, main_scoped_model_class, pyfunc_custom_env, tmp_path
 606  ):
 607      sklearn_model_path = os.path.join(tmp_path, "sklearn_model")
 608      mlflow.sklearn.save_model(
 609          sk_model=sklearn_knn_model,
 610          path=sklearn_model_path,
 611          serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE,
 612      )
 613  
 614      pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model")
 615      mlflow.pyfunc.save_model(
 616          path=pyfunc_model_path,
 617          artifacts={"sk_model": sklearn_model_path},
 618          python_model=main_scoped_model_class(predict_fn=None),
 619          conda_env=pyfunc_custom_env,
 620      )
 621  
 622      pyfunc_conf = _get_flavor_configuration(
 623          model_path=pyfunc_model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME
 624      )
 625      saved_conda_env_path = os.path.join(pyfunc_model_path, pyfunc_conf[mlflow.pyfunc.ENV]["conda"])
 626      assert os.path.exists(saved_conda_env_path)
 627      assert saved_conda_env_path != pyfunc_custom_env
 628  
 629      with open(pyfunc_custom_env) as f:
 630          pyfunc_custom_env_parsed = yaml.safe_load(f)
 631      with open(saved_conda_env_path) as f:
 632          saved_conda_env_parsed = yaml.safe_load(f)
 633      assert saved_conda_env_parsed == pyfunc_custom_env_parsed
 634  
 635  
 636  def test_save_model_persists_requirements_in_mlflow_model_directory(
 637      sklearn_knn_model, main_scoped_model_class, pyfunc_custom_env, tmp_path
 638  ):
 639      sklearn_model_path = os.path.join(tmp_path, "sklearn_model")
 640      mlflow.sklearn.save_model(
 641          sk_model=sklearn_knn_model,
 642          path=sklearn_model_path,
 643          serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE,
 644      )
 645  
 646      pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model")
 647      mlflow.pyfunc.save_model(
 648          path=pyfunc_model_path,
 649          artifacts={"sk_model": sklearn_model_path},
 650          python_model=main_scoped_model_class(predict_fn=None),
 651          conda_env=pyfunc_custom_env,
 652      )
 653  
 654      saved_pip_req_path = os.path.join(pyfunc_model_path, "requirements.txt")
 655      _compare_conda_env_requirements(pyfunc_custom_env, saved_pip_req_path)
 656  
 657  
 658  def test_log_model_with_pip_requirements(sklearn_knn_model, main_scoped_model_class, tmp_path):
 659      expected_mlflow_version = _mlflow_major_version_string()
 660      python_model = main_scoped_model_class(predict_fn=None)
 661      sklearn_model_path = os.path.join(tmp_path, "sklearn_model")
 662      mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path)
 663      # Path to a requirements file
 664      req_file = tmp_path.joinpath("requirements.txt")
 665      req_file.write_text("a")
 666      with mlflow.start_run():
 667          model_info = mlflow.pyfunc.log_model(
 668              name="model",
 669              python_model=python_model,
 670              pip_requirements=str(req_file),
 671              artifacts={"sk_model": sklearn_model_path},
 672          )
 673          _assert_pip_requirements(
 674              model_info.model_uri,
 675              [expected_mlflow_version, "a"],
 676              strict=True,
 677          )
 678  
 679      # List of requirements
 680      with mlflow.start_run():
 681          model_info = mlflow.pyfunc.log_model(
 682              name="model",
 683              python_model=python_model,
 684              pip_requirements=[f"-r {req_file}", "b"],
 685              artifacts={"sk_model": sklearn_model_path},
 686          )
 687          _assert_pip_requirements(
 688              model_info.model_uri,
 689              [expected_mlflow_version, "a", "b"],
 690              strict=True,
 691          )
 692  
 693      # Constraints file
 694      with mlflow.start_run():
 695          model_info = mlflow.pyfunc.log_model(
 696              name="model",
 697              python_model=python_model,
 698              pip_requirements=[f"-c {req_file}", "b"],
 699              artifacts={"sk_model": sklearn_model_path},
 700          )
 701          _assert_pip_requirements(
 702              model_info.model_uri,
 703              [expected_mlflow_version, "b", "-c constraints.txt"],
 704              ["a"],
 705              strict=True,
 706          )
 707  
 708  
 709  def test_log_model_with_extra_pip_requirements(
 710      sklearn_knn_model, main_scoped_model_class, tmp_path
 711  ):
 712      expected_mlflow_version = _mlflow_major_version_string()
 713      sklearn_model_path = str(tmp_path.joinpath("sklearn_model"))
 714      mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path)
 715  
 716      python_model = main_scoped_model_class(predict_fn=None)
 717      default_reqs = mlflow.pyfunc.get_default_pip_requirements()
 718  
 719      # Path to a requirements file
 720      req_file = tmp_path.joinpath("requirements.txt")
 721      req_file.write_text("a")
 722      with mlflow.start_run():
 723          model_info = mlflow.pyfunc.log_model(
 724              name="model",
 725              python_model=python_model,
 726              artifacts={"sk_model": sklearn_model_path},
 727              extra_pip_requirements=str(req_file),
 728          )
 729          _assert_pip_requirements(
 730              model_info.model_uri,
 731              [expected_mlflow_version, *default_reqs, "a"],
 732          )
 733  
 734      # List of requirements
 735      with mlflow.start_run():
 736          model_info = mlflow.pyfunc.log_model(
 737              name="model",
 738              artifacts={"sk_model": sklearn_model_path},
 739              python_model=python_model,
 740              extra_pip_requirements=[f"-r {req_file}", "b"],
 741          )
 742          _assert_pip_requirements(
 743              model_info.model_uri,
 744              [expected_mlflow_version, *default_reqs, "a", "b"],
 745          )
 746  
 747      # Constraints file
 748      with mlflow.start_run():
 749          model_info = mlflow.pyfunc.log_model(
 750              name="model",
 751              artifacts={"sk_model": sklearn_model_path},
 752              python_model=python_model,
 753              extra_pip_requirements=[f"-c {req_file}", "b"],
 754          )
 755          _assert_pip_requirements(
 756              model_info.model_uri,
 757              [expected_mlflow_version, *default_reqs, "b", "-c constraints.txt"],
 758              ["a"],
 759          )
 760  
 761  
 762  def test_log_model_persists_specified_conda_env_in_mlflow_model_directory(
 763      sklearn_knn_model, main_scoped_model_class, pyfunc_custom_env
 764  ):
 765      sklearn_artifact_path = "sk_model"
 766      with mlflow.start_run():
 767          sklearn_model_info = mlflow.sklearn.log_model(sklearn_knn_model, name=sklearn_artifact_path)
 768  
 769      pyfunc_artifact_path = "pyfunc_model"
 770      with mlflow.start_run():
 771          pyfunc_model_info = mlflow.pyfunc.log_model(
 772              name=pyfunc_artifact_path,
 773              artifacts={"sk_model": sklearn_model_info.model_uri},
 774              python_model=main_scoped_model_class(predict_fn=None),
 775              conda_env=pyfunc_custom_env,
 776          )
 777          pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_info.model_uri)
 778  
 779      pyfunc_conf = _get_flavor_configuration(
 780          model_path=pyfunc_model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME
 781      )
 782      saved_conda_env_path = os.path.join(pyfunc_model_path, pyfunc_conf[mlflow.pyfunc.ENV]["conda"])
 783      assert os.path.exists(saved_conda_env_path)
 784      assert saved_conda_env_path != pyfunc_custom_env
 785  
 786      with open(pyfunc_custom_env) as f:
 787          pyfunc_custom_env_parsed = yaml.safe_load(f)
 788      with open(saved_conda_env_path) as f:
 789          saved_conda_env_parsed = yaml.safe_load(f)
 790      assert saved_conda_env_parsed == pyfunc_custom_env_parsed
 791  
 792  
 793  def test_model_log_persists_requirements_in_mlflow_model_directory(
 794      sklearn_knn_model, main_scoped_model_class, pyfunc_custom_env
 795  ):
 796      sklearn_artifact_path = "sk_model"
 797      with mlflow.start_run():
 798          sklearn_model_info = mlflow.sklearn.log_model(sklearn_knn_model, name=sklearn_artifact_path)
 799  
 800      pyfunc_artifact_path = "pyfunc_model"
 801      with mlflow.start_run():
 802          pyfunc_model_info = mlflow.pyfunc.log_model(
 803              name=pyfunc_artifact_path,
 804              artifacts={"sk_model": sklearn_model_info.model_uri},
 805              python_model=main_scoped_model_class(predict_fn=None),
 806              conda_env=pyfunc_custom_env,
 807          )
 808          pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_info.model_uri)
 809  
 810      saved_pip_req_path = os.path.join(pyfunc_model_path, "requirements.txt")
 811      _compare_conda_env_requirements(pyfunc_custom_env, saved_pip_req_path)
 812  
 813  
 814  def test_save_model_without_specified_conda_env_uses_default_env_with_expected_dependencies(
 815      sklearn_logreg_model, main_scoped_model_class, tmp_path
 816  ):
 817      sklearn_model_path = os.path.join(tmp_path, "sklearn_model")
 818      mlflow.sklearn.save_model(sk_model=sklearn_logreg_model, path=sklearn_model_path)
 819  
 820      pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model")
 821      mlflow.pyfunc.save_model(
 822          path=pyfunc_model_path,
 823          artifacts={"sk_model": sklearn_model_path},
 824          python_model=main_scoped_model_class(predict_fn=None),
 825          conda_env=_conda_env(),
 826      )
 827      _assert_pip_requirements(pyfunc_model_path, mlflow.pyfunc.get_default_pip_requirements())
 828  
 829  
 830  def test_log_model_without_specified_conda_env_uses_default_env_with_expected_dependencies(
 831      sklearn_knn_model, main_scoped_model_class
 832  ):
 833      sklearn_artifact_path = "sk_model"
 834      with mlflow.start_run():
 835          sklearn_model_info = mlflow.sklearn.log_model(sklearn_knn_model, name=sklearn_artifact_path)
 836  
 837      pyfunc_artifact_path = "pyfunc_model"
 838      with mlflow.start_run():
 839          pyfunc_model_info = mlflow.pyfunc.log_model(
 840              name=pyfunc_artifact_path,
 841              artifacts={
 842                  "sk_model": sklearn_model_info.model_uri,
 843              },
 844              python_model=main_scoped_model_class(predict_fn=None),
 845          )
 846      _assert_pip_requirements(
 847          pyfunc_model_info.model_uri, mlflow.pyfunc.get_default_pip_requirements()
 848      )
 849  
 850  
 851  def test_save_model_correctly_resolves_directory_artifact_with_nested_contents(
 852      tmp_path, model_path, iris_data
 853  ):
 854      directory_artifact_path = os.path.join(tmp_path, "directory_artifact")
 855      nested_file_relative_path = os.path.join(
 856          "my", "somewhat", "heavily", "nested", "directory", "myfile.txt"
 857      )
 858      nested_file_path = os.path.join(directory_artifact_path, nested_file_relative_path)
 859      os.makedirs(os.path.dirname(nested_file_path))
 860      nested_file_text = "some sample file text"
 861      with open(nested_file_path, "w") as f:
 862          f.write(nested_file_text)
 863  
 864      class ArtifactValidationModel(mlflow.pyfunc.PythonModel):
 865          def predict(self, context, model_input, params=None):
 866              expected_file_path = os.path.join(
 867                  context.artifacts["testdir"], nested_file_relative_path
 868              )
 869              if not os.path.exists(expected_file_path):
 870                  return False
 871              else:
 872                  with open(expected_file_path) as f:
 873                      return f.read() == nested_file_text
 874  
 875      mlflow.pyfunc.save_model(
 876          path=model_path,
 877          artifacts={"testdir": directory_artifact_path},
 878          python_model=ArtifactValidationModel(),
 879          conda_env=_conda_env(),
 880      )
 881  
 882      loaded_model = mlflow.pyfunc.load_model(model_uri=model_path)
 883      assert loaded_model.predict(iris_data[0])
 884  
 885  
 886  def test_save_model_with_no_artifacts_does_not_produce_artifacts_dir(model_path):
 887      mlflow.pyfunc.save_model(
 888          path=model_path,
 889          python_model=ModuleScopedSklearnModel(predict_fn=None),
 890          artifacts=None,
 891          conda_env=_conda_env(),
 892      )
 893  
 894      assert os.path.exists(model_path)
 895      assert "artifacts" not in os.listdir(model_path)
 896      pyfunc_conf = _get_flavor_configuration(
 897          model_path=model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME
 898      )
 899      assert mlflow.pyfunc.model.CONFIG_KEY_ARTIFACTS not in pyfunc_conf
 900  
 901  
 902  def test_save_model_with_python_model_argument_of_invalid_type_raises_exception(
 903      tmp_path,
 904  ):
 905      with pytest.raises(
 906          MlflowException,
 907          match="must be a PythonModel instance, callable object, or path to a",
 908      ):
 909          mlflow.pyfunc.save_model(path=os.path.join(tmp_path, "model1"), python_model=5)
 910  
 911      with pytest.raises(
 912          MlflowException,
 913          match="must be a PythonModel instance, callable object, or path to a",
 914      ):
 915          mlflow.pyfunc.save_model(
 916              path=os.path.join(tmp_path, "model2"), python_model=["not a python model"]
 917          )
 918      with pytest.raises(MlflowException, match="The provided model path"):
 919          mlflow.pyfunc.save_model(
 920              path=os.path.join(tmp_path, "model3"), python_model="not a valid filepath"
 921          )
 922  
 923  
 924  def test_save_model_with_unsupported_argument_combinations_throws_exception(model_path):
 925      with pytest.raises(
 926          MlflowException,
 927          match="Either `loader_module` or `python_model` must be specified",
 928      ) as exc_info:
 929          mlflow.pyfunc.save_model(
 930              path=model_path,
 931              artifacts={"artifact": "/path/to/artifact"},
 932              python_model=None,
 933          )
 934  
 935      python_model = ModuleScopedSklearnModel(predict_fn=None)
 936      loader_module = __name__
 937      with pytest.raises(
 938          MlflowException,
 939          match="The following sets of parameters cannot be specified together",
 940      ) as exc_info:
 941          mlflow.pyfunc.save_model(
 942              path=model_path, python_model=python_model, loader_module=loader_module
 943          )
 944      assert str(python_model) in str(exc_info)
 945      assert str(loader_module) in str(exc_info)
 946  
 947      with pytest.raises(
 948          MlflowException,
 949          match="The following sets of parameters cannot be specified together",
 950      ) as exc_info:
 951          mlflow.pyfunc.save_model(
 952              path=model_path,
 953              python_model=python_model,
 954              data_path="/path/to/data",
 955              artifacts={"artifact": "/path/to/artifact"},
 956          )
 957  
 958      with pytest.raises(
 959          MlflowException,
 960          match="Either `loader_module` or `python_model` must be specified",
 961      ):
 962          mlflow.pyfunc.save_model(path=model_path, python_model=None, loader_module=None)
 963  
 964  
 965  def test_log_model_with_unsupported_argument_combinations_throws_exception():
 966      match = (
 967          "Either `loader_module` or `python_model` must be specified. A `loader_module` "
 968          "should be a python module. A `python_model` should be a subclass of "
 969          "PythonModel"
 970      )
 971      with mlflow.start_run(), pytest.raises(MlflowException, match=match):
 972          mlflow.pyfunc.log_model(
 973              name="pyfunc_model",
 974              artifacts={"artifact": "/path/to/artifact"},
 975              python_model=None,
 976          )
 977  
 978      python_model = ModuleScopedSklearnModel(predict_fn=None)
 979      loader_module = __name__
 980      with (
 981          mlflow.start_run(),
 982          pytest.raises(
 983              MlflowException,
 984              match="The following sets of parameters cannot be specified together",
 985          ) as exc_info,
 986      ):
 987          mlflow.pyfunc.log_model(
 988              name="pyfunc_model",
 989              python_model=python_model,
 990              loader_module=loader_module,
 991          )
 992      assert str(python_model) in str(exc_info)
 993      assert str(loader_module) in str(exc_info)
 994  
 995      with (
 996          mlflow.start_run(),
 997          pytest.raises(
 998              MlflowException,
 999              match="The following sets of parameters cannot be specified together",
1000          ) as exc_info,
1001      ):
1002          mlflow.pyfunc.log_model(
1003              name="pyfunc_model",
1004              python_model=python_model,
1005              data_path="/path/to/data",
1006              artifacts={"artifact1": "/path/to/artifact"},
1007          )
1008  
1009      with (
1010          mlflow.start_run(),
1011          pytest.raises(
1012              MlflowException,
1013              match="Either `loader_module` or `python_model` must be specified",
1014          ),
1015      ):
1016          mlflow.pyfunc.log_model(name="pyfunc_model", python_model=None, loader_module=None)
1017  
1018  
1019  def test_repr_can_be_called_without_run_id_or_artifact_path():
1020      model_meta = Model(
1021          artifact_path=None,
1022          run_id=None,
1023          flavors={"python_function": {"loader_module": "someFlavour"}},
1024      )
1025  
1026      class TestModel:
1027          def predict(self, model_input, params=None):
1028              return model_input
1029  
1030      model_impl = TestModel()
1031  
1032      assert "flavor: someFlavour" in mlflow.pyfunc.PyFuncModel(model_meta, model_impl).__repr__()
1033  
1034  
1035  def test_load_model_with_differing_cloudpickle_version_at_micro_granularity_logs_warning(
1036      model_path,
1037  ):
1038      class TestModel(mlflow.pyfunc.PythonModel):
1039          def predict(self, context, model_input, params=None):
1040              return model_input
1041  
1042      mlflow.pyfunc.save_model(path=model_path, python_model=TestModel())
1043      saver_cloudpickle_version = "0.5.8"
1044      model_config_path = os.path.join(model_path, "MLmodel")
1045      model_config = Model.load(model_config_path)
1046      model_config.flavors[mlflow.pyfunc.FLAVOR_NAME][
1047          mlflow.pyfunc.model.CONFIG_KEY_CLOUDPICKLE_VERSION
1048      ] = saver_cloudpickle_version
1049      model_config.save(model_config_path)
1050  
1051      log_messages = []
1052  
1053      def custom_warn(message_text, *args, **kwargs):
1054          log_messages.append(message_text % args % kwargs)
1055  
1056      loader_cloudpickle_version = "0.5.7"
1057      with (
1058          mock.patch("mlflow.pyfunc._logger.warning") as warn_mock,
1059          mock.patch("cloudpickle.__version__") as cloudpickle_version_mock,
1060      ):
1061          cloudpickle_version_mock.__str__ = lambda *args, **kwargs: loader_cloudpickle_version
1062          warn_mock.side_effect = custom_warn
1063          mlflow.pyfunc.load_model(model_uri=model_path)
1064  
1065      assert any(
1066          "differs from the version of CloudPickle that is currently running" in log_message
1067          and saver_cloudpickle_version in log_message
1068          and loader_cloudpickle_version in log_message
1069          for log_message in log_messages
1070      )
1071  
1072  
1073  def test_load_model_with_missing_cloudpickle_version_logs_warning(model_path):
1074      class TestModel(mlflow.pyfunc.PythonModel):
1075          def predict(self, context, model_input, params=None):
1076              return model_input
1077  
1078      mlflow.pyfunc.save_model(path=model_path, python_model=TestModel())
1079      model_config_path = os.path.join(model_path, "MLmodel")
1080      model_config = Model.load(model_config_path)
1081      del model_config.flavors[mlflow.pyfunc.FLAVOR_NAME][
1082          mlflow.pyfunc.model.CONFIG_KEY_CLOUDPICKLE_VERSION
1083      ]
1084      model_config.save(model_config_path)
1085  
1086      log_messages = []
1087  
1088      def custom_warn(message_text, *args, **kwargs):
1089          log_messages.append(message_text % args % kwargs)
1090  
1091      with mock.patch("mlflow.pyfunc._logger.warning") as warn_mock:
1092          warn_mock.side_effect = custom_warn
1093          mlflow.pyfunc.load_model(model_uri=model_path)
1094  
1095      assert any(
1096          (
1097              "The version of CloudPickle used to save the model could not be found"
1098              " in the MLmodel configuration"
1099          )
1100          in log_message
1101          for log_message in log_messages
1102      )
1103  
1104  
1105  def test_load_cloudpickle_model_raises_when_pickle_deserialization_disallowed(
1106      model_path, monkeypatch
1107  ):
1108      class TestModel(mlflow.pyfunc.PythonModel):
1109          def predict(self, context, model_input, params=None):
1110              return model_input
1111  
1112      mlflow.pyfunc.save_model(path=model_path, python_model=TestModel())
1113      monkeypatch.setenv(MLFLOW_ALLOW_PICKLE_DESERIALIZATION.name, "false")
1114  
1115      with pytest.raises(MlflowException, match="Deserializing model using pickle is disallowed"):
1116          mlflow.pyfunc.load_model(model_uri=model_path)
1117  
1118  
1119  def test_save_and_load_model_with_special_chars(
1120      sklearn_knn_model, main_scoped_model_class, iris_data, tmp_path
1121  ):
1122      sklearn_model_path = os.path.join(tmp_path, "sklearn_  model")
1123      mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path)
1124  
1125      def test_predict(sk_model, model_input):
1126          return sk_model.predict(model_input) * 2
1127  
1128      # Intentionally create a path that has non-url-compatible characters
1129      pyfunc_model_path = os.path.join(tmp_path, "pyfunc_ :% model")
1130  
1131      mlflow.pyfunc.save_model(
1132          path=pyfunc_model_path,
1133          artifacts={"sk_model": sklearn_model_path},
1134          conda_env=_conda_env(),
1135          python_model=main_scoped_model_class(test_predict),
1136      )
1137  
1138      loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path)
1139      np.testing.assert_array_equal(
1140          loaded_pyfunc_model.predict(iris_data[0]),
1141          test_predict(sk_model=sklearn_knn_model, model_input=iris_data[0]),
1142      )
1143  
1144  
1145  def test_model_with_code_path_containing_main(tmp_path):
1146      directory = tmp_path.joinpath("model_with_main")
1147      directory.mkdir()
1148      main = directory.joinpath("__main__.py")
1149      main.write_text("# empty main")
1150      with mlflow.start_run():
1151          model_info = mlflow.pyfunc.log_model(
1152              name="model",
1153              python_model=mlflow.pyfunc.model.PythonModel(),
1154              code_paths=[str(directory)],
1155          )
1156  
1157      assert "__main__" in sys.modules
1158      mlflow.pyfunc.load_model(model_info.model_uri)
1159      assert "__main__" in sys.modules
1160  
1161  
1162  def test_model_save_load_with_metadata(tmp_path):
1163      pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model")
1164  
1165      mlflow.pyfunc.save_model(
1166          path=pyfunc_model_path,
1167          conda_env=_conda_env(),
1168          python_model=mlflow.pyfunc.model.PythonModel(),
1169          metadata={"metadata_key": "metadata_value"},
1170      )
1171  
1172      reloaded_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path)
1173      assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value"
1174  
1175  
1176  def test_model_log_with_metadata():
1177      pyfunc_artifact_path = "pyfunc_model"
1178      with mlflow.start_run():
1179          mlflow.pyfunc.log_model(
1180              name=pyfunc_artifact_path,
1181              python_model=mlflow.pyfunc.model.PythonModel(),
1182              metadata={"metadata_key": "metadata_value"},
1183          )
1184          pyfunc_model_uri = f"runs:/{mlflow.active_run().info.run_id}/{pyfunc_artifact_path}"
1185  
1186      reloaded_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_uri)
1187      assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value"
1188  
1189  
1190  class SklearnModel(mlflow.pyfunc.PythonModel):
1191      def __init__(self) -> None:
1192          from sklearn.linear_model import LinearRegression
1193  
1194          self.model = LinearRegression()
1195  
1196      def predict(self, context, model_input, params=None):
1197          return self.model.predict(model_input)
1198  
1199  
1200  def test_dependency_inference_does_not_exclude_mlflow_dependencies(tmp_path):
1201      mlflow.pyfunc.save_model(
1202          path=tmp_path,
1203          python_model=SklearnModel(),
1204      )
1205      requiments = tmp_path.joinpath("requirements.txt").read_text()
1206      assert f"scikit-learn=={sklearn.__version__}" in requiments
1207  
1208  
1209  def test_functional_python_model_no_type_hints(tmp_path):
1210      def python_model(x):
1211          return x
1212  
1213      mlflow.pyfunc.save_model(path=tmp_path, python_model=python_model, input_example=[{"a": "b"}])
1214      model = Model.load(tmp_path)
1215      assert model.signature.inputs == Schema([ColSpec("string", name="a")])
1216      assert model.signature.outputs == Schema([ColSpec("string", name="a")])
1217  
1218  
1219  def list_to_list(x: List[str]) -> List[str]:  # noqa: UP006
1220      return x
1221  
1222  
1223  def list_dict_to_list(x: List[Dict[str, str]]) -> List[str]:  # noqa: UP006
1224      return ["".join((*d.keys(), *d.values())) for d in x]  # join keys and values
1225  
1226  
1227  def test_functional_python_model_list_dict_to_list_without_example(tmp_path):
1228      mlflow.pyfunc.save_model(
1229          path=tmp_path, python_model=list_dict_to_list, pip_requirements=["pandas"]
1230      )
1231      model = Model.load(tmp_path)
1232      assert model.signature.inputs == Schema([ColSpec(Map("string"))])
1233      assert model.signature.outputs == Schema([ColSpec("string")])
1234      loaded_model = mlflow.pyfunc.load_model(tmp_path)
1235      assert loaded_model.predict([{"a": "x"}, {"a": "y"}]) == ["ax", "ay"]
1236  
1237  
1238  @pytest.mark.parametrize(
1239      ("input_example"),
1240      [
1241          [0],
1242          [{"a": "b"}],
1243      ],
1244  )
1245  def test_functional_python_model_list_invalid_example(tmp_path, input_example):
1246      with mock.patch("mlflow.models.signature._logger.warning") as mock_warning:
1247          mlflow.pyfunc.save_model(
1248              path=tmp_path, python_model=list_to_list, input_example=input_example
1249          )
1250          assert any(
1251              "Input example is not compatible with the type hint" in call[0][0]
1252              for call in mock_warning.call_args_list
1253          )
1254  
1255  
1256  @pytest.mark.parametrize(
1257      "input_example",
1258      [
1259          ["a"],
1260          [{0: "a"}],
1261          [{"a": 0}],
1262      ],
1263  )
1264  def test_functional_python_model_list_dict_invalid_example(tmp_path, input_example):
1265      with mock.patch("mlflow.models.signature._logger.warning") as mock_warning:
1266          mlflow.pyfunc.save_model(
1267              path=tmp_path, python_model=list_dict_to_list, input_example=input_example
1268          )
1269          assert any(
1270              "Input example is not compatible with the type hint" in call[0][0]
1271              for call in mock_warning.call_args_list
1272          )
1273  
1274  
1275  def test_functional_python_model_list_dict_to_list(tmp_path):
1276      mlflow.pyfunc.save_model(
1277          path=tmp_path,
1278          python_model=list_dict_to_list,
1279          input_example=[{"a": "x", "b": "y"}],
1280      )
1281      model = Model.load(tmp_path)
1282      assert model.signature.inputs == Schema([ColSpec(Map("string"))])
1283      assert model.signature.outputs == Schema([ColSpec("string")])
1284      loaded_model = mlflow.pyfunc.load_model(tmp_path)
1285      assert loaded_model.predict([{"a": "x", "b": "y"}]) == ["abxy"]
1286  
1287  
1288  def list_dict_to_list_dict(x: list[dict[str, str]]) -> list[dict[str, str]]:
1289      return [{v: k for k, v in d.items()} for d in x]  # swap keys and values
1290  
1291  
1292  def test_functional_python_model_list_dict_to_list_dict():
1293      with mlflow.start_run():
1294          model_info = mlflow.pyfunc.log_model(
1295              name="test_model",
1296              python_model=list_dict_to_list_dict,
1297              input_example=[{"a": "x", "b": "y"}],
1298          )
1299  
1300      assert model_info.signature.inputs.to_dict() == [
1301          {"type": "map", "values": {"type": "string"}, "required": True}
1302      ]
1303      assert model_info.signature.outputs.to_dict() == [
1304          {"type": "map", "values": {"type": "string"}, "required": True}
1305      ]
1306  
1307      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
1308      assert pyfunc_model.predict([{"a": "x", "b": "y"}]) == [{"x": "a", "y": "b"}]
1309  
1310  
1311  def test_list_dict_with_signature_override():
1312      class CustomModel(mlflow.pyfunc.PythonModel):
1313          def predict(self, context, model_input: list[dict[str, str]], params=None):
1314              return model_input
1315  
1316      signature = infer_signature([{"a": "x", "b": "y"}, {"a": "z"}])
1317      with mlflow.start_run():
1318          model_info = mlflow.pyfunc.log_model(
1319              name="test_model",
1320              python_model=CustomModel(),
1321              signature=signature,
1322          )
1323      assert model_info.signature.inputs == _infer_schema_from_list_type_hint(list[dict[str, str]])
1324      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
1325      assert pyfunc_model.predict([{"a": "z"}]) == [{"a": "z"}]
1326  
1327  
1328  def list_dict_to_list_dict_pep585(x: list[dict[str, str]]) -> list[dict[str, str]]:
1329      return [{v: k for k, v in d.items()} for d in x]  # swap keys and values
1330  
1331  
1332  def test_functional_python_model_list_dict_to_list_dict_with_example_pep585(tmp_path):
1333      mlflow.pyfunc.save_model(
1334          path=tmp_path,
1335          python_model=list_dict_to_list_dict_pep585,
1336          input_example=[{"a": "x", "b": "y"}],
1337      )
1338      model = Model.load(tmp_path)
1339      assert model.signature.inputs.to_dict() == [
1340          {"type": "map", "values": {"type": "string"}, "required": True},
1341      ]
1342      assert model.signature.outputs.to_dict() == [
1343          {"type": "map", "values": {"type": "string"}, "required": True},
1344      ]
1345      loaded_model = mlflow.pyfunc.load_model(tmp_path)
1346      assert loaded_model.predict([{"a": "x", "b": "y"}]) == [{"x": "a", "y": "b"}]
1347  
1348  
1349  def multiple_arguments(x: list[str], y: list[str]) -> list[str]:
1350      return x + y
1351  
1352  
1353  def test_functional_python_model_multiple_arguments(tmp_path):
1354      with pytest.raises(
1355          MlflowException, match=r"must accept exactly one argument\. Found 2 arguments\."
1356      ):
1357          mlflow.pyfunc.save_model(path=tmp_path, python_model=multiple_arguments)
1358  
1359  
1360  def no_arguments() -> list[str]:
1361      return []
1362  
1363  
1364  def test_functional_python_model_no_arguments(tmp_path):
1365      with pytest.raises(
1366          MlflowException, match=r"must accept exactly one argument\. Found 0 arguments\."
1367      ):
1368          mlflow.pyfunc.save_model(path=tmp_path, python_model=no_arguments)
1369  
1370  
1371  def requires_sklearn(x: list[str]) -> list[str]:
1372      import sklearn  # noqa: F401
1373  
1374      return x
1375  
1376  
1377  def test_functional_python_model_infer_requirements(tmp_path):
1378      mlflow.pyfunc.save_model(path=tmp_path, python_model=requires_sklearn, input_example=["a"])
1379      assert "scikit-learn==" in tmp_path.joinpath("requirements.txt").read_text()
1380  
1381  
1382  def test_functional_python_model_throws_when_required_arguments_are_missing(tmp_path):
1383      mlflow.pyfunc.save_model(
1384          path=tmp_path / uuid.uuid4().hex,
1385          python_model=requires_sklearn,
1386          input_example=["a"],
1387      )
1388      mlflow.pyfunc.save_model(
1389          path=tmp_path / uuid.uuid4().hex,
1390          python_model=requires_sklearn,
1391          pip_requirements=["scikit-learn"],
1392      )
1393      mlflow.pyfunc.save_model(
1394          path=tmp_path / uuid.uuid4().hex,
1395          python_model=requires_sklearn,
1396          extra_pip_requirements=["scikit-learn"],
1397      )
1398      with pytest.raises(MlflowException, match="at least one of"):
1399          mlflow.pyfunc.save_model(path=tmp_path / uuid.uuid4().hex, python_model=requires_sklearn)
1400  
1401  
1402  class AnnotatedPythonModel(mlflow.pyfunc.PythonModel):
1403      def predict(self, context: dict[str, Any], model_input: list[str], params=None) -> list[str]:
1404          assert isinstance(model_input, list)
1405          assert all(isinstance(x, str) for x in model_input)
1406          return model_input
1407  
1408  
1409  def test_class_python_model_type_hints(tmp_path):
1410      mlflow.pyfunc.save_model(path=tmp_path, python_model=AnnotatedPythonModel())
1411      model = Model.load(tmp_path)
1412      assert model.signature.inputs.to_dict() == [{"type": "string", "required": True}]
1413      assert model.signature.outputs.to_dict() == [{"type": "string", "required": True}]
1414      model = mlflow.pyfunc.load_model(tmp_path)
1415      assert model.predict(["a", "b"]) == ["a", "b"]
1416  
1417  
1418  def test_python_model_predict_with_params():
1419      with mlflow.start_run():
1420          model_info = mlflow.pyfunc.log_model(
1421              name="test_model",
1422              python_model=AnnotatedPythonModel(),
1423          )
1424  
1425      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1426      assert loaded_model.predict(["a", "b"], params={"foo": [0, 1]}) == ["a", "b"]
1427      assert loaded_model.predict(["a", "b"], params={"foo": np.array([0, 1])}) == [
1428          "a",
1429          "b",
1430      ]
1431  
1432  
1433  def test_python_model_with_type_hint_errors_with_different_signature():
1434      signature = infer_signature(["input1", "input2"], params={"foo": [8]})
1435  
1436      with mlflow.start_run():
1437          with mock.patch("mlflow.pyfunc._logger.warning") as warn_mock:
1438              mlflow.pyfunc.log_model(
1439                  name="test_model",
1440                  python_model=AnnotatedPythonModel(),
1441                  signature=signature,
1442              )
1443          assert (
1444              "Provided signature does not match the signature inferred from"
1445              in warn_mock.call_args[0][0]
1446          )
1447  
1448  
1449  def test_artifact_path_posix(sklearn_knn_model, main_scoped_model_class, tmp_path):
1450      sklearn_model_path = tmp_path.joinpath("sklearn_model")
1451      mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path)
1452  
1453      def test_predict(sk_model, model_input):
1454          return sk_model.predict(model_input) * 2
1455  
1456      pyfunc_model_path = tmp_path.joinpath("pyfunc_model")
1457  
1458      mlflow.pyfunc.save_model(
1459          path=pyfunc_model_path,
1460          artifacts={"sk_model": str(sklearn_model_path)},
1461          conda_env=_conda_env(),
1462          python_model=main_scoped_model_class(test_predict),
1463      )
1464  
1465      artifacts = _load_pyfunc(pyfunc_model_path).context.artifacts
1466      assert all("\\" not in artifact_uri for artifact_uri in artifacts.values())
1467  
1468  
1469  def test_load_model_fails_for_feature_store_models(tmp_path):
1470      feature_store = os.path.join(tmp_path, "feature_store")
1471      os.mkdir(feature_store)
1472      feature_spec = os.path.join(feature_store, "feature_spec.yaml")
1473      with open(feature_spec, "w+") as f:
1474          f.write("contents")
1475  
1476      with mlflow.start_run() as run:
1477          mlflow.pyfunc.log_model(
1478              name="model",
1479              data_path=feature_store,
1480              loader_module=_DATABRICKS_FS_LOADER_MODULE,
1481              code_paths=[__file__],
1482          )
1483      with pytest.raises(
1484          MlflowException,
1485          match="Note: mlflow.pyfunc.load_model is not supported for Feature Store models",
1486      ):
1487          mlflow.pyfunc.load_model(f"runs:/{run.info.run_id}/model")
1488  
1489  
1490  def test_pyfunc_model_infer_signature_from_type_hints():
1491      class TestModel(mlflow.pyfunc.PythonModel):
1492          def predict(self, context, model_input: list[str], params=None) -> list[str]:
1493              return model_input
1494  
1495      with mlflow.start_run():
1496          model_info = mlflow.pyfunc.log_model(
1497              name="test_model",
1498              python_model=TestModel(),
1499              input_example=["a"],
1500          )
1501      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
1502      assert pyfunc_model.metadata.get_input_schema() == Schema([ColSpec("string")])
1503      assert pyfunc_model.predict(["a", "b"]) == ["a", "b"]
1504  
1505  
1506  def test_streamable_model_save_load(iris_data, tmp_path):
1507      class StreamableModel(mlflow.pyfunc.PythonModel):
1508          def __init__(self):
1509              pass
1510  
1511          def predict(self, context, model_input, params=None):
1512              pass
1513  
1514          def predict_stream(self, context, model_input, params=None):
1515              yield "test1"
1516              yield "test2"
1517  
1518      pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model")
1519  
1520      python_model = StreamableModel()
1521  
1522      mlflow.pyfunc.save_model(
1523          path=pyfunc_model_path,
1524          python_model=python_model,
1525      )
1526  
1527      loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path)
1528  
1529      stream_result = loaded_pyfunc_model.predict_stream("single-input")
1530      assert isinstance(stream_result, types.GeneratorType)
1531  
1532      assert list(stream_result) == ["test1", "test2"]
1533  
1534  
1535  def test_streamable_model_save_load(tmp_path):
1536      pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model")
1537  
1538      mlflow.pyfunc.save_model(
1539          path=pyfunc_model_path,
1540          python_model="tests/pyfunc/sample_code/streamable_model_code.py",
1541      )
1542  
1543      loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path)
1544  
1545      stream_result = loaded_pyfunc_model.predict_stream("single-input")
1546      assert isinstance(stream_result, types.GeneratorType)
1547  
1548      assert list(stream_result) == ["test1", "test2"]
1549  
1550  
1551  def test_model_save_load_with_resources(tmp_path):
1552      pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model")
1553      pyfunc_model_path_2 = os.path.join(tmp_path, "pyfunc_model_2")
1554  
1555      expected_resources = {
1556          "api_version": "1",
1557          "databricks": {
1558              "serving_endpoint": [
1559                  {"name": "databricks-mixtral-8x7b-instruct"},
1560                  {"name": "databricks-bge-large-en"},
1561                  {"name": "azure-eastus-model-serving-2_vs_endpoint"},
1562              ],
1563              "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}],
1564              "sql_warehouse": [{"name": "testid"}],
1565              "function": [
1566                  {"name": "rag.studio.test_function_a"},
1567                  {"name": "rag.studio.test_function_b"},
1568              ],
1569              "genie_space": [{"name": "genie_space_id_1"}, {"name": "genie_space_id_2"}],
1570              "uc_connection": [{"name": "test_connection_1"}, {"name": "test_connection_2"}],
1571              "table": [{"name": "rag.studio.table_a"}, {"name": "rag.studio.table_b"}],
1572              "app": [{"name": "test_databricks_app"}],
1573              "lakebase": [{"name": "test_databricks_lakebase"}],
1574          },
1575      }
1576      mlflow.pyfunc.save_model(
1577          path=pyfunc_model_path,
1578          conda_env=_conda_env(),
1579          python_model=mlflow.pyfunc.model.PythonModel(),
1580          resources=[
1581              DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"),
1582              DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"),
1583              DatabricksServingEndpoint(endpoint_name="azure-eastus-model-serving-2_vs_endpoint"),
1584              DatabricksVectorSearchIndex(index_name="rag.studio_bugbash.databricks_docs_index"),
1585              DatabricksSQLWarehouse(warehouse_id="testid"),
1586              DatabricksFunction(function_name="rag.studio.test_function_a"),
1587              DatabricksFunction(function_name="rag.studio.test_function_b"),
1588              DatabricksGenieSpace(genie_space_id="genie_space_id_1"),
1589              DatabricksGenieSpace(genie_space_id="genie_space_id_2"),
1590              DatabricksUCConnection(connection_name="test_connection_1"),
1591              DatabricksUCConnection(connection_name="test_connection_2"),
1592              DatabricksTable(table_name="rag.studio.table_a"),
1593              DatabricksTable(table_name="rag.studio.table_b"),
1594              DatabricksApp(app_name="test_databricks_app"),
1595              DatabricksLakebase(database_instance_name="test_databricks_lakebase"),
1596          ],
1597      )
1598  
1599      reloaded_model = Model.load(pyfunc_model_path)
1600      assert reloaded_model.resources == expected_resources
1601  
1602      yaml_file = tmp_path.joinpath("resources.yaml")
1603      with open(yaml_file, "w") as f:
1604          f.write(
1605              """
1606              api_version: "1"
1607              databricks:
1608                  vector_search_index:
1609                  - name: rag.studio_bugbash.databricks_docs_index
1610                  serving_endpoint:
1611                  - name: databricks-mixtral-8x7b-instruct
1612                  - name: databricks-bge-large-en
1613                  - name: azure-eastus-model-serving-2_vs_endpoint
1614                  sql_warehouse:
1615                  - name: testid
1616                  function:
1617                  - name: rag.studio.test_function_a
1618                  - name: rag.studio.test_function_b
1619                  lakebase:
1620                  - name: test_databricks_lakebase
1621                  genie_space:
1622                  - name: genie_space_id_1
1623                  - name: genie_space_id_2
1624                  uc_connection:
1625                  - name: test_connection_1
1626                  - name: test_connection_2
1627                  table:
1628                  - name: rag.studio.table_a
1629                  - name: rag.studio.table_b
1630                  app:
1631                  - name: test_databricks_app
1632              """
1633          )
1634  
1635      mlflow.pyfunc.save_model(
1636          path=pyfunc_model_path_2,
1637          conda_env=_conda_env(),
1638          python_model=mlflow.pyfunc.model.PythonModel(),
1639          resources=yaml_file,
1640      )
1641      reloaded_model = Model.load(pyfunc_model_path_2)
1642      assert reloaded_model.resources == expected_resources
1643  
1644  
1645  def test_model_save_load_with_invokers_resources(tmp_path):
1646      pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model")
1647      pyfunc_model_path_2 = os.path.join(tmp_path, "pyfunc_model_2")
1648  
1649      expected_resources = {
1650          "api_version": "1",
1651          "databricks": {
1652              "serving_endpoint": [
1653                  {"name": "databricks-mixtral-8x7b-instruct", "on_behalf_of_user": True},
1654                  {"name": "databricks-bge-large-en"},
1655                  {"name": "azure-eastus-model-serving-2_vs_endpoint"},
1656              ],
1657              "vector_search_index": [
1658                  {"name": "rag.studio_bugbash.databricks_docs_index", "on_behalf_of_user": True}
1659              ],
1660              "sql_warehouse": [{"name": "testid"}],
1661              "function": [
1662                  {"name": "rag.studio.test_function_a", "on_behalf_of_user": True},
1663                  {"name": "rag.studio.test_function_b"},
1664              ],
1665              "genie_space": [
1666                  {"name": "genie_space_id_1", "on_behalf_of_user": True},
1667                  {"name": "genie_space_id_2"},
1668              ],
1669              "uc_connection": [{"name": "test_connection_1"}, {"name": "test_connection_2"}],
1670              "table": [
1671                  {"name": "rag.studio.table_a", "on_behalf_of_user": True},
1672                  {"name": "rag.studio.table_b"},
1673              ],
1674              "app": [{"name": "test_databricks_app"}],
1675              "lakebase": [{"name": "test_databricks_lakebase"}],
1676          },
1677      }
1678      mlflow.pyfunc.save_model(
1679          path=pyfunc_model_path,
1680          conda_env=_conda_env(),
1681          python_model=mlflow.pyfunc.model.PythonModel(),
1682          resources=[
1683              DatabricksServingEndpoint(
1684                  endpoint_name="databricks-mixtral-8x7b-instruct", on_behalf_of_user=True
1685              ),
1686              DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"),
1687              DatabricksServingEndpoint(endpoint_name="azure-eastus-model-serving-2_vs_endpoint"),
1688              DatabricksVectorSearchIndex(
1689                  index_name="rag.studio_bugbash.databricks_docs_index", on_behalf_of_user=True
1690              ),
1691              DatabricksSQLWarehouse(warehouse_id="testid"),
1692              DatabricksFunction(function_name="rag.studio.test_function_a", on_behalf_of_user=True),
1693              DatabricksFunction(function_name="rag.studio.test_function_b"),
1694              DatabricksGenieSpace(genie_space_id="genie_space_id_1", on_behalf_of_user=True),
1695              DatabricksGenieSpace(genie_space_id="genie_space_id_2"),
1696              DatabricksUCConnection(connection_name="test_connection_1"),
1697              DatabricksUCConnection(connection_name="test_connection_2"),
1698              DatabricksTable(table_name="rag.studio.table_a", on_behalf_of_user=True),
1699              DatabricksTable(table_name="rag.studio.table_b"),
1700              DatabricksApp(app_name="test_databricks_app"),
1701              DatabricksLakebase(database_instance_name="test_databricks_lakebase"),
1702          ],
1703      )
1704  
1705      reloaded_model = Model.load(pyfunc_model_path)
1706      assert reloaded_model.resources == expected_resources
1707  
1708      yaml_file = tmp_path.joinpath("resources.yaml")
1709      with open(yaml_file, "w") as f:
1710          f.write(
1711              """
1712              api_version: "1"
1713              databricks:
1714                  vector_search_index:
1715                  - name: rag.studio_bugbash.databricks_docs_index
1716                    on_behalf_of_user: True
1717                  serving_endpoint:
1718                  - name: databricks-mixtral-8x7b-instruct
1719                    on_behalf_of_user: True
1720                  - name: databricks-bge-large-en
1721                  - name: azure-eastus-model-serving-2_vs_endpoint
1722                  sql_warehouse:
1723                  - name: testid
1724                  function:
1725                  - name: rag.studio.test_function_a
1726                    on_behalf_of_user: True
1727                  - name: rag.studio.test_function_b
1728                  lakebase:
1729                  - name: test_databricks_lakebase
1730                  genie_space:
1731                  - name: genie_space_id_1
1732                    on_behalf_of_user: True
1733                  - name: genie_space_id_2
1734                  uc_connection:
1735                  - name: test_connection_1
1736                  - name: test_connection_2
1737                  table:
1738                  - name: rag.studio.table_a
1739                    on_behalf_of_user: True
1740                  - name: rag.studio.table_b
1741                  app:
1742                  - name: test_databricks_app
1743              """
1744          )
1745  
1746      mlflow.pyfunc.save_model(
1747          path=pyfunc_model_path_2,
1748          conda_env=_conda_env(),
1749          python_model=mlflow.pyfunc.model.PythonModel(),
1750          resources=yaml_file,
1751      )
1752  
1753      reloaded_model = Model.load(pyfunc_model_path_2)
1754      assert reloaded_model.resources == expected_resources
1755  
1756  
1757  def test_model_log_with_invokers_resources(tmp_path):
1758      pyfunc_artifact_path = "pyfunc_model"
1759  
1760      expected_resources = {
1761          "api_version": "1",
1762          "databricks": {
1763              "serving_endpoint": [
1764                  {"name": "databricks-mixtral-8x7b-instruct"},
1765                  {"name": "databricks-bge-large-en", "on_behalf_of_user": True},
1766                  {"name": "azure-eastus-model-serving-2_vs_endpoint"},
1767              ],
1768              "vector_search_index": [
1769                  {"name": "rag.studio_bugbash.databricks_docs_index", "on_behalf_of_user": True}
1770              ],
1771              "sql_warehouse": [{"name": "testid", "on_behalf_of_user": True}],
1772              "function": [
1773                  {"name": "rag.studio.test_function_a"},
1774                  {"name": "rag.studio.test_function_b", "on_behalf_of_user": True},
1775              ],
1776              "genie_space": [
1777                  {"name": "genie_space_id_1"},
1778                  {"name": "genie_space_id_2", "on_behalf_of_user": True},
1779              ],
1780              "uc_connection": [
1781                  {"name": "test_connection_1"},
1782                  {"name": "test_connection_2", "on_behalf_of_user": True},
1783              ],
1784              "table": [
1785                  {"name": "rag.studio.table_a"},
1786                  {"name": "rag.studio.table_b", "on_behalf_of_user": True},
1787              ],
1788          },
1789      }
1790      with mlflow.start_run() as run:
1791          mlflow.pyfunc.log_model(
1792              name=pyfunc_artifact_path,
1793              python_model=mlflow.pyfunc.model.PythonModel(),
1794              resources=[
1795                  DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"),
1796                  DatabricksServingEndpoint(
1797                      endpoint_name="databricks-bge-large-en", on_behalf_of_user=True
1798                  ),
1799                  DatabricksServingEndpoint(endpoint_name="azure-eastus-model-serving-2_vs_endpoint"),
1800                  DatabricksVectorSearchIndex(
1801                      index_name="rag.studio_bugbash.databricks_docs_index", on_behalf_of_user=True
1802                  ),
1803                  DatabricksSQLWarehouse(warehouse_id="testid", on_behalf_of_user=True),
1804                  DatabricksFunction(function_name="rag.studio.test_function_a"),
1805                  DatabricksFunction(
1806                      function_name="rag.studio.test_function_b", on_behalf_of_user=True
1807                  ),
1808                  DatabricksGenieSpace(genie_space_id="genie_space_id_1"),
1809                  DatabricksGenieSpace(genie_space_id="genie_space_id_2", on_behalf_of_user=True),
1810                  DatabricksUCConnection(connection_name="test_connection_1"),
1811                  DatabricksUCConnection(connection_name="test_connection_2", on_behalf_of_user=True),
1812                  DatabricksTable(table_name="rag.studio.table_a"),
1813                  DatabricksTable(table_name="rag.studio.table_b", on_behalf_of_user=True),
1814              ],
1815          )
1816      pyfunc_model_uri = f"runs:/{run.info.run_id}/{pyfunc_artifact_path}"
1817      pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_uri)
1818      reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
1819      assert reloaded_model.resources == expected_resources
1820  
1821      yaml_file = tmp_path.joinpath("resources.yaml")
1822      with open(yaml_file, "w") as f:
1823          f.write(
1824              """
1825              api_version: "1"
1826              databricks:
1827                  vector_search_index:
1828                  - name: rag.studio_bugbash.databricks_docs_index
1829                    on_behalf_of_user: True
1830                  serving_endpoint:
1831                  - name: databricks-mixtral-8x7b-instruct
1832                  - name: databricks-bge-large-en
1833                    on_behalf_of_user: True
1834                  - name: azure-eastus-model-serving-2_vs_endpoint
1835                  sql_warehouse:
1836                  - name: testid
1837                    on_behalf_of_user: True
1838                  function:
1839                  - name: rag.studio.test_function_a
1840                  - name: rag.studio.test_function_b
1841                    on_behalf_of_user: True
1842                  genie_space:
1843                  - name: genie_space_id_1
1844                  - name: genie_space_id_2
1845                    on_behalf_of_user: True
1846                  uc_connection:
1847                  - name: test_connection_1
1848                  - name: test_connection_2
1849                    on_behalf_of_user: True
1850                  table:
1851                  - name: "rag.studio.table_a"
1852                  - name: "rag.studio.table_b"
1853                    on_behalf_of_user: True
1854              """
1855          )
1856  
1857      with mlflow.start_run() as run:
1858          mlflow.pyfunc.log_model(
1859              name=pyfunc_artifact_path,
1860              python_model=mlflow.pyfunc.model.PythonModel(),
1861              resources=yaml_file,
1862          )
1863      pyfunc_model_uri = f"runs:/{run.info.run_id}/{pyfunc_artifact_path}"
1864      pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_uri)
1865      reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
1866      assert reloaded_model.resources == expected_resources
1867  
1868  
1869  def test_model_log_with_resources(tmp_path):
1870      pyfunc_artifact_path = "pyfunc_model"
1871  
1872      expected_resources = {
1873          "api_version": "1",
1874          "databricks": {
1875              "serving_endpoint": [
1876                  {"name": "databricks-mixtral-8x7b-instruct"},
1877                  {"name": "databricks-bge-large-en"},
1878                  {"name": "azure-eastus-model-serving-2_vs_endpoint"},
1879              ],
1880              "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}],
1881              "sql_warehouse": [{"name": "testid"}],
1882              "function": [
1883                  {"name": "rag.studio.test_function_a"},
1884                  {"name": "rag.studio.test_function_b"},
1885              ],
1886              "genie_space": [
1887                  {"name": "genie_space_id_1"},
1888                  {"name": "genie_space_id_2"},
1889              ],
1890              "uc_connection": [{"name": "test_connection_1"}, {"name": "test_connection_2"}],
1891              "table": [{"name": "rag.studio.table_a"}, {"name": "rag.studio.table_b"}],
1892              "app": [{"name": "test_databricks_app"}],
1893              "lakebase": [{"name": "test_databricks_lakebase"}],
1894          },
1895      }
1896      with mlflow.start_run() as run:
1897          mlflow.pyfunc.log_model(
1898              name=pyfunc_artifact_path,
1899              python_model=mlflow.pyfunc.model.PythonModel(),
1900              resources=[
1901                  DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"),
1902                  DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"),
1903                  DatabricksServingEndpoint(endpoint_name="azure-eastus-model-serving-2_vs_endpoint"),
1904                  DatabricksVectorSearchIndex(index_name="rag.studio_bugbash.databricks_docs_index"),
1905                  DatabricksSQLWarehouse(warehouse_id="testid"),
1906                  DatabricksFunction(function_name="rag.studio.test_function_a"),
1907                  DatabricksFunction(function_name="rag.studio.test_function_b"),
1908                  DatabricksGenieSpace(genie_space_id="genie_space_id_1"),
1909                  DatabricksGenieSpace(genie_space_id="genie_space_id_2"),
1910                  DatabricksUCConnection(connection_name="test_connection_1"),
1911                  DatabricksUCConnection(connection_name="test_connection_2"),
1912                  DatabricksTable(table_name="rag.studio.table_a"),
1913                  DatabricksTable(table_name="rag.studio.table_b"),
1914                  DatabricksApp(app_name="test_databricks_app"),
1915                  DatabricksLakebase(database_instance_name="test_databricks_lakebase"),
1916              ],
1917          )
1918      pyfunc_model_uri = f"runs:/{run.info.run_id}/{pyfunc_artifact_path}"
1919      pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_uri)
1920      reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
1921      assert reloaded_model.resources == expected_resources
1922  
1923      yaml_file = tmp_path.joinpath("resources.yaml")
1924      with open(yaml_file, "w") as f:
1925          f.write(
1926              """
1927              api_version: "1"
1928              databricks:
1929                  vector_search_index:
1930                  - name: rag.studio_bugbash.databricks_docs_index
1931                  serving_endpoint:
1932                  - name: databricks-mixtral-8x7b-instruct
1933                  - name: databricks-bge-large-en
1934                  - name: azure-eastus-model-serving-2_vs_endpoint
1935                  sql_warehouse:
1936                  - name: testid
1937                  function:
1938                  - name: rag.studio.test_function_a
1939                  - name: rag.studio.test_function_b
1940                  lakebase:
1941                  - name: test_databricks_lakebase
1942                  genie_space:
1943                  - name: genie_space_id_1
1944                  - name: genie_space_id_2
1945                  uc_connection:
1946                  - name: test_connection_1
1947                  - name: test_connection_2
1948                  table:
1949                  - name: "rag.studio.table_a"
1950                  - name: "rag.studio.table_b"
1951                  app:
1952                  - name: test_databricks_app
1953              """
1954          )
1955  
1956      with mlflow.start_run() as run:
1957          mlflow.pyfunc.log_model(
1958              name=pyfunc_artifact_path,
1959              python_model=mlflow.pyfunc.model.PythonModel(),
1960              resources=yaml_file,
1961          )
1962      pyfunc_model_uri = f"runs:/{run.info.run_id}/{pyfunc_artifact_path}"
1963      pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_uri)
1964      reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
1965      assert reloaded_model.resources == expected_resources
1966  
1967  
1968  def test_pyfunc_as_code_log_and_load():
1969      with mlflow.start_run():
1970          model_info = mlflow.pyfunc.log_model(
1971              name="model",
1972              python_model="tests/pyfunc/sample_code/python_model.py",
1973          )
1974  
1975      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1976      model_input = "asdf"
1977      expected_output = f"This was the input: {model_input}"
1978      assert loaded_model.predict(model_input) == expected_output
1979  
1980  
1981  def test_pyfunc_as_code_log_and_load_with_path():
1982      with mlflow.start_run():
1983          model_info = mlflow.pyfunc.log_model(
1984              name="model",
1985              python_model=Path("tests/pyfunc/sample_code/python_model.py"),
1986          )
1987  
1988      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1989      model_input = "asdf"
1990      expected_output = f"This was the input: {model_input}"
1991      assert loaded_model.predict(model_input) == expected_output
1992  
1993  
1994  def test_pyfunc_as_code_with_config(tmp_path):
1995      temp_file = tmp_path / "config.yml"
1996      temp_file.write_text("timeout: 400")
1997  
1998      with mlflow.start_run():
1999          model_info = mlflow.pyfunc.log_model(
2000              name="model",
2001              python_model="tests/pyfunc/sample_code/python_model_with_config.py",
2002              model_config=str(temp_file),
2003          )
2004  
2005      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
2006      model_input = "input"
2007      expected_output = f"Predict called with input {model_input}, timeout 400"
2008      assert loaded_model.predict(model_input) == expected_output
2009  
2010  
2011  def test_pyfunc_as_code_with_path_config(tmp_path):
2012      temp_file = tmp_path / "config.yml"
2013      temp_file.write_text("timeout: 400")
2014  
2015      with mlflow.start_run():
2016          model_info = mlflow.pyfunc.log_model(
2017              name="model",
2018              python_model="tests/pyfunc/sample_code/python_model_with_config.py",
2019              model_config=temp_file,
2020          )
2021  
2022      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
2023      model_input = "input"
2024      expected_output = f"Predict called with input {model_input}, timeout 400"
2025      assert loaded_model.predict(model_input) == expected_output
2026  
2027  
2028  def test_pyfunc_as_code_with_dict_config():
2029      with mlflow.start_run():
2030          model_info = mlflow.pyfunc.log_model(
2031              name="model",
2032              python_model="tests/pyfunc/sample_code/python_model_with_config.py",
2033              model_config={"timeout": 400},
2034          )
2035  
2036      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
2037      model_input = "input"
2038      expected_output = f"Predict called with input {model_input}, timeout 400"
2039      assert loaded_model.predict(model_input) == expected_output
2040  
2041  
2042  def test_pyfunc_as_code_log_and_load_with_code_paths():
2043      with mlflow.start_run():
2044          model_info = mlflow.pyfunc.log_model(
2045              name="model",
2046              python_model="tests/pyfunc/sample_code/python_model_with_utils.py",
2047              code_paths=["tests/pyfunc/sample_code/utils.py"],
2048          )
2049  
2050      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
2051      model_input = "asdf"
2052      expected_output = f"My utils function received this input: {model_input}"
2053      assert loaded_model.predict(model_input) == expected_output
2054  
2055  
2056  def test_pyfunc_as_code_with_dependencies():
2057      with mlflow.start_run():
2058          model_info = mlflow.pyfunc.log_model(
2059              name="model",
2060              python_model="tests/pyfunc/sample_code/code_with_dependencies.py",
2061              pip_requirements=["pandas"],
2062          )
2063  
2064      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
2065      model_input = "user_123"
2066      expected_output = f"Input: {model_input}. Retriever called with ID: {model_input}. Output: 42."
2067      assert loaded_model.predict(model_input) == expected_output
2068  
2069      pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri)
2070      reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
2071      assert reloaded_model.metadata["dependencies_schemas"] == {
2072          "retrievers": [
2073              {
2074                  "doc_uri": "doc-uri",
2075                  "name": "retriever",
2076                  "other_columns": ["column1", "column2"],
2077                  "primary_key": "primary-key",
2078                  "text_column": "text-column",
2079              }
2080          ]
2081      }
2082  
2083  
2084  @pytest.mark.parametrize("is_in_db_model_serving", ["true", "false"])
2085  @pytest.mark.parametrize("stream", [True, False])
2086  def test_pyfunc_as_code_with_dependencies_store_dependencies_schemas_in_trace(
2087      monkeypatch, is_in_db_model_serving, stream
2088  ):
2089      monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", is_in_db_model_serving)
2090      monkeypatch.setenv("ENABLE_MLFLOW_TRACING", "true")
2091      is_in_db_model_serving = is_in_db_model_serving == "true"
2092      with mlflow.start_run():
2093          model_info = mlflow.pyfunc.log_model(
2094              name="model",
2095              python_model="tests/pyfunc/sample_code/code_with_dependencies.py",
2096              pip_requirements=["pandas"],
2097          )
2098  
2099      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
2100      model_input = "user_123"
2101      expected_output = f"Input: {model_input}. Retriever called with ID: {model_input}. Output: 42."
2102      func = loaded_model.predict_stream if stream else loaded_model.predict
2103  
2104      def _get_result(output):
2105          return list(output)[0] if stream else output
2106  
2107      if is_in_db_model_serving:
2108          with set_prediction_context(Context(request_id="1234")):
2109              assert _get_result(func(model_input)) == expected_output
2110      else:
2111          assert _get_result(func(model_input)) == expected_output
2112  
2113      pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri)
2114      reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
2115      expected_dependencies_schemas = {
2116          DependenciesSchemasType.RETRIEVERS.value: [
2117              {
2118                  "doc_uri": "doc-uri",
2119                  "name": "retriever",
2120                  "other_columns": ["column1", "column2"],
2121                  "primary_key": "primary-key",
2122                  "text_column": "text-column",
2123              }
2124          ]
2125      }
2126      assert reloaded_model.metadata["dependencies_schemas"] == expected_dependencies_schemas
2127  
2128      if is_in_db_model_serving:
2129          trace_dict = pop_trace("1234")
2130          trace = Trace.from_dict(trace_dict)
2131          assert trace.info.trace_id.startswith("tr-")
2132          assert trace.info.client_request_id == "1234"
2133      else:
2134          trace = get_traces()[0]
2135      assert trace.info.tags[DependenciesSchemasType.RETRIEVERS.value] == json.dumps(
2136          expected_dependencies_schemas[DependenciesSchemasType.RETRIEVERS.value]
2137      )
2138  
2139  
2140  @pytest.mark.parametrize("stream", [True, False])
2141  def test_no_traces_collected_for_pyfunc_as_code_with_dependencies_if_no_tracing_enabled(
2142      monkeypatch, stream
2143  ):
2144      # This sets model without trace inside code_with_dependencies.py file
2145      monkeypatch.setenv("TEST_TRACE", "false")
2146      with mlflow.start_run():
2147          model_info = mlflow.pyfunc.log_model(
2148              name="model",
2149              python_model="tests/pyfunc/sample_code/code_with_dependencies.py",
2150              pip_requirements=["pandas"],
2151          )
2152  
2153      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
2154      model_input = "user_123"
2155      expected_output = f"Input: {model_input}. Retriever called with ID: {model_input}. Output: 42."
2156      if stream:
2157          assert next(loaded_model.predict_stream(model_input)) == expected_output
2158      else:
2159          assert loaded_model.predict(model_input) == expected_output
2160  
2161      pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri)
2162      reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
2163      expected_dependencies_schemas = {
2164          DependenciesSchemasType.RETRIEVERS.value: [
2165              {
2166                  "doc_uri": "doc-uri",
2167                  "name": "retriever",
2168                  "other_columns": ["column1", "column2"],
2169                  "primary_key": "primary-key",
2170                  "text_column": "text-column",
2171              }
2172          ]
2173      }
2174      assert reloaded_model.metadata["dependencies_schemas"] == expected_dependencies_schemas
2175  
2176      # no traces will be logged at all
2177      traces = get_traces()
2178      assert len(traces) == 0
2179  
2180  
2181  def test_pyfunc_as_code_log_and_load_wrong_path():
2182      with pytest.raises(
2183          MlflowException,
2184          match="The provided model path",
2185      ):
2186          with mlflow.start_run():
2187              mlflow.pyfunc.log_model(
2188                  name="model",
2189                  python_model="asdf",
2190              )
2191  
2192  
2193  def test_predict_as_code():
2194      with mlflow.start_run():
2195          model_info = mlflow.pyfunc.log_model(
2196              name="model",
2197              python_model="tests/pyfunc/sample_code/func_code.py",
2198              input_example=["string"],
2199          )
2200  
2201      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
2202      model_input = "asdf"
2203      expected_output = pd.DataFrame([model_input])
2204      pandas.testing.assert_frame_equal(loaded_model.predict([model_input]), expected_output)
2205  
2206  
2207  def test_predict_as_code_with_type_hint():
2208      with mlflow.start_run():
2209          model_info = mlflow.pyfunc.log_model(
2210              name="model",
2211              python_model="tests/pyfunc/sample_code/func_code_with_type_hint.py",
2212              input_example=["string"],
2213          )
2214  
2215      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
2216      model_input = "asdf"
2217      expected_output = [model_input]
2218      assert loaded_model.predict([model_input]) == expected_output
2219  
2220  
2221  def test_predict_as_code_with_config():
2222      with mlflow.start_run():
2223          model_info = mlflow.pyfunc.log_model(
2224              name="model",
2225              python_model="tests/pyfunc/sample_code/func_code_with_config.py",
2226              input_example=["string"],
2227              model_config="tests/pyfunc/sample_code/config.yml",
2228          )
2229  
2230      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
2231      model_input = "asdf"
2232      expected_output = f"This was the input: {model_input}, timeout 300"
2233      assert loaded_model.predict([model_input]) == expected_output
2234  
2235  
2236  def test_model_as_code_pycache_cleaned_up():
2237      with mlflow.start_run():
2238          model_info = mlflow.pyfunc.log_model(
2239              name="model",
2240              python_model="tests/pyfunc/sample_code/python_model.py",
2241          )
2242  
2243      path = _download_artifact_from_uri(model_info.model_uri)
2244      assert list(Path(path).rglob("__pycache__")) == []
2245  
2246  
2247  def test_model_pip_requirements_pin_numpy_when_pandas_included():
2248      class TestModel(mlflow.pyfunc.PythonModel):
2249          def predict(self, context, model_input, params=None):
2250              import pandas as pd  # noqa: F401
2251  
2252              return model_input
2253  
2254      expected_mlflow_version = _mlflow_major_version_string()
2255  
2256      # no numpy when pandas > 2.1.2
2257      with mlflow.start_run():
2258          model_info = mlflow.pyfunc.log_model(
2259              name="model", python_model=TestModel(), input_example="abc"
2260          )
2261  
2262          _assert_pip_requirements(
2263              model_info.model_uri,
2264              [
2265                  expected_mlflow_version,
2266                  f"cloudpickle=={importlib.metadata.version('cloudpickle')}",
2267                  f"pandas=={importlib.metadata.version('pandas')}",
2268              ],
2269              strict=True,
2270          )
2271  
2272      original_get_installed_version = _get_installed_version
2273  
2274      def mock_get_installed_version(package, module=None):
2275          if package == "pandas":
2276              return "2.1.0"
2277          return original_get_installed_version(package, module)
2278  
2279      # include numpy when pandas < 2.1.2
2280      with (
2281          mlflow.start_run(),
2282          mock.patch(
2283              "mlflow.utils.requirements_utils._get_installed_version",
2284              side_effect=mock_get_installed_version,
2285          ),
2286      ):
2287          model_info = mlflow.pyfunc.log_model(
2288              name="model", python_model=TestModel(), input_example="abc"
2289          )
2290          _assert_pip_requirements(
2291              model_info.model_uri,
2292              [
2293                  expected_mlflow_version,
2294                  "pandas==2.1.0",
2295                  f"numpy=={np.__version__}",
2296                  f"cloudpickle=={cloudpickle.__version__}",
2297              ],
2298              strict=True,
2299          )
2300  
2301      # no input_example, so pandas not included in requirements
2302      with mlflow.start_run():
2303          model_info = mlflow.pyfunc.log_model(name="model", python_model=TestModel())
2304          _assert_pip_requirements(
2305              model_info.model_uri,
2306              [expected_mlflow_version, f"cloudpickle=={cloudpickle.__version__}"],
2307              strict=True,
2308          )
2309  
2310  
2311  def test_environment_variables_used_during_model_logging(monkeypatch):
2312      class MyModel(mlflow.pyfunc.PythonModel):
2313          def predict(self, context, model_input, params=None):
2314              monkeypatch.setenv("TEST_API_KEY", "test_env")
2315              monkeypatch.setenv("ANOTHER_API_KEY", "test_env")
2316              monkeypatch.setenv("INVALID_ENV_VAR", "var")
2317              # existing env var is tracked
2318              os.environ["TEST_API_KEY"]
2319              # existing env var fetched by getenv is tracked
2320              os.environ.get("ANOTHER_API_KEY")
2321              # existing env var not in allowlist is not tracked
2322              os.environ.get("INVALID_ENV_VAR")
2323              # non-existing env var is not tracked
2324              os.environ.get("INVALID_API_KEY")
2325              return model_input
2326  
2327      with mlflow.start_run():
2328          model_info = mlflow.pyfunc.log_model(
2329              name="model", python_model=MyModel(), input_example="data"
2330          )
2331      assert "TEST_API_KEY" in model_info.env_vars
2332      assert "ANOTHER_API_KEY" in model_info.env_vars
2333      assert "INVALID_ENV_VAR" not in model_info.env_vars
2334      assert "INVALID_API_KEY" not in model_info.env_vars
2335      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
2336      assert pyfunc_model.metadata.env_vars == model_info.env_vars
2337  
2338      # if no input_example provided, we do not run predict, and no env vars are captured
2339      with mlflow.start_run():
2340          model_info = mlflow.pyfunc.log_model(name="model", python_model=MyModel())
2341      assert model_info.env_vars is None
2342  
2343      # disable logging by setting environment variable
2344      monkeypatch.setenv(MLFLOW_RECORD_ENV_VARS_IN_MODEL_LOGGING.name, "false")
2345      with mlflow.start_run():
2346          model_info = mlflow.pyfunc.log_model(
2347              name="model", python_model=MyModel(), input_example="data"
2348          )
2349      assert model_info.env_vars is None
2350  
2351  
2352  def test_pyfunc_model_without_context_in_predict():
2353      class Model(mlflow.pyfunc.PythonModel):
2354          def predict(self, model_input, params=None):
2355              return model_input
2356  
2357          def predict_stream(self, model_input, params=None):
2358              yield model_input
2359  
2360      m = Model()
2361      assert m.predict("abc") == "abc"
2362      assert next(iter(m.predict_stream("abc"))) == "abc"
2363  
2364      with mlflow.start_run():
2365          model_info = mlflow.pyfunc.log_model(name="model", python_model=m, input_example="abc")
2366      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
2367      assert pyfunc_model.predict("abc") is not None
2368      assert next(iter(pyfunc_model.predict_stream("abc"))) is not None
2369  
2370  
2371  def test_callable_python_model_without_context_in_predict():
2372      def predict(model_input):
2373          return model_input
2374  
2375      assert predict("abc") == "abc"
2376      with mlflow.start_run():
2377          model_info = mlflow.pyfunc.log_model(
2378              name="model", python_model=predict, input_example="abc"
2379          )
2380      pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
2381      assert pyfunc_model.predict("abc") is not None
2382  
2383  
2384  def test_pyfunc_model_with_wrong_predict_signature_warning():
2385      with pytest.warns(
2386          FutureWarning,
2387          match=r"Model's `predict` method contains invalid parameters: {'messages'}",
2388      ):
2389  
2390          class Model(mlflow.pyfunc.PythonModel):
2391              def predict(self, context, messages, params=None):
2392                  return messages
2393  
2394      with pytest.warns(
2395          FutureWarning,
2396          match=r"Model's `predict_stream` method contains invalid parameters: {'_'}",
2397      ):
2398  
2399          class Model(mlflow.pyfunc.PythonModel):
2400              def predict(self, model_input, params=None):
2401                  return model_input
2402  
2403              def predict_stream(self, _, model_input, params=None):
2404                  yield model_input
2405  
2406  
2407  def test_pyfunc_model_input_example_with_signature():
2408      class Model(mlflow.pyfunc.PythonModel):
2409          def predict(self, context, model_input, params=None):
2410              return model_input
2411  
2412      signature = infer_signature(["a", "b", "c"])
2413      with mlflow.start_run():
2414          with pytest.warns(
2415              UserWarning, match=r"An input example was not provided when logging the model"
2416          ):
2417              mlflow.pyfunc.log_model(name="model", python_model=Model(), signature=signature)
2418  
2419      with mlflow.start_run():
2420          with pytest.raises(
2421              MlflowException, match=r"Input example does not match the model signature"
2422          ):
2423              mlflow.pyfunc.log_model(
2424                  name="model", python_model=Model(), signature=signature, input_example=123
2425              )
2426  
2427  
2428  @pytest.mark.parametrize("save_model", [True, False])
2429  @pytest.mark.parametrize("use_user_auth_policy", [True, False])
2430  @pytest.mark.parametrize("use_system_policy", [True, False])
2431  def test_model_log_with_auth_policy(tmp_path, save_model, use_user_auth_policy, use_system_policy):
2432      pyfunc_save_artifact_path = os.path.join(tmp_path, "pyfunc_model_save")
2433      pyfunc_log_artifact_path = "pyfunc_model_log"
2434  
2435      expected_auth_policy = {"system_auth_policy": {}, "user_auth_policy": {}}
2436  
2437      system_auth_policy = None
2438      if use_system_policy:
2439          system_auth_policy = SystemAuthPolicy(
2440              resources=[
2441                  DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"),
2442                  DatabricksVectorSearchIndex(index_name="rag.studio_bugbash.databricks_docs_index"),
2443                  DatabricksFunction(function_name="rag.studio.test_function_a"),
2444                  DatabricksUCConnection(connection_name="test_connection_1"),
2445              ]
2446          )
2447          expected_auth_policy["system_auth_policy"] = {
2448              "resources": {
2449                  "api_version": "1",
2450                  "databricks": {
2451                      "function": [{"name": "rag.studio.test_function_a"}],
2452                      "serving_endpoint": [{"name": "databricks-mixtral-8x7b-instruct"}],
2453                      "uc_connection": [{"name": "test_connection_1"}],
2454                      "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}],
2455                  },
2456              }
2457          }
2458  
2459      user_auth_policy = None
2460      if use_user_auth_policy:
2461          user_auth_policy = UserAuthPolicy(
2462              api_scopes=[
2463                  "catalog.catalogs",
2464                  "vectorsearch.vector-search-indexes",
2465                  "workspace.workspace",
2466              ]
2467          )
2468          expected_auth_policy["user_auth_policy"] = {
2469              "api_scopes": [
2470                  "catalog.catalogs",
2471                  "vectorsearch.vector-search-indexes",
2472                  "workspace.workspace",
2473              ]
2474          }
2475  
2476      auth_policy = AuthPolicy(
2477          user_auth_policy=user_auth_policy, system_auth_policy=system_auth_policy
2478      )
2479  
2480      if save_model:
2481          mlflow.pyfunc.save_model(
2482              path=pyfunc_save_artifact_path,
2483              conda_env=_conda_env(),
2484              python_model=mlflow.pyfunc.model.PythonModel(),
2485              auth_policy=auth_policy,
2486          )
2487          reloaded_model = Model.load(pyfunc_save_artifact_path)
2488      else:
2489          with mlflow.start_run() as run:
2490              mlflow.pyfunc.log_model(
2491                  name=pyfunc_log_artifact_path,
2492                  python_model=mlflow.pyfunc.model.PythonModel(),
2493                  auth_policy=auth_policy,
2494              )
2495  
2496          pyfunc_model_uri = f"runs:/{run.info.run_id}/{pyfunc_log_artifact_path}"
2497          pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_uri)
2498          reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
2499  
2500      assert reloaded_model.auth_policy == expected_auth_policy
2501  
2502  
2503  def test_both_resources_and_auth_policy():
2504      pyfunc_log_artifact_path = "pyfunc_model_log"
2505      system_auth_policy = SystemAuthPolicy(
2506          resources=[DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct")]
2507      )
2508      user_auth_policy = UserAuthPolicy(api_scopes=["workspace.workspace"])
2509      auth_policy = AuthPolicy(
2510          user_auth_policy=user_auth_policy, system_auth_policy=system_auth_policy
2511      )
2512  
2513      with mlflow.start_run() as _:
2514          with pytest.raises(
2515              ValueError, match="Only one of `resources`, and `auth_policy` can be specified."
2516          ):
2517              mlflow.pyfunc.log_model(
2518                  name=pyfunc_log_artifact_path,
2519                  python_model=mlflow.pyfunc.model.PythonModel(),
2520                  auth_policy=auth_policy,
2521                  resources=[
2522                      DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct")
2523                  ],
2524              )
2525  
2526  
2527  @pytest.mark.parametrize("compression", ["lzma", "bzip2", "gzip"])
2528  def test_model_save_load_compression(
2529      monkeypatch, sklearn_knn_model, main_scoped_model_class, iris_data, tmp_path, compression
2530  ):
2531      monkeypatch.setenv(MLFLOW_LOG_MODEL_COMPRESSION.name, compression)
2532      sklearn_model_path = os.path.join(tmp_path, "sklearn_model")
2533      mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path)
2534  
2535      def test_predict(sk_model, model_input):
2536          return sk_model.predict(model_input) * 2
2537  
2538      pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model")
2539  
2540      mlflow.pyfunc.save_model(
2541          path=pyfunc_model_path,
2542          artifacts={"sk_model": sklearn_model_path},
2543          conda_env=_conda_env(),
2544          python_model=main_scoped_model_class(test_predict),
2545      )
2546  
2547      loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path)
2548      np.testing.assert_array_equal(
2549          loaded_pyfunc_model.predict(iris_data[0]),
2550          test_predict(sk_model=sklearn_knn_model, model_input=iris_data[0]),
2551      )
2552  
2553  
2554  @pytest.mark.skip(reason="Enable once we re-enable the warning")
2555  def test_load_model_warning():
2556      class Model(mlflow.pyfunc.PythonModel):
2557          def predict(self, model_input: list[str]):
2558              return model_input
2559  
2560      with mlflow.start_run() as run:
2561          mlflow.pyfunc.log_model(
2562              python_model=Model(),
2563              name="model",
2564              input_example=["a", "b", "c"],
2565          )
2566  
2567      with pytest.warns(UserWarning, match=r"`runs:/<run_id>/artifact_path` is deprecated"):
2568          mlflow.pyfunc.load_model(f"runs:/{run.info.run_id}/model")
2569  
2570  
2571  def test_pyfunc_model_traces_link_to_model_id():
2572      class TestModel(mlflow.pyfunc.PythonModel):
2573          @mlflow.trace
2574          def predict(self, model_input: list[str]) -> list[str]:
2575              return model_input
2576  
2577      model_infos = [
2578          mlflow.pyfunc.log_model(
2579              name="test_model",
2580              python_model=TestModel(),
2581              input_example=["a", "b", "c"],
2582          )
2583          for i in range(3)
2584      ]
2585  
2586      for model_info in model_infos:
2587          pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
2588          pyfunc_model.predict(["a", "b", "c"])
2589  
2590      traces = get_traces()[::-1]
2591      assert len(traces) == 3
2592      for i in range(3):
2593          assert traces[i].info.request_metadata[TraceMetadataKey.MODEL_ID] == model_infos[i].model_id
2594  
2595  
2596  class ExampleModel(mlflow.pyfunc.PythonModel):
2597      def predict(self, model_input: list[str]) -> list[str]:
2598          return model_input
2599  
2600  
2601  def test_lock_model_requirements(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
2602      monkeypatch.setenv("MLFLOW_LOCK_MODEL_DEPENDENCIES", "true")
2603  
2604      model_info = mlflow.pyfunc.log_model(name="model", python_model=ExampleModel())
2605      pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri, output_path=tmp_path)
2606      requirements_txt = next(Path(pyfunc_model_path).rglob("requirements.txt"))
2607      requirements_txt_contents = requirements_txt.read_text()
2608      assert "# Locked requirements" in requirements_txt_contents
2609      assert "mlflow==" in requirements_txt_contents
2610      assert "packaging==" in requirements_txt_contents
2611      # Check that pip can install the locked requirements
2612      subprocess.check_call(
2613          [
2614              sys.executable,
2615              "-m",
2616              "pip",
2617              "install",
2618              "--ignore-installed",
2619              "--dry-run",
2620              "--requirement",
2621              requirements_txt,
2622          ],
2623      )
2624      # Check that conda environment can be created with the locked requirements
2625      conda_yaml = next(Path(pyfunc_model_path).rglob("conda.yaml"))
2626      conda_yaml_contents = conda_yaml.read_text()
2627      assert "# Locked requirements" in conda_yaml_contents
2628      assert "mlflow==" in requirements_txt_contents
2629      assert "packaging==" in conda_yaml_contents
2630      subprocess.check_call(
2631          [
2632              "conda",
2633              "env",
2634              "create",
2635              "--file",
2636              conda_yaml,
2637              "--dry-run",
2638              "--yes",
2639          ],
2640      )
2641  
2642  
2643  def test_lock_model_requirements_pip_requirements(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
2644      monkeypatch.setenv("MLFLOW_LOCK_MODEL_DEPENDENCIES", "true")
2645      model_info = mlflow.pyfunc.log_model(
2646          name="model",
2647          python_model=ExampleModel(),
2648          pip_requirements=["openai"],
2649      )
2650      pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri, output_path=tmp_path)
2651      requirements_txt = next(Path(pyfunc_model_path).rglob("requirements.txt"))
2652      contents = requirements_txt.read_text()
2653      assert "# Locked requirements" in contents
2654      assert "mlflow==" in contents
2655      assert "openai==" in contents
2656      assert "httpx==" in contents
2657  
2658  
2659  def test_lock_model_requirements_extra_pip_requirements(
2660      monkeypatch: pytest.MonkeyPatch, tmp_path: Path
2661  ):
2662      monkeypatch.setenv("MLFLOW_LOCK_MODEL_DEPENDENCIES", "true")
2663      model_info = mlflow.pyfunc.log_model(
2664          name="model",
2665          python_model=ExampleModel(),
2666          extra_pip_requirements=["openai"],
2667      )
2668      pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri, output_path=tmp_path)
2669      requirements_txt = next(Path(pyfunc_model_path).rglob("requirements.txt"))
2670      contents = requirements_txt.read_text()
2671      assert "# Locked requirements" in contents
2672      assert "mlflow==" in contents
2673      assert "openai==" in contents
2674      assert "httpx==" in contents
2675  
2676  
2677  def test_lock_model_requirements_constraints(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
2678      constraints_file = tmp_path / "constraints.txt"
2679      constraints_file.write_text("openai==1.82.0")
2680      monkeypatch.setenv("MLFLOW_LOCK_MODEL_DEPENDENCIES", "true")
2681      model_info = mlflow.pyfunc.log_model(
2682          name="model",
2683          python_model=ExampleModel(),
2684          pip_requirements=["openai", f"-c {constraints_file}"],
2685      )
2686      pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri, output_path=tmp_path)
2687      requirements_txt = next(Path(pyfunc_model_path).rglob("requirements.txt"))
2688      contents = requirements_txt.read_text()
2689      assert "# Locked requirements" in contents
2690      assert "mlflow==" in contents
2691      assert "openai==1.82.0" in contents
2692      assert "httpx==" in contents
2693  
2694  
2695  @pytest.mark.parametrize(
2696      ("input_example", "expected_result"), [(["Hello", "World"], True), (None, False)]
2697  )
2698  def test_load_context_with_input_example(input_example, expected_result):
2699      class MyModel(mlflow.pyfunc.PythonModel):
2700          def load_context(self, context):
2701              raise Exception("load_context was called")
2702  
2703          def predict(self, model_input: list[str], params=None):
2704              return model_input
2705  
2706      msg = "Failed to run the predict function on input example"
2707  
2708      with mock.patch("mlflow.models.signature._logger.warning") as mock_warning:
2709          mlflow.pyfunc.log_model(
2710              name="model",
2711              python_model=MyModel(),
2712              input_example=input_example,
2713          )
2714          assert any(msg in call.args[0] for call in mock_warning.call_args_list) == expected_result