test_utils.py
1 import os 2 import random 3 from typing import Any, NamedTuple 4 from unittest import mock 5 6 import numpy as np 7 import pandas as pd 8 import pytest 9 import sklearn.neighbors as knn 10 from sklearn import datasets 11 12 import mlflow 13 from mlflow import MlflowClient 14 from mlflow.entities.model_registry import ModelVersion 15 from mlflow.environment_variables import MLFLOW_DISABLE_SCHEMA_DETAILS 16 from mlflow.exceptions import MlflowException 17 from mlflow.models import add_libraries_to_model 18 from mlflow.models.utils import ( 19 _config_context, 20 _convert_llm_input_data, 21 _enforce_array, 22 _enforce_datatype, 23 _enforce_mlflow_datatype, 24 _enforce_object, 25 _enforce_property, 26 _flatten_nested_params, 27 _validate_and_get_model_code_path, 28 _validate_model_code_from_notebook, 29 get_model_version_from_model_uri, 30 ) 31 from mlflow.pyfunc import _enforce_schema, _validate_prediction_input 32 from mlflow.types import DataType, Schema 33 from mlflow.types.schema import Array, ColSpec, Object, Property 34 35 36 class ModelWithData(NamedTuple): 37 model: Any 38 inference_data: Any 39 40 41 @pytest.fixture(scope="module") 42 def sklearn_knn_model(): 43 iris = datasets.load_iris() 44 X = iris.data[:, :2] # we only take the first two features. 45 y = iris.target 46 knn_model = knn.KNeighborsClassifier() 47 knn_model.fit(X, y) 48 return ModelWithData(model=knn_model, inference_data=X) 49 50 51 def random_int(lo=1, hi=1000000000): 52 return random.randint(int(lo), int(hi)) 53 54 55 def test_adding_libraries_to_model_default(sklearn_knn_model): 56 model_name = f"wheels-test-{random_int()}" 57 artifact_path = "model" 58 model_uri = f"models:/{model_name}/1" 59 wheeled_model_uri = f"models:/{model_name}/2" 60 61 # Log a model 62 with mlflow.start_run(): 63 run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id 64 mlflow.sklearn.log_model( 65 sklearn_knn_model.model, 66 name=artifact_path, 67 registered_model_name=model_name, 68 ) 69 70 wheeled_model_info = add_libraries_to_model(model_uri) 71 assert wheeled_model_info.run_id == run_id 72 73 # Verify new model version created 74 wheeled_model_version = get_model_version_from_model_uri(wheeled_model_uri) 75 assert wheeled_model_version.run_id == run_id 76 assert wheeled_model_version.name == model_name 77 78 79 def test_adding_libraries_to_model_new_run(sklearn_knn_model): 80 model_name = f"wheels-test-{random_int()}" 81 artifact_path = "model" 82 model_uri = f"models:/{model_name}/1" 83 wheeled_model_uri = f"models:/{model_name}/2" 84 85 # Log a model 86 with mlflow.start_run(): 87 original_run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id 88 mlflow.sklearn.log_model( 89 sklearn_knn_model.model, 90 name=artifact_path, 91 registered_model_name=model_name, 92 ) 93 94 with mlflow.start_run(): 95 wheeled_run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id 96 wheeled_model_info = add_libraries_to_model(model_uri) 97 assert original_run_id != wheeled_run_id 98 assert wheeled_model_info.run_id == wheeled_run_id 99 100 # Verify new model version created 101 wheeled_model_version = get_model_version_from_model_uri(wheeled_model_uri) 102 assert wheeled_model_version.run_id == wheeled_run_id 103 assert wheeled_model_version.name == model_name 104 105 106 def test_adding_libraries_to_model_run_id_passed(sklearn_knn_model): 107 model_name = f"wheels-test-{random_int()}" 108 artifact_path = "model" 109 model_uri = f"models:/{model_name}/1" 110 wheeled_model_uri = f"models:/{model_name}/2" 111 112 # Log a model 113 with mlflow.start_run(): 114 original_run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id 115 mlflow.sklearn.log_model( 116 sklearn_knn_model.model, 117 name=artifact_path, 118 registered_model_name=model_name, 119 ) 120 121 with mlflow.start_run(): 122 wheeled_run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id 123 124 wheeled_model_info = add_libraries_to_model(model_uri, run_id=wheeled_run_id) 125 assert original_run_id != wheeled_run_id 126 assert wheeled_model_info.run_id == wheeled_run_id 127 128 # Verify new model version created 129 wheeled_model_version = get_model_version_from_model_uri(wheeled_model_uri) 130 assert wheeled_model_version.run_id == wheeled_run_id 131 assert wheeled_model_version.name == model_name 132 133 134 def test_adding_libraries_to_model_new_model_name(sklearn_knn_model): 135 model_name = f"wheels-test-{random_int()}" 136 wheeled_model_name = f"wheels-test-{random_int()}" 137 artifact_path = "model" 138 model_uri = f"models:/{model_name}/1" 139 wheeled_model_uri = f"models:/{wheeled_model_name}/1" 140 141 # Log a model 142 with mlflow.start_run(): 143 mlflow.sklearn.log_model( 144 sklearn_knn_model.model, 145 name=artifact_path, 146 registered_model_name=model_name, 147 ) 148 149 with mlflow.start_run(): 150 new_run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id 151 wheeled_model_info = add_libraries_to_model( 152 model_uri, registered_model_name=wheeled_model_name 153 ) 154 assert wheeled_model_info.run_id == new_run_id 155 156 # Verify new model version created 157 wheeled_model_version = get_model_version_from_model_uri(wheeled_model_uri) 158 assert wheeled_model_version.run_id == new_run_id 159 assert wheeled_model_version.name == wheeled_model_name 160 assert wheeled_model_name != model_name 161 162 163 def test_adding_libraries_to_model_when_version_source_None(sklearn_knn_model): 164 model_name = f"wheels-test-{random_int()}" 165 artifact_path = "model" 166 model_uri = f"models:/{model_name}/1" 167 168 # Log a model 169 with mlflow.start_run(): 170 original_run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id 171 mlflow.sklearn.log_model( 172 sklearn_knn_model.model, 173 name=artifact_path, 174 registered_model_name=model_name, 175 ) 176 177 model_version_without_source = ModelVersion(name=model_name, version=1, creation_timestamp=124) 178 assert model_version_without_source.run_id is None 179 with mock.patch.object( 180 MlflowClient, "get_model_version", return_value=model_version_without_source 181 ) as mlflow_client_mock: 182 wheeled_model_info = add_libraries_to_model(model_uri) 183 assert wheeled_model_info.run_id is not None 184 assert wheeled_model_info.run_id != original_run_id 185 mlflow_client_mock.assert_called_once_with(model_name, "1") 186 187 188 @pytest.mark.parametrize( 189 ("data", "data_type"), 190 [ 191 ("string", DataType.string), 192 (np.int32(1), DataType.integer), 193 (np.int32(1), DataType.long), 194 (np.int32(1), DataType.double), 195 (True, DataType.boolean), 196 (1.0, DataType.double), 197 (np.float32(0.1), DataType.float), 198 (np.float32(0.1), DataType.double), 199 (np.int64(100), DataType.long), 200 (np.datetime64("2023-10-13 00:00:00"), DataType.datetime), 201 ], 202 ) 203 def test_enforce_datatype(data, data_type): 204 assert _enforce_datatype(data, data_type) == data 205 206 207 def test_enforce_datatype_with_errors(): 208 with pytest.raises(MlflowException, match=r"Expected dtype to be DataType, got str"): 209 _enforce_datatype("string", "string") 210 211 with pytest.raises( 212 MlflowException, match=r"Failed to enforce schema of data `123` with dtype `string`" 213 ): 214 _enforce_datatype(123, DataType.string) 215 216 217 @pytest.mark.parametrize( 218 "dtype", 219 [ 220 pd.StringDtype(), 221 "string", 222 object, 223 None, # infers object in pandas <3.0, StringDtype in pandas 3.0 224 ], 225 ) 226 def test_enforce_mlflow_datatype_with_string_dtype(dtype): 227 # Test that string dtypes are handled correctly (pandas 3.0 compatibility) 228 series = pd.Series(["a", "b", "c"], dtype=dtype) 229 result = _enforce_mlflow_datatype("col", series, DataType.string) 230 assert result is series 231 232 233 def test_enforce_object(): 234 data = { 235 "a": "some_sentence", 236 "b": b"some_bytes", 237 "c": ["sentence1", "sentence2"], 238 "d": {"str": "value", "arr": [0.1, 0.2]}, 239 } 240 obj = Object([ 241 Property("a", DataType.string), 242 Property("b", DataType.binary, required=False), 243 Property("c", Array(DataType.string)), 244 Property( 245 "d", 246 Object([ 247 Property("str", DataType.string), 248 Property("arr", Array(DataType.double), required=False), 249 ]), 250 ), 251 ]) 252 assert _enforce_object(data, obj) == data 253 254 data = {"a": "some_sentence", "c": ["sentence1", "sentence2"], "d": {"str": "some_value"}} 255 assert _enforce_object(data, obj) == data 256 257 258 def test_enforce_object_with_errors(): 259 with pytest.raises(MlflowException, match=r"Expected data to be dictionary, got list"): 260 _enforce_object(["some_sentence"], Object([Property("a", DataType.string)])) 261 262 with pytest.raises(MlflowException, match=r"Expected obj to be Object, got Property"): 263 _enforce_object({"a": "some_sentence"}, Property("a", DataType.string)) 264 265 obj = Object([Property("a", DataType.string), Property("b", DataType.string, required=False)]) 266 with pytest.raises(MlflowException, match=r"Missing required properties: {'a'}"): 267 _enforce_object({}, obj) 268 269 with pytest.raises( 270 MlflowException, match=r"Invalid properties not defined in the schema found: {'c'}" 271 ): 272 _enforce_object({"a": "some_sentence", "c": "some_sentence"}, obj) 273 274 with pytest.raises( 275 MlflowException, 276 match=r"Failed to enforce schema for key `a`. Expected type string, received type int", 277 ): 278 _enforce_object({"a": 1}, obj) 279 280 281 def test_enforce_property(): 282 data = "some_sentence" 283 prop = Property("a", DataType.string) 284 assert _enforce_property(data, prop) == data 285 286 data = ["some_sentence1", "some_sentence2"] 287 prop = Property("a", Array(DataType.string)) 288 assert _enforce_property(data, prop) == data 289 290 prop = Property("a", Array(DataType.binary)) 291 assert _enforce_property(data, prop) == [b"some_sentence1", b"some_sentence2"] 292 293 data = np.array([np.int32(1), np.int32(2)]) 294 prop = Property("a", Array(DataType.integer)) 295 assert (_enforce_property(data, prop) == data).all() 296 297 data = { 298 "a": "some_sentence", 299 "b": b"some_bytes", 300 "c": ["sentence1", "sentence2"], 301 "d": {"str": "value", "arr": [0.1, 0.2]}, 302 } 303 prop = Property( 304 "any_name", 305 Object([ 306 Property("a", DataType.string), 307 Property("b", DataType.binary, required=False), 308 Property("c", Array(DataType.string), required=False), 309 Property( 310 "d", 311 Object([ 312 Property("str", DataType.string), 313 Property("arr", Array(DataType.double), required=False), 314 ]), 315 ), 316 ]), 317 ) 318 assert _enforce_property(data, prop) == data 319 data = {"a": "some_sentence", "d": {"str": "some_value"}} 320 assert _enforce_property(data, prop) == data 321 322 323 def test_enforce_property_with_errors(): 324 with pytest.raises( 325 MlflowException, match=r"Failed to enforce schema of data `123` with dtype `string`" 326 ): 327 _enforce_property(123, Property("a", DataType.string)) 328 329 with pytest.raises(MlflowException, match=r"Missing required properties: {'a'}"): 330 _enforce_property( 331 {"b": ["some_sentence1", "some_sentence2"]}, 332 Property( 333 "any_name", 334 Object([Property("a", DataType.string), Property("b", Array(DataType.string))]), 335 ), 336 ) 337 338 with pytest.raises( 339 MlflowException, 340 match=r"Failed to enforce schema for key `a`. Expected type string, received type list", 341 ): 342 _enforce_property( 343 {"a": ["some_sentence1", "some_sentence2"]}, 344 Property("any_name", Object([Property("a", DataType.string)])), 345 ) 346 347 348 @pytest.mark.parametrize( 349 ("data", "schema"), 350 [ 351 # 1. Flat list 352 (["some_sentence1", "some_sentence2"], Array(DataType.string)), 353 # 2. Nested list 354 ( 355 [ 356 [["a", "b"], ["c", "d"]], 357 [["e", "f", "g"], ["h"]], 358 [[]], 359 ], 360 Array(Array(Array(DataType.string))), 361 ), 362 # 3. Array of Object 363 ( 364 [ 365 {"a": "some_sentence1", "b": "some_sentence2"}, 366 {"a": "some_sentence3", "c": ["some_sentence4", "some_sentence5"]}, 367 ], 368 Array( 369 Object([ 370 Property("a", DataType.string), 371 Property("b", DataType.string, required=False), 372 Property("c", Array(DataType.string), required=False), 373 ]) 374 ), 375 ), 376 # 4. Empty list 377 ([], Array(DataType.string)), 378 ], 379 ) 380 def test_enforce_array_on_list(data, schema): 381 assert _enforce_array(data, schema) == data 382 383 384 @pytest.mark.parametrize( 385 ("data", "schema"), 386 [ 387 # 1. 1D array 388 (np.array(["some_sentence1", "some_sentence2"]), Array(DataType.string)), 389 # 2. 2D array 390 ( 391 np.array([ 392 ["a", "b"], 393 ["c", "d"], 394 ]), 395 Array(Array(DataType.string)), 396 ), 397 # 3. Empty array 398 (np.array([[], []]), Array(Array(DataType.string))), 399 ], 400 ) 401 def test_enforce_array_on_numpy_array(data, schema): 402 assert (_enforce_array(data, schema) == data).all() 403 404 405 def test_enforce_array_with_errors(): 406 with pytest.raises(MlflowException, match=r"Expected data to be list or numpy array, got str"): 407 _enforce_array("abc", Array(DataType.string)) 408 409 with pytest.raises(MlflowException, match=r"Incompatible input types"): 410 _enforce_array([123, 456, 789], Array(DataType.string)) 411 412 # Nested array with mixed type elements 413 with pytest.raises(MlflowException, match=r"Incompatible input types"): 414 _enforce_array([["a", "b"], [1, 2]], Array(Array(DataType.string))) 415 416 # Nested array with different nest level 417 with pytest.raises(MlflowException, match=r"Expected data to be list or numpy array, got str"): 418 _enforce_array([["a", "b"], "c"], Array(Array(DataType.string))) 419 420 # Missing priperties in Object 421 with pytest.raises(MlflowException, match=r"Missing required properties: {'b'}"): 422 _enforce_array( 423 [ 424 {"a": "some_sentence1", "b": "some_sentence2"}, 425 {"a": "some_sentence3", "c": ["some_sentence4", "some_sentence5"]}, 426 ], 427 Array(Object([Property("a", DataType.string), Property("b", DataType.string)])), 428 ) 429 430 # Extra properties 431 with pytest.raises( 432 MlflowException, match=r"Invalid properties not defined in the schema found: {'c'}" 433 ): 434 _enforce_array( 435 [ 436 {"a": "some_sentence1", "b": "some_sentence2"}, 437 {"a": "some_sentence3", "c": ["some_sentence4", "some_sentence5"]}, 438 ], 439 Array( 440 Object([ 441 Property("a", DataType.string), 442 Property("b", DataType.string, required=False), 443 ]) 444 ), 445 ) 446 447 448 def test_model_code_validation(): 449 # Invalid code with dbutils 450 invalid_code = "dbutils.library.restartPython()\nsome_python_variable = 5" 451 452 with mock.patch("mlflow.models.utils._logger.warning") as mock_warning: 453 _validate_model_code_from_notebook(invalid_code) 454 mock_warning.assert_called_once_with( 455 "The model file uses 'dbutils' commands which are not supported. To ensure your " 456 "code functions correctly, make sure that it does not rely on these dbutils " 457 "commands for correctness." 458 ) 459 460 # Code with commented magic commands displays warning 461 warning_code = "# dbutils.library.restartPython()\n# MAGIC %run ../wheel_installer" 462 463 with mock.patch("mlflow.models.utils._logger.warning") as mock_warning: 464 _validate_model_code_from_notebook(warning_code) 465 mock_warning.assert_called_once_with( 466 "The model file uses magic commands which have been commented out. To ensure your code " 467 "functions correctly, make sure that it does not rely on these magic commands for " 468 "correctness." 469 ) 470 471 # Code with commented pip magic commands does not warn 472 warning_code = "# MAGIC %pip install mlflow" 473 with mock.patch("mlflow.models.utils._logger.warning") as mock_warning: 474 _validate_model_code_from_notebook(warning_code) 475 mock_warning.assert_not_called() 476 477 # Test valid code 478 valid_code = "some_valid_python_code = 'valid'" 479 480 validated_code = _validate_model_code_from_notebook(valid_code).decode("utf-8") 481 assert validated_code == valid_code 482 483 # Test uncommented magic commands 484 code_with_magic_command = ( 485 "valid_python_code = 'valid'\n%pip install sqlparse\nvalid_python_code = 'valid'\n# Comment" 486 ) 487 expected_validated_code = ( 488 "valid_python_code = 'valid'\n# MAGIC %pip install sqlparse\nvalid_python_code = " 489 "'valid'\n# Comment" 490 ) 491 492 validated_code_with_magic_command = _validate_model_code_from_notebook( 493 code_with_magic_command 494 ).decode("utf-8") 495 assert validated_code_with_magic_command == expected_validated_code 496 497 498 def test_config_context(): 499 with _config_context("tests/langchain/config.yml"): 500 assert mlflow.models.model_config.__mlflow_model_config__ == "tests/langchain/config.yml" 501 502 assert mlflow.models.model_config.__mlflow_model_config__ is None 503 504 505 def test_flatten_nested_params(): 506 nested_params = { 507 "a": 1, 508 "b": {"c": 2, "d": {"e": 3}}, 509 "f": {"g": {"h": 4}}, 510 } 511 expected_flattened_params = { 512 "a": 1, 513 "b.c": 2, 514 "b.d.e": 3, 515 "f.g.h": 4, 516 } 517 assert _flatten_nested_params(nested_params, sep=".") == expected_flattened_params 518 assert _flatten_nested_params(nested_params, sep="/") == { 519 "a": 1, 520 "b/c": 2, 521 "b/d/e": 3, 522 "f/g/h": 4, 523 } 524 assert _flatten_nested_params({}) == {} 525 526 params = {"a": 1, "b": 2, "c": 3} 527 assert _flatten_nested_params(params) == params 528 529 params = { 530 "a": 1, 531 "b": {"c": 2, "d": {"e": 3, "f": [1, 2, 3]}, "g": "hello"}, 532 "h": {"i": None}, 533 } 534 expected_flattened_params = { 535 "a": 1, 536 "b/c": 2, 537 "b/d/e": 3, 538 "b/d/f": [1, 2, 3], 539 "b/g": "hello", 540 "h/i": None, 541 } 542 assert _flatten_nested_params(params) == expected_flattened_params 543 544 nested_params = {1: {2: {3: 4}}, "a": {"b": {"c": 5}}} 545 expected_flattened_params_mixed = { 546 "1/2/3": 4, 547 "a/b/c": 5, 548 } 549 assert _flatten_nested_params(nested_params) == expected_flattened_params_mixed 550 551 rag_params = { 552 "workspace_url": "https://e2-dogfood.staging.cloud.databricks.com", 553 "vector_search_endpoint_name": "dbdemos_vs_endpoint", 554 "vector_search_index": "monitoring.rag.databricks_docs_index", 555 "embedding_model_endpoint_name": "databricks-bge-large-en", 556 "embedding_model_query_instructions": "Represent this sentence for searching", 557 "llm_model": "databricks-dbrx-instruct", 558 "llm_prompt_template": "You are a trustful assistant for Databricks users.", 559 "retriever_config": {"k": 5, "use_mmr": "false"}, 560 "llm_parameters": {"temperature": 0.01, "max_tokens": 200}, 561 "llm_prompt_template_variables": ["chat_history", "context", "question"], 562 "secret_scope": "dbdemos", 563 "secret_key": "rag_sunish", 564 } 565 566 expected_rag_flattened_params = { 567 "workspace_url": "https://e2-dogfood.staging.cloud.databricks.com", 568 "vector_search_endpoint_name": "dbdemos_vs_endpoint", 569 "vector_search_index": "monitoring.rag.databricks_docs_index", 570 "embedding_model_endpoint_name": "databricks-bge-large-en", 571 "embedding_model_query_instructions": "Represent this sentence for searching", 572 "llm_model": "databricks-dbrx-instruct", 573 "llm_prompt_template": "You are a trustful assistant for Databricks users.", 574 "retriever_config/k": 5, 575 "retriever_config/use_mmr": "false", 576 "llm_parameters/temperature": 0.01, 577 "llm_parameters/max_tokens": 200, 578 "llm_prompt_template_variables": ["chat_history", "context", "question"], 579 "secret_scope": "dbdemos", 580 "secret_key": "rag_sunish", 581 } 582 583 assert _flatten_nested_params(rag_params) == expected_rag_flattened_params 584 585 586 @pytest.mark.parametrize( 587 ("data", "target", "target_type"), 588 [ 589 (pd.DataFrame([{"a": [1, 2, 3]}]), [{"a": [1, 2, 3]}], list), 590 (pd.DataFrame([{"a": np.array([1, 2, 3])}]), [{"a": [1, 2, 3]}], list), 591 (pd.DataFrame([{0: np.array(["abc"])[0]}]), ["abc"], list), 592 (np.array([1, 2, 3]), [1, 2, 3], list), 593 (np.array([123])[0], 123, int), 594 (np.array(["abc"])[0], "abc", str), 595 ], 596 ) 597 def test_convert_llm_input_data(data, target, target_type): 598 result = _convert_llm_input_data(data) 599 assert result == target 600 assert type(result) == target_type 601 602 603 @pytest.mark.parametrize( 604 ("model_path", "error_message"), 605 [ 606 ( 607 "model.py", 608 f"The provided model path '{os.getcwd()}/model.py' does not exist. " 609 "Ensure the file path is valid and try again.", 610 ), 611 ( 612 "model", 613 f"The provided model path '{os.getcwd()}/model' does not exist. " 614 "Ensure the file path is valid and try again. " 615 f"Perhaps you meant '{os.getcwd()}/model.py'?", 616 ), 617 ], 618 ) 619 def test_validate_and_get_model_code_path_not_found(model_path, error_message, tmp_path): 620 with pytest.raises(MlflowException, match=error_message): 621 _validate_and_get_model_code_path(model_path, tmp_path) 622 623 624 def test_validate_and_get_model_code_path_success(tmp_path): 625 # if the model file exists, return the path as is 626 model_path = os.path.abspath(__file__) 627 actual = _validate_and_get_model_code_path(model_path, tmp_path) 628 629 assert actual == model_path 630 631 632 def test_suppress_schema_error(monkeypatch): 633 schema = Schema([ 634 ColSpec("double", "id"), 635 ColSpec("string", "name"), 636 ]) 637 monkeypatch.setenv(MLFLOW_DISABLE_SCHEMA_DETAILS.name, "true") 638 data = pd.DataFrame({"id": [1, 2]}, dtype="float64") 639 640 with pytest.raises( 641 MlflowException, 642 match=r"Failed to enforce model input schema. Please check your input data.", 643 ): 644 _validate_prediction_input(data, None, schema, None) 645 646 647 def test_enforce_schema_with_missing_and_extra_columns(monkeypatch): 648 schema = Schema([ 649 ColSpec("long", "id"), 650 ColSpec("string", "name"), 651 ]) 652 monkeypatch.setenv(MLFLOW_DISABLE_SCHEMA_DETAILS.name, "true") 653 input_data = pd.DataFrame({"id": [1, 2], "extra_col": ["mlflow", "oss"]}) 654 with pytest.raises( 655 MlflowException, match=r"Input schema validation failed.*extra inputs provided" 656 ): 657 _enforce_schema(input_data, schema)