test_model_export_with_class_and_artifacts.py
1 from __future__ import annotations 2 3 import importlib.metadata 4 import json 5 import os 6 import subprocess 7 import sys 8 import types 9 import uuid 10 from pathlib import Path 11 from subprocess import PIPE, Popen 12 from typing import Any, Dict, List 13 from unittest import mock 14 15 import cloudpickle 16 import numpy as np 17 import pandas as pd 18 import pandas.testing 19 import pytest 20 import sklearn 21 import sklearn.datasets 22 import sklearn.linear_model 23 import sklearn.neighbors 24 import yaml 25 26 import mlflow 27 import mlflow.pyfunc 28 import mlflow.pyfunc.model 29 import mlflow.pyfunc.scoring_server as pyfunc_scoring_server 30 import mlflow.sklearn 31 from mlflow.entities import Trace 32 from mlflow.environment_variables import ( 33 MLFLOW_ALLOW_PICKLE_DESERIALIZATION, 34 MLFLOW_LOG_MODEL_COMPRESSION, 35 MLFLOW_RECORD_ENV_VARS_IN_MODEL_LOGGING, 36 ) 37 from mlflow.exceptions import MlflowException 38 from mlflow.models import Model, infer_signature 39 from mlflow.models.auth_policy import AuthPolicy, SystemAuthPolicy, UserAuthPolicy 40 from mlflow.models.dependencies_schemas import DependenciesSchemasType 41 from mlflow.models.model import _DATABRICKS_FS_LOADER_MODULE 42 from mlflow.models.resources import ( 43 DatabricksApp, 44 DatabricksFunction, 45 DatabricksGenieSpace, 46 DatabricksLakebase, 47 DatabricksServingEndpoint, 48 DatabricksSQLWarehouse, 49 DatabricksTable, 50 DatabricksUCConnection, 51 DatabricksVectorSearchIndex, 52 ) 53 from mlflow.models.utils import _read_example 54 from mlflow.pyfunc.context import Context, set_prediction_context 55 from mlflow.pyfunc.model import _load_pyfunc 56 from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository 57 from mlflow.tracing.constant import TraceMetadataKey 58 from mlflow.tracing.export.inference_table import pop_trace 59 from mlflow.tracking.artifact_utils import ( 60 _download_artifact_from_uri, 61 ) 62 from mlflow.types.schema import ColSpec, Map, Schema 63 from mlflow.types.type_hints import _infer_schema_from_list_type_hint 64 from mlflow.utils.environment import _mlflow_conda_env 65 from mlflow.utils.file_utils import TempDir 66 from mlflow.utils.model_utils import _get_flavor_configuration 67 from mlflow.utils.requirements_utils import _get_installed_version 68 69 import tests 70 from tests.helper_functions import ( 71 _assert_pip_requirements, 72 _compare_conda_env_requirements, 73 _mlflow_major_version_string, 74 assert_register_model_called_with_local_model_path, 75 pyfunc_serve_and_score_model, 76 ) 77 from tests.tracing.helper import get_traces 78 79 80 def get_model_class(): 81 """ 82 Defines a custom Python model class that wraps a scikit-learn estimator. 83 This can be invoked within a pytest fixture to define the class in the ``__main__`` scope. 84 Alternatively, it can be invoked within a module to define the class in the module's scope. 85 """ 86 87 class CustomSklearnModel(mlflow.pyfunc.PythonModel): 88 def __init__(self, predict_fn): 89 self.predict_fn = predict_fn 90 91 def load_context(self, context): 92 super().load_context(context) 93 94 self.model = ( 95 mlflow.sklearn.load_model(model_uri=context.artifacts["sk_model"]) 96 if context.artifacts and "sk_model" in context.artifacts 97 else None 98 ) 99 100 def predict(self, context, model_input, params=None): 101 return self.predict_fn(self.model, model_input) 102 103 return CustomSklearnModel 104 105 106 class ModuleScopedSklearnModel(get_model_class()): 107 """ 108 A custom Python model class defined in the test module scope. 109 """ 110 111 112 @pytest.fixture(scope="module") 113 def main_scoped_model_class(): 114 """ 115 A custom Python model class defined in the ``__main__`` scope. 116 """ 117 return get_model_class() 118 119 120 @pytest.fixture(scope="module") 121 def iris_data(): 122 iris = sklearn.datasets.load_iris() 123 x = iris.data[:, :2] 124 y = iris.target 125 return x, y 126 127 128 @pytest.fixture(scope="module") 129 def sklearn_knn_model(iris_data): 130 x, y = iris_data 131 knn_model = sklearn.neighbors.KNeighborsClassifier() 132 knn_model.fit(x, y) 133 return knn_model 134 135 136 @pytest.fixture(scope="module") 137 def sklearn_logreg_model(iris_data): 138 x, y = iris_data 139 linear_lr = sklearn.linear_model.LogisticRegression() 140 linear_lr.fit(x, y) 141 return linear_lr 142 143 144 @pytest.fixture 145 def model_path(tmp_path): 146 return os.path.join(tmp_path, "model") 147 148 149 @pytest.fixture 150 def pyfunc_custom_env(tmp_path): 151 conda_env = os.path.join(tmp_path, "conda_env.yml") 152 _mlflow_conda_env( 153 conda_env, 154 additional_pip_deps=["scikit-learn", "pytest", "cloudpickle"], 155 ) 156 return conda_env 157 158 159 def _conda_env(): 160 # NB: We need mlflow as a dependency in the environment. 161 return _mlflow_conda_env( 162 additional_pip_deps=[ 163 f"cloudpickle=={cloudpickle.__version__}", 164 f"scikit-learn=={sklearn.__version__}", 165 ], 166 ) 167 168 169 def test_model_save_load(sklearn_knn_model, main_scoped_model_class, iris_data, tmp_path): 170 sklearn_model_path = os.path.join(tmp_path, "sklearn_model") 171 mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path) 172 173 def test_predict(sk_model, model_input): 174 return sk_model.predict(model_input) * 2 175 176 pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model") 177 178 mlflow.pyfunc.save_model( 179 path=pyfunc_model_path, 180 artifacts={"sk_model": sklearn_model_path}, 181 conda_env=_conda_env(), 182 python_model=main_scoped_model_class(test_predict), 183 ) 184 185 loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path) 186 np.testing.assert_array_equal( 187 loaded_pyfunc_model.predict(iris_data[0]), 188 test_predict(sk_model=sklearn_knn_model, model_input=iris_data[0]), 189 ) 190 191 192 @pytest.mark.skip( 193 reason="In MLflow 3.0, `log_model` does not start a run. Consider removing this test." 194 ) 195 def test_pyfunc_model_log_load_no_active_run(sklearn_knn_model, main_scoped_model_class, iris_data): 196 sklearn_artifact_path = "sk_model_no_run" 197 with mlflow.start_run(): 198 mlflow.sklearn.log_model(sklearn_knn_model, name=sklearn_artifact_path) 199 sklearn_model_uri = f"runs:/{mlflow.active_run().info.run_id}/{sklearn_artifact_path}" 200 201 def test_predict(sk_model, model_input): 202 return sk_model.predict(model_input) * 2 203 204 pyfunc_artifact_path = "pyfunc_model" 205 assert mlflow.active_run() is None 206 mlflow.pyfunc.log_model( 207 name=pyfunc_artifact_path, 208 artifacts={"sk_model": sklearn_model_uri}, 209 python_model=main_scoped_model_class(test_predict), 210 ) 211 pyfunc_model_uri = f"runs:/{mlflow.active_run().info.run_id}/{pyfunc_artifact_path}" 212 loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_uri) 213 np.testing.assert_array_equal( 214 loaded_pyfunc_model.predict(iris_data[0]), 215 test_predict(sk_model=sklearn_knn_model, model_input=iris_data[0]), 216 ) 217 mlflow.end_run() 218 219 220 def test_model_log_load(sklearn_knn_model, main_scoped_model_class, iris_data): 221 sklearn_artifact_path = "sk_model" 222 with mlflow.start_run(): 223 sklearn_model_info = mlflow.sklearn.log_model(sklearn_knn_model, name=sklearn_artifact_path) 224 225 def test_predict(sk_model, model_input): 226 return sk_model.predict(model_input) * 2 227 228 pyfunc_artifact_path = "pyfunc_model" 229 with ( 230 mlflow.start_run(), 231 mock.patch("mlflow.pyfunc._logger.warning") as mock_warning, 232 ): 233 pyfunc_model_info = mlflow.pyfunc.log_model( 234 name=pyfunc_artifact_path, 235 artifacts={"sk_model": sklearn_model_info.model_uri}, 236 python_model=main_scoped_model_class(test_predict), 237 ) 238 pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_info.model_uri) 239 model_config = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) 240 assert "Consider using a file path (str or Path) instead" in mock_warning.call_args[0][0] 241 242 loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_info.model_uri) 243 assert model_config.to_yaml() == loaded_pyfunc_model.metadata.to_yaml() 244 np.testing.assert_array_equal( 245 loaded_pyfunc_model.predict(iris_data[0]), 246 test_predict(sk_model=sklearn_knn_model, model_input=iris_data[0]), 247 ) 248 249 250 def test_python_model_predict_compatible_without_params(sklearn_knn_model, iris_data): 251 class CustomSklearnModelWithoutParams(mlflow.pyfunc.PythonModel): 252 def __init__(self, predict_fn): 253 self.predict_fn = predict_fn 254 255 def load_context(self, context): 256 super().load_context(context) 257 258 self.model = mlflow.sklearn.load_model(model_uri=context.artifacts["sk_model"]) 259 260 def predict(self, context, model_input): 261 return self.predict_fn(self.model, model_input) 262 263 sklearn_artifact_path = "sk_model" 264 with mlflow.start_run(): 265 model_info = mlflow.sklearn.log_model(sklearn_knn_model, name=sklearn_artifact_path) 266 sklearn_model_uri = model_info.model_uri 267 268 def test_predict(sk_model, model_input): 269 return sk_model.predict(model_input) * 2 270 271 pyfunc_artifact_path = "pyfunc_model" 272 with mlflow.start_run(): 273 model_info = mlflow.pyfunc.log_model( 274 name=pyfunc_artifact_path, 275 artifacts={"sk_model": sklearn_model_uri}, 276 python_model=CustomSklearnModelWithoutParams(test_predict), 277 ) 278 pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri) 279 model_config = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) 280 281 loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri) 282 assert model_config.to_yaml() == loaded_pyfunc_model.metadata.to_yaml() 283 np.testing.assert_array_equal( 284 loaded_pyfunc_model.predict(iris_data[0]), 285 test_predict(sk_model=sklearn_knn_model, model_input=iris_data[0]), 286 ) 287 288 289 def test_signature_and_examples_are_saved_correctly(iris_data, main_scoped_model_class, tmp_path): 290 sklearn_model_path = str(tmp_path.joinpath("sklearn_model")) 291 mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path) 292 293 def test_predict(sk_model, model_input): 294 return sk_model.predict(model_input) * 2 295 296 data = iris_data 297 signature_ = infer_signature(*data) 298 example_ = data[0][:3] 299 for signature in (None, signature_): 300 for example in (None, example_): 301 with TempDir() as tmp: 302 path = tmp.path("model") 303 mlflow.pyfunc.save_model( 304 path=path, 305 artifacts={"sk_model": sklearn_model_path}, 306 python_model=main_scoped_model_class(test_predict), 307 signature=signature, 308 input_example=example, 309 ) 310 mlflow_model = Model.load(path) 311 assert signature == mlflow_model.signature 312 if example is None: 313 assert mlflow_model.saved_input_example_info is None 314 else: 315 np.testing.assert_array_equal(_read_example(mlflow_model, path), example) 316 317 318 class DummyModel(mlflow.pyfunc.PythonModel): 319 def predict(self, context, model_input, params=None): 320 return model_input 321 322 323 def test_log_model_calls_register_model(sklearn_knn_model, main_scoped_model_class): 324 with mlflow.start_run(): 325 with mock.patch( 326 "mlflow.tracking._model_registry.fluent._register_model" 327 ) as register_model_mock: 328 registered_model_name = "AdsModel1" 329 pyfunc_model_info = mlflow.pyfunc.log_model( 330 name="pyfunc_model", 331 python_model=DummyModel(), 332 registered_model_name=registered_model_name, 333 ) 334 assert_register_model_called_with_local_model_path( 335 register_model_mock, pyfunc_model_info.model_uri, registered_model_name 336 ) 337 338 339 def test_log_model_no_registered_model_name(sklearn_knn_model, main_scoped_model_class): 340 with mlflow.start_run(): 341 with mock.patch( 342 "mlflow.tracking._model_registry.fluent._register_model" 343 ) as register_model_mock: 344 mlflow.pyfunc.log_model( 345 name="pyfunc_model", 346 python_model=DummyModel(), 347 ) 348 register_model_mock.assert_not_called() 349 350 351 def test_model_load_from_remote_uri_succeeds( 352 sklearn_knn_model, main_scoped_model_class, tmp_path, mock_s3_bucket, iris_data 353 ): 354 artifact_root = f"s3://{mock_s3_bucket}" 355 artifact_repo = S3ArtifactRepository(artifact_root) 356 357 sklearn_model_path = os.path.join(tmp_path, "sklearn_model") 358 mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path) 359 sklearn_artifact_path = "sk_model" 360 artifact_repo.log_artifacts(sklearn_model_path, artifact_path=sklearn_artifact_path) 361 362 def test_predict(sk_model, model_input): 363 return sk_model.predict(model_input) * 2 364 365 pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model") 366 mlflow.pyfunc.save_model( 367 path=pyfunc_model_path, 368 artifacts={"sk_model": sklearn_model_path}, 369 python_model=main_scoped_model_class(test_predict), 370 conda_env=_conda_env(), 371 ) 372 373 pyfunc_artifact_path = "pyfunc_model" 374 artifact_repo.log_artifacts(pyfunc_model_path, artifact_path=pyfunc_artifact_path) 375 376 model_uri = artifact_root + "/" + pyfunc_artifact_path 377 loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=model_uri) 378 np.testing.assert_array_equal( 379 loaded_pyfunc_model.predict(iris_data[0]), 380 test_predict(sk_model=sklearn_knn_model, model_input=iris_data[0]), 381 ) 382 383 384 def test_add_to_model_adds_specified_kwargs_to_mlmodel_configuration(): 385 custom_kwargs = { 386 "key1": "value1", 387 "key2": 20, 388 "key3": range(10), 389 } 390 model_config = Model() 391 mlflow.pyfunc.add_to_model( 392 model=model_config, 393 loader_module=os.path.basename(__file__)[:-3], 394 data="data", 395 code="code", 396 env=None, 397 **custom_kwargs, 398 ) 399 400 assert mlflow.pyfunc.FLAVOR_NAME in model_config.flavors 401 assert all(item in model_config.flavors[mlflow.pyfunc.FLAVOR_NAME] for item in custom_kwargs) 402 403 404 def test_pyfunc_model_serving_without_conda_env_activation_succeeds_with_main_scoped_class( 405 sklearn_knn_model, main_scoped_model_class, iris_data, tmp_path 406 ): 407 sklearn_model_path = os.path.join(tmp_path, "sklearn_model") 408 mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path) 409 410 def test_predict(sk_model, model_input): 411 return sk_model.predict(model_input) * 2 412 413 pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model") 414 mlflow.pyfunc.save_model( 415 path=pyfunc_model_path, 416 artifacts={"sk_model": sklearn_model_path}, 417 python_model=main_scoped_model_class(test_predict), 418 conda_env=_conda_env(), 419 ) 420 loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path) 421 422 sample_input = pd.DataFrame(iris_data[0]) 423 scoring_response = pyfunc_serve_and_score_model( 424 model_uri=pyfunc_model_path, 425 data=sample_input, 426 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 427 extra_args=["--env-manager", "local"], 428 ) 429 assert scoring_response.status_code == 200 430 np.testing.assert_array_equal( 431 np.array(json.loads(scoring_response.text)["predictions"]), 432 loaded_pyfunc_model.predict(sample_input), 433 ) 434 435 436 def test_pyfunc_model_serving_with_conda_env_activation_succeeds_with_main_scoped_class( 437 sklearn_knn_model, main_scoped_model_class, iris_data, tmp_path 438 ): 439 sklearn_model_path = os.path.join(tmp_path, "sklearn_model") 440 mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path) 441 442 def test_predict(sk_model, model_input): 443 return sk_model.predict(model_input) * 2 444 445 pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model") 446 mlflow.pyfunc.save_model( 447 path=pyfunc_model_path, 448 artifacts={"sk_model": sklearn_model_path}, 449 python_model=main_scoped_model_class(test_predict), 450 conda_env=_conda_env(), 451 ) 452 loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path) 453 454 sample_input = pd.DataFrame(iris_data[0]) 455 scoring_response = pyfunc_serve_and_score_model( 456 model_uri=pyfunc_model_path, 457 data=sample_input, 458 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 459 ) 460 assert scoring_response.status_code == 200 461 np.testing.assert_array_equal( 462 np.array(json.loads(scoring_response.text)["predictions"]), 463 loaded_pyfunc_model.predict(sample_input), 464 ) 465 466 467 def test_pyfunc_model_serving_without_conda_env_activation_succeeds_with_module_scoped_class( 468 sklearn_knn_model, iris_data, tmp_path 469 ): 470 sklearn_model_path = os.path.join(tmp_path, "sklearn_model") 471 mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path) 472 473 def test_predict(sk_model, model_input): 474 return sk_model.predict(model_input) * 2 475 476 pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model") 477 mlflow.pyfunc.save_model( 478 path=pyfunc_model_path, 479 artifacts={"sk_model": sklearn_model_path}, 480 python_model=ModuleScopedSklearnModel(test_predict), 481 code_paths=[os.path.dirname(tests.__file__)], 482 conda_env=_conda_env(), 483 ) 484 loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path) 485 486 sample_input = pd.DataFrame(iris_data[0]) 487 scoring_response = pyfunc_serve_and_score_model( 488 model_uri=pyfunc_model_path, 489 data=sample_input, 490 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 491 extra_args=["--env-manager", "local"], 492 ) 493 assert scoring_response.status_code == 200 494 np.testing.assert_array_equal( 495 np.array(json.loads(scoring_response.text)["predictions"]), 496 loaded_pyfunc_model.predict(sample_input), 497 ) 498 499 500 def test_pyfunc_cli_predict_command_without_conda_env_activation_succeeds( 501 sklearn_knn_model, main_scoped_model_class, iris_data, tmp_path 502 ): 503 sklearn_model_path = os.path.join(tmp_path, "sklearn_model") 504 mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path) 505 506 def test_predict(sk_model, model_input): 507 return sk_model.predict(model_input) * 2 508 509 pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model") 510 mlflow.pyfunc.save_model( 511 path=pyfunc_model_path, 512 artifacts={"sk_model": sklearn_model_path}, 513 python_model=main_scoped_model_class(test_predict), 514 conda_env=_conda_env(), 515 ) 516 loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path) 517 518 sample_input = pd.DataFrame(iris_data[0]) 519 input_csv_path = os.path.join(tmp_path, "input with spaces.csv") 520 sample_input.to_csv(input_csv_path, header=True, index=False) 521 output_json_path = os.path.join(tmp_path, "output.json") 522 process = Popen( 523 [ 524 sys.executable, 525 "-m", 526 "mlflow", 527 "models", 528 "predict", 529 "--model-uri", 530 pyfunc_model_path, 531 "-i", 532 input_csv_path, 533 "--content-type", 534 "csv", 535 "-o", 536 output_json_path, 537 "--env-manager", 538 "local", 539 ], 540 stdout=PIPE, 541 stderr=PIPE, 542 preexec_fn=os.setsid, 543 ) 544 _, stderr = process.communicate() 545 assert process.wait() == 0, f"stderr = \n\n{stderr}\n\n" 546 with open(output_json_path) as f: 547 result_df = pd.DataFrame(data=json.load(f)["predictions"]) 548 np.testing.assert_array_equal( 549 result_df.values.transpose()[0], loaded_pyfunc_model.predict(sample_input) 550 ) 551 552 553 def test_pyfunc_cli_predict_command_with_conda_env_activation_succeeds( 554 sklearn_knn_model, main_scoped_model_class, iris_data, tmp_path 555 ): 556 sklearn_model_path = os.path.join(tmp_path, "sklearn_model") 557 mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path) 558 559 def test_predict(sk_model, model_input): 560 return sk_model.predict(model_input) * 2 561 562 pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model") 563 mlflow.pyfunc.save_model( 564 path=pyfunc_model_path, 565 artifacts={"sk_model": sklearn_model_path}, 566 python_model=main_scoped_model_class(test_predict), 567 conda_env=_conda_env(), 568 ) 569 loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path) 570 571 sample_input = pd.DataFrame(iris_data[0]) 572 input_csv_path = os.path.join(tmp_path, "input with spaces.csv") 573 sample_input.to_csv(input_csv_path, header=True, index=False) 574 output_json_path = os.path.join(tmp_path, "output.json") 575 process = Popen( 576 [ 577 sys.executable, 578 "-m", 579 "mlflow", 580 "models", 581 "predict", 582 "--model-uri", 583 pyfunc_model_path, 584 "-i", 585 input_csv_path, 586 "--content-type", 587 "csv", 588 "-o", 589 output_json_path, 590 ], 591 stderr=PIPE, 592 stdout=PIPE, 593 preexec_fn=os.setsid, 594 ) 595 stdout, stderr = process.communicate() 596 assert process.wait() == 0, f"stdout = \n\n{stdout}\n\n stderr = \n\n{stderr}\n\n" 597 with open(output_json_path) as f: 598 result_df = pandas.DataFrame(json.load(f)["predictions"]) 599 np.testing.assert_array_equal( 600 result_df.values.transpose()[0], loaded_pyfunc_model.predict(sample_input) 601 ) 602 603 604 def test_save_model_persists_specified_conda_env_in_mlflow_model_directory( 605 sklearn_knn_model, main_scoped_model_class, pyfunc_custom_env, tmp_path 606 ): 607 sklearn_model_path = os.path.join(tmp_path, "sklearn_model") 608 mlflow.sklearn.save_model( 609 sk_model=sklearn_knn_model, 610 path=sklearn_model_path, 611 serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE, 612 ) 613 614 pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model") 615 mlflow.pyfunc.save_model( 616 path=pyfunc_model_path, 617 artifacts={"sk_model": sklearn_model_path}, 618 python_model=main_scoped_model_class(predict_fn=None), 619 conda_env=pyfunc_custom_env, 620 ) 621 622 pyfunc_conf = _get_flavor_configuration( 623 model_path=pyfunc_model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME 624 ) 625 saved_conda_env_path = os.path.join(pyfunc_model_path, pyfunc_conf[mlflow.pyfunc.ENV]["conda"]) 626 assert os.path.exists(saved_conda_env_path) 627 assert saved_conda_env_path != pyfunc_custom_env 628 629 with open(pyfunc_custom_env) as f: 630 pyfunc_custom_env_parsed = yaml.safe_load(f) 631 with open(saved_conda_env_path) as f: 632 saved_conda_env_parsed = yaml.safe_load(f) 633 assert saved_conda_env_parsed == pyfunc_custom_env_parsed 634 635 636 def test_save_model_persists_requirements_in_mlflow_model_directory( 637 sklearn_knn_model, main_scoped_model_class, pyfunc_custom_env, tmp_path 638 ): 639 sklearn_model_path = os.path.join(tmp_path, "sklearn_model") 640 mlflow.sklearn.save_model( 641 sk_model=sklearn_knn_model, 642 path=sklearn_model_path, 643 serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE, 644 ) 645 646 pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model") 647 mlflow.pyfunc.save_model( 648 path=pyfunc_model_path, 649 artifacts={"sk_model": sklearn_model_path}, 650 python_model=main_scoped_model_class(predict_fn=None), 651 conda_env=pyfunc_custom_env, 652 ) 653 654 saved_pip_req_path = os.path.join(pyfunc_model_path, "requirements.txt") 655 _compare_conda_env_requirements(pyfunc_custom_env, saved_pip_req_path) 656 657 658 def test_log_model_with_pip_requirements(sklearn_knn_model, main_scoped_model_class, tmp_path): 659 expected_mlflow_version = _mlflow_major_version_string() 660 python_model = main_scoped_model_class(predict_fn=None) 661 sklearn_model_path = os.path.join(tmp_path, "sklearn_model") 662 mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path) 663 # Path to a requirements file 664 req_file = tmp_path.joinpath("requirements.txt") 665 req_file.write_text("a") 666 with mlflow.start_run(): 667 model_info = mlflow.pyfunc.log_model( 668 name="model", 669 python_model=python_model, 670 pip_requirements=str(req_file), 671 artifacts={"sk_model": sklearn_model_path}, 672 ) 673 _assert_pip_requirements( 674 model_info.model_uri, 675 [expected_mlflow_version, "a"], 676 strict=True, 677 ) 678 679 # List of requirements 680 with mlflow.start_run(): 681 model_info = mlflow.pyfunc.log_model( 682 name="model", 683 python_model=python_model, 684 pip_requirements=[f"-r {req_file}", "b"], 685 artifacts={"sk_model": sklearn_model_path}, 686 ) 687 _assert_pip_requirements( 688 model_info.model_uri, 689 [expected_mlflow_version, "a", "b"], 690 strict=True, 691 ) 692 693 # Constraints file 694 with mlflow.start_run(): 695 model_info = mlflow.pyfunc.log_model( 696 name="model", 697 python_model=python_model, 698 pip_requirements=[f"-c {req_file}", "b"], 699 artifacts={"sk_model": sklearn_model_path}, 700 ) 701 _assert_pip_requirements( 702 model_info.model_uri, 703 [expected_mlflow_version, "b", "-c constraints.txt"], 704 ["a"], 705 strict=True, 706 ) 707 708 709 def test_log_model_with_extra_pip_requirements( 710 sklearn_knn_model, main_scoped_model_class, tmp_path 711 ): 712 expected_mlflow_version = _mlflow_major_version_string() 713 sklearn_model_path = str(tmp_path.joinpath("sklearn_model")) 714 mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path) 715 716 python_model = main_scoped_model_class(predict_fn=None) 717 default_reqs = mlflow.pyfunc.get_default_pip_requirements() 718 719 # Path to a requirements file 720 req_file = tmp_path.joinpath("requirements.txt") 721 req_file.write_text("a") 722 with mlflow.start_run(): 723 model_info = mlflow.pyfunc.log_model( 724 name="model", 725 python_model=python_model, 726 artifacts={"sk_model": sklearn_model_path}, 727 extra_pip_requirements=str(req_file), 728 ) 729 _assert_pip_requirements( 730 model_info.model_uri, 731 [expected_mlflow_version, *default_reqs, "a"], 732 ) 733 734 # List of requirements 735 with mlflow.start_run(): 736 model_info = mlflow.pyfunc.log_model( 737 name="model", 738 artifacts={"sk_model": sklearn_model_path}, 739 python_model=python_model, 740 extra_pip_requirements=[f"-r {req_file}", "b"], 741 ) 742 _assert_pip_requirements( 743 model_info.model_uri, 744 [expected_mlflow_version, *default_reqs, "a", "b"], 745 ) 746 747 # Constraints file 748 with mlflow.start_run(): 749 model_info = mlflow.pyfunc.log_model( 750 name="model", 751 artifacts={"sk_model": sklearn_model_path}, 752 python_model=python_model, 753 extra_pip_requirements=[f"-c {req_file}", "b"], 754 ) 755 _assert_pip_requirements( 756 model_info.model_uri, 757 [expected_mlflow_version, *default_reqs, "b", "-c constraints.txt"], 758 ["a"], 759 ) 760 761 762 def test_log_model_persists_specified_conda_env_in_mlflow_model_directory( 763 sklearn_knn_model, main_scoped_model_class, pyfunc_custom_env 764 ): 765 sklearn_artifact_path = "sk_model" 766 with mlflow.start_run(): 767 sklearn_model_info = mlflow.sklearn.log_model(sklearn_knn_model, name=sklearn_artifact_path) 768 769 pyfunc_artifact_path = "pyfunc_model" 770 with mlflow.start_run(): 771 pyfunc_model_info = mlflow.pyfunc.log_model( 772 name=pyfunc_artifact_path, 773 artifacts={"sk_model": sklearn_model_info.model_uri}, 774 python_model=main_scoped_model_class(predict_fn=None), 775 conda_env=pyfunc_custom_env, 776 ) 777 pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_info.model_uri) 778 779 pyfunc_conf = _get_flavor_configuration( 780 model_path=pyfunc_model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME 781 ) 782 saved_conda_env_path = os.path.join(pyfunc_model_path, pyfunc_conf[mlflow.pyfunc.ENV]["conda"]) 783 assert os.path.exists(saved_conda_env_path) 784 assert saved_conda_env_path != pyfunc_custom_env 785 786 with open(pyfunc_custom_env) as f: 787 pyfunc_custom_env_parsed = yaml.safe_load(f) 788 with open(saved_conda_env_path) as f: 789 saved_conda_env_parsed = yaml.safe_load(f) 790 assert saved_conda_env_parsed == pyfunc_custom_env_parsed 791 792 793 def test_model_log_persists_requirements_in_mlflow_model_directory( 794 sklearn_knn_model, main_scoped_model_class, pyfunc_custom_env 795 ): 796 sklearn_artifact_path = "sk_model" 797 with mlflow.start_run(): 798 sklearn_model_info = mlflow.sklearn.log_model(sklearn_knn_model, name=sklearn_artifact_path) 799 800 pyfunc_artifact_path = "pyfunc_model" 801 with mlflow.start_run(): 802 pyfunc_model_info = mlflow.pyfunc.log_model( 803 name=pyfunc_artifact_path, 804 artifacts={"sk_model": sklearn_model_info.model_uri}, 805 python_model=main_scoped_model_class(predict_fn=None), 806 conda_env=pyfunc_custom_env, 807 ) 808 pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_info.model_uri) 809 810 saved_pip_req_path = os.path.join(pyfunc_model_path, "requirements.txt") 811 _compare_conda_env_requirements(pyfunc_custom_env, saved_pip_req_path) 812 813 814 def test_save_model_without_specified_conda_env_uses_default_env_with_expected_dependencies( 815 sklearn_logreg_model, main_scoped_model_class, tmp_path 816 ): 817 sklearn_model_path = os.path.join(tmp_path, "sklearn_model") 818 mlflow.sklearn.save_model(sk_model=sklearn_logreg_model, path=sklearn_model_path) 819 820 pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model") 821 mlflow.pyfunc.save_model( 822 path=pyfunc_model_path, 823 artifacts={"sk_model": sklearn_model_path}, 824 python_model=main_scoped_model_class(predict_fn=None), 825 conda_env=_conda_env(), 826 ) 827 _assert_pip_requirements(pyfunc_model_path, mlflow.pyfunc.get_default_pip_requirements()) 828 829 830 def test_log_model_without_specified_conda_env_uses_default_env_with_expected_dependencies( 831 sklearn_knn_model, main_scoped_model_class 832 ): 833 sklearn_artifact_path = "sk_model" 834 with mlflow.start_run(): 835 sklearn_model_info = mlflow.sklearn.log_model(sklearn_knn_model, name=sklearn_artifact_path) 836 837 pyfunc_artifact_path = "pyfunc_model" 838 with mlflow.start_run(): 839 pyfunc_model_info = mlflow.pyfunc.log_model( 840 name=pyfunc_artifact_path, 841 artifacts={ 842 "sk_model": sklearn_model_info.model_uri, 843 }, 844 python_model=main_scoped_model_class(predict_fn=None), 845 ) 846 _assert_pip_requirements( 847 pyfunc_model_info.model_uri, mlflow.pyfunc.get_default_pip_requirements() 848 ) 849 850 851 def test_save_model_correctly_resolves_directory_artifact_with_nested_contents( 852 tmp_path, model_path, iris_data 853 ): 854 directory_artifact_path = os.path.join(tmp_path, "directory_artifact") 855 nested_file_relative_path = os.path.join( 856 "my", "somewhat", "heavily", "nested", "directory", "myfile.txt" 857 ) 858 nested_file_path = os.path.join(directory_artifact_path, nested_file_relative_path) 859 os.makedirs(os.path.dirname(nested_file_path)) 860 nested_file_text = "some sample file text" 861 with open(nested_file_path, "w") as f: 862 f.write(nested_file_text) 863 864 class ArtifactValidationModel(mlflow.pyfunc.PythonModel): 865 def predict(self, context, model_input, params=None): 866 expected_file_path = os.path.join( 867 context.artifacts["testdir"], nested_file_relative_path 868 ) 869 if not os.path.exists(expected_file_path): 870 return False 871 else: 872 with open(expected_file_path) as f: 873 return f.read() == nested_file_text 874 875 mlflow.pyfunc.save_model( 876 path=model_path, 877 artifacts={"testdir": directory_artifact_path}, 878 python_model=ArtifactValidationModel(), 879 conda_env=_conda_env(), 880 ) 881 882 loaded_model = mlflow.pyfunc.load_model(model_uri=model_path) 883 assert loaded_model.predict(iris_data[0]) 884 885 886 def test_save_model_with_no_artifacts_does_not_produce_artifacts_dir(model_path): 887 mlflow.pyfunc.save_model( 888 path=model_path, 889 python_model=ModuleScopedSklearnModel(predict_fn=None), 890 artifacts=None, 891 conda_env=_conda_env(), 892 ) 893 894 assert os.path.exists(model_path) 895 assert "artifacts" not in os.listdir(model_path) 896 pyfunc_conf = _get_flavor_configuration( 897 model_path=model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME 898 ) 899 assert mlflow.pyfunc.model.CONFIG_KEY_ARTIFACTS not in pyfunc_conf 900 901 902 def test_save_model_with_python_model_argument_of_invalid_type_raises_exception( 903 tmp_path, 904 ): 905 with pytest.raises( 906 MlflowException, 907 match="must be a PythonModel instance, callable object, or path to a", 908 ): 909 mlflow.pyfunc.save_model(path=os.path.join(tmp_path, "model1"), python_model=5) 910 911 with pytest.raises( 912 MlflowException, 913 match="must be a PythonModel instance, callable object, or path to a", 914 ): 915 mlflow.pyfunc.save_model( 916 path=os.path.join(tmp_path, "model2"), python_model=["not a python model"] 917 ) 918 with pytest.raises(MlflowException, match="The provided model path"): 919 mlflow.pyfunc.save_model( 920 path=os.path.join(tmp_path, "model3"), python_model="not a valid filepath" 921 ) 922 923 924 def test_save_model_with_unsupported_argument_combinations_throws_exception(model_path): 925 with pytest.raises( 926 MlflowException, 927 match="Either `loader_module` or `python_model` must be specified", 928 ) as exc_info: 929 mlflow.pyfunc.save_model( 930 path=model_path, 931 artifacts={"artifact": "/path/to/artifact"}, 932 python_model=None, 933 ) 934 935 python_model = ModuleScopedSklearnModel(predict_fn=None) 936 loader_module = __name__ 937 with pytest.raises( 938 MlflowException, 939 match="The following sets of parameters cannot be specified together", 940 ) as exc_info: 941 mlflow.pyfunc.save_model( 942 path=model_path, python_model=python_model, loader_module=loader_module 943 ) 944 assert str(python_model) in str(exc_info) 945 assert str(loader_module) in str(exc_info) 946 947 with pytest.raises( 948 MlflowException, 949 match="The following sets of parameters cannot be specified together", 950 ) as exc_info: 951 mlflow.pyfunc.save_model( 952 path=model_path, 953 python_model=python_model, 954 data_path="/path/to/data", 955 artifacts={"artifact": "/path/to/artifact"}, 956 ) 957 958 with pytest.raises( 959 MlflowException, 960 match="Either `loader_module` or `python_model` must be specified", 961 ): 962 mlflow.pyfunc.save_model(path=model_path, python_model=None, loader_module=None) 963 964 965 def test_log_model_with_unsupported_argument_combinations_throws_exception(): 966 match = ( 967 "Either `loader_module` or `python_model` must be specified. A `loader_module` " 968 "should be a python module. A `python_model` should be a subclass of " 969 "PythonModel" 970 ) 971 with mlflow.start_run(), pytest.raises(MlflowException, match=match): 972 mlflow.pyfunc.log_model( 973 name="pyfunc_model", 974 artifacts={"artifact": "/path/to/artifact"}, 975 python_model=None, 976 ) 977 978 python_model = ModuleScopedSklearnModel(predict_fn=None) 979 loader_module = __name__ 980 with ( 981 mlflow.start_run(), 982 pytest.raises( 983 MlflowException, 984 match="The following sets of parameters cannot be specified together", 985 ) as exc_info, 986 ): 987 mlflow.pyfunc.log_model( 988 name="pyfunc_model", 989 python_model=python_model, 990 loader_module=loader_module, 991 ) 992 assert str(python_model) in str(exc_info) 993 assert str(loader_module) in str(exc_info) 994 995 with ( 996 mlflow.start_run(), 997 pytest.raises( 998 MlflowException, 999 match="The following sets of parameters cannot be specified together", 1000 ) as exc_info, 1001 ): 1002 mlflow.pyfunc.log_model( 1003 name="pyfunc_model", 1004 python_model=python_model, 1005 data_path="/path/to/data", 1006 artifacts={"artifact1": "/path/to/artifact"}, 1007 ) 1008 1009 with ( 1010 mlflow.start_run(), 1011 pytest.raises( 1012 MlflowException, 1013 match="Either `loader_module` or `python_model` must be specified", 1014 ), 1015 ): 1016 mlflow.pyfunc.log_model(name="pyfunc_model", python_model=None, loader_module=None) 1017 1018 1019 def test_repr_can_be_called_without_run_id_or_artifact_path(): 1020 model_meta = Model( 1021 artifact_path=None, 1022 run_id=None, 1023 flavors={"python_function": {"loader_module": "someFlavour"}}, 1024 ) 1025 1026 class TestModel: 1027 def predict(self, model_input, params=None): 1028 return model_input 1029 1030 model_impl = TestModel() 1031 1032 assert "flavor: someFlavour" in mlflow.pyfunc.PyFuncModel(model_meta, model_impl).__repr__() 1033 1034 1035 def test_load_model_with_differing_cloudpickle_version_at_micro_granularity_logs_warning( 1036 model_path, 1037 ): 1038 class TestModel(mlflow.pyfunc.PythonModel): 1039 def predict(self, context, model_input, params=None): 1040 return model_input 1041 1042 mlflow.pyfunc.save_model(path=model_path, python_model=TestModel()) 1043 saver_cloudpickle_version = "0.5.8" 1044 model_config_path = os.path.join(model_path, "MLmodel") 1045 model_config = Model.load(model_config_path) 1046 model_config.flavors[mlflow.pyfunc.FLAVOR_NAME][ 1047 mlflow.pyfunc.model.CONFIG_KEY_CLOUDPICKLE_VERSION 1048 ] = saver_cloudpickle_version 1049 model_config.save(model_config_path) 1050 1051 log_messages = [] 1052 1053 def custom_warn(message_text, *args, **kwargs): 1054 log_messages.append(message_text % args % kwargs) 1055 1056 loader_cloudpickle_version = "0.5.7" 1057 with ( 1058 mock.patch("mlflow.pyfunc._logger.warning") as warn_mock, 1059 mock.patch("cloudpickle.__version__") as cloudpickle_version_mock, 1060 ): 1061 cloudpickle_version_mock.__str__ = lambda *args, **kwargs: loader_cloudpickle_version 1062 warn_mock.side_effect = custom_warn 1063 mlflow.pyfunc.load_model(model_uri=model_path) 1064 1065 assert any( 1066 "differs from the version of CloudPickle that is currently running" in log_message 1067 and saver_cloudpickle_version in log_message 1068 and loader_cloudpickle_version in log_message 1069 for log_message in log_messages 1070 ) 1071 1072 1073 def test_load_model_with_missing_cloudpickle_version_logs_warning(model_path): 1074 class TestModel(mlflow.pyfunc.PythonModel): 1075 def predict(self, context, model_input, params=None): 1076 return model_input 1077 1078 mlflow.pyfunc.save_model(path=model_path, python_model=TestModel()) 1079 model_config_path = os.path.join(model_path, "MLmodel") 1080 model_config = Model.load(model_config_path) 1081 del model_config.flavors[mlflow.pyfunc.FLAVOR_NAME][ 1082 mlflow.pyfunc.model.CONFIG_KEY_CLOUDPICKLE_VERSION 1083 ] 1084 model_config.save(model_config_path) 1085 1086 log_messages = [] 1087 1088 def custom_warn(message_text, *args, **kwargs): 1089 log_messages.append(message_text % args % kwargs) 1090 1091 with mock.patch("mlflow.pyfunc._logger.warning") as warn_mock: 1092 warn_mock.side_effect = custom_warn 1093 mlflow.pyfunc.load_model(model_uri=model_path) 1094 1095 assert any( 1096 ( 1097 "The version of CloudPickle used to save the model could not be found" 1098 " in the MLmodel configuration" 1099 ) 1100 in log_message 1101 for log_message in log_messages 1102 ) 1103 1104 1105 def test_load_cloudpickle_model_raises_when_pickle_deserialization_disallowed( 1106 model_path, monkeypatch 1107 ): 1108 class TestModel(mlflow.pyfunc.PythonModel): 1109 def predict(self, context, model_input, params=None): 1110 return model_input 1111 1112 mlflow.pyfunc.save_model(path=model_path, python_model=TestModel()) 1113 monkeypatch.setenv(MLFLOW_ALLOW_PICKLE_DESERIALIZATION.name, "false") 1114 1115 with pytest.raises(MlflowException, match="Deserializing model using pickle is disallowed"): 1116 mlflow.pyfunc.load_model(model_uri=model_path) 1117 1118 1119 def test_save_and_load_model_with_special_chars( 1120 sklearn_knn_model, main_scoped_model_class, iris_data, tmp_path 1121 ): 1122 sklearn_model_path = os.path.join(tmp_path, "sklearn_ model") 1123 mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path) 1124 1125 def test_predict(sk_model, model_input): 1126 return sk_model.predict(model_input) * 2 1127 1128 # Intentionally create a path that has non-url-compatible characters 1129 pyfunc_model_path = os.path.join(tmp_path, "pyfunc_ :% model") 1130 1131 mlflow.pyfunc.save_model( 1132 path=pyfunc_model_path, 1133 artifacts={"sk_model": sklearn_model_path}, 1134 conda_env=_conda_env(), 1135 python_model=main_scoped_model_class(test_predict), 1136 ) 1137 1138 loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path) 1139 np.testing.assert_array_equal( 1140 loaded_pyfunc_model.predict(iris_data[0]), 1141 test_predict(sk_model=sklearn_knn_model, model_input=iris_data[0]), 1142 ) 1143 1144 1145 def test_model_with_code_path_containing_main(tmp_path): 1146 directory = tmp_path.joinpath("model_with_main") 1147 directory.mkdir() 1148 main = directory.joinpath("__main__.py") 1149 main.write_text("# empty main") 1150 with mlflow.start_run(): 1151 model_info = mlflow.pyfunc.log_model( 1152 name="model", 1153 python_model=mlflow.pyfunc.model.PythonModel(), 1154 code_paths=[str(directory)], 1155 ) 1156 1157 assert "__main__" in sys.modules 1158 mlflow.pyfunc.load_model(model_info.model_uri) 1159 assert "__main__" in sys.modules 1160 1161 1162 def test_model_save_load_with_metadata(tmp_path): 1163 pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model") 1164 1165 mlflow.pyfunc.save_model( 1166 path=pyfunc_model_path, 1167 conda_env=_conda_env(), 1168 python_model=mlflow.pyfunc.model.PythonModel(), 1169 metadata={"metadata_key": "metadata_value"}, 1170 ) 1171 1172 reloaded_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path) 1173 assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value" 1174 1175 1176 def test_model_log_with_metadata(): 1177 pyfunc_artifact_path = "pyfunc_model" 1178 with mlflow.start_run(): 1179 mlflow.pyfunc.log_model( 1180 name=pyfunc_artifact_path, 1181 python_model=mlflow.pyfunc.model.PythonModel(), 1182 metadata={"metadata_key": "metadata_value"}, 1183 ) 1184 pyfunc_model_uri = f"runs:/{mlflow.active_run().info.run_id}/{pyfunc_artifact_path}" 1185 1186 reloaded_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_uri) 1187 assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value" 1188 1189 1190 class SklearnModel(mlflow.pyfunc.PythonModel): 1191 def __init__(self) -> None: 1192 from sklearn.linear_model import LinearRegression 1193 1194 self.model = LinearRegression() 1195 1196 def predict(self, context, model_input, params=None): 1197 return self.model.predict(model_input) 1198 1199 1200 def test_dependency_inference_does_not_exclude_mlflow_dependencies(tmp_path): 1201 mlflow.pyfunc.save_model( 1202 path=tmp_path, 1203 python_model=SklearnModel(), 1204 ) 1205 requiments = tmp_path.joinpath("requirements.txt").read_text() 1206 assert f"scikit-learn=={sklearn.__version__}" in requiments 1207 1208 1209 def test_functional_python_model_no_type_hints(tmp_path): 1210 def python_model(x): 1211 return x 1212 1213 mlflow.pyfunc.save_model(path=tmp_path, python_model=python_model, input_example=[{"a": "b"}]) 1214 model = Model.load(tmp_path) 1215 assert model.signature.inputs == Schema([ColSpec("string", name="a")]) 1216 assert model.signature.outputs == Schema([ColSpec("string", name="a")]) 1217 1218 1219 def list_to_list(x: List[str]) -> List[str]: # noqa: UP006 1220 return x 1221 1222 1223 def list_dict_to_list(x: List[Dict[str, str]]) -> List[str]: # noqa: UP006 1224 return ["".join((*d.keys(), *d.values())) for d in x] # join keys and values 1225 1226 1227 def test_functional_python_model_list_dict_to_list_without_example(tmp_path): 1228 mlflow.pyfunc.save_model( 1229 path=tmp_path, python_model=list_dict_to_list, pip_requirements=["pandas"] 1230 ) 1231 model = Model.load(tmp_path) 1232 assert model.signature.inputs == Schema([ColSpec(Map("string"))]) 1233 assert model.signature.outputs == Schema([ColSpec("string")]) 1234 loaded_model = mlflow.pyfunc.load_model(tmp_path) 1235 assert loaded_model.predict([{"a": "x"}, {"a": "y"}]) == ["ax", "ay"] 1236 1237 1238 @pytest.mark.parametrize( 1239 ("input_example"), 1240 [ 1241 [0], 1242 [{"a": "b"}], 1243 ], 1244 ) 1245 def test_functional_python_model_list_invalid_example(tmp_path, input_example): 1246 with mock.patch("mlflow.models.signature._logger.warning") as mock_warning: 1247 mlflow.pyfunc.save_model( 1248 path=tmp_path, python_model=list_to_list, input_example=input_example 1249 ) 1250 assert any( 1251 "Input example is not compatible with the type hint" in call[0][0] 1252 for call in mock_warning.call_args_list 1253 ) 1254 1255 1256 @pytest.mark.parametrize( 1257 "input_example", 1258 [ 1259 ["a"], 1260 [{0: "a"}], 1261 [{"a": 0}], 1262 ], 1263 ) 1264 def test_functional_python_model_list_dict_invalid_example(tmp_path, input_example): 1265 with mock.patch("mlflow.models.signature._logger.warning") as mock_warning: 1266 mlflow.pyfunc.save_model( 1267 path=tmp_path, python_model=list_dict_to_list, input_example=input_example 1268 ) 1269 assert any( 1270 "Input example is not compatible with the type hint" in call[0][0] 1271 for call in mock_warning.call_args_list 1272 ) 1273 1274 1275 def test_functional_python_model_list_dict_to_list(tmp_path): 1276 mlflow.pyfunc.save_model( 1277 path=tmp_path, 1278 python_model=list_dict_to_list, 1279 input_example=[{"a": "x", "b": "y"}], 1280 ) 1281 model = Model.load(tmp_path) 1282 assert model.signature.inputs == Schema([ColSpec(Map("string"))]) 1283 assert model.signature.outputs == Schema([ColSpec("string")]) 1284 loaded_model = mlflow.pyfunc.load_model(tmp_path) 1285 assert loaded_model.predict([{"a": "x", "b": "y"}]) == ["abxy"] 1286 1287 1288 def list_dict_to_list_dict(x: list[dict[str, str]]) -> list[dict[str, str]]: 1289 return [{v: k for k, v in d.items()} for d in x] # swap keys and values 1290 1291 1292 def test_functional_python_model_list_dict_to_list_dict(): 1293 with mlflow.start_run(): 1294 model_info = mlflow.pyfunc.log_model( 1295 name="test_model", 1296 python_model=list_dict_to_list_dict, 1297 input_example=[{"a": "x", "b": "y"}], 1298 ) 1299 1300 assert model_info.signature.inputs.to_dict() == [ 1301 {"type": "map", "values": {"type": "string"}, "required": True} 1302 ] 1303 assert model_info.signature.outputs.to_dict() == [ 1304 {"type": "map", "values": {"type": "string"}, "required": True} 1305 ] 1306 1307 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 1308 assert pyfunc_model.predict([{"a": "x", "b": "y"}]) == [{"x": "a", "y": "b"}] 1309 1310 1311 def test_list_dict_with_signature_override(): 1312 class CustomModel(mlflow.pyfunc.PythonModel): 1313 def predict(self, context, model_input: list[dict[str, str]], params=None): 1314 return model_input 1315 1316 signature = infer_signature([{"a": "x", "b": "y"}, {"a": "z"}]) 1317 with mlflow.start_run(): 1318 model_info = mlflow.pyfunc.log_model( 1319 name="test_model", 1320 python_model=CustomModel(), 1321 signature=signature, 1322 ) 1323 assert model_info.signature.inputs == _infer_schema_from_list_type_hint(list[dict[str, str]]) 1324 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 1325 assert pyfunc_model.predict([{"a": "z"}]) == [{"a": "z"}] 1326 1327 1328 def list_dict_to_list_dict_pep585(x: list[dict[str, str]]) -> list[dict[str, str]]: 1329 return [{v: k for k, v in d.items()} for d in x] # swap keys and values 1330 1331 1332 def test_functional_python_model_list_dict_to_list_dict_with_example_pep585(tmp_path): 1333 mlflow.pyfunc.save_model( 1334 path=tmp_path, 1335 python_model=list_dict_to_list_dict_pep585, 1336 input_example=[{"a": "x", "b": "y"}], 1337 ) 1338 model = Model.load(tmp_path) 1339 assert model.signature.inputs.to_dict() == [ 1340 {"type": "map", "values": {"type": "string"}, "required": True}, 1341 ] 1342 assert model.signature.outputs.to_dict() == [ 1343 {"type": "map", "values": {"type": "string"}, "required": True}, 1344 ] 1345 loaded_model = mlflow.pyfunc.load_model(tmp_path) 1346 assert loaded_model.predict([{"a": "x", "b": "y"}]) == [{"x": "a", "y": "b"}] 1347 1348 1349 def multiple_arguments(x: list[str], y: list[str]) -> list[str]: 1350 return x + y 1351 1352 1353 def test_functional_python_model_multiple_arguments(tmp_path): 1354 with pytest.raises( 1355 MlflowException, match=r"must accept exactly one argument\. Found 2 arguments\." 1356 ): 1357 mlflow.pyfunc.save_model(path=tmp_path, python_model=multiple_arguments) 1358 1359 1360 def no_arguments() -> list[str]: 1361 return [] 1362 1363 1364 def test_functional_python_model_no_arguments(tmp_path): 1365 with pytest.raises( 1366 MlflowException, match=r"must accept exactly one argument\. Found 0 arguments\." 1367 ): 1368 mlflow.pyfunc.save_model(path=tmp_path, python_model=no_arguments) 1369 1370 1371 def requires_sklearn(x: list[str]) -> list[str]: 1372 import sklearn # noqa: F401 1373 1374 return x 1375 1376 1377 def test_functional_python_model_infer_requirements(tmp_path): 1378 mlflow.pyfunc.save_model(path=tmp_path, python_model=requires_sklearn, input_example=["a"]) 1379 assert "scikit-learn==" in tmp_path.joinpath("requirements.txt").read_text() 1380 1381 1382 def test_functional_python_model_throws_when_required_arguments_are_missing(tmp_path): 1383 mlflow.pyfunc.save_model( 1384 path=tmp_path / uuid.uuid4().hex, 1385 python_model=requires_sklearn, 1386 input_example=["a"], 1387 ) 1388 mlflow.pyfunc.save_model( 1389 path=tmp_path / uuid.uuid4().hex, 1390 python_model=requires_sklearn, 1391 pip_requirements=["scikit-learn"], 1392 ) 1393 mlflow.pyfunc.save_model( 1394 path=tmp_path / uuid.uuid4().hex, 1395 python_model=requires_sklearn, 1396 extra_pip_requirements=["scikit-learn"], 1397 ) 1398 with pytest.raises(MlflowException, match="at least one of"): 1399 mlflow.pyfunc.save_model(path=tmp_path / uuid.uuid4().hex, python_model=requires_sklearn) 1400 1401 1402 class AnnotatedPythonModel(mlflow.pyfunc.PythonModel): 1403 def predict(self, context: dict[str, Any], model_input: list[str], params=None) -> list[str]: 1404 assert isinstance(model_input, list) 1405 assert all(isinstance(x, str) for x in model_input) 1406 return model_input 1407 1408 1409 def test_class_python_model_type_hints(tmp_path): 1410 mlflow.pyfunc.save_model(path=tmp_path, python_model=AnnotatedPythonModel()) 1411 model = Model.load(tmp_path) 1412 assert model.signature.inputs.to_dict() == [{"type": "string", "required": True}] 1413 assert model.signature.outputs.to_dict() == [{"type": "string", "required": True}] 1414 model = mlflow.pyfunc.load_model(tmp_path) 1415 assert model.predict(["a", "b"]) == ["a", "b"] 1416 1417 1418 def test_python_model_predict_with_params(): 1419 with mlflow.start_run(): 1420 model_info = mlflow.pyfunc.log_model( 1421 name="test_model", 1422 python_model=AnnotatedPythonModel(), 1423 ) 1424 1425 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1426 assert loaded_model.predict(["a", "b"], params={"foo": [0, 1]}) == ["a", "b"] 1427 assert loaded_model.predict(["a", "b"], params={"foo": np.array([0, 1])}) == [ 1428 "a", 1429 "b", 1430 ] 1431 1432 1433 def test_python_model_with_type_hint_errors_with_different_signature(): 1434 signature = infer_signature(["input1", "input2"], params={"foo": [8]}) 1435 1436 with mlflow.start_run(): 1437 with mock.patch("mlflow.pyfunc._logger.warning") as warn_mock: 1438 mlflow.pyfunc.log_model( 1439 name="test_model", 1440 python_model=AnnotatedPythonModel(), 1441 signature=signature, 1442 ) 1443 assert ( 1444 "Provided signature does not match the signature inferred from" 1445 in warn_mock.call_args[0][0] 1446 ) 1447 1448 1449 def test_artifact_path_posix(sklearn_knn_model, main_scoped_model_class, tmp_path): 1450 sklearn_model_path = tmp_path.joinpath("sklearn_model") 1451 mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path) 1452 1453 def test_predict(sk_model, model_input): 1454 return sk_model.predict(model_input) * 2 1455 1456 pyfunc_model_path = tmp_path.joinpath("pyfunc_model") 1457 1458 mlflow.pyfunc.save_model( 1459 path=pyfunc_model_path, 1460 artifacts={"sk_model": str(sklearn_model_path)}, 1461 conda_env=_conda_env(), 1462 python_model=main_scoped_model_class(test_predict), 1463 ) 1464 1465 artifacts = _load_pyfunc(pyfunc_model_path).context.artifacts 1466 assert all("\\" not in artifact_uri for artifact_uri in artifacts.values()) 1467 1468 1469 def test_load_model_fails_for_feature_store_models(tmp_path): 1470 feature_store = os.path.join(tmp_path, "feature_store") 1471 os.mkdir(feature_store) 1472 feature_spec = os.path.join(feature_store, "feature_spec.yaml") 1473 with open(feature_spec, "w+") as f: 1474 f.write("contents") 1475 1476 with mlflow.start_run() as run: 1477 mlflow.pyfunc.log_model( 1478 name="model", 1479 data_path=feature_store, 1480 loader_module=_DATABRICKS_FS_LOADER_MODULE, 1481 code_paths=[__file__], 1482 ) 1483 with pytest.raises( 1484 MlflowException, 1485 match="Note: mlflow.pyfunc.load_model is not supported for Feature Store models", 1486 ): 1487 mlflow.pyfunc.load_model(f"runs:/{run.info.run_id}/model") 1488 1489 1490 def test_pyfunc_model_infer_signature_from_type_hints(): 1491 class TestModel(mlflow.pyfunc.PythonModel): 1492 def predict(self, context, model_input: list[str], params=None) -> list[str]: 1493 return model_input 1494 1495 with mlflow.start_run(): 1496 model_info = mlflow.pyfunc.log_model( 1497 name="test_model", 1498 python_model=TestModel(), 1499 input_example=["a"], 1500 ) 1501 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 1502 assert pyfunc_model.metadata.get_input_schema() == Schema([ColSpec("string")]) 1503 assert pyfunc_model.predict(["a", "b"]) == ["a", "b"] 1504 1505 1506 def test_streamable_model_save_load(iris_data, tmp_path): 1507 class StreamableModel(mlflow.pyfunc.PythonModel): 1508 def __init__(self): 1509 pass 1510 1511 def predict(self, context, model_input, params=None): 1512 pass 1513 1514 def predict_stream(self, context, model_input, params=None): 1515 yield "test1" 1516 yield "test2" 1517 1518 pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model") 1519 1520 python_model = StreamableModel() 1521 1522 mlflow.pyfunc.save_model( 1523 path=pyfunc_model_path, 1524 python_model=python_model, 1525 ) 1526 1527 loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path) 1528 1529 stream_result = loaded_pyfunc_model.predict_stream("single-input") 1530 assert isinstance(stream_result, types.GeneratorType) 1531 1532 assert list(stream_result) == ["test1", "test2"] 1533 1534 1535 def test_streamable_model_save_load(tmp_path): 1536 pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model") 1537 1538 mlflow.pyfunc.save_model( 1539 path=pyfunc_model_path, 1540 python_model="tests/pyfunc/sample_code/streamable_model_code.py", 1541 ) 1542 1543 loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path) 1544 1545 stream_result = loaded_pyfunc_model.predict_stream("single-input") 1546 assert isinstance(stream_result, types.GeneratorType) 1547 1548 assert list(stream_result) == ["test1", "test2"] 1549 1550 1551 def test_model_save_load_with_resources(tmp_path): 1552 pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model") 1553 pyfunc_model_path_2 = os.path.join(tmp_path, "pyfunc_model_2") 1554 1555 expected_resources = { 1556 "api_version": "1", 1557 "databricks": { 1558 "serving_endpoint": [ 1559 {"name": "databricks-mixtral-8x7b-instruct"}, 1560 {"name": "databricks-bge-large-en"}, 1561 {"name": "azure-eastus-model-serving-2_vs_endpoint"}, 1562 ], 1563 "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}], 1564 "sql_warehouse": [{"name": "testid"}], 1565 "function": [ 1566 {"name": "rag.studio.test_function_a"}, 1567 {"name": "rag.studio.test_function_b"}, 1568 ], 1569 "genie_space": [{"name": "genie_space_id_1"}, {"name": "genie_space_id_2"}], 1570 "uc_connection": [{"name": "test_connection_1"}, {"name": "test_connection_2"}], 1571 "table": [{"name": "rag.studio.table_a"}, {"name": "rag.studio.table_b"}], 1572 "app": [{"name": "test_databricks_app"}], 1573 "lakebase": [{"name": "test_databricks_lakebase"}], 1574 }, 1575 } 1576 mlflow.pyfunc.save_model( 1577 path=pyfunc_model_path, 1578 conda_env=_conda_env(), 1579 python_model=mlflow.pyfunc.model.PythonModel(), 1580 resources=[ 1581 DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"), 1582 DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"), 1583 DatabricksServingEndpoint(endpoint_name="azure-eastus-model-serving-2_vs_endpoint"), 1584 DatabricksVectorSearchIndex(index_name="rag.studio_bugbash.databricks_docs_index"), 1585 DatabricksSQLWarehouse(warehouse_id="testid"), 1586 DatabricksFunction(function_name="rag.studio.test_function_a"), 1587 DatabricksFunction(function_name="rag.studio.test_function_b"), 1588 DatabricksGenieSpace(genie_space_id="genie_space_id_1"), 1589 DatabricksGenieSpace(genie_space_id="genie_space_id_2"), 1590 DatabricksUCConnection(connection_name="test_connection_1"), 1591 DatabricksUCConnection(connection_name="test_connection_2"), 1592 DatabricksTable(table_name="rag.studio.table_a"), 1593 DatabricksTable(table_name="rag.studio.table_b"), 1594 DatabricksApp(app_name="test_databricks_app"), 1595 DatabricksLakebase(database_instance_name="test_databricks_lakebase"), 1596 ], 1597 ) 1598 1599 reloaded_model = Model.load(pyfunc_model_path) 1600 assert reloaded_model.resources == expected_resources 1601 1602 yaml_file = tmp_path.joinpath("resources.yaml") 1603 with open(yaml_file, "w") as f: 1604 f.write( 1605 """ 1606 api_version: "1" 1607 databricks: 1608 vector_search_index: 1609 - name: rag.studio_bugbash.databricks_docs_index 1610 serving_endpoint: 1611 - name: databricks-mixtral-8x7b-instruct 1612 - name: databricks-bge-large-en 1613 - name: azure-eastus-model-serving-2_vs_endpoint 1614 sql_warehouse: 1615 - name: testid 1616 function: 1617 - name: rag.studio.test_function_a 1618 - name: rag.studio.test_function_b 1619 lakebase: 1620 - name: test_databricks_lakebase 1621 genie_space: 1622 - name: genie_space_id_1 1623 - name: genie_space_id_2 1624 uc_connection: 1625 - name: test_connection_1 1626 - name: test_connection_2 1627 table: 1628 - name: rag.studio.table_a 1629 - name: rag.studio.table_b 1630 app: 1631 - name: test_databricks_app 1632 """ 1633 ) 1634 1635 mlflow.pyfunc.save_model( 1636 path=pyfunc_model_path_2, 1637 conda_env=_conda_env(), 1638 python_model=mlflow.pyfunc.model.PythonModel(), 1639 resources=yaml_file, 1640 ) 1641 reloaded_model = Model.load(pyfunc_model_path_2) 1642 assert reloaded_model.resources == expected_resources 1643 1644 1645 def test_model_save_load_with_invokers_resources(tmp_path): 1646 pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model") 1647 pyfunc_model_path_2 = os.path.join(tmp_path, "pyfunc_model_2") 1648 1649 expected_resources = { 1650 "api_version": "1", 1651 "databricks": { 1652 "serving_endpoint": [ 1653 {"name": "databricks-mixtral-8x7b-instruct", "on_behalf_of_user": True}, 1654 {"name": "databricks-bge-large-en"}, 1655 {"name": "azure-eastus-model-serving-2_vs_endpoint"}, 1656 ], 1657 "vector_search_index": [ 1658 {"name": "rag.studio_bugbash.databricks_docs_index", "on_behalf_of_user": True} 1659 ], 1660 "sql_warehouse": [{"name": "testid"}], 1661 "function": [ 1662 {"name": "rag.studio.test_function_a", "on_behalf_of_user": True}, 1663 {"name": "rag.studio.test_function_b"}, 1664 ], 1665 "genie_space": [ 1666 {"name": "genie_space_id_1", "on_behalf_of_user": True}, 1667 {"name": "genie_space_id_2"}, 1668 ], 1669 "uc_connection": [{"name": "test_connection_1"}, {"name": "test_connection_2"}], 1670 "table": [ 1671 {"name": "rag.studio.table_a", "on_behalf_of_user": True}, 1672 {"name": "rag.studio.table_b"}, 1673 ], 1674 "app": [{"name": "test_databricks_app"}], 1675 "lakebase": [{"name": "test_databricks_lakebase"}], 1676 }, 1677 } 1678 mlflow.pyfunc.save_model( 1679 path=pyfunc_model_path, 1680 conda_env=_conda_env(), 1681 python_model=mlflow.pyfunc.model.PythonModel(), 1682 resources=[ 1683 DatabricksServingEndpoint( 1684 endpoint_name="databricks-mixtral-8x7b-instruct", on_behalf_of_user=True 1685 ), 1686 DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"), 1687 DatabricksServingEndpoint(endpoint_name="azure-eastus-model-serving-2_vs_endpoint"), 1688 DatabricksVectorSearchIndex( 1689 index_name="rag.studio_bugbash.databricks_docs_index", on_behalf_of_user=True 1690 ), 1691 DatabricksSQLWarehouse(warehouse_id="testid"), 1692 DatabricksFunction(function_name="rag.studio.test_function_a", on_behalf_of_user=True), 1693 DatabricksFunction(function_name="rag.studio.test_function_b"), 1694 DatabricksGenieSpace(genie_space_id="genie_space_id_1", on_behalf_of_user=True), 1695 DatabricksGenieSpace(genie_space_id="genie_space_id_2"), 1696 DatabricksUCConnection(connection_name="test_connection_1"), 1697 DatabricksUCConnection(connection_name="test_connection_2"), 1698 DatabricksTable(table_name="rag.studio.table_a", on_behalf_of_user=True), 1699 DatabricksTable(table_name="rag.studio.table_b"), 1700 DatabricksApp(app_name="test_databricks_app"), 1701 DatabricksLakebase(database_instance_name="test_databricks_lakebase"), 1702 ], 1703 ) 1704 1705 reloaded_model = Model.load(pyfunc_model_path) 1706 assert reloaded_model.resources == expected_resources 1707 1708 yaml_file = tmp_path.joinpath("resources.yaml") 1709 with open(yaml_file, "w") as f: 1710 f.write( 1711 """ 1712 api_version: "1" 1713 databricks: 1714 vector_search_index: 1715 - name: rag.studio_bugbash.databricks_docs_index 1716 on_behalf_of_user: True 1717 serving_endpoint: 1718 - name: databricks-mixtral-8x7b-instruct 1719 on_behalf_of_user: True 1720 - name: databricks-bge-large-en 1721 - name: azure-eastus-model-serving-2_vs_endpoint 1722 sql_warehouse: 1723 - name: testid 1724 function: 1725 - name: rag.studio.test_function_a 1726 on_behalf_of_user: True 1727 - name: rag.studio.test_function_b 1728 lakebase: 1729 - name: test_databricks_lakebase 1730 genie_space: 1731 - name: genie_space_id_1 1732 on_behalf_of_user: True 1733 - name: genie_space_id_2 1734 uc_connection: 1735 - name: test_connection_1 1736 - name: test_connection_2 1737 table: 1738 - name: rag.studio.table_a 1739 on_behalf_of_user: True 1740 - name: rag.studio.table_b 1741 app: 1742 - name: test_databricks_app 1743 """ 1744 ) 1745 1746 mlflow.pyfunc.save_model( 1747 path=pyfunc_model_path_2, 1748 conda_env=_conda_env(), 1749 python_model=mlflow.pyfunc.model.PythonModel(), 1750 resources=yaml_file, 1751 ) 1752 1753 reloaded_model = Model.load(pyfunc_model_path_2) 1754 assert reloaded_model.resources == expected_resources 1755 1756 1757 def test_model_log_with_invokers_resources(tmp_path): 1758 pyfunc_artifact_path = "pyfunc_model" 1759 1760 expected_resources = { 1761 "api_version": "1", 1762 "databricks": { 1763 "serving_endpoint": [ 1764 {"name": "databricks-mixtral-8x7b-instruct"}, 1765 {"name": "databricks-bge-large-en", "on_behalf_of_user": True}, 1766 {"name": "azure-eastus-model-serving-2_vs_endpoint"}, 1767 ], 1768 "vector_search_index": [ 1769 {"name": "rag.studio_bugbash.databricks_docs_index", "on_behalf_of_user": True} 1770 ], 1771 "sql_warehouse": [{"name": "testid", "on_behalf_of_user": True}], 1772 "function": [ 1773 {"name": "rag.studio.test_function_a"}, 1774 {"name": "rag.studio.test_function_b", "on_behalf_of_user": True}, 1775 ], 1776 "genie_space": [ 1777 {"name": "genie_space_id_1"}, 1778 {"name": "genie_space_id_2", "on_behalf_of_user": True}, 1779 ], 1780 "uc_connection": [ 1781 {"name": "test_connection_1"}, 1782 {"name": "test_connection_2", "on_behalf_of_user": True}, 1783 ], 1784 "table": [ 1785 {"name": "rag.studio.table_a"}, 1786 {"name": "rag.studio.table_b", "on_behalf_of_user": True}, 1787 ], 1788 }, 1789 } 1790 with mlflow.start_run() as run: 1791 mlflow.pyfunc.log_model( 1792 name=pyfunc_artifact_path, 1793 python_model=mlflow.pyfunc.model.PythonModel(), 1794 resources=[ 1795 DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"), 1796 DatabricksServingEndpoint( 1797 endpoint_name="databricks-bge-large-en", on_behalf_of_user=True 1798 ), 1799 DatabricksServingEndpoint(endpoint_name="azure-eastus-model-serving-2_vs_endpoint"), 1800 DatabricksVectorSearchIndex( 1801 index_name="rag.studio_bugbash.databricks_docs_index", on_behalf_of_user=True 1802 ), 1803 DatabricksSQLWarehouse(warehouse_id="testid", on_behalf_of_user=True), 1804 DatabricksFunction(function_name="rag.studio.test_function_a"), 1805 DatabricksFunction( 1806 function_name="rag.studio.test_function_b", on_behalf_of_user=True 1807 ), 1808 DatabricksGenieSpace(genie_space_id="genie_space_id_1"), 1809 DatabricksGenieSpace(genie_space_id="genie_space_id_2", on_behalf_of_user=True), 1810 DatabricksUCConnection(connection_name="test_connection_1"), 1811 DatabricksUCConnection(connection_name="test_connection_2", on_behalf_of_user=True), 1812 DatabricksTable(table_name="rag.studio.table_a"), 1813 DatabricksTable(table_name="rag.studio.table_b", on_behalf_of_user=True), 1814 ], 1815 ) 1816 pyfunc_model_uri = f"runs:/{run.info.run_id}/{pyfunc_artifact_path}" 1817 pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_uri) 1818 reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) 1819 assert reloaded_model.resources == expected_resources 1820 1821 yaml_file = tmp_path.joinpath("resources.yaml") 1822 with open(yaml_file, "w") as f: 1823 f.write( 1824 """ 1825 api_version: "1" 1826 databricks: 1827 vector_search_index: 1828 - name: rag.studio_bugbash.databricks_docs_index 1829 on_behalf_of_user: True 1830 serving_endpoint: 1831 - name: databricks-mixtral-8x7b-instruct 1832 - name: databricks-bge-large-en 1833 on_behalf_of_user: True 1834 - name: azure-eastus-model-serving-2_vs_endpoint 1835 sql_warehouse: 1836 - name: testid 1837 on_behalf_of_user: True 1838 function: 1839 - name: rag.studio.test_function_a 1840 - name: rag.studio.test_function_b 1841 on_behalf_of_user: True 1842 genie_space: 1843 - name: genie_space_id_1 1844 - name: genie_space_id_2 1845 on_behalf_of_user: True 1846 uc_connection: 1847 - name: test_connection_1 1848 - name: test_connection_2 1849 on_behalf_of_user: True 1850 table: 1851 - name: "rag.studio.table_a" 1852 - name: "rag.studio.table_b" 1853 on_behalf_of_user: True 1854 """ 1855 ) 1856 1857 with mlflow.start_run() as run: 1858 mlflow.pyfunc.log_model( 1859 name=pyfunc_artifact_path, 1860 python_model=mlflow.pyfunc.model.PythonModel(), 1861 resources=yaml_file, 1862 ) 1863 pyfunc_model_uri = f"runs:/{run.info.run_id}/{pyfunc_artifact_path}" 1864 pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_uri) 1865 reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) 1866 assert reloaded_model.resources == expected_resources 1867 1868 1869 def test_model_log_with_resources(tmp_path): 1870 pyfunc_artifact_path = "pyfunc_model" 1871 1872 expected_resources = { 1873 "api_version": "1", 1874 "databricks": { 1875 "serving_endpoint": [ 1876 {"name": "databricks-mixtral-8x7b-instruct"}, 1877 {"name": "databricks-bge-large-en"}, 1878 {"name": "azure-eastus-model-serving-2_vs_endpoint"}, 1879 ], 1880 "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}], 1881 "sql_warehouse": [{"name": "testid"}], 1882 "function": [ 1883 {"name": "rag.studio.test_function_a"}, 1884 {"name": "rag.studio.test_function_b"}, 1885 ], 1886 "genie_space": [ 1887 {"name": "genie_space_id_1"}, 1888 {"name": "genie_space_id_2"}, 1889 ], 1890 "uc_connection": [{"name": "test_connection_1"}, {"name": "test_connection_2"}], 1891 "table": [{"name": "rag.studio.table_a"}, {"name": "rag.studio.table_b"}], 1892 "app": [{"name": "test_databricks_app"}], 1893 "lakebase": [{"name": "test_databricks_lakebase"}], 1894 }, 1895 } 1896 with mlflow.start_run() as run: 1897 mlflow.pyfunc.log_model( 1898 name=pyfunc_artifact_path, 1899 python_model=mlflow.pyfunc.model.PythonModel(), 1900 resources=[ 1901 DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"), 1902 DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"), 1903 DatabricksServingEndpoint(endpoint_name="azure-eastus-model-serving-2_vs_endpoint"), 1904 DatabricksVectorSearchIndex(index_name="rag.studio_bugbash.databricks_docs_index"), 1905 DatabricksSQLWarehouse(warehouse_id="testid"), 1906 DatabricksFunction(function_name="rag.studio.test_function_a"), 1907 DatabricksFunction(function_name="rag.studio.test_function_b"), 1908 DatabricksGenieSpace(genie_space_id="genie_space_id_1"), 1909 DatabricksGenieSpace(genie_space_id="genie_space_id_2"), 1910 DatabricksUCConnection(connection_name="test_connection_1"), 1911 DatabricksUCConnection(connection_name="test_connection_2"), 1912 DatabricksTable(table_name="rag.studio.table_a"), 1913 DatabricksTable(table_name="rag.studio.table_b"), 1914 DatabricksApp(app_name="test_databricks_app"), 1915 DatabricksLakebase(database_instance_name="test_databricks_lakebase"), 1916 ], 1917 ) 1918 pyfunc_model_uri = f"runs:/{run.info.run_id}/{pyfunc_artifact_path}" 1919 pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_uri) 1920 reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) 1921 assert reloaded_model.resources == expected_resources 1922 1923 yaml_file = tmp_path.joinpath("resources.yaml") 1924 with open(yaml_file, "w") as f: 1925 f.write( 1926 """ 1927 api_version: "1" 1928 databricks: 1929 vector_search_index: 1930 - name: rag.studio_bugbash.databricks_docs_index 1931 serving_endpoint: 1932 - name: databricks-mixtral-8x7b-instruct 1933 - name: databricks-bge-large-en 1934 - name: azure-eastus-model-serving-2_vs_endpoint 1935 sql_warehouse: 1936 - name: testid 1937 function: 1938 - name: rag.studio.test_function_a 1939 - name: rag.studio.test_function_b 1940 lakebase: 1941 - name: test_databricks_lakebase 1942 genie_space: 1943 - name: genie_space_id_1 1944 - name: genie_space_id_2 1945 uc_connection: 1946 - name: test_connection_1 1947 - name: test_connection_2 1948 table: 1949 - name: "rag.studio.table_a" 1950 - name: "rag.studio.table_b" 1951 app: 1952 - name: test_databricks_app 1953 """ 1954 ) 1955 1956 with mlflow.start_run() as run: 1957 mlflow.pyfunc.log_model( 1958 name=pyfunc_artifact_path, 1959 python_model=mlflow.pyfunc.model.PythonModel(), 1960 resources=yaml_file, 1961 ) 1962 pyfunc_model_uri = f"runs:/{run.info.run_id}/{pyfunc_artifact_path}" 1963 pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_uri) 1964 reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) 1965 assert reloaded_model.resources == expected_resources 1966 1967 1968 def test_pyfunc_as_code_log_and_load(): 1969 with mlflow.start_run(): 1970 model_info = mlflow.pyfunc.log_model( 1971 name="model", 1972 python_model="tests/pyfunc/sample_code/python_model.py", 1973 ) 1974 1975 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1976 model_input = "asdf" 1977 expected_output = f"This was the input: {model_input}" 1978 assert loaded_model.predict(model_input) == expected_output 1979 1980 1981 def test_pyfunc_as_code_log_and_load_with_path(): 1982 with mlflow.start_run(): 1983 model_info = mlflow.pyfunc.log_model( 1984 name="model", 1985 python_model=Path("tests/pyfunc/sample_code/python_model.py"), 1986 ) 1987 1988 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1989 model_input = "asdf" 1990 expected_output = f"This was the input: {model_input}" 1991 assert loaded_model.predict(model_input) == expected_output 1992 1993 1994 def test_pyfunc_as_code_with_config(tmp_path): 1995 temp_file = tmp_path / "config.yml" 1996 temp_file.write_text("timeout: 400") 1997 1998 with mlflow.start_run(): 1999 model_info = mlflow.pyfunc.log_model( 2000 name="model", 2001 python_model="tests/pyfunc/sample_code/python_model_with_config.py", 2002 model_config=str(temp_file), 2003 ) 2004 2005 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 2006 model_input = "input" 2007 expected_output = f"Predict called with input {model_input}, timeout 400" 2008 assert loaded_model.predict(model_input) == expected_output 2009 2010 2011 def test_pyfunc_as_code_with_path_config(tmp_path): 2012 temp_file = tmp_path / "config.yml" 2013 temp_file.write_text("timeout: 400") 2014 2015 with mlflow.start_run(): 2016 model_info = mlflow.pyfunc.log_model( 2017 name="model", 2018 python_model="tests/pyfunc/sample_code/python_model_with_config.py", 2019 model_config=temp_file, 2020 ) 2021 2022 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 2023 model_input = "input" 2024 expected_output = f"Predict called with input {model_input}, timeout 400" 2025 assert loaded_model.predict(model_input) == expected_output 2026 2027 2028 def test_pyfunc_as_code_with_dict_config(): 2029 with mlflow.start_run(): 2030 model_info = mlflow.pyfunc.log_model( 2031 name="model", 2032 python_model="tests/pyfunc/sample_code/python_model_with_config.py", 2033 model_config={"timeout": 400}, 2034 ) 2035 2036 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 2037 model_input = "input" 2038 expected_output = f"Predict called with input {model_input}, timeout 400" 2039 assert loaded_model.predict(model_input) == expected_output 2040 2041 2042 def test_pyfunc_as_code_log_and_load_with_code_paths(): 2043 with mlflow.start_run(): 2044 model_info = mlflow.pyfunc.log_model( 2045 name="model", 2046 python_model="tests/pyfunc/sample_code/python_model_with_utils.py", 2047 code_paths=["tests/pyfunc/sample_code/utils.py"], 2048 ) 2049 2050 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 2051 model_input = "asdf" 2052 expected_output = f"My utils function received this input: {model_input}" 2053 assert loaded_model.predict(model_input) == expected_output 2054 2055 2056 def test_pyfunc_as_code_with_dependencies(): 2057 with mlflow.start_run(): 2058 model_info = mlflow.pyfunc.log_model( 2059 name="model", 2060 python_model="tests/pyfunc/sample_code/code_with_dependencies.py", 2061 pip_requirements=["pandas"], 2062 ) 2063 2064 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 2065 model_input = "user_123" 2066 expected_output = f"Input: {model_input}. Retriever called with ID: {model_input}. Output: 42." 2067 assert loaded_model.predict(model_input) == expected_output 2068 2069 pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri) 2070 reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) 2071 assert reloaded_model.metadata["dependencies_schemas"] == { 2072 "retrievers": [ 2073 { 2074 "doc_uri": "doc-uri", 2075 "name": "retriever", 2076 "other_columns": ["column1", "column2"], 2077 "primary_key": "primary-key", 2078 "text_column": "text-column", 2079 } 2080 ] 2081 } 2082 2083 2084 @pytest.mark.parametrize("is_in_db_model_serving", ["true", "false"]) 2085 @pytest.mark.parametrize("stream", [True, False]) 2086 def test_pyfunc_as_code_with_dependencies_store_dependencies_schemas_in_trace( 2087 monkeypatch, is_in_db_model_serving, stream 2088 ): 2089 monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", is_in_db_model_serving) 2090 monkeypatch.setenv("ENABLE_MLFLOW_TRACING", "true") 2091 is_in_db_model_serving = is_in_db_model_serving == "true" 2092 with mlflow.start_run(): 2093 model_info = mlflow.pyfunc.log_model( 2094 name="model", 2095 python_model="tests/pyfunc/sample_code/code_with_dependencies.py", 2096 pip_requirements=["pandas"], 2097 ) 2098 2099 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 2100 model_input = "user_123" 2101 expected_output = f"Input: {model_input}. Retriever called with ID: {model_input}. Output: 42." 2102 func = loaded_model.predict_stream if stream else loaded_model.predict 2103 2104 def _get_result(output): 2105 return list(output)[0] if stream else output 2106 2107 if is_in_db_model_serving: 2108 with set_prediction_context(Context(request_id="1234")): 2109 assert _get_result(func(model_input)) == expected_output 2110 else: 2111 assert _get_result(func(model_input)) == expected_output 2112 2113 pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri) 2114 reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) 2115 expected_dependencies_schemas = { 2116 DependenciesSchemasType.RETRIEVERS.value: [ 2117 { 2118 "doc_uri": "doc-uri", 2119 "name": "retriever", 2120 "other_columns": ["column1", "column2"], 2121 "primary_key": "primary-key", 2122 "text_column": "text-column", 2123 } 2124 ] 2125 } 2126 assert reloaded_model.metadata["dependencies_schemas"] == expected_dependencies_schemas 2127 2128 if is_in_db_model_serving: 2129 trace_dict = pop_trace("1234") 2130 trace = Trace.from_dict(trace_dict) 2131 assert trace.info.trace_id.startswith("tr-") 2132 assert trace.info.client_request_id == "1234" 2133 else: 2134 trace = get_traces()[0] 2135 assert trace.info.tags[DependenciesSchemasType.RETRIEVERS.value] == json.dumps( 2136 expected_dependencies_schemas[DependenciesSchemasType.RETRIEVERS.value] 2137 ) 2138 2139 2140 @pytest.mark.parametrize("stream", [True, False]) 2141 def test_no_traces_collected_for_pyfunc_as_code_with_dependencies_if_no_tracing_enabled( 2142 monkeypatch, stream 2143 ): 2144 # This sets model without trace inside code_with_dependencies.py file 2145 monkeypatch.setenv("TEST_TRACE", "false") 2146 with mlflow.start_run(): 2147 model_info = mlflow.pyfunc.log_model( 2148 name="model", 2149 python_model="tests/pyfunc/sample_code/code_with_dependencies.py", 2150 pip_requirements=["pandas"], 2151 ) 2152 2153 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 2154 model_input = "user_123" 2155 expected_output = f"Input: {model_input}. Retriever called with ID: {model_input}. Output: 42." 2156 if stream: 2157 assert next(loaded_model.predict_stream(model_input)) == expected_output 2158 else: 2159 assert loaded_model.predict(model_input) == expected_output 2160 2161 pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri) 2162 reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) 2163 expected_dependencies_schemas = { 2164 DependenciesSchemasType.RETRIEVERS.value: [ 2165 { 2166 "doc_uri": "doc-uri", 2167 "name": "retriever", 2168 "other_columns": ["column1", "column2"], 2169 "primary_key": "primary-key", 2170 "text_column": "text-column", 2171 } 2172 ] 2173 } 2174 assert reloaded_model.metadata["dependencies_schemas"] == expected_dependencies_schemas 2175 2176 # no traces will be logged at all 2177 traces = get_traces() 2178 assert len(traces) == 0 2179 2180 2181 def test_pyfunc_as_code_log_and_load_wrong_path(): 2182 with pytest.raises( 2183 MlflowException, 2184 match="The provided model path", 2185 ): 2186 with mlflow.start_run(): 2187 mlflow.pyfunc.log_model( 2188 name="model", 2189 python_model="asdf", 2190 ) 2191 2192 2193 def test_predict_as_code(): 2194 with mlflow.start_run(): 2195 model_info = mlflow.pyfunc.log_model( 2196 name="model", 2197 python_model="tests/pyfunc/sample_code/func_code.py", 2198 input_example=["string"], 2199 ) 2200 2201 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 2202 model_input = "asdf" 2203 expected_output = pd.DataFrame([model_input]) 2204 pandas.testing.assert_frame_equal(loaded_model.predict([model_input]), expected_output) 2205 2206 2207 def test_predict_as_code_with_type_hint(): 2208 with mlflow.start_run(): 2209 model_info = mlflow.pyfunc.log_model( 2210 name="model", 2211 python_model="tests/pyfunc/sample_code/func_code_with_type_hint.py", 2212 input_example=["string"], 2213 ) 2214 2215 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 2216 model_input = "asdf" 2217 expected_output = [model_input] 2218 assert loaded_model.predict([model_input]) == expected_output 2219 2220 2221 def test_predict_as_code_with_config(): 2222 with mlflow.start_run(): 2223 model_info = mlflow.pyfunc.log_model( 2224 name="model", 2225 python_model="tests/pyfunc/sample_code/func_code_with_config.py", 2226 input_example=["string"], 2227 model_config="tests/pyfunc/sample_code/config.yml", 2228 ) 2229 2230 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 2231 model_input = "asdf" 2232 expected_output = f"This was the input: {model_input}, timeout 300" 2233 assert loaded_model.predict([model_input]) == expected_output 2234 2235 2236 def test_model_as_code_pycache_cleaned_up(): 2237 with mlflow.start_run(): 2238 model_info = mlflow.pyfunc.log_model( 2239 name="model", 2240 python_model="tests/pyfunc/sample_code/python_model.py", 2241 ) 2242 2243 path = _download_artifact_from_uri(model_info.model_uri) 2244 assert list(Path(path).rglob("__pycache__")) == [] 2245 2246 2247 def test_model_pip_requirements_pin_numpy_when_pandas_included(): 2248 class TestModel(mlflow.pyfunc.PythonModel): 2249 def predict(self, context, model_input, params=None): 2250 import pandas as pd # noqa: F401 2251 2252 return model_input 2253 2254 expected_mlflow_version = _mlflow_major_version_string() 2255 2256 # no numpy when pandas > 2.1.2 2257 with mlflow.start_run(): 2258 model_info = mlflow.pyfunc.log_model( 2259 name="model", python_model=TestModel(), input_example="abc" 2260 ) 2261 2262 _assert_pip_requirements( 2263 model_info.model_uri, 2264 [ 2265 expected_mlflow_version, 2266 f"cloudpickle=={importlib.metadata.version('cloudpickle')}", 2267 f"pandas=={importlib.metadata.version('pandas')}", 2268 ], 2269 strict=True, 2270 ) 2271 2272 original_get_installed_version = _get_installed_version 2273 2274 def mock_get_installed_version(package, module=None): 2275 if package == "pandas": 2276 return "2.1.0" 2277 return original_get_installed_version(package, module) 2278 2279 # include numpy when pandas < 2.1.2 2280 with ( 2281 mlflow.start_run(), 2282 mock.patch( 2283 "mlflow.utils.requirements_utils._get_installed_version", 2284 side_effect=mock_get_installed_version, 2285 ), 2286 ): 2287 model_info = mlflow.pyfunc.log_model( 2288 name="model", python_model=TestModel(), input_example="abc" 2289 ) 2290 _assert_pip_requirements( 2291 model_info.model_uri, 2292 [ 2293 expected_mlflow_version, 2294 "pandas==2.1.0", 2295 f"numpy=={np.__version__}", 2296 f"cloudpickle=={cloudpickle.__version__}", 2297 ], 2298 strict=True, 2299 ) 2300 2301 # no input_example, so pandas not included in requirements 2302 with mlflow.start_run(): 2303 model_info = mlflow.pyfunc.log_model(name="model", python_model=TestModel()) 2304 _assert_pip_requirements( 2305 model_info.model_uri, 2306 [expected_mlflow_version, f"cloudpickle=={cloudpickle.__version__}"], 2307 strict=True, 2308 ) 2309 2310 2311 def test_environment_variables_used_during_model_logging(monkeypatch): 2312 class MyModel(mlflow.pyfunc.PythonModel): 2313 def predict(self, context, model_input, params=None): 2314 monkeypatch.setenv("TEST_API_KEY", "test_env") 2315 monkeypatch.setenv("ANOTHER_API_KEY", "test_env") 2316 monkeypatch.setenv("INVALID_ENV_VAR", "var") 2317 # existing env var is tracked 2318 os.environ["TEST_API_KEY"] 2319 # existing env var fetched by getenv is tracked 2320 os.environ.get("ANOTHER_API_KEY") 2321 # existing env var not in allowlist is not tracked 2322 os.environ.get("INVALID_ENV_VAR") 2323 # non-existing env var is not tracked 2324 os.environ.get("INVALID_API_KEY") 2325 return model_input 2326 2327 with mlflow.start_run(): 2328 model_info = mlflow.pyfunc.log_model( 2329 name="model", python_model=MyModel(), input_example="data" 2330 ) 2331 assert "TEST_API_KEY" in model_info.env_vars 2332 assert "ANOTHER_API_KEY" in model_info.env_vars 2333 assert "INVALID_ENV_VAR" not in model_info.env_vars 2334 assert "INVALID_API_KEY" not in model_info.env_vars 2335 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 2336 assert pyfunc_model.metadata.env_vars == model_info.env_vars 2337 2338 # if no input_example provided, we do not run predict, and no env vars are captured 2339 with mlflow.start_run(): 2340 model_info = mlflow.pyfunc.log_model(name="model", python_model=MyModel()) 2341 assert model_info.env_vars is None 2342 2343 # disable logging by setting environment variable 2344 monkeypatch.setenv(MLFLOW_RECORD_ENV_VARS_IN_MODEL_LOGGING.name, "false") 2345 with mlflow.start_run(): 2346 model_info = mlflow.pyfunc.log_model( 2347 name="model", python_model=MyModel(), input_example="data" 2348 ) 2349 assert model_info.env_vars is None 2350 2351 2352 def test_pyfunc_model_without_context_in_predict(): 2353 class Model(mlflow.pyfunc.PythonModel): 2354 def predict(self, model_input, params=None): 2355 return model_input 2356 2357 def predict_stream(self, model_input, params=None): 2358 yield model_input 2359 2360 m = Model() 2361 assert m.predict("abc") == "abc" 2362 assert next(iter(m.predict_stream("abc"))) == "abc" 2363 2364 with mlflow.start_run(): 2365 model_info = mlflow.pyfunc.log_model(name="model", python_model=m, input_example="abc") 2366 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 2367 assert pyfunc_model.predict("abc") is not None 2368 assert next(iter(pyfunc_model.predict_stream("abc"))) is not None 2369 2370 2371 def test_callable_python_model_without_context_in_predict(): 2372 def predict(model_input): 2373 return model_input 2374 2375 assert predict("abc") == "abc" 2376 with mlflow.start_run(): 2377 model_info = mlflow.pyfunc.log_model( 2378 name="model", python_model=predict, input_example="abc" 2379 ) 2380 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 2381 assert pyfunc_model.predict("abc") is not None 2382 2383 2384 def test_pyfunc_model_with_wrong_predict_signature_warning(): 2385 with pytest.warns( 2386 FutureWarning, 2387 match=r"Model's `predict` method contains invalid parameters: {'messages'}", 2388 ): 2389 2390 class Model(mlflow.pyfunc.PythonModel): 2391 def predict(self, context, messages, params=None): 2392 return messages 2393 2394 with pytest.warns( 2395 FutureWarning, 2396 match=r"Model's `predict_stream` method contains invalid parameters: {'_'}", 2397 ): 2398 2399 class Model(mlflow.pyfunc.PythonModel): 2400 def predict(self, model_input, params=None): 2401 return model_input 2402 2403 def predict_stream(self, _, model_input, params=None): 2404 yield model_input 2405 2406 2407 def test_pyfunc_model_input_example_with_signature(): 2408 class Model(mlflow.pyfunc.PythonModel): 2409 def predict(self, context, model_input, params=None): 2410 return model_input 2411 2412 signature = infer_signature(["a", "b", "c"]) 2413 with mlflow.start_run(): 2414 with pytest.warns( 2415 UserWarning, match=r"An input example was not provided when logging the model" 2416 ): 2417 mlflow.pyfunc.log_model(name="model", python_model=Model(), signature=signature) 2418 2419 with mlflow.start_run(): 2420 with pytest.raises( 2421 MlflowException, match=r"Input example does not match the model signature" 2422 ): 2423 mlflow.pyfunc.log_model( 2424 name="model", python_model=Model(), signature=signature, input_example=123 2425 ) 2426 2427 2428 @pytest.mark.parametrize("save_model", [True, False]) 2429 @pytest.mark.parametrize("use_user_auth_policy", [True, False]) 2430 @pytest.mark.parametrize("use_system_policy", [True, False]) 2431 def test_model_log_with_auth_policy(tmp_path, save_model, use_user_auth_policy, use_system_policy): 2432 pyfunc_save_artifact_path = os.path.join(tmp_path, "pyfunc_model_save") 2433 pyfunc_log_artifact_path = "pyfunc_model_log" 2434 2435 expected_auth_policy = {"system_auth_policy": {}, "user_auth_policy": {}} 2436 2437 system_auth_policy = None 2438 if use_system_policy: 2439 system_auth_policy = SystemAuthPolicy( 2440 resources=[ 2441 DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"), 2442 DatabricksVectorSearchIndex(index_name="rag.studio_bugbash.databricks_docs_index"), 2443 DatabricksFunction(function_name="rag.studio.test_function_a"), 2444 DatabricksUCConnection(connection_name="test_connection_1"), 2445 ] 2446 ) 2447 expected_auth_policy["system_auth_policy"] = { 2448 "resources": { 2449 "api_version": "1", 2450 "databricks": { 2451 "function": [{"name": "rag.studio.test_function_a"}], 2452 "serving_endpoint": [{"name": "databricks-mixtral-8x7b-instruct"}], 2453 "uc_connection": [{"name": "test_connection_1"}], 2454 "vector_search_index": [{"name": "rag.studio_bugbash.databricks_docs_index"}], 2455 }, 2456 } 2457 } 2458 2459 user_auth_policy = None 2460 if use_user_auth_policy: 2461 user_auth_policy = UserAuthPolicy( 2462 api_scopes=[ 2463 "catalog.catalogs", 2464 "vectorsearch.vector-search-indexes", 2465 "workspace.workspace", 2466 ] 2467 ) 2468 expected_auth_policy["user_auth_policy"] = { 2469 "api_scopes": [ 2470 "catalog.catalogs", 2471 "vectorsearch.vector-search-indexes", 2472 "workspace.workspace", 2473 ] 2474 } 2475 2476 auth_policy = AuthPolicy( 2477 user_auth_policy=user_auth_policy, system_auth_policy=system_auth_policy 2478 ) 2479 2480 if save_model: 2481 mlflow.pyfunc.save_model( 2482 path=pyfunc_save_artifact_path, 2483 conda_env=_conda_env(), 2484 python_model=mlflow.pyfunc.model.PythonModel(), 2485 auth_policy=auth_policy, 2486 ) 2487 reloaded_model = Model.load(pyfunc_save_artifact_path) 2488 else: 2489 with mlflow.start_run() as run: 2490 mlflow.pyfunc.log_model( 2491 name=pyfunc_log_artifact_path, 2492 python_model=mlflow.pyfunc.model.PythonModel(), 2493 auth_policy=auth_policy, 2494 ) 2495 2496 pyfunc_model_uri = f"runs:/{run.info.run_id}/{pyfunc_log_artifact_path}" 2497 pyfunc_model_path = _download_artifact_from_uri(pyfunc_model_uri) 2498 reloaded_model = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) 2499 2500 assert reloaded_model.auth_policy == expected_auth_policy 2501 2502 2503 def test_both_resources_and_auth_policy(): 2504 pyfunc_log_artifact_path = "pyfunc_model_log" 2505 system_auth_policy = SystemAuthPolicy( 2506 resources=[DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct")] 2507 ) 2508 user_auth_policy = UserAuthPolicy(api_scopes=["workspace.workspace"]) 2509 auth_policy = AuthPolicy( 2510 user_auth_policy=user_auth_policy, system_auth_policy=system_auth_policy 2511 ) 2512 2513 with mlflow.start_run() as _: 2514 with pytest.raises( 2515 ValueError, match="Only one of `resources`, and `auth_policy` can be specified." 2516 ): 2517 mlflow.pyfunc.log_model( 2518 name=pyfunc_log_artifact_path, 2519 python_model=mlflow.pyfunc.model.PythonModel(), 2520 auth_policy=auth_policy, 2521 resources=[ 2522 DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct") 2523 ], 2524 ) 2525 2526 2527 @pytest.mark.parametrize("compression", ["lzma", "bzip2", "gzip"]) 2528 def test_model_save_load_compression( 2529 monkeypatch, sklearn_knn_model, main_scoped_model_class, iris_data, tmp_path, compression 2530 ): 2531 monkeypatch.setenv(MLFLOW_LOG_MODEL_COMPRESSION.name, compression) 2532 sklearn_model_path = os.path.join(tmp_path, "sklearn_model") 2533 mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path) 2534 2535 def test_predict(sk_model, model_input): 2536 return sk_model.predict(model_input) * 2 2537 2538 pyfunc_model_path = os.path.join(tmp_path, "pyfunc_model") 2539 2540 mlflow.pyfunc.save_model( 2541 path=pyfunc_model_path, 2542 artifacts={"sk_model": sklearn_model_path}, 2543 conda_env=_conda_env(), 2544 python_model=main_scoped_model_class(test_predict), 2545 ) 2546 2547 loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=pyfunc_model_path) 2548 np.testing.assert_array_equal( 2549 loaded_pyfunc_model.predict(iris_data[0]), 2550 test_predict(sk_model=sklearn_knn_model, model_input=iris_data[0]), 2551 ) 2552 2553 2554 @pytest.mark.skip(reason="Enable once we re-enable the warning") 2555 def test_load_model_warning(): 2556 class Model(mlflow.pyfunc.PythonModel): 2557 def predict(self, model_input: list[str]): 2558 return model_input 2559 2560 with mlflow.start_run() as run: 2561 mlflow.pyfunc.log_model( 2562 python_model=Model(), 2563 name="model", 2564 input_example=["a", "b", "c"], 2565 ) 2566 2567 with pytest.warns(UserWarning, match=r"`runs:/<run_id>/artifact_path` is deprecated"): 2568 mlflow.pyfunc.load_model(f"runs:/{run.info.run_id}/model") 2569 2570 2571 def test_pyfunc_model_traces_link_to_model_id(): 2572 class TestModel(mlflow.pyfunc.PythonModel): 2573 @mlflow.trace 2574 def predict(self, model_input: list[str]) -> list[str]: 2575 return model_input 2576 2577 model_infos = [ 2578 mlflow.pyfunc.log_model( 2579 name="test_model", 2580 python_model=TestModel(), 2581 input_example=["a", "b", "c"], 2582 ) 2583 for i in range(3) 2584 ] 2585 2586 for model_info in model_infos: 2587 pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri) 2588 pyfunc_model.predict(["a", "b", "c"]) 2589 2590 traces = get_traces()[::-1] 2591 assert len(traces) == 3 2592 for i in range(3): 2593 assert traces[i].info.request_metadata[TraceMetadataKey.MODEL_ID] == model_infos[i].model_id 2594 2595 2596 class ExampleModel(mlflow.pyfunc.PythonModel): 2597 def predict(self, model_input: list[str]) -> list[str]: 2598 return model_input 2599 2600 2601 def test_lock_model_requirements(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): 2602 monkeypatch.setenv("MLFLOW_LOCK_MODEL_DEPENDENCIES", "true") 2603 2604 model_info = mlflow.pyfunc.log_model(name="model", python_model=ExampleModel()) 2605 pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri, output_path=tmp_path) 2606 requirements_txt = next(Path(pyfunc_model_path).rglob("requirements.txt")) 2607 requirements_txt_contents = requirements_txt.read_text() 2608 assert "# Locked requirements" in requirements_txt_contents 2609 assert "mlflow==" in requirements_txt_contents 2610 assert "packaging==" in requirements_txt_contents 2611 # Check that pip can install the locked requirements 2612 subprocess.check_call( 2613 [ 2614 sys.executable, 2615 "-m", 2616 "pip", 2617 "install", 2618 "--ignore-installed", 2619 "--dry-run", 2620 "--requirement", 2621 requirements_txt, 2622 ], 2623 ) 2624 # Check that conda environment can be created with the locked requirements 2625 conda_yaml = next(Path(pyfunc_model_path).rglob("conda.yaml")) 2626 conda_yaml_contents = conda_yaml.read_text() 2627 assert "# Locked requirements" in conda_yaml_contents 2628 assert "mlflow==" in requirements_txt_contents 2629 assert "packaging==" in conda_yaml_contents 2630 subprocess.check_call( 2631 [ 2632 "conda", 2633 "env", 2634 "create", 2635 "--file", 2636 conda_yaml, 2637 "--dry-run", 2638 "--yes", 2639 ], 2640 ) 2641 2642 2643 def test_lock_model_requirements_pip_requirements(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): 2644 monkeypatch.setenv("MLFLOW_LOCK_MODEL_DEPENDENCIES", "true") 2645 model_info = mlflow.pyfunc.log_model( 2646 name="model", 2647 python_model=ExampleModel(), 2648 pip_requirements=["openai"], 2649 ) 2650 pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri, output_path=tmp_path) 2651 requirements_txt = next(Path(pyfunc_model_path).rglob("requirements.txt")) 2652 contents = requirements_txt.read_text() 2653 assert "# Locked requirements" in contents 2654 assert "mlflow==" in contents 2655 assert "openai==" in contents 2656 assert "httpx==" in contents 2657 2658 2659 def test_lock_model_requirements_extra_pip_requirements( 2660 monkeypatch: pytest.MonkeyPatch, tmp_path: Path 2661 ): 2662 monkeypatch.setenv("MLFLOW_LOCK_MODEL_DEPENDENCIES", "true") 2663 model_info = mlflow.pyfunc.log_model( 2664 name="model", 2665 python_model=ExampleModel(), 2666 extra_pip_requirements=["openai"], 2667 ) 2668 pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri, output_path=tmp_path) 2669 requirements_txt = next(Path(pyfunc_model_path).rglob("requirements.txt")) 2670 contents = requirements_txt.read_text() 2671 assert "# Locked requirements" in contents 2672 assert "mlflow==" in contents 2673 assert "openai==" in contents 2674 assert "httpx==" in contents 2675 2676 2677 def test_lock_model_requirements_constraints(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): 2678 constraints_file = tmp_path / "constraints.txt" 2679 constraints_file.write_text("openai==1.82.0") 2680 monkeypatch.setenv("MLFLOW_LOCK_MODEL_DEPENDENCIES", "true") 2681 model_info = mlflow.pyfunc.log_model( 2682 name="model", 2683 python_model=ExampleModel(), 2684 pip_requirements=["openai", f"-c {constraints_file}"], 2685 ) 2686 pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri, output_path=tmp_path) 2687 requirements_txt = next(Path(pyfunc_model_path).rglob("requirements.txt")) 2688 contents = requirements_txt.read_text() 2689 assert "# Locked requirements" in contents 2690 assert "mlflow==" in contents 2691 assert "openai==1.82.0" in contents 2692 assert "httpx==" in contents 2693 2694 2695 @pytest.mark.parametrize( 2696 ("input_example", "expected_result"), [(["Hello", "World"], True), (None, False)] 2697 ) 2698 def test_load_context_with_input_example(input_example, expected_result): 2699 class MyModel(mlflow.pyfunc.PythonModel): 2700 def load_context(self, context): 2701 raise Exception("load_context was called") 2702 2703 def predict(self, model_input: list[str], params=None): 2704 return model_input 2705 2706 msg = "Failed to run the predict function on input example" 2707 2708 with mock.patch("mlflow.models.signature._logger.warning") as mock_warning: 2709 mlflow.pyfunc.log_model( 2710 name="model", 2711 python_model=MyModel(), 2712 input_example=input_example, 2713 ) 2714 assert any(msg in call.args[0] for call in mock_warning.call_args_list) == expected_result