test_sentence_transformers_model_export.py
1 import json 2 import os 3 from unittest import mock 4 5 import numpy as np 6 import pandas as pd 7 import pytest 8 import sentence_transformers 9 import yaml 10 from packaging.version import Version 11 from pyspark.sql import SparkSession 12 from pyspark.sql.types import ArrayType, DoubleType 13 from sentence_transformers import SentenceTransformer 14 15 import mlflow 16 import mlflow.pyfunc.scoring_server as pyfunc_scoring_server 17 import mlflow.sentence_transformers 18 from mlflow import pyfunc 19 from mlflow.exceptions import MlflowException 20 from mlflow.models import Model, infer_signature 21 from mlflow.models.utils import _read_example, load_serving_example 22 from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository 23 from mlflow.utils.environment import _mlflow_conda_env 24 25 from tests.helper_functions import ( 26 _assert_pip_requirements, 27 _compare_logged_code_paths, 28 _mlflow_major_version_string, 29 assert_register_model_called_with_local_model_path, 30 pyfunc_serve_and_score_model, 31 ) 32 from tests.transformers.version import IS_TRANSFORMERS_V5_OR_LATER 33 34 35 @pytest.fixture 36 def model_path(tmp_path): 37 return tmp_path.joinpath("model") 38 39 40 @pytest.fixture 41 def basic_model(): 42 return SentenceTransformer("all-MiniLM-L6-v2") 43 44 45 @pytest.fixture 46 def model_with_remote_code(): 47 return SentenceTransformer("Alibaba-NLP/gte-base-en-v1.5", trust_remote_code=True) 48 49 50 @pytest.fixture(scope="module") 51 def spark(): 52 with SparkSession.builder.master("local[1]").getOrCreate() as s: 53 yield s 54 55 56 def test_model_save_and_load(model_path, basic_model): 57 mlflow.sentence_transformers.save_model(model=basic_model, path=model_path) 58 59 loaded_model = mlflow.sentence_transformers.load_model(model_path) 60 61 encoded_single = loaded_model.encode("I'm just a simple string; nothing to see here.") 62 encoded_multi = loaded_model.encode(["I'm a string", "I'm also a string", "Please encode me"]) 63 64 assert isinstance(encoded_single, np.ndarray) 65 assert len(encoded_single) == 384 66 assert isinstance(encoded_multi, np.ndarray) 67 assert len(encoded_multi) == 3 68 assert all(len(x) == 384 for x in encoded_multi) 69 70 71 @pytest.mark.skipif( 72 Version(sentence_transformers.__version__) < Version("2.4.0"), 73 reason="`trust_remote_code` is not supported in Sentence Transformers < 2.3.0 " 74 "and `include_prompt` from gte-base-en-v1.5 requires 2.4.0 or above", 75 ) 76 @pytest.mark.skipif( 77 IS_TRANSFORMERS_V5_OR_LATER, 78 reason="Alibaba-NLP/gte-base-en-v1.5 has corrupted position_ids buffers on transformers 5.x " 79 "due to uninitialized meta-device loading (https://github.com/huggingface/transformers/issues/43957)", 80 ) 81 def test_model_save_and_load_with_custom_code(model_path, model_with_remote_code): 82 mlflow.sentence_transformers.save_model(model=model_with_remote_code, path=model_path) 83 loaded_model = mlflow.sentence_transformers.load_model(model_path) 84 85 encoded_single = loaded_model.encode("I'm just a simple string; nothing to see here.") 86 assert isinstance(encoded_single, np.ndarray) 87 assert len(encoded_single) == 768 88 89 90 def test_dependency_mapping(): 91 pip_requirements = mlflow.sentence_transformers.get_default_pip_requirements() 92 93 expected_requirements = {"sentence-transformers", "torch", "transformers"} 94 assert {package.split("=")[0] for package in pip_requirements}.intersection( 95 expected_requirements 96 ) == expected_requirements 97 98 conda_requirements = mlflow.sentence_transformers.get_default_conda_env() 99 pip_in_conda = { 100 package.split("=")[0] for package in conda_requirements["dependencies"][2]["pip"] 101 } 102 expected_conda = {"mlflow"} 103 expected_conda.update(expected_requirements) 104 assert pip_in_conda.intersection(expected_conda) == expected_conda 105 106 107 def test_logged_data_structure(model_path, basic_model): 108 mlflow.sentence_transformers.save_model(model=basic_model, path=model_path) 109 110 with model_path.joinpath("requirements.txt").open() as file: 111 requirements = file.read() 112 reqs = {req.split("==")[0] for req in requirements.split("\n")} 113 expected_requirements = {"sentence-transformers", "torch", "transformers"} 114 assert reqs.intersection(expected_requirements) == expected_requirements 115 conda_env = yaml.safe_load(model_path.joinpath("conda.yaml").read_bytes()) 116 assert {req.split("==")[0] for req in conda_env["dependencies"][2]["pip"]}.intersection( 117 expected_requirements 118 ) == expected_requirements 119 120 mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes()) 121 assert "model_size_bytes" in mlmodel 122 123 pyfunc_flavor = mlmodel["flavors"]["python_function"] 124 assert pyfunc_flavor["loader_module"] == "mlflow.sentence_transformers" 125 assert pyfunc_flavor["data"] == mlflow.sentence_transformers.SENTENCE_TRANSFORMERS_DATA_PATH 126 127 st_flavor = mlmodel["flavors"]["sentence_transformers"] 128 assert st_flavor["pipeline_model_type"] == "BertModel" 129 assert st_flavor["source_model_name"] == "sentence-transformers/all-MiniLM-L6-v2" 130 131 132 @pytest.mark.parametrize( 133 ("model_name", "expected"), 134 [ 135 ( 136 "sentence-transformers/all-MiniLM-L6-v2", 137 "sentence-transformers/all-MiniLM-L6-v2", 138 ), 139 ( 140 "/path./to_/local-/path?/sentence-transformers_all-MiniLM-L6-v2/", 141 "sentence-transformers/all-MiniLM-L6-v2", 142 ), 143 ( 144 "/path/to/local/path/custom-user-009_model_name_with_underscore/", 145 "custom-user-009/model_name_with_underscore", 146 ), 147 ], 148 ) 149 def test_get_transformers_model_name(model_name, expected): 150 assert mlflow.sentence_transformers._get_transformers_model_name(model_name) == expected 151 152 153 def test_model_logging_and_inference(basic_model): 154 artifact_path = "sentence_transformer" 155 with mlflow.start_run(): 156 model_info = mlflow.sentence_transformers.log_model(basic_model, name=artifact_path) 157 158 model = mlflow.sentence_transformers.load_model(model_info.model_uri) 159 160 encoded_single = model.encode( 161 "Encodings provide a fixed width output regardless of input size." 162 ) 163 encoded_multi = model.encode([ 164 "Just a small town girl", 165 "livin' in a lonely world", 166 "she took the midnight train", 167 "going anywhere", 168 ]) 169 170 assert isinstance(encoded_single, np.ndarray) 171 assert len(encoded_single) == 384 172 assert isinstance(encoded_multi, np.ndarray) 173 assert len(encoded_multi) == 4 174 assert all(len(x) == 384 for x in encoded_multi) 175 176 177 def test_load_from_remote_uri(model_path, basic_model, mock_s3_bucket): 178 mlflow.sentence_transformers.save_model(model=basic_model, path=model_path) 179 artifact_root = f"s3://{mock_s3_bucket}" 180 artifact_path = "model" 181 artifact_repo = S3ArtifactRepository(artifact_root) 182 artifact_repo.log_artifacts(model_path, artifact_path=artifact_path) 183 model_uri = os.path.join(artifact_root, artifact_path) 184 loaded = mlflow.sentence_transformers.load_model(model_uri=str(model_uri)) 185 186 encoding = loaded.encode( 187 "I can see why these are useful when you do distance calculations on them!" 188 ) 189 190 assert len(encoding) == 384 191 192 193 def test_log_model_calls_register_model(tmp_path, basic_model): 194 artifact_path = "sentence_transformer" 195 register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model") 196 with mlflow.start_run(), register_model_patch: 197 conda_env = tmp_path.joinpath("conda_env.yaml") 198 _mlflow_conda_env( 199 conda_env, additional_pip_deps=["transformers", "torch", "sentence-transformers"] 200 ) 201 model_info = mlflow.sentence_transformers.log_model( 202 basic_model, 203 name=artifact_path, 204 conda_env=str(conda_env), 205 registered_model_name="My super cool encoder", 206 ) 207 assert_register_model_called_with_local_model_path( 208 register_model_mock=mlflow.tracking._model_registry.fluent._register_model, 209 model_uri=model_info.model_uri, 210 registered_model_name="My super cool encoder", 211 ) 212 213 214 def test_log_model_with_no_registered_model_name(tmp_path, basic_model): 215 artifact_path = "sentence_transformer" 216 register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model") 217 with mlflow.start_run(), register_model_patch: 218 conda_env = tmp_path.joinpath("conda_env.yaml") 219 _mlflow_conda_env( 220 conda_env, additional_pip_deps=["transformers", "torch", "sentence-transformers"] 221 ) 222 mlflow.sentence_transformers.log_model( 223 basic_model, 224 name=artifact_path, 225 conda_env=str(conda_env), 226 ) 227 mlflow.tracking._model_registry.fluent._register_model.assert_not_called() 228 229 230 def test_log_with_pip_requirements(tmp_path, basic_model): 231 expected_mlflow_version = _mlflow_major_version_string() 232 233 requirements_file = tmp_path.joinpath("requirements.txt") 234 requirements_file.write_text("some-clever-package") 235 with mlflow.start_run(): 236 model_info = mlflow.sentence_transformers.log_model( 237 basic_model, name="model", pip_requirements=str(requirements_file) 238 ) 239 _assert_pip_requirements( 240 model_info.model_uri, 241 [expected_mlflow_version, "some-clever-package"], 242 strict=True, 243 ) 244 with mlflow.start_run(): 245 model_info = mlflow.sentence_transformers.log_model( 246 basic_model, 247 name="model", 248 pip_requirements=[f"-r {requirements_file}", "a-hopefully-useful-package"], 249 ) 250 _assert_pip_requirements( 251 model_info.model_uri, 252 [expected_mlflow_version, "some-clever-package", "a-hopefully-useful-package"], 253 strict=True, 254 ) 255 with mlflow.start_run(): 256 model_info = mlflow.sentence_transformers.log_model( 257 basic_model, 258 name="model", 259 pip_requirements=[f"-c {requirements_file}", "i-dunno-maybe-its-good"], 260 ) 261 _assert_pip_requirements( 262 model_info.model_uri, 263 [expected_mlflow_version, "i-dunno-maybe-its-good", "-c constraints.txt"], 264 ["some-clever-package"], 265 strict=True, 266 ) 267 268 269 def test_log_with_extra_pip_requirements(basic_model, tmp_path): 270 expected_mlflow_version = _mlflow_major_version_string() 271 default_requirements = mlflow.sentence_transformers.get_default_pip_requirements() 272 requirements_file = tmp_path.joinpath("requirements.txt") 273 requirements_file.write_text("effective-package") 274 with mlflow.start_run(): 275 model_info = mlflow.sentence_transformers.log_model( 276 basic_model, name="model", extra_pip_requirements=str(requirements_file) 277 ) 278 _assert_pip_requirements( 279 model_info.model_uri, 280 [expected_mlflow_version, *default_requirements, "effective-package"], 281 ) 282 with mlflow.start_run(): 283 model_info = mlflow.sentence_transformers.log_model( 284 basic_model, 285 name="model", 286 extra_pip_requirements=[f"-r {requirements_file}", "useful-package"], 287 ) 288 _assert_pip_requirements( 289 model_info.model_uri, 290 [expected_mlflow_version, *default_requirements, "effective-package", "useful-package"], 291 ) 292 with mlflow.start_run(): 293 model_info = mlflow.sentence_transformers.log_model( 294 basic_model, 295 name="model", 296 extra_pip_requirements=[f"-c {requirements_file}", "constrained-pkg"], 297 ) 298 _assert_pip_requirements( 299 model_info.model_uri, 300 [ 301 expected_mlflow_version, 302 *default_requirements, 303 "constrained-pkg", 304 "-c constraints.txt", 305 ], 306 ["effective-package"], 307 ) 308 309 310 def test_model_save_without_conda_env_uses_default_env_with_expected_dependencies( 311 basic_model, model_path 312 ): 313 mlflow.sentence_transformers.save_model(basic_model, model_path) 314 _assert_pip_requirements( 315 model_path, mlflow.sentence_transformers.get_default_pip_requirements() 316 ) 317 318 319 def test_model_log_without_conda_env_uses_default_env_with_expected_dependencies( 320 basic_model, 321 ): 322 artifact_path = "model" 323 with mlflow.start_run(): 324 model_info = mlflow.sentence_transformers.log_model(basic_model, name=artifact_path) 325 _assert_pip_requirements( 326 model_info.model_uri, mlflow.sentence_transformers.get_default_pip_requirements() 327 ) 328 329 330 def test_log_model_with_code_paths(basic_model): 331 artifact_path = "model" 332 with ( 333 mlflow.start_run(), 334 mock.patch("mlflow.sentence_transformers._add_code_from_conf_to_system_path") as add_mock, 335 ): 336 model_info = mlflow.sentence_transformers.log_model( 337 basic_model, name=artifact_path, code_paths=[__file__] 338 ) 339 _compare_logged_code_paths( 340 __file__, model_info.model_uri, mlflow.sentence_transformers.FLAVOR_NAME 341 ) 342 mlflow.sentence_transformers.load_model(model_info.model_uri) 343 add_mock.assert_called() 344 345 346 def test_default_signature_assignment(): 347 expected_signature = { 348 "inputs": '[{"type": "string", "required": true}]', 349 "outputs": '[{"type": "tensor", "tensor-spec": {"dtype": "float64", "shape": [-1]}}]', 350 "params": None, 351 } 352 353 default_signature = mlflow.sentence_transformers._get_default_signature() 354 355 assert default_signature.to_dict() == expected_signature 356 357 358 def test_model_pyfunc_save_load(basic_model, model_path): 359 mlflow.sentence_transformers.save_model(basic_model, model_path) 360 loaded_pyfunc = pyfunc.load_model(model_uri=model_path) 361 362 sentence = "hello world and hello mlflow" 363 sentences = [sentence, "goodbye my friends", "i am a sentence"] 364 embedding_dim = basic_model.get_sentence_embedding_dimension() 365 366 emb0 = loaded_pyfunc.predict(sentence) 367 assert emb0.shape == (1, embedding_dim) 368 369 emb1 = loaded_pyfunc.predict(sentences) 370 emb2 = loaded_pyfunc.predict(pd.Series(sentences)) 371 emb3 = loaded_pyfunc.predict(pd.Series(sentences).to_numpy()) 372 373 for emb in [emb1, emb2, emb3]: 374 assert emb.shape == (3, embedding_dim) 375 376 np.testing.assert_array_equal(emb1, emb2) 377 np.testing.assert_array_equal(emb1, emb3) 378 379 380 def test_model_pyfunc_predict_with_params(basic_model, tmp_path): 381 sentence = "hello world and hello mlflow" 382 params = {"batch_size": 16} 383 384 model_path = tmp_path / "model1" 385 signature = infer_signature(sentence, params=params) 386 mlflow.sentence_transformers.save_model(basic_model, model_path, signature=signature) 387 loaded_pyfunc = pyfunc.load_model(model_uri=model_path) 388 embedding_dim = basic_model.get_sentence_embedding_dimension() 389 390 emb0 = loaded_pyfunc.predict(sentence, params) 391 assert emb0.shape == (1, embedding_dim) 392 393 with pytest.raises(MlflowException, match=r"Invalid parameters found"): 394 loaded_pyfunc.predict(sentence, {"batch_size": "16"}) 395 396 model_path = tmp_path / "model3" 397 mlflow.sentence_transformers.save_model(basic_model, model_path) 398 loaded_pyfunc = pyfunc.load_model(model_uri=model_path) 399 with mock.patch("mlflow.models.utils._logger.warning") as mock_warning: 400 loaded_pyfunc.predict(sentence, params) 401 mock_warning.assert_called_with( 402 "`params` can only be specified at inference time if the model signature defines a params " 403 "schema. This model does not define a params schema. Ignoring provided params: " 404 "['batch_size']" 405 ) 406 407 408 @pytest.mark.skipif( 409 Version(sentence_transformers.__version__) >= Version("3.1.0"), 410 reason="This test only passes for Sentence Transformers < 3.1.0", 411 ) 412 def test_model_pyfunc_predict_with_invalid_params(basic_model, tmp_path): 413 sentence = "hello world and hello mlflow" 414 model_path = tmp_path / "model" 415 mlflow.sentence_transformers.save_model( 416 basic_model, 417 model_path, 418 signature=infer_signature(sentence, params={"invalid_param": "value"}), 419 ) 420 loaded_pyfunc = pyfunc.load_model(model_uri=model_path) 421 422 loaded_pyfunc = pyfunc.load_model(model_uri=model_path) 423 with pytest.raises( 424 MlflowException, match=r"Received invalid parameter value for `params` argument" 425 ): 426 loaded_pyfunc.predict(sentence, {"invalid_param": "random_value"}) 427 428 429 def test_spark_udf(basic_model, spark): 430 params = {"batch_size": 16} 431 with mlflow.start_run(): 432 signature = infer_signature(SENTENCES, basic_model.encode(SENTENCES), params) 433 model_info = mlflow.sentence_transformers.log_model( 434 basic_model, name="my_model", signature=signature 435 ) 436 437 result_type = ArrayType(DoubleType()) 438 loaded_model = mlflow.pyfunc.spark_udf( 439 spark, 440 model_info.model_uri, 441 result_type=result_type, 442 params=params, 443 ) 444 445 df = spark.createDataFrame([("hello MLflow",), ("bye world",)], ["text"]) 446 df = df.withColumn("embedding", loaded_model("text")) 447 assert df.schema[1].dataType == result_type 448 449 pdf = df.toPandas() 450 assert pdf.shape == (2, 2) 451 assert pdf["embedding"].dtype == "object" 452 453 embeddings = np.array(pdf.embedding.to_list()) 454 assert embeddings.shape == (2, basic_model.get_sentence_embedding_dimension()) 455 456 457 @pytest.mark.parametrize( 458 ("input1", "input2"), 459 [ 460 (["hello world"], ["goodbye world!"]), 461 (["hello world", "i am mlflow"], ["goodbye world!", "i am mlflow"]), 462 ], 463 ) 464 def test_pyfunc_serve_and_score(input1, input2, basic_model): 465 with mlflow.start_run(): 466 model_info = mlflow.sentence_transformers.log_model( 467 basic_model, name="my_model", input_example=input1 468 ) 469 loaded_pyfunc = pyfunc.load_model(model_uri=model_info.model_uri) 470 local_predict = loaded_pyfunc.predict(input1) 471 472 # Check that the giving the same string to the served model results in the same result 473 inference_data = load_serving_example(model_info.model_uri) 474 assert json.loads(inference_data) == {"inputs": input1} 475 resp = pyfunc_serve_and_score_model( 476 model_info.model_uri, 477 data=inference_data, 478 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 479 extra_args=["--env-manager", "local"], 480 ) 481 serving_result = json.loads(resp.content.decode("utf-8"))["predictions"] 482 np.testing.assert_array_equal(local_predict, serving_result) 483 484 # Check that the giving a different string to the served model results in a different result 485 inference_data = json.dumps({"inputs": input2}) 486 resp = pyfunc_serve_and_score_model( 487 model_info.model_uri, 488 data=inference_data, 489 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 490 extra_args=["--env-manager", "local"], 491 ) 492 serving_result = json.loads(resp.content.decode("utf-8"))["predictions"] 493 assert not np.equal(local_predict, serving_result).all() 494 495 496 SENTENCES = ["hello world", "i am mlflow"] 497 SENTENCES_DF = pd.DataFrame(SENTENCES) 498 SIGNATURE = infer_signature( 499 model_input=SENTENCES, 500 model_output=SentenceTransformer("all-MiniLM-L6-v2").encode(SENTENCES), 501 ) 502 SIGNATURE_FROM_EXAMPLE = infer_signature( 503 model_input=SENTENCES_DF, 504 model_output=SentenceTransformer("all-MiniLM-L6-v2").encode(SENTENCES), 505 ) 506 507 508 @pytest.mark.parametrize( 509 ("example", "signature", "expected_signature"), 510 [ 511 (None, None, mlflow.sentence_transformers._get_default_signature()), 512 (SENTENCES_DF, None, SIGNATURE_FROM_EXAMPLE), 513 (None, SIGNATURE, SIGNATURE), 514 (SENTENCES, SIGNATURE, SIGNATURE), 515 ], 516 ) 517 def test_signature_and_examples_are_saved_correctly( 518 example, signature, expected_signature, basic_model, model_path 519 ): 520 mlflow.sentence_transformers.save_model( 521 basic_model, 522 path=model_path, 523 signature=signature, 524 input_example=example, 525 ) 526 mlflow_model = Model.load(model_path) 527 528 assert mlflow_model.signature == expected_signature 529 530 if example is None: 531 assert mlflow_model.saved_input_example_info is None 532 else: 533 if isinstance(example, pd.DataFrame): 534 assert mlflow_model.saved_input_example_info["type"] == "dataframe" 535 pd.testing.assert_frame_equal(_read_example(mlflow_model, model_path), example) 536 else: 537 assert mlflow_model.saved_input_example_info["type"] == "json_object" 538 np.testing.assert_equal(_read_example(mlflow_model, model_path), example) 539 540 541 def test_model_log_with_signature_inference(basic_model): 542 artifact_path = "model" 543 544 with mlflow.start_run(): 545 model_info = mlflow.sentence_transformers.log_model( 546 basic_model, name=artifact_path, input_example=SENTENCES 547 ) 548 549 loaded_model_info = Model.load(model_info.model_uri) 550 assert loaded_model_info.signature == SIGNATURE 551 552 553 def test_verify_task_and_update_metadata(): 554 # Update embedding task with empty metadata 555 metadata = mlflow.sentence_transformers._verify_task_and_update_metadata("llm/v1/embeddings") 556 assert metadata == {"task": "llm/v1/embeddings"} 557 # Update embedding task with metadata containing task 558 metadata = mlflow.sentence_transformers._verify_task_and_update_metadata( 559 "llm/v1/embeddings", metadata 560 ) 561 assert metadata == {"task": "llm/v1/embeddings"} 562 563 # Update embedding task with metadata containing different task 564 metadata = {"task": "llm/v1/completions"} 565 with pytest.raises( 566 MlflowException, match=r"Task type is inconsistent with the task value from metadata" 567 ): 568 mlflow.sentence_transformers._verify_task_and_update_metadata("llm/v1/embeddings", metadata) 569 570 # Invalid task type 571 with pytest.raises(MlflowException, match=r"Task type could only be llm/v1/embeddings"): 572 mlflow.sentence_transformers._verify_task_and_update_metadata("llm/v1/completions") 573 574 575 def test_model_pyfunc_with_dict_input(basic_model, model_path): 576 mlflow.sentence_transformers.save_model(basic_model, model_path, task="llm/v1/embeddings") 577 loaded_pyfunc = pyfunc.load_model(model_uri=model_path) 578 579 sentence = "hello world and hello mlflow" 580 sentences = [sentence, "goodbye my friends", "i am a sentence"] 581 embedding_dim = basic_model.get_sentence_embedding_dimension() 582 583 single_input = {"input": sentence} 584 emb_single_input = loaded_pyfunc.predict(single_input) 585 586 assert isinstance(emb_single_input, dict) 587 assert len(emb_single_input["data"]) == 1 588 assert isinstance(emb_single_input["data"][0], dict) 589 assert emb_single_input["data"][0]["embedding"].shape == (embedding_dim,) 590 assert emb_single_input["usage"]["prompt_tokens"] == 8 591 592 multiple_input = {"input": sentences} 593 emb_multiple_input = loaded_pyfunc.predict(multiple_input) 594 595 assert isinstance(emb_multiple_input, dict) 596 assert len(emb_multiple_input["data"]) == 3 597 assert emb_multiple_input["data"][0]["embedding"].shape == (embedding_dim,) 598 assert emb_multiple_input["usage"]["prompt_tokens"] == 19