test_transformers_model_export.py
1 import base64 2 import gc 3 import importlib.util 4 import json 5 import math 6 import os 7 import pathlib 8 import re 9 import shutil 10 import textwrap 11 from pathlib import Path 12 from unittest import mock 13 14 import huggingface_hub 15 import librosa 16 import numpy as np 17 import pandas as pd 18 import pytest 19 import torch 20 import transformers 21 import yaml 22 from datasets import load_dataset 23 from huggingface_hub import ModelCard 24 from packaging.version import Version 25 26 import mlflow 27 import mlflow.pyfunc.scoring_server as pyfunc_scoring_server 28 from mlflow import pyfunc 29 from mlflow.deployments import PredictionsResponse 30 from mlflow.exceptions import MlflowException 31 from mlflow.models import Model, ModelSignature, infer_signature 32 from mlflow.models.model import METADATA_FILES 33 from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository 34 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 35 from mlflow.transformers import ( 36 _CARD_DATA_FILE_NAME, 37 _CARD_TEXT_FILE_NAME, 38 _build_pipeline_from_model_input, 39 _fetch_model_card, 40 _get_task_for_model, 41 _is_model_distributed_in_memory, 42 _should_add_pyfunc_to_model, 43 _TransformersWrapper, 44 _try_import_conversational_pipeline, 45 _validate_llm_inference_task_type, 46 _write_card_data, 47 _write_license_information, 48 get_default_conda_env, 49 get_default_pip_requirements, 50 ) 51 from mlflow.types.schema import Array, ColSpec, DataType, ParamSchema, ParamSpec, Schema 52 from mlflow.utils.environment import _mlflow_conda_env 53 54 from tests.helper_functions import ( 55 _assert_pip_requirements, 56 _compare_conda_env_requirements, 57 _compare_logged_code_paths, 58 _get_deps_from_requirement_file, 59 _mlflow_major_version_string, 60 assert_register_model_called_with_local_model_path, 61 flaky, 62 pyfunc_scoring_endpoint, 63 pyfunc_serve_and_score_model, 64 ) 65 from tests.transformers.helper import ( 66 CHAT_TEMPLATE, 67 IS_NEW_FEATURE_EXTRACTION_API, 68 IS_TRANSFORMERS_V5_OR_LATER, 69 ) 70 from tests.transformers.test_transformers_peft_model import SKIP_IF_PEFT_NOT_AVAILABLE 71 72 # NB: Some pipelines under test in this suite come very close or outright exceed the 73 # default runner containers specs of 7GB RAM. Due to this inability to run the suite without 74 # generating a SIGTERM Error (143), some tests are marked as local only. 75 # See: https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted- \ 76 # runners#supported-runners-and-hardware-resources for instance specs. 77 RUNNING_IN_GITHUB_ACTIONS = os.environ.get("GITHUB_ACTIONS") == "true" 78 GITHUB_ACTIONS_SKIP_REASON = "Test consumes too much memory" 79 80 skip_transformers_v5_or_later = pytest.mark.skipif( 81 IS_TRANSFORMERS_V5_OR_LATER, 82 reason="Incompatible API changes in transformers 5.x", 83 ) 84 image_url = "https://raw.githubusercontent.com/mlflow/mlflow/master/tests/datasets/cat.png" 85 image_file_path = pathlib.Path(pathlib.Path(__file__).parent.parent, "datasets", "cat.png") 86 # Test that can only be run locally: 87 # - Summarization pipeline tests 88 # - TextClassifier pipeline tests 89 # - Text2TextGeneration pipeline tests 90 # - Conversational pipeline tests 91 92 93 @pytest.fixture(autouse=True) 94 def force_gc(): 95 # This reduces the memory pressure for the usage of the larger pipeline fixtures ~500MB - 1GB 96 gc.disable() 97 gc.collect() 98 gc.set_threshold(0) 99 gc.collect() 100 gc.enable() 101 102 103 @pytest.fixture 104 def model_path(tmp_path): 105 model_path = tmp_path.joinpath("model") 106 yield model_path 107 108 # Pytest keeps the temporary directory created by `tmp_path` fixture for 3 recent test sessions 109 # by default. This is useful for debugging during local testing, but in CI it just wastes the 110 # disk space. 111 if os.environ.get("GITHUB_ACTIONS") == "true": 112 shutil.rmtree(model_path, ignore_errors=True) 113 114 115 @pytest.fixture 116 def transformers_custom_env(tmp_path): 117 conda_env = tmp_path.joinpath("conda_env.yml") 118 _mlflow_conda_env(conda_env, additional_pip_deps=["transformers"]) 119 return conda_env 120 121 122 @pytest.fixture 123 def mock_pyfunc_wrapper(): 124 return mlflow.transformers._TransformersWrapper("mock") 125 126 127 @pytest.fixture 128 @flaky() 129 def image_for_test(): 130 dataset = load_dataset("hf-internal-testing/dummy_image_text_data") 131 return dataset["train"]["image"][3] 132 133 134 @pytest.mark.parametrize( 135 ("pipeline", "expected_requirements"), 136 [ 137 ("small_qa_pipeline", {"transformers", "torch", "torchvision"}), 138 pytest.param( 139 "peft_pipeline", 140 {"peft", "transformers", "torch", "torchvision"}, 141 marks=SKIP_IF_PEFT_NOT_AVAILABLE, 142 ), 143 ], 144 ) 145 def test_default_requirements(pipeline, expected_requirements, request): 146 if "torch" in expected_requirements and importlib.util.find_spec("accelerate"): 147 expected_requirements.add("accelerate") 148 149 model = request.getfixturevalue(pipeline).model 150 pip_requirements = get_default_pip_requirements(model) 151 conda_requirements = get_default_conda_env(model)["dependencies"][2]["pip"] 152 153 def _strip_requirements(requirements): 154 return {req.split("==")[0] for req in requirements} 155 156 assert _strip_requirements(pip_requirements) == expected_requirements 157 assert _strip_requirements(conda_requirements) == (expected_requirements | {"mlflow"}) 158 159 160 def test_inference_task_validation(small_qa_pipeline): 161 with pytest.raises( 162 MlflowException, match="The task provided is invalid. 'llm/v1/invalid' is not" 163 ): 164 _validate_llm_inference_task_type("llm/v1/invalid", "text-generation") 165 with pytest.raises( 166 MlflowException, match="The task provided is invalid. 'llm/v1/completions' is not" 167 ): 168 _validate_llm_inference_task_type("llm/v1/completions", small_qa_pipeline) 169 _validate_llm_inference_task_type("llm/v1/completions", "text-generation") 170 171 172 @pytest.mark.parametrize( 173 ("model", "result"), 174 [ 175 ("small_qa_pipeline", True), 176 ("small_multi_modal_pipeline", False), 177 ("small_vision_model", True), 178 ], 179 ) 180 def test_pipeline_eligibility_for_pyfunc_registration(model, result, request): 181 pipeline = request.getfixturevalue(model) 182 assert _should_add_pyfunc_to_model(pipeline) == result 183 184 185 def test_component_multi_modal_model_ineligible_for_pyfunc(component_multi_modal): 186 task = transformers.pipelines.get_task(component_multi_modal["model"].name_or_path) 187 pipeline = _build_pipeline_from_model_input(component_multi_modal, task) 188 assert not _should_add_pyfunc_to_model(pipeline) 189 190 191 def test_pipeline_construction_from_base_nlp_model(small_qa_pipeline): 192 generated = _build_pipeline_from_model_input( 193 {"model": small_qa_pipeline.model, "tokenizer": small_qa_pipeline.tokenizer}, 194 "question-answering", 195 ) 196 assert isinstance(generated, type(small_qa_pipeline)) 197 assert isinstance(generated.tokenizer, type(small_qa_pipeline.tokenizer)) 198 199 200 def test_pipeline_construction_from_base_vision_model(small_vision_model): 201 model = {"model": small_vision_model.model, "tokenizer": small_vision_model.tokenizer} 202 if IS_NEW_FEATURE_EXTRACTION_API: 203 model.update({"image_processor": small_vision_model.image_processor}) 204 else: 205 model.update({"feature_extractor": small_vision_model.feature_extractor}) 206 generated = _build_pipeline_from_model_input(model, task="image-classification") 207 assert isinstance(generated, type(small_vision_model)) 208 assert isinstance(generated.tokenizer, type(small_vision_model.tokenizer)) 209 if IS_NEW_FEATURE_EXTRACTION_API: 210 assert isinstance(generated.image_processor, type(small_vision_model.image_processor)) 211 else: 212 assert isinstance(generated.feature_extractor, transformers.MobileNetV2ImageProcessor) 213 214 215 def test_saving_with_invalid_dict_as_model(model_path): 216 with pytest.raises( 217 MlflowException, match="Invalid dictionary submitted for 'transformers_model'. The " 218 ): 219 mlflow.transformers.save_model(transformers_model={"invalid": "key"}, path=model_path) 220 221 with pytest.raises( 222 MlflowException, match="The 'transformers_model' dictionary must have an entry" 223 ): 224 mlflow.transformers.save_model( 225 transformers_model={"tokenizer": "some_tokenizer"}, path=model_path 226 ) 227 228 229 def test_model_card_acquisition_vision_model(small_vision_model): 230 model_provided_card = _fetch_model_card(small_vision_model.model.name_or_path) 231 assert model_provided_card.data.to_dict()["tags"] == ["vision", "image-classification"] 232 assert len(model_provided_card.text) > 0 233 234 235 @pytest.mark.parametrize( 236 ("repo_id", "license_file"), 237 [ 238 ("google/mobilenet_v2_1.0_224", "LICENSE.txt"), # no license declared 239 ("csarron/mobilebert-uncased-squad-v2", "LICENSE.txt"), # mit license 240 ("codellama/CodeLlama-34b-hf", "LICENSE"), # custom license 241 ("openai/whisper-tiny", "LICENSE.txt"), # apache license 242 ("stabilityai/stable-code-3b", "LICENSE"), # custom 243 ("mistralai/Mixtral-8x7B-Instruct-v0.1", "LICENSE.txt"), # apache 244 ], 245 ) 246 def test_license_acquisition(repo_id, license_file, tmp_path): 247 card_data = _fetch_model_card(repo_id) 248 _write_license_information(repo_id, card_data, tmp_path) 249 license_file = list(tmp_path.glob("*LICENSE*")) 250 assert len(license_file) == 1 251 assert tmp_path.joinpath(license_file[0]).stat().st_size > 0 252 253 254 def test_license_fallback(tmp_path): 255 _write_license_information("not a real repo", None, tmp_path) 256 assert tmp_path.joinpath("LICENSE.txt").stat().st_size > 0 257 258 259 def test_vision_model_save_pipeline_with_defaults(small_vision_model, model_path): 260 mlflow.transformers.save_model(transformers_model=small_vision_model, path=model_path) 261 # validate inferred pip requirements 262 requirements = model_path.joinpath("requirements.txt").read_text() 263 reqs = {req.split("==")[0] for req in requirements.split("\n")} 264 expected_requirements = {"torch", "torchvision", "transformers"} 265 assert reqs.intersection(expected_requirements) == expected_requirements 266 # validate inferred model card data 267 card_data = yaml.safe_load(model_path.joinpath("model_card_data.yaml").read_bytes()) 268 assert card_data["tags"] == ["vision", "image-classification"] 269 # verify the license file has been written 270 license_file = model_path.joinpath("LICENSE.txt").read_text() 271 assert len(license_file) > 0 272 # Validate inferred model card text 273 with model_path.joinpath("model_card.md").open() as file: 274 card_text = file.read() 275 assert len(card_text) > 0 276 # Validate conda.yaml 277 conda_env = yaml.safe_load(model_path.joinpath("conda.yaml").read_bytes()) 278 assert {req.split("==")[0] for req in conda_env["dependencies"][2]["pip"]}.intersection( 279 expected_requirements 280 ) == expected_requirements 281 # Validate the MLModel file 282 mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes()) 283 flavor_config = mlmodel["flavors"]["transformers"] 284 assert flavor_config["instance_type"] == "ImageClassificationPipeline" 285 assert flavor_config["pipeline_model_type"] == "MobileNetV2ForImageClassification" 286 assert flavor_config["task"] == "image-classification" 287 assert flavor_config["source_model_name"] == "google/mobilenet_v2_1.0_224" 288 289 290 def test_vision_model_save_model_for_task_and_card_inference(small_vision_model, model_path): 291 mlflow.transformers.save_model(transformers_model=small_vision_model, path=model_path) 292 # validate inferred pip requirements 293 requirements = model_path.joinpath("requirements.txt").read_text() 294 reqs = {req.split("==")[0] for req in requirements.split("\n")} 295 expected_requirements = {"torch", "torchvision", "transformers"} 296 assert reqs.intersection(expected_requirements) == expected_requirements 297 # validate inferred model card data 298 card_data = yaml.safe_load(model_path.joinpath("model_card_data.yaml").read_bytes()) 299 assert card_data["tags"] == ["vision", "image-classification"] 300 # Validate inferred model card text 301 card_text = model_path.joinpath("model_card.md").read_text(encoding="utf-8") 302 assert len(card_text) > 0 303 # verify the license file has been written 304 license_file = model_path.joinpath("LICENSE.txt").read_text() 305 assert len(license_file) > 0 306 # Validate the MLModel file 307 mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes()) 308 flavor_config = mlmodel["flavors"]["transformers"] 309 assert flavor_config["instance_type"] == "ImageClassificationPipeline" 310 assert flavor_config["pipeline_model_type"] == "MobileNetV2ForImageClassification" 311 assert flavor_config["task"] == "image-classification" 312 assert flavor_config["source_model_name"] == "google/mobilenet_v2_1.0_224" 313 314 315 def test_qa_model_save_model_for_task_and_card_inference(small_qa_pipeline, model_path): 316 mlflow.transformers.save_model( 317 transformers_model={ 318 "model": small_qa_pipeline.model, 319 "tokenizer": small_qa_pipeline.tokenizer, 320 }, 321 path=model_path, 322 ) 323 # validate inferred pip requirements 324 with model_path.joinpath("requirements.txt").open() as file: 325 requirements = file.read() 326 reqs = {req.split("==")[0] for req in requirements.split("\n")} 327 expected_requirements = {"torch", "transformers"} 328 assert reqs.intersection(expected_requirements) == expected_requirements 329 # validate that the card was acquired by model reference 330 card_data = yaml.safe_load(model_path.joinpath("model_card_data.yaml").read_bytes()) 331 assert card_data["datasets"] == ["squad_v2"] 332 assert "tags" in card_data 333 # verify the license file has been written 334 license_file = model_path.joinpath("LICENSE.txt").read_text() 335 assert len(license_file) > 0 336 # Validate inferred model card text 337 with model_path.joinpath("model_card.md").open() as file: 338 card_text = file.read() 339 assert len(card_text) > 0 340 # validate MLmodel files 341 mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes()) 342 flavor_config = mlmodel["flavors"]["transformers"] 343 assert flavor_config["instance_type"] == "QuestionAnsweringPipeline" 344 assert flavor_config["pipeline_model_type"] == "MobileBertForQuestionAnswering" 345 assert flavor_config["task"] == "question-answering" 346 assert flavor_config["source_model_name"] == "csarron/mobilebert-uncased-squad-v2" 347 348 349 def test_qa_model_save_and_override_card(small_qa_pipeline, model_path): 350 supplied_card = """ 351 --- 352 language: en 353 license: bsd 354 --- 355 356 # I made a new model! 357 """ 358 card_info = textwrap.dedent(supplied_card) 359 card = ModelCard(card_info) 360 # save the model instance 361 mlflow.transformers.save_model( 362 transformers_model=small_qa_pipeline, 363 path=model_path, 364 model_card=card, 365 ) 366 # validate that the card was acquired by model reference 367 card_data = yaml.safe_load(model_path.joinpath("model_card_data.yaml").read_bytes()) 368 assert card_data["language"] == "en" 369 assert card_data["license"] == "bsd" 370 # Validate inferred model card text 371 with model_path.joinpath("model_card.md").open() as file: 372 card_text = file.read() 373 # verify the license file has been written 374 license_file = model_path.joinpath("LICENSE.txt").read_text() 375 assert len(license_file) > 0 376 assert card_text.startswith("\n# I made a new model!") 377 # validate MLmodel files 378 mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes()) 379 flavor_config = mlmodel["flavors"]["transformers"] 380 assert flavor_config["instance_type"] == "QuestionAnsweringPipeline" 381 assert flavor_config["pipeline_model_type"] == "MobileBertForQuestionAnswering" 382 assert flavor_config["task"] == "question-answering" 383 assert flavor_config["source_model_name"] == "csarron/mobilebert-uncased-squad-v2" 384 385 386 def test_basic_save_model_and_load_text_pipeline(text_classification_pipeline, model_path): 387 mlflow.transformers.save_model( 388 transformers_model={ 389 "model": text_classification_pipeline.model, 390 "tokenizer": text_classification_pipeline.tokenizer, 391 }, 392 path=model_path, 393 ) 394 loaded = mlflow.transformers.load_model(model_path) 395 result = loaded("MLflow is a really neat tool!") 396 assert result[0]["label"] == "POSITIVE" 397 assert result[0]["score"] > 0.5 398 399 400 @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float64]) 401 def test_basic_save_model_with_torch_dtype(text2text_generation_pipeline, model_path, dtype): 402 mlflow.transformers.save_model( 403 transformers_model=text2text_generation_pipeline, 404 path=model_path, 405 torch_dtype=dtype, 406 ) 407 408 loaded = mlflow.transformers.load_model(model_path) 409 assert loaded.model.dtype == dtype 410 411 loaded = mlflow.transformers.load_model(model_path, torch_dtype=torch.float32) 412 assert loaded.model.dtype == torch.float32 413 414 415 def test_basic_save_model_and_load_vision_pipeline(small_vision_model, model_path, image_for_test): 416 if IS_NEW_FEATURE_EXTRACTION_API: 417 model = { 418 "model": small_vision_model.model, 419 "image_processor": small_vision_model.image_processor, 420 "tokenizer": small_vision_model.tokenizer, 421 } 422 else: 423 model = { 424 "model": small_vision_model.model, 425 "feature_extractor": small_vision_model.feature_extractor, 426 "tokenizer": small_vision_model.tokenizer, 427 } 428 mlflow.transformers.save_model( 429 transformers_model=model, 430 path=model_path, 431 ) 432 loaded = mlflow.transformers.load_model(model_path) 433 prediction = loaded(image_for_test) 434 assert prediction[0]["label"] == "wall clock" 435 assert prediction[0]["score"] > 0.5 436 437 438 @flaky() 439 def test_multi_modal_pipeline_save_and_load(small_multi_modal_pipeline, model_path, image_for_test): 440 mlflow.transformers.save_model(transformers_model=small_multi_modal_pipeline, path=model_path) 441 question = "How many wall clocks are in the picture?" 442 # Load components 443 components = mlflow.transformers.load_model(model_path, return_type="components") 444 if IS_NEW_FEATURE_EXTRACTION_API: 445 expected_components = {"model", "task", "tokenizer", "image_processor"} 446 else: 447 expected_components = {"model", "task", "tokenizer", "feature_extractor"} 448 assert set(components.keys()).intersection(expected_components) == expected_components 449 constructed_pipeline = transformers.pipeline(**components) 450 answer = constructed_pipeline(image=image_for_test, question=question) 451 assert answer[0]["answer"] == "1" 452 # Load pipeline 453 pipeline = mlflow.transformers.load_model(model_path) 454 pipeline_answer = pipeline(image=image_for_test, question=question) 455 assert pipeline_answer[0]["answer"] == "1" 456 # Test invalid loading mode 457 with pytest.raises(MlflowException, match="The specified return_type mode 'magic' is"): 458 mlflow.transformers.load_model(model_path, return_type="magic") 459 460 461 def test_multi_modal_component_save_and_load(component_multi_modal, model_path, image_for_test): 462 if IS_NEW_FEATURE_EXTRACTION_API: 463 processor = component_multi_modal["image_processor"] 464 else: 465 processor = component_multi_modal["feature_extractor"] 466 mlflow.transformers.save_model( 467 transformers_model=component_multi_modal, 468 path=model_path, 469 processor=processor, 470 ) 471 # Ensure that the appropriate Processor object was detected and loaded with the pipeline. 472 loaded_components = mlflow.transformers.load_model( 473 model_uri=model_path, return_type="components" 474 ) 475 assert isinstance(loaded_components["model"], transformers.ViltForQuestionAnswering) 476 assert isinstance(loaded_components["tokenizer"], transformers.BertTokenizerFast) 477 # This is to simulate a post-processing processor that would be used externally to a Pipeline 478 # This isn't being tested on an actual use case of such a model type due to the size of 479 # these types of models that have this interface being ill-suited for CI testing. 480 481 if IS_NEW_FEATURE_EXTRACTION_API: 482 processor_key = "image_processor" 483 assert isinstance(loaded_components[processor_key], transformers.ViltImageProcessor) 484 else: 485 processor_key = "feature_extractor" 486 assert isinstance(loaded_components[processor_key], transformers.ViltProcessor) 487 assert isinstance(loaded_components["processor"], transformers.ViltProcessor) 488 if not IS_NEW_FEATURE_EXTRACTION_API: 489 # NB: This simulated behavior is no longer valid in versions 4.27.4 and above. 490 # With the port of functionality away from feature extractor types, the new architecture 491 # for multi-modal models is entirely pipeline based. 492 # Make sure that the component usage works correctly when extracted from inference loading 493 model = loaded_components["model"] 494 processor = loaded_components["processor"] 495 question = "What are the cats doing?" 496 inputs = processor(image_for_test, question, return_tensors="pt") 497 outputs = model(**inputs) 498 logits = outputs.logits 499 idx = logits.argmax(-1).item() 500 answer = model.config.id2label[idx] 501 assert answer == "sleeping" 502 503 504 @flaky() 505 def test_pipeline_saved_model_with_processor_cannot_be_loaded_as_pipeline( 506 component_multi_modal, model_path 507 ): 508 invalid_pipeline = transformers.pipeline( 509 task="visual-question-answering", **component_multi_modal 510 ) 511 if IS_NEW_FEATURE_EXTRACTION_API: 512 processor = component_multi_modal["image_processor"] 513 else: 514 processor = component_multi_modal["feature_extractor"] 515 mlflow.transformers.save_model( 516 transformers_model=invalid_pipeline, 517 path=model_path, 518 processor=processor, # If this is specified, we cannot guarantee correct inference 519 ) 520 with pytest.raises( 521 MlflowException, match="This model has been saved with a processor. Processor objects" 522 ): 523 mlflow.transformers.load_model(model_uri=model_path, return_type="pipeline") 524 525 526 def test_component_saved_model_with_processor_cannot_be_loaded_as_pipeline( 527 component_multi_modal, model_path 528 ): 529 if IS_NEW_FEATURE_EXTRACTION_API: 530 processor = component_multi_modal["image_processor"] 531 else: 532 processor = component_multi_modal["feature_extractor"] 533 mlflow.transformers.save_model( 534 transformers_model=component_multi_modal, 535 path=model_path, 536 processor=processor, 537 ) 538 with pytest.raises( 539 MlflowException, 540 match="This model has been saved with a processor. Processor objects are not compatible " 541 "with Pipelines. Please load", 542 ): 543 mlflow.transformers.load_model(model_uri=model_path, return_type="pipeline") 544 545 546 @pytest.mark.parametrize("should_start_run", [True, False]) 547 def test_log_and_load_transformers_pipeline(small_qa_pipeline, tmp_path, should_start_run): 548 try: 549 if should_start_run: 550 mlflow.start_run() 551 artifact_path = "transformers" 552 conda_env = tmp_path.joinpath("conda_env.yaml") 553 _mlflow_conda_env(conda_env, additional_pip_deps=["transformers"]) 554 model_info = mlflow.transformers.log_model( 555 small_qa_pipeline, 556 name=artifact_path, 557 conda_env=str(conda_env), 558 ) 559 reloaded_model = mlflow.transformers.load_model( 560 model_uri=model_info.model_uri, return_type="pipeline" 561 ) 562 assert ( 563 reloaded_model( 564 question="Who's house?", context="The house is owned by a man named Run." 565 )["answer"] 566 == "Run" 567 ) 568 model_path = pathlib.Path(_download_artifact_from_uri(artifact_uri=model_info.model_uri)) 569 model_config = Model.load(str(model_path.joinpath("MLmodel"))) 570 assert pyfunc.FLAVOR_NAME in model_config.flavors 571 assert pyfunc.ENV in model_config.flavors[pyfunc.FLAVOR_NAME] 572 env_path = model_config.flavors[pyfunc.FLAVOR_NAME][pyfunc.ENV]["conda"] 573 assert model_path.joinpath(env_path).exists() 574 finally: 575 mlflow.end_run() 576 577 578 def test_load_pipeline_from_remote_uri_succeeds( 579 text_classification_pipeline, model_path, mock_s3_bucket 580 ): 581 mlflow.transformers.save_model(transformers_model=text_classification_pipeline, path=model_path) 582 artifact_root = f"s3://{mock_s3_bucket}" 583 artifact_path = "model" 584 artifact_repo = S3ArtifactRepository(artifact_root) 585 artifact_repo.log_artifacts(model_path, artifact_path=artifact_path) 586 model_uri = os.path.join(artifact_root, artifact_path) 587 loaded = mlflow.transformers.load_model(model_uri=str(model_uri), return_type="pipeline") 588 assert loaded("I like it when CI checks pass and are never flaky!")[0]["label"] == "POSITIVE" 589 590 591 def test_transformers_log_model_calls_register_model(small_qa_pipeline, tmp_path): 592 artifact_path = "transformers" 593 register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model") 594 with mlflow.start_run(), register_model_patch: 595 conda_env = tmp_path.joinpath("conda_env.yaml") 596 _mlflow_conda_env(conda_env, additional_pip_deps=["transformers", "torch", "torchvision"]) 597 model_info = mlflow.transformers.log_model( 598 small_qa_pipeline, 599 name=artifact_path, 600 conda_env=str(conda_env), 601 registered_model_name="Question-Answering Model 1", 602 ) 603 assert_register_model_called_with_local_model_path( 604 register_model_mock=mlflow.tracking._model_registry.fluent._register_model, 605 model_uri=model_info.model_uri, 606 registered_model_name="Question-Answering Model 1", 607 ) 608 609 610 def test_transformers_log_model_with_no_registered_model_name(small_vision_model, tmp_path): 611 if IS_NEW_FEATURE_EXTRACTION_API: 612 model = { 613 "model": small_vision_model.model, 614 "image_processor": small_vision_model.image_processor, 615 "tokenizer": small_vision_model.tokenizer, 616 } 617 else: 618 model = { 619 "model": small_vision_model.model, 620 "feature_extractor": small_vision_model.feature_extractor, 621 "tokenizer": small_vision_model.tokenizer, 622 } 623 624 artifact_path = "transformers" 625 registered_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model") 626 with mlflow.start_run(), registered_model_patch: 627 conda_env = tmp_path.joinpath("conda_env.yaml") 628 _mlflow_conda_env(conda_env, additional_pip_deps=["tensorflow", "transformers"]) 629 mlflow.transformers.log_model( 630 model, 631 name=artifact_path, 632 conda_env=str(conda_env), 633 ) 634 mlflow.tracking._model_registry.fluent._register_model.assert_not_called() 635 636 637 def test_transformers_log_model_with_prompt_template_sets_return_full_text_false( 638 text_generation_pipeline, 639 ): 640 artifact_path = "text_generation_with_prompt_template" 641 prompt_template = "User: {prompt}" 642 643 with mlflow.start_run(): 644 model_info = mlflow.transformers.log_model( 645 text_generation_pipeline, 646 name=artifact_path, 647 prompt_template=prompt_template, 648 ) 649 650 model_path = pathlib.Path(_download_artifact_from_uri(model_info.model_uri)) 651 mlmodel = Model.load(str(model_path.joinpath("MLmodel"))) 652 653 pyfunc_flavor = mlmodel.flavors["python_function"] 654 config = pyfunc_flavor.get("config") 655 656 assert config.get("return_full_text") is False 657 658 659 def test_transformers_save_persists_requirements_in_mlflow_directory( 660 small_qa_pipeline, model_path, transformers_custom_env 661 ): 662 mlflow.transformers.save_model( 663 transformers_model=small_qa_pipeline, 664 path=model_path, 665 conda_env=str(transformers_custom_env), 666 ) 667 saved_pip_req_path = model_path.joinpath("requirements.txt") 668 _compare_conda_env_requirements(transformers_custom_env, saved_pip_req_path) 669 670 671 def test_transformers_log_with_pip_requirements(small_multi_modal_pipeline, tmp_path): 672 expected_mlflow_version = _mlflow_major_version_string() 673 674 requirements_file = tmp_path.joinpath("requirements.txt") 675 requirements_file.write_text("coolpackage") 676 with mlflow.start_run(): 677 model_info = mlflow.transformers.log_model( 678 small_multi_modal_pipeline, name="model", pip_requirements=str(requirements_file) 679 ) 680 _assert_pip_requirements( 681 model_info.model_uri, [expected_mlflow_version, "coolpackage"], strict=True 682 ) 683 with mlflow.start_run(): 684 model_info = mlflow.transformers.log_model( 685 small_multi_modal_pipeline, 686 name="model", 687 pip_requirements=[f"-r {requirements_file}", "alsocool"], 688 ) 689 _assert_pip_requirements( 690 model_info.model_uri, 691 [expected_mlflow_version, "coolpackage", "alsocool"], 692 strict=True, 693 ) 694 with mlflow.start_run(): 695 model_info = mlflow.transformers.log_model( 696 small_multi_modal_pipeline, 697 name="model", 698 pip_requirements=[f"-c {requirements_file}", "constrainedcool"], 699 ) 700 _assert_pip_requirements( 701 model_info.model_uri, 702 [expected_mlflow_version, "constrainedcool", "-c constraints.txt"], 703 ["coolpackage"], 704 strict=True, 705 ) 706 707 708 def test_transformers_log_with_extra_pip_requirements(small_multi_modal_pipeline, tmp_path): 709 expected_mlflow_version = _mlflow_major_version_string() 710 default_requirements = mlflow.transformers.get_default_pip_requirements( 711 small_multi_modal_pipeline.model 712 ) 713 requirements_file = tmp_path.joinpath("requirements.txt") 714 requirements_file.write_text("coolpackage") 715 with mlflow.start_run(): 716 model_info = mlflow.transformers.log_model( 717 small_multi_modal_pipeline, name="model", extra_pip_requirements=str(requirements_file) 718 ) 719 _assert_pip_requirements( 720 model_info.model_uri, 721 [expected_mlflow_version, *default_requirements, "coolpackage"], 722 strict=True, 723 ) 724 with mlflow.start_run(): 725 model_info = mlflow.transformers.log_model( 726 small_multi_modal_pipeline, 727 name="model", 728 extra_pip_requirements=[f"-r {requirements_file}", "alsocool"], 729 ) 730 _assert_pip_requirements( 731 model_info.model_uri, 732 [expected_mlflow_version, *default_requirements, "coolpackage", "alsocool"], 733 strict=True, 734 ) 735 with mlflow.start_run(): 736 model_info = mlflow.transformers.log_model( 737 small_multi_modal_pipeline, 738 name="model", 739 extra_pip_requirements=[f"-c {requirements_file}", "constrainedcool"], 740 ) 741 _assert_pip_requirements( 742 model_info.model_uri, 743 [ 744 expected_mlflow_version, 745 *default_requirements, 746 "constrainedcool", 747 "-c constraints.txt", 748 ], 749 ["coolpackage"], 750 strict=True, 751 ) 752 753 754 def test_transformers_log_with_duplicate_extra_pip_requirements(small_multi_modal_pipeline): 755 with pytest.raises( 756 MlflowException, match="The specified requirements versions are incompatible" 757 ): 758 with mlflow.start_run(): 759 mlflow.transformers.log_model( 760 small_multi_modal_pipeline, 761 name="model", 762 extra_pip_requirements=["transformers==1.1.0"], 763 ) 764 765 766 def test_transformers_pt_model_save_without_conda_env_uses_default_env_with_expected_dependencies( 767 small_qa_pipeline, model_path 768 ): 769 mlflow.transformers.save_model(small_qa_pipeline, model_path) 770 _assert_pip_requirements( 771 model_path, mlflow.transformers.get_default_pip_requirements(small_qa_pipeline.model) 772 ) 773 pip_requirements = _get_deps_from_requirement_file(model_path) 774 assert "tensorflow" not in pip_requirements 775 assert "accelerate" in pip_requirements 776 assert "torch" in pip_requirements 777 778 779 @pytest.mark.skipif( 780 importlib.util.find_spec("accelerate") is not None, reason="fails when accelerate is installed" 781 ) 782 def test_transformers_pt_model_save_dependencies_without_accelerate( 783 text_generation_pipeline, model_path 784 ): 785 mlflow.transformers.save_model(text_generation_pipeline, model_path) 786 _assert_pip_requirements( 787 model_path, mlflow.transformers.get_default_pip_requirements(text_generation_pipeline.model) 788 ) 789 pip_requirements = _get_deps_from_requirement_file(model_path) 790 assert "tensorflow" not in pip_requirements 791 assert "accelerate" not in pip_requirements 792 assert "torch" in pip_requirements 793 794 795 def test_transformers_pt_model_log_without_conda_env_uses_default_env_with_expected_dependencies( 796 small_qa_pipeline, 797 ): 798 artifact_path = "model" 799 with mlflow.start_run(): 800 model_info = mlflow.transformers.log_model(small_qa_pipeline, name=artifact_path) 801 _assert_pip_requirements( 802 model_info.model_uri, 803 mlflow.transformers.get_default_pip_requirements(small_qa_pipeline.model), 804 ) 805 pip_requirements = _get_deps_from_requirement_file(model_info.model_uri) 806 assert "tensorflow" not in pip_requirements 807 assert "torch" in pip_requirements 808 809 810 def test_log_model_with_code_paths(small_qa_pipeline): 811 artifact_path = "model" 812 with ( 813 mlflow.start_run(), 814 mock.patch("mlflow.transformers._add_code_from_conf_to_system_path") as add_mock, 815 ): 816 model_info = mlflow.transformers.log_model( 817 small_qa_pipeline, name=artifact_path, code_paths=[__file__] 818 ) 819 model_uri = model_info.model_uri 820 _compare_logged_code_paths(__file__, model_uri, mlflow.transformers.FLAVOR_NAME) 821 mlflow.transformers.load_model(model_uri) 822 add_mock.assert_called() 823 824 825 def test_non_existent_model_card_entry(small_qa_pipeline, model_path): 826 with mock.patch("mlflow.transformers._fetch_model_card", return_value=None): 827 mlflow.transformers.save_model(transformers_model=small_qa_pipeline, path=model_path) 828 829 contents = {item.name for item in model_path.iterdir()} 830 assert not contents.intersection({"model_card.txt", "model_card_data.yaml"}) 831 832 833 def test_huggingface_hub_not_installed(small_qa_pipeline, model_path): 834 with mock.patch.dict("sys.modules", {"huggingface_hub": None}): 835 result = mlflow.transformers._fetch_model_card(small_qa_pipeline.model.name_or_path) 836 837 assert result is None 838 839 mlflow.transformers.save_model(transformers_model=small_qa_pipeline, path=model_path) 840 841 contents = {item.name for item in model_path.iterdir()} 842 assert not contents.intersection({"model_card.txt", "model_card_data.yaml"}) 843 844 license_data = model_path.joinpath("LICENSE.txt").read_text() 845 assert license_data.rstrip().endswith("mobilebert-uncased-squad-v2") 846 847 848 @pytest.mark.skipif( 849 _try_import_conversational_pipeline() is None, 850 reason="Conversation model is deprecated and removed.", 851 ) 852 def test_save_pipeline_without_defined_components(small_conversational_model, model_path): 853 # This pipeline type explicitly does not have a configuration for an image_processor 854 with mlflow.start_run(): 855 mlflow.transformers.save_model( 856 transformers_model=small_conversational_model, path=model_path 857 ) 858 pipe = mlflow.transformers.load_model(model_path) 859 convo = transformers.Conversation("How are you today?") 860 convo = pipe(convo) 861 assert convo.generated_responses[-1] == "good" 862 863 864 @flaky() 865 def test_invalid_model_type_without_registered_name_does_not_save(model_path): 866 invalid_pipeline = transformers.pipeline(task="text-generation", model="gpt2") 867 del invalid_pipeline.model.name_or_path 868 869 with pytest.raises(MlflowException, match="The submitted model type"): 870 mlflow.transformers.save_model(transformers_model=invalid_pipeline, path=model_path) 871 872 873 def test_invalid_input_to_pyfunc_signature_output_wrapper_raises(component_multi_modal): 874 with pytest.raises(MlflowException, match="The pipeline type submitted is not a valid"): 875 mlflow.transformers.generate_signature_output(component_multi_modal["model"], "bogus") 876 877 878 @pytest.mark.parametrize( 879 "inference_payload", 880 [ 881 ({"question": "Who's house?", "context": "The house is owned by a man named Run."}), 882 ([ 883 { 884 "question": "What color is it?", 885 "context": "Some people said it was green but I know that it's definitely blue", 886 }, 887 { 888 "question": "How do the wheels go?", 889 "context": "The wheels on the bus go round and round. Round and round.", 890 }, 891 ]), 892 ([ 893 { 894 "question": "What color is it?", 895 "context": "Some people said it was green but I know that it's pink.", 896 }, 897 { 898 "context": "The people on the bus go up and down. Up and down.", 899 "question": "How do the people go?", 900 }, 901 ]), 902 ], 903 ) 904 def test_qa_pipeline_pyfunc_load_and_infer(small_qa_pipeline, model_path, inference_payload): 905 signature = infer_signature( 906 inference_payload, 907 mlflow.transformers.generate_signature_output(small_qa_pipeline, inference_payload), 908 ) 909 910 mlflow.transformers.save_model( 911 transformers_model=small_qa_pipeline, 912 path=model_path, 913 signature=signature, 914 ) 915 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 916 917 inference = pyfunc_loaded.predict(inference_payload) 918 919 assert isinstance(inference, list) 920 assert all(isinstance(element, str) for element in inference) 921 922 pd_input = ( 923 pd.DataFrame([inference_payload]) 924 if isinstance(inference_payload, dict) 925 else pd.DataFrame(inference_payload) 926 ) 927 pd_inference = pyfunc_loaded.predict(pd_input) 928 929 assert isinstance(pd_inference, list) 930 assert all(isinstance(element, str) for element in inference) 931 932 933 @pytest.mark.parametrize( 934 "inference_payload", 935 [ 936 image_url, 937 str(image_file_path), 938 "base64", 939 pytest.param( 940 "base64_encodebytes", 941 marks=pytest.mark.skipif( 942 Version(transformers.__version__) < Version("4.41"), 943 reason="base64 encodebytes feature not present", 944 ), 945 ), 946 ], 947 ) 948 def test_vision_pipeline_pyfunc_load_and_infer(small_vision_model, model_path, inference_payload): 949 if inference_payload == "base64": 950 inference_payload = base64.b64encode(image_file_path.read_bytes()).decode("utf-8") 951 elif inference_payload == "base64_encodebytes": 952 inference_payload = base64.encodebytes(image_file_path.read_bytes()).decode("utf-8") 953 signature = infer_signature( 954 inference_payload, 955 mlflow.transformers.generate_signature_output(small_vision_model, inference_payload), 956 ) 957 mlflow.transformers.save_model( 958 transformers_model=small_vision_model, 959 path=model_path, 960 signature=signature, 961 ) 962 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 963 predictions = pyfunc_loaded.predict(inference_payload) 964 965 transformers_loaded_model = mlflow.transformers.load_model(model_path) 966 expected_predictions = transformers_loaded_model.predict(inference_payload) 967 assert list(predictions.to_dict("records")[0].values()) == expected_predictions 968 969 970 @pytest.mark.parametrize( 971 ("data", "result"), 972 [ 973 ("muppet keyboard type", ["A man is typing a muppet on a keyboard."]), 974 ( 975 ["pencil draw paper", "pie apple eat"], 976 # NB: The result of this test case, without inference config overrides is: 977 # ["A man drawing on paper with pencil", "A man eating a pie with applies"] 978 # The inference config override forces additional insertion of more grammatically 979 # correct responses to validate that the inference config is being applied. 980 ["A man draws a pencil on a paper.", "A man eats a pie of apples."], 981 ), 982 ], 983 ) 984 def test_text2text_generation_pipeline_with_model_configs( 985 text2text_generation_pipeline, tmp_path, data, result 986 ): 987 signature = infer_signature( 988 data, mlflow.transformers.generate_signature_output(text2text_generation_pipeline, data) 989 ) 990 991 model_config = { 992 "top_k": 2, 993 "num_beams": 5, 994 "max_length": 30, 995 "temperature": 0.62, 996 "top_p": 0.85, 997 "repetition_penalty": 1.15, 998 } 999 model_path1 = tmp_path.joinpath("model1") 1000 mlflow.transformers.save_model( 1001 text2text_generation_pipeline, 1002 path=model_path1, 1003 model_config=model_config, 1004 signature=signature, 1005 ) 1006 pyfunc_loaded = mlflow.pyfunc.load_model(model_path1) 1007 1008 inference = pyfunc_loaded.predict(data) 1009 1010 assert inference == result 1011 1012 pd_input = pd.DataFrame([data]) if isinstance(data, str) else pd.DataFrame(data) 1013 pd_inference = pyfunc_loaded.predict(pd_input) 1014 assert pd_inference == result 1015 1016 model_path2 = tmp_path.joinpath("model2") 1017 signature_with_params = infer_signature( 1018 data, 1019 mlflow.transformers.generate_signature_output(text2text_generation_pipeline, data), 1020 model_config, 1021 ) 1022 mlflow.transformers.save_model( 1023 text2text_generation_pipeline, 1024 path=model_path2, 1025 signature=signature_with_params, 1026 ) 1027 pyfunc_loaded = mlflow.pyfunc.load_model(model_path2) 1028 1029 dict_inference = pyfunc_loaded.predict( 1030 data, 1031 params=model_config, 1032 ) 1033 1034 assert dict_inference == inference 1035 1036 1037 def test_text2text_generation_pipeline_with_model_config_and_params( 1038 text2text_generation_pipeline, model_path 1039 ): 1040 data = "muppet keyboard type" 1041 model_config = { 1042 "top_k": 2, 1043 "num_beams": 5, 1044 "top_p": 0.85, 1045 "repetition_penalty": 1.15, 1046 "do_sample": True, 1047 } 1048 parameters = {"top_k": 3, "max_new_tokens": 30} 1049 generated_output = mlflow.transformers.generate_signature_output( 1050 text2text_generation_pipeline, data 1051 ) 1052 signature = infer_signature( 1053 data, 1054 generated_output, 1055 parameters, 1056 ) 1057 1058 mlflow.transformers.save_model( 1059 text2text_generation_pipeline, 1060 path=model_path, 1061 model_config=model_config, 1062 signature=signature, 1063 ) 1064 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 1065 1066 # model_config and default params are all applied 1067 res = pyfunc_loaded.predict(data) 1068 applied_params = model_config.copy() 1069 applied_params.update(parameters) 1070 res2 = pyfunc_loaded.predict(data, applied_params) 1071 assert res == res2 1072 1073 assert res != pyfunc_loaded.predict(data, {"max_new_tokens": 3}) 1074 1075 # Extra params are ignored 1076 assert res == pyfunc_loaded.predict(data, {"extra_param": "extra_value"}) 1077 1078 1079 def test_text2text_generation_pipeline_with_params_success( 1080 text2text_generation_pipeline, model_path 1081 ): 1082 data = "muppet keyboard type" 1083 parameters = {"top_k": 2, "num_beams": 5, "do_sample": True} 1084 generated_output = mlflow.transformers.generate_signature_output( 1085 text2text_generation_pipeline, data 1086 ) 1087 signature = infer_signature( 1088 data, 1089 generated_output, 1090 parameters, 1091 ) 1092 1093 mlflow.transformers.save_model( 1094 text2text_generation_pipeline, 1095 path=model_path, 1096 signature=signature, 1097 ) 1098 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 1099 1100 # parameters saved with ModelSignature is applied by default 1101 res = pyfunc_loaded.predict(data) 1102 res2 = pyfunc_loaded.predict(data, parameters) 1103 assert res == res2 1104 1105 1106 def test_text2text_generation_pipeline_with_params_with_errors( 1107 text2text_generation_pipeline, model_path 1108 ): 1109 data = "muppet keyboard type" 1110 parameters = {"top_k": 2, "num_beams": 5, "invalid_param": "invalid_param", "do_sample": True} 1111 generated_output = mlflow.transformers.generate_signature_output( 1112 text2text_generation_pipeline, data 1113 ) 1114 1115 mlflow.transformers.save_model( 1116 text2text_generation_pipeline, 1117 path=model_path, 1118 signature=infer_signature( 1119 data, 1120 generated_output, 1121 parameters, 1122 ), 1123 ) 1124 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 1125 with pytest.raises( 1126 MlflowException, 1127 match=r"The params provided to the `predict` method are " 1128 r"not valid for pipeline Text2TextGenerationPipeline.", 1129 ): 1130 pyfunc_loaded.predict(data, parameters) 1131 1132 # Type validation of params failure 1133 with pytest.raises(MlflowException, match=r"Invalid parameters found"): 1134 pyfunc_loaded.predict(data, {"top_k": "2"}) 1135 1136 1137 def test_text2text_generation_pipeline_with_inferred_schema(text2text_generation_pipeline): 1138 with mlflow.start_run(): 1139 model_info = mlflow.transformers.log_model(text2text_generation_pipeline, name="my_model") 1140 pyfunc_loaded = mlflow.pyfunc.load_model(model_info.model_uri) 1141 1142 assert pyfunc_loaded.predict("muppet board nails hammer")[0].startswith("A hammer") 1143 1144 1145 @pytest.mark.parametrize( 1146 "invalid_data", 1147 [ 1148 ({"answer": "something", "context": ["nothing", "that", "makes", "sense"]}), 1149 ([{"answer": ["42"], "context": "life"}, {"unmatched": "keys", "cause": "failure"}]), 1150 ], 1151 ) 1152 def test_invalid_input_to_text2text_pipeline(text2text_generation_pipeline, invalid_data): 1153 # Adding this validation test due to the fact that we're constructing the input to the 1154 # Pipeline. The Pipeline requires a format of a pseudo-dict-like string. An example of 1155 # a valid input string: "answer: green. context: grass is primarily green in color." 1156 # We generate this string from a dict or generate a list of these strings from a list of 1157 # dictionaries. 1158 with pytest.raises( 1159 MlflowException, match=r"An invalid type has been supplied: .+\. Please supply" 1160 ): 1161 infer_signature( 1162 invalid_data, 1163 mlflow.transformers.generate_signature_output( 1164 text2text_generation_pipeline, invalid_data 1165 ), 1166 ) 1167 1168 1169 @pytest.mark.parametrize( 1170 "data", ["Generative models are", (["Generative models are", "Computers are"])] 1171 ) 1172 def test_text_generation_pipeline(text_generation_pipeline, model_path, data): 1173 signature = infer_signature( 1174 data, mlflow.transformers.generate_signature_output(text_generation_pipeline, data) 1175 ) 1176 1177 model_config = { 1178 "prefix": "software", 1179 "top_k": 2, 1180 "num_beams": 5, 1181 "max_length": 30, 1182 "temperature": 0.62, 1183 "top_p": 0.85, 1184 "repetition_penalty": 1.15, 1185 } 1186 mlflow.transformers.save_model( 1187 text_generation_pipeline, 1188 path=model_path, 1189 model_config=model_config, 1190 signature=signature, 1191 ) 1192 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 1193 1194 inference = pyfunc_loaded.predict(data) 1195 1196 if isinstance(data, list): 1197 assert inference[0].startswith(data[0]) 1198 assert inference[1].startswith(data[1]) 1199 else: 1200 assert inference[0].startswith(data) 1201 1202 pd_input = pd.DataFrame([data], index=[0]) if isinstance(data, str) else pd.DataFrame(data) 1203 pd_inference = pyfunc_loaded.predict(pd_input) 1204 1205 if isinstance(data, list): 1206 assert pd_inference[0].startswith(data[0]) 1207 assert pd_inference[1].startswith(data[1]) 1208 else: 1209 assert pd_inference[0].startswith(data) 1210 1211 1212 @pytest.mark.parametrize( 1213 "invalid_data", 1214 [ 1215 ({"my_input": "something to predict"}), 1216 ([{"bogus_input": "invalid"}, "not_valid"]), 1217 (["tell me a story", {"of": "a properly configured pipeline input"}]), 1218 ], 1219 ) 1220 def test_invalid_input_to_text_generation_pipeline(text_generation_pipeline, invalid_data): 1221 if isinstance(invalid_data, list): 1222 match = "If supplying a list, all values must be of string type" 1223 else: 1224 match = "The input data is of an incorrect type" 1225 with pytest.raises(MlflowException, match=match): 1226 infer_signature( 1227 invalid_data, 1228 mlflow.transformers.generate_signature_output(text_generation_pipeline, invalid_data), 1229 ) 1230 1231 1232 @pytest.mark.parametrize( 1233 ("inference_payload", "result"), 1234 [ 1235 ("Riding a <mask> on the beach is fun!", ["bike"]), 1236 (["If I had <mask>, I would fly to the top of a mountain"], ["wings"]), 1237 ( 1238 ["I use stacks of <mask> to buy things", "I <mask> the whole bowl of cherries"], 1239 ["cash", "ate"], 1240 ), 1241 ], 1242 ) 1243 def test_fill_mask_pipeline(fill_mask_pipeline, model_path, inference_payload, result): 1244 signature = infer_signature( 1245 inference_payload, 1246 mlflow.transformers.generate_signature_output(fill_mask_pipeline, inference_payload), 1247 ) 1248 1249 mlflow.transformers.save_model(fill_mask_pipeline, path=model_path, signature=signature) 1250 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 1251 1252 inference = pyfunc_loaded.predict(inference_payload) 1253 assert inference == result 1254 1255 if len(inference_payload) > 1 and isinstance(inference_payload, list): 1256 pd_input = pd.DataFrame([{"inputs": v} for v in inference_payload]) 1257 elif isinstance(inference_payload, list) and len(inference_payload) == 1: 1258 pd_input = pd.DataFrame([{"inputs": v} for v in inference_payload], index=[0]) 1259 else: 1260 pd_input = pd.DataFrame({"inputs": inference_payload}, index=[0]) 1261 1262 pd_inference = pyfunc_loaded.predict(pd_input) 1263 assert pd_inference == result 1264 1265 1266 def test_fill_mask_pipeline_with_multiple_masks(fill_mask_pipeline, model_path): 1267 data = ["I <mask> the whole <mask> of <mask>", "I <mask> the whole <mask> of <mask>"] 1268 1269 mlflow.transformers.save_model(fill_mask_pipeline, path=model_path) 1270 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 1271 1272 inference = pyfunc_loaded.predict(data) 1273 assert len(inference) == 2 1274 assert all(len(value) == 3 for value in inference) 1275 1276 1277 @pytest.mark.parametrize( 1278 "invalid_data", 1279 [ 1280 ({"a": "b"}), 1281 ([{"a": "b"}, [{"a": "c"}]]), 1282 ], 1283 ) 1284 def test_invalid_input_to_fill_mask_pipeline(fill_mask_pipeline, invalid_data): 1285 if isinstance(invalid_data, list): 1286 match = "Invalid data submission. Ensure all" 1287 else: 1288 match = "The input data is of an incorrect type" 1289 with pytest.raises(MlflowException, match=match): 1290 infer_signature( 1291 invalid_data, 1292 mlflow.transformers.generate_signature_output(fill_mask_pipeline, invalid_data), 1293 ) 1294 1295 1296 @pytest.mark.parametrize( 1297 "data", 1298 [ 1299 { 1300 "sequences": "I love the latest update to this IDE!", 1301 "candidate_labels": ["happy", "sad"], 1302 }, 1303 { 1304 "sequences": ["My dog loves to eat spaghetti", "My dog hates going to the vet"], 1305 "candidate_labels": ["happy", "sad"], 1306 "hypothesis_template": "This example talks about how the dog is {}", 1307 }, 1308 ], 1309 ) 1310 def test_zero_shot_classification_pipeline(zero_shot_pipeline, model_path, data): 1311 # NB: The list submission for this pipeline type can accept json-encoded lists or lists within 1312 # the values of the dictionary. 1313 signature = infer_signature( 1314 data, mlflow.transformers.generate_signature_output(zero_shot_pipeline, data) 1315 ) 1316 1317 mlflow.transformers.save_model(zero_shot_pipeline, model_path, signature=signature) 1318 1319 loaded_pyfunc = mlflow.pyfunc.load_model(model_path) 1320 inference = loaded_pyfunc.predict(data) 1321 1322 assert isinstance(inference, pd.DataFrame) 1323 if isinstance(data["sequences"], str): 1324 assert len(inference) == len(data["candidate_labels"]) 1325 else: 1326 assert len(inference) == len(data["sequences"]) * len(data["candidate_labels"]) 1327 1328 1329 @pytest.mark.parametrize( 1330 "query", 1331 [ 1332 "What should we order more of?", 1333 [ 1334 "What is our highest sales?", 1335 "What should we order more of?", 1336 ], 1337 ], 1338 ) 1339 def test_table_question_answering_pipeline(table_question_answering_pipeline, model_path, query): 1340 table = { 1341 "Fruit": ["Apples", "Bananas", "Oranges", "Watermelon", "Blueberries"], 1342 "Sales": ["1230945.55", "86453.12", "11459.23", "8341.23", "2325.88"], 1343 "Inventory": ["910", "4589", "11200", "80", "3459"], 1344 } 1345 json_table = json.dumps(table) 1346 data = {"query": query, "table": json_table} 1347 signature = infer_signature( 1348 data, mlflow.transformers.generate_signature_output(table_question_answering_pipeline, data) 1349 ) 1350 1351 mlflow.transformers.save_model( 1352 table_question_answering_pipeline, model_path, signature=signature 1353 ) 1354 loaded = mlflow.pyfunc.load_model(model_path) 1355 1356 inference = loaded.predict(data) 1357 assert len(inference) == 1 if isinstance(query, str) else len(query) 1358 1359 pd_input = pd.DataFrame([data]) 1360 pd_inference = loaded.predict(pd_input) 1361 assert pd_inference is not None 1362 1363 1364 def test_custom_code_pipeline(custom_code_pipeline, model_path): 1365 data = "hello" 1366 1367 signature = infer_signature( 1368 data, mlflow.transformers.generate_signature_output(custom_code_pipeline, data) 1369 ) 1370 1371 mlflow.transformers.save_model( 1372 custom_code_pipeline, 1373 path=model_path, 1374 signature=signature, 1375 ) 1376 1377 # just test that it doesn't blow up when performing inference 1378 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 1379 pyfunc_pred = pyfunc_loaded.predict(data) 1380 assert isinstance(pyfunc_pred[0][0], float) 1381 1382 transformers_loaded = mlflow.transformers.load_model(model_path) 1383 transformers_pred = transformers_loaded(data) 1384 assert pyfunc_pred[0][0] == transformers_pred[0][0][0] 1385 1386 1387 def test_custom_components_pipeline(custom_components_pipeline, model_path): 1388 data = "hello" 1389 1390 signature = infer_signature( 1391 data, mlflow.transformers.generate_signature_output(custom_components_pipeline, data) 1392 ) 1393 1394 components = { 1395 "model": custom_components_pipeline.model, 1396 "tokenizer": custom_components_pipeline.tokenizer, 1397 "feature_extractor": custom_components_pipeline.feature_extractor, 1398 } 1399 1400 mlflow.transformers.save_model( 1401 transformers_model=components, path=model_path, signature=signature 1402 ) 1403 1404 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 1405 pyfunc_pred = pyfunc_loaded.predict(data) 1406 assert isinstance(pyfunc_pred[0][0], float) 1407 1408 transformers_loaded = mlflow.transformers.load_model(model_path) 1409 transformers_pred = transformers_loaded(data) 1410 assert pyfunc_pred[0][0] == transformers_pred[0][0][0] 1411 1412 # assert that all the reloaded components exist 1413 # and have the same class name as pre-save 1414 for name, component in components.items(): 1415 assert component.__class__.__name__ == getattr(transformers_loaded, name).__class__.__name__ 1416 1417 1418 @pytest.mark.parametrize( 1419 ("data", "result"), 1420 [ 1421 ("I've got a lovely bunch of coconuts!", ["Ich habe eine schöne Haufe von Kokos!"]), 1422 ( 1423 [ 1424 "I am the very model of a modern major general", 1425 "Once upon a time, there was a little turtle", 1426 ], 1427 [ 1428 "Ich bin das Modell eines modernen Generals.", 1429 "Einmal gab es eine kleine Schildkröte.", 1430 ], 1431 ), 1432 ], 1433 ) 1434 def test_translation_pipeline(translation_pipeline, model_path, data, result): 1435 signature = infer_signature( 1436 data, mlflow.transformers.generate_signature_output(translation_pipeline, data) 1437 ) 1438 1439 mlflow.transformers.save_model(translation_pipeline, path=model_path, signature=signature) 1440 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 1441 inference = pyfunc_loaded.predict(data) 1442 assert inference == result 1443 1444 if len(data) > 1 and isinstance(data, list): 1445 pd_input = pd.DataFrame([{"inputs": v} for v in data]) 1446 elif isinstance(data, list) and len(data) == 1: 1447 pd_input = pd.DataFrame([{"inputs": v} for v in data], index=[0]) 1448 else: 1449 pd_input = pd.DataFrame({"inputs": data}, index=[0]) 1450 1451 pd_inference = pyfunc_loaded.predict(pd_input) 1452 assert pd_inference == result 1453 1454 1455 @pytest.mark.parametrize( 1456 "data", 1457 [ 1458 "There once was a boy", 1459 ["Dolly isn't just a sheep anymore"], 1460 ["Baking cookies is quite easy", "Writing unittests is good for"], 1461 ], 1462 ) 1463 def test_summarization_pipeline(summarizer_pipeline, model_path, data): 1464 model_config = { 1465 "top_k": 2, 1466 "num_beams": 5, 1467 "max_length": 90, 1468 "temperature": 0.62, 1469 "top_p": 0.85, 1470 "repetition_penalty": 1.15, 1471 } 1472 signature = infer_signature( 1473 data, mlflow.transformers.generate_signature_output(summarizer_pipeline, data) 1474 ) 1475 1476 mlflow.transformers.save_model( 1477 summarizer_pipeline, path=model_path, model_config=model_config, signature=signature 1478 ) 1479 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 1480 1481 inference = pyfunc_loaded.predict(data) 1482 if isinstance(data, list) and len(data) > 1: 1483 for i, entry in enumerate(data): 1484 assert inference[i].strip().startswith(entry) 1485 elif isinstance(data, list) and len(data) == 1: 1486 assert inference[0].strip().startswith(data[0]) 1487 else: 1488 assert inference[0].strip().startswith(data) 1489 1490 if len(data) > 1 and isinstance(data, list): 1491 pd_input = pd.DataFrame([{"inputs": v} for v in data]) 1492 elif isinstance(data, list) and len(data) == 1: 1493 pd_input = pd.DataFrame([{"inputs": v} for v in data], index=[0]) 1494 else: 1495 pd_input = pd.DataFrame({"inputs": data}, index=[0]) 1496 1497 pd_inference = pyfunc_loaded.predict(pd_input) 1498 if isinstance(data, list) and len(data) > 1: 1499 for i, entry in enumerate(data): 1500 assert pd_inference[i].strip().startswith(entry) 1501 elif isinstance(data, list) and len(data) == 1: 1502 assert pd_inference[0].strip().startswith(data[0]) 1503 else: 1504 assert pd_inference[0].strip().startswith(data) 1505 1506 1507 @pytest.mark.parametrize( 1508 "data", 1509 [ 1510 "I'm telling you that Han shot first!", 1511 [ 1512 "I think this sushi might have gone off", 1513 "That gym smells like feet, hot garbage, and sadness", 1514 "I love that we have a moon", 1515 ], 1516 [{"text": "test1", "text_pair": "test2"}], 1517 [{"text": "test1", "text_pair": "pair1"}, {"text": "test2", "text_pair": "pair2"}], 1518 ], 1519 ) 1520 def test_classifier_pipeline(text_classification_pipeline, model_path, data): 1521 signature = infer_signature( 1522 data, mlflow.transformers.generate_signature_output(text_classification_pipeline, data) 1523 ) 1524 mlflow.transformers.save_model( 1525 text_classification_pipeline, path=model_path, signature=signature 1526 ) 1527 1528 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 1529 inference = pyfunc_loaded.predict(data) 1530 1531 # verify that native transformers outputs match the pyfunc return values 1532 native_inference = text_classification_pipeline(data) 1533 inference_dict = inference.to_dict() 1534 1535 if isinstance(data, str): 1536 assert len(inference) == 1 1537 assert inference_dict["label"][0] == native_inference[0]["label"] 1538 assert inference_dict["score"][0] == native_inference[0]["score"] 1539 else: 1540 assert len(inference) == len(data) 1541 for key in ["score", "label"]: 1542 for value in range(0, len(data)): 1543 if key == "label": 1544 assert inference_dict[key][value] == native_inference[value][key] 1545 else: 1546 assert math.isclose( 1547 native_inference[value][key], inference_dict[key][value], rel_tol=1e-6 1548 ) 1549 1550 1551 @pytest.mark.parametrize( 1552 ("data", "result"), 1553 [ 1554 ( 1555 "I have a dog and his name is Willy!", 1556 ["PRON,VERB,DET,NOUN,CCONJ,PRON,NOUN,AUX,PROPN,PUNCT"], 1557 ), 1558 (["I like turtles"], ["PRON,VERB,NOUN"]), 1559 ( 1560 ["We are the knights who say nee!", "Houston, we may have a problem."], 1561 [ 1562 "PRON,AUX,DET,PROPN,PRON,VERB,INTJ,PUNCT", 1563 "PROPN,PUNCT,PRON,AUX,VERB,DET,NOUN,PUNCT", 1564 ], 1565 ), 1566 ], 1567 ) 1568 @pytest.mark.parametrize("pipeline_name", ["ner_pipeline", "ner_pipeline_aggregation"]) 1569 def test_ner_pipeline(pipeline_name, model_path, data, result, request): 1570 pipeline = request.getfixturevalue(pipeline_name) 1571 1572 signature = infer_signature(data, mlflow.transformers.generate_signature_output(pipeline, data)) 1573 1574 mlflow.transformers.save_model(pipeline, model_path, signature=signature) 1575 loaded_pyfunc = mlflow.pyfunc.load_model(model_path) 1576 inference = loaded_pyfunc.predict(data) 1577 1578 assert inference == result 1579 1580 if len(data) > 1 and isinstance(data, list): 1581 pd_input = pd.DataFrame([{"inputs": v} for v in data]) 1582 elif isinstance(data, list) and len(data) == 1: 1583 pd_input = pd.DataFrame([{"inputs": v} for v in data], index=[0]) 1584 else: 1585 pd_input = pd.DataFrame({"inputs": data}, index=[0]) 1586 pd_inference = loaded_pyfunc.predict(pd_input) 1587 assert pd_inference == result 1588 1589 1590 @pytest.mark.skipif( 1591 _try_import_conversational_pipeline() is None, 1592 reason="Conversation model is deprecated and removed.", 1593 ) 1594 def test_conversational_pipeline(conversational_pipeline, model_path): 1595 assert mlflow.transformers._is_conversational_pipeline(conversational_pipeline) 1596 1597 signature = infer_signature( 1598 "Hi there!", 1599 mlflow.transformers.generate_signature_output(conversational_pipeline, "Hi there!"), 1600 ) 1601 1602 mlflow.transformers.save_model(conversational_pipeline, model_path, signature=signature) 1603 loaded_pyfunc = mlflow.pyfunc.load_model(model_path) 1604 1605 first_response = loaded_pyfunc.predict("What is the best way to get to Antarctica?") 1606 1607 assert first_response == "The best way would be to go to space." 1608 1609 second_response = loaded_pyfunc.predict("What kind of boat should I use?") 1610 1611 assert second_response == "The best way to get to space would be to reach out and touch it." 1612 1613 # Test that a new loaded instance has no context. 1614 loaded_again_pyfunc = mlflow.pyfunc.load_model(model_path) 1615 third_response = loaded_again_pyfunc.predict("What kind of boat should I use?") 1616 1617 assert third_response == "The one with the guns." 1618 1619 fourth_response = loaded_again_pyfunc.predict("Can I use it to go to the moon?") 1620 1621 assert fourth_response == "Sure." 1622 1623 1624 def test_qa_pipeline_pyfunc_predict(small_qa_pipeline): 1625 artifact_path = "qa_model" 1626 with mlflow.start_run(): 1627 model_info = mlflow.transformers.log_model( 1628 small_qa_pipeline, 1629 name=artifact_path, 1630 ) 1631 1632 inference_payload = json.dumps({ 1633 "inputs": { 1634 "question": [ 1635 "What color is it?", 1636 "How do the people go?", 1637 "What does the 'wolf' howl at?", 1638 ], 1639 "context": [ 1640 "Some people said it was green but I know that it's pink.", 1641 "The people on the bus go up and down. Up and down.", 1642 "The pack of 'wolves' stood on the cliff and a 'lone wolf' howled at " 1643 "the moon for hours.", 1644 ], 1645 } 1646 }) 1647 response = pyfunc_serve_and_score_model( 1648 model_info.model_uri, 1649 data=inference_payload, 1650 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1651 extra_args=["--env-manager", "local"], 1652 ) 1653 values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions() 1654 1655 assert values.to_dict(orient="records") == [{0: "pink"}, {0: "up and down"}, {0: "the moon"}] 1656 1657 inference_payload = json.dumps({ 1658 "inputs": { 1659 "question": "Who's house?", 1660 "context": "The house is owned by a man named Run.", 1661 } 1662 }) 1663 1664 response = pyfunc_serve_and_score_model( 1665 model_info.model_uri, 1666 data=inference_payload, 1667 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1668 extra_args=["--env-manager", "local"], 1669 ) 1670 values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions() 1671 1672 assert values.to_dict(orient="records") == [{0: "Run"}] 1673 1674 1675 @pytest.mark.parametrize( 1676 ("input_image", "result"), 1677 [ 1678 (str(image_file_path), False), 1679 (image_url, False), 1680 ("base64", True), 1681 ("random string", False), 1682 ], 1683 ) 1684 def test_vision_is_base64_image(input_image, result): 1685 if input_image == "base64": 1686 input_image = base64.b64encode(image_file_path.read_bytes()).decode("utf-8") 1687 assert _TransformersWrapper.is_base64_image(input_image) == result 1688 1689 1690 @pytest.mark.parametrize( 1691 "inference_payload", 1692 [ 1693 [str(image_file_path)], 1694 [image_url], 1695 "base64", 1696 pytest.param( 1697 "base64_encodebytes", 1698 marks=pytest.mark.skipif( 1699 Version(transformers.__version__) < Version("4.41"), 1700 reason="base64 encodebytes feature not present", 1701 ), 1702 ), 1703 ], 1704 ) 1705 def test_vision_pipeline_pyfunc_predict(small_vision_model, inference_payload): 1706 if inference_payload == "base64": 1707 inference_payload = [ 1708 base64.b64encode(image_file_path.read_bytes()).decode("utf-8"), 1709 ] 1710 elif inference_payload == "base64_encodebytes": 1711 inference_payload = [ 1712 base64.encodebytes(image_file_path.read_bytes()).decode("utf-8"), 1713 ] 1714 artifact_path = "image_classification_model" 1715 1716 # Log the image classification model 1717 with mlflow.start_run(): 1718 model_info = mlflow.transformers.log_model( 1719 small_vision_model, 1720 name=artifact_path, 1721 ) 1722 pyfunc_inference_payload = json.dumps({"inputs": inference_payload}) 1723 response = pyfunc_serve_and_score_model( 1724 model_info.model_uri, 1725 data=pyfunc_inference_payload, 1726 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1727 extra_args=["--env-manager", "local"], 1728 ) 1729 1730 predictions = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions() 1731 1732 transformers_loaded_model = mlflow.transformers.load_model(model_info.model_uri) 1733 expected_predictions = transformers_loaded_model.predict(inference_payload) 1734 1735 assert [list(pred.values()) for pred in predictions.to_dict("records")] == expected_predictions 1736 1737 1738 def test_classifier_pipeline_pyfunc_predict(text_classification_pipeline): 1739 artifact_path = "text_classifier_model" 1740 data = [ 1741 "I think this sushi might have gone off", 1742 "That gym smells like feet, hot garbage, and sadness", 1743 "I love that we have a moon", 1744 "I 'love' debugging subprocesses", 1745 'Quote "in" the string', 1746 ] 1747 signature = infer_signature(data) 1748 with mlflow.start_run(): 1749 model_info = mlflow.transformers.log_model( 1750 text_classification_pipeline, 1751 name=artifact_path, 1752 signature=signature, 1753 ) 1754 1755 response = pyfunc_serve_and_score_model( 1756 model_info.model_uri, 1757 data=json.dumps({"inputs": data}), 1758 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1759 extra_args=["--env-manager", "local"], 1760 ) 1761 values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions() 1762 1763 assert len(values.to_dict()) == 2 1764 assert len(values.to_dict()["score"]) == 5 1765 1766 # test simple string input 1767 inference_payload = json.dumps({"inputs": ["testing"]}) 1768 1769 response = pyfunc_serve_and_score_model( 1770 model_info.model_uri, 1771 data=inference_payload, 1772 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1773 extra_args=["--env-manager", "local"], 1774 ) 1775 values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions() 1776 1777 assert len(values.to_dict()) == 2 1778 assert len(values.to_dict()["score"]) == 1 1779 1780 # Test the alternate TextClassificationPipeline input structure where text_pair is used 1781 # and ensure that model serving and direct native inference match 1782 inference_data = [ 1783 {"text": "test1", "text_pair": "pair1"}, 1784 {"text": "test2", "text_pair": "pair2"}, 1785 {"text": "test 'quote", "text_pair": "pair 'quote'"}, 1786 ] 1787 signature = infer_signature(inference_data) 1788 with mlflow.start_run(): 1789 model_info = mlflow.transformers.log_model( 1790 text_classification_pipeline, 1791 name=artifact_path, 1792 signature=signature, 1793 ) 1794 1795 inference_payload = json.dumps({"inputs": inference_data}) 1796 response = pyfunc_serve_and_score_model( 1797 model_info.model_uri, 1798 data=inference_payload, 1799 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1800 extra_args=["--env-manager", "local"], 1801 ) 1802 values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions() 1803 values_dict = values.to_dict() 1804 native_predict = text_classification_pipeline(inference_data) 1805 1806 # validate that the pyfunc served model registers text_pair in the same manner as native 1807 for key in ["score", "label"]: 1808 for value in [0, 1]: 1809 if key == "label": 1810 assert values_dict[key][value] == native_predict[value][key] 1811 else: 1812 assert math.isclose( 1813 values_dict[key][value], native_predict[value][key], rel_tol=1e-6 1814 ) 1815 1816 1817 def test_zero_shot_pipeline_pyfunc_predict(zero_shot_pipeline): 1818 artifact_path = "zero_shot_classifier_model" 1819 with mlflow.start_run(): 1820 model_info = mlflow.transformers.log_model( 1821 zero_shot_pipeline, 1822 name=artifact_path, 1823 ) 1824 model_uri = model_info.model_uri 1825 1826 inference_payload = json.dumps({ 1827 "inputs": { 1828 "sequences": "My dog loves running through troughs of spaghetti with his mouth open", 1829 "candidate_labels": ["happy", "sad"], 1830 "hypothesis_template": "This example talks about how the dog is {}", 1831 } 1832 }) 1833 1834 response = pyfunc_serve_and_score_model( 1835 model_uri, 1836 data=inference_payload, 1837 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1838 extra_args=["--env-manager", "local"], 1839 ) 1840 values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions() 1841 1842 assert len(values.to_dict()) == 3 1843 assert len(values.to_dict()["labels"]) == 2 1844 1845 inference_payload = json.dumps({ 1846 "inputs": { 1847 "sequences": [ 1848 "My dog loves to eat spaghetti", 1849 "My dog hates going to the vet", 1850 "My 'hamster' loves to play with my 'friendly' dog", 1851 ], 1852 "candidate_labels": '["happy", "sad"]', 1853 "hypothesis_template": "This example talks about how the dog is {}", 1854 } 1855 }) 1856 response = pyfunc_serve_and_score_model( 1857 model_uri, 1858 data=inference_payload, 1859 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1860 extra_args=["--env-manager", "local"], 1861 ) 1862 values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions() 1863 1864 assert len(values.to_dict()) == 3 1865 assert len(values.to_dict()["labels"]) == 6 1866 1867 1868 def test_table_question_answering_pyfunc_predict(table_question_answering_pipeline): 1869 artifact_path = "table_qa_model" 1870 with mlflow.start_run(): 1871 model_info = mlflow.transformers.log_model( 1872 table_question_answering_pipeline, 1873 name=artifact_path, 1874 ) 1875 1876 table = { 1877 "Fruit": ["Apples", "Bananas", "Oranges", "Watermelon 'small'", "Blueberries"], 1878 "Sales": ["1230945.55", "86453.12", "11459.23", "8341.23", "2325.88"], 1879 "Inventory": ["910", "4589", "11200", "80", "3459"], 1880 } 1881 1882 inference_payload = json.dumps({ 1883 "inputs": { 1884 "query": "What should we order more of?", 1885 "table": table, 1886 } 1887 }) 1888 1889 response = pyfunc_serve_and_score_model( 1890 model_info.model_uri, 1891 data=inference_payload, 1892 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1893 extra_args=["--env-manager", "local"], 1894 ) 1895 values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions() 1896 1897 assert len(values.to_dict(orient="records")) == 1 1898 1899 inference_payload = json.dumps({ 1900 "inputs": { 1901 "query": [ 1902 "What is our highest sales?", 1903 "What should we order more of?", 1904 "Which 'fruit' has the 'highest' 'sales'?", 1905 ], 1906 "table": table, 1907 } 1908 }) 1909 response = pyfunc_serve_and_score_model( 1910 model_info.model_uri, 1911 data=inference_payload, 1912 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1913 extra_args=["--env-manager", "local"], 1914 ) 1915 values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions() 1916 1917 assert len(values.to_dict(orient="records")) == 3 1918 1919 1920 def test_feature_extraction_pipeline(feature_extraction_pipeline): 1921 sentences = ["hi", "hello"] 1922 signature = infer_signature( 1923 sentences, 1924 mlflow.transformers.generate_signature_output(feature_extraction_pipeline, sentences), 1925 ) 1926 1927 artifact_path = "feature_extraction_pipeline" 1928 with mlflow.start_run(): 1929 model_info = mlflow.transformers.log_model( 1930 feature_extraction_pipeline, 1931 name=artifact_path, 1932 signature=signature, 1933 input_example=["A sentence", "Another sentence"], 1934 ) 1935 1936 # Load as native 1937 loaded_pipeline = mlflow.transformers.load_model(model_info.model_uri) 1938 1939 inference_single = "Testing" 1940 inference_mult = ["Testing something", "Testing something else"] 1941 1942 pred = loaded_pipeline(inference_single) 1943 assert len(pred[0][0]) > 10 1944 assert isinstance(pred[0][0][0], float) 1945 1946 pred_multiple = loaded_pipeline(inference_mult) 1947 assert len(pred_multiple[0][0]) > 2 1948 assert isinstance(pred_multiple[0][0][0][0], float) 1949 1950 loaded_pyfunc = mlflow.pyfunc.load_model(model_info.model_uri) 1951 1952 pyfunc_pred = loaded_pyfunc.predict(inference_single) 1953 1954 assert isinstance(pyfunc_pred, np.ndarray) 1955 1956 assert np.array_equal(np.array(pred[0]), pyfunc_pred) 1957 1958 pyfunc_pred_multiple = loaded_pyfunc.predict(inference_mult) 1959 1960 assert np.array_equal(np.array(pred_multiple[0][0]), pyfunc_pred_multiple) 1961 1962 1963 def test_feature_extraction_pipeline_pyfunc_predict(feature_extraction_pipeline): 1964 artifact_path = "feature_extraction" 1965 with mlflow.start_run(): 1966 model_info = mlflow.transformers.log_model( 1967 feature_extraction_pipeline, 1968 name=artifact_path, 1969 ) 1970 1971 inference_payload = json.dumps({"inputs": ["sentence one", "sentence two"]}) 1972 1973 response = pyfunc_serve_and_score_model( 1974 model_info.model_uri, 1975 data=inference_payload, 1976 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1977 extra_args=["--env-manager", "local"], 1978 ) 1979 values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions() 1980 1981 assert len(values.columns) == 384 1982 assert len(values) == 4 1983 1984 inference_payload = json.dumps({"inputs": "sentence three"}) 1985 1986 response = pyfunc_serve_and_score_model( 1987 model_info.model_uri, 1988 data=inference_payload, 1989 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 1990 extra_args=["--env-manager", "local"], 1991 ) 1992 assert response.status_code == 200 1993 prediction = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions() 1994 assert len(prediction.columns) == 384 1995 assert len(prediction) == 4 1996 1997 1998 def test_loading_unsupported_pipeline_type_as_pyfunc(small_multi_modal_pipeline, model_path): 1999 mlflow.transformers.save_model(small_multi_modal_pipeline, model_path) 2000 with pytest.raises(MlflowException, match='Model does not have the "python_function" flavor'): 2001 mlflow.pyfunc.load_model(model_path) 2002 2003 2004 def test_pyfunc_input_validations(mock_pyfunc_wrapper): 2005 def ensure_raises(data, match): 2006 with pytest.raises(MlflowException, match=match): 2007 mock_pyfunc_wrapper._validate_str_or_list_str(data) 2008 2009 match1 = "The input data is of an incorrect type" 2010 match2 = "If supplying a list, all values must" 2011 ensure_raises({"a": "b"}, match1) 2012 ensure_raises(("a", "b"), match1) 2013 ensure_raises({"a", "b"}, match1) 2014 ensure_raises(True, match1) 2015 ensure_raises(12, match1) 2016 ensure_raises([1, 2, 3], match2) 2017 ensure_raises([{"a", "b"}], match2) 2018 ensure_raises([["a", "b", "c'"]], match2) 2019 ensure_raises([{"a": "b"}, {"a": "c"}], match2) 2020 ensure_raises([[1], [2]], match2) 2021 2022 2023 def test_pyfunc_json_encoded_dict_parsing(mock_pyfunc_wrapper): 2024 plain_dict = {"a": "b", "b": "c"} 2025 list_dict = [plain_dict, plain_dict] 2026 2027 plain_input = {"in": json.dumps(plain_dict)} 2028 list_input = {"in": json.dumps(list_dict)} 2029 2030 plain_parsed = mock_pyfunc_wrapper._parse_json_encoded_dict_payload_to_dict(plain_input, "in") 2031 assert plain_parsed == {"in": plain_dict} 2032 2033 list_parsed = mock_pyfunc_wrapper._parse_json_encoded_dict_payload_to_dict(list_input, "in") 2034 assert list_parsed == {"in": list_dict} 2035 2036 invalid_parsed = mock_pyfunc_wrapper._parse_json_encoded_dict_payload_to_dict( 2037 plain_input, "invalid" 2038 ) 2039 assert invalid_parsed != {"in": plain_dict} 2040 assert invalid_parsed == plain_input 2041 2042 2043 def test_pyfunc_json_encoded_list_parsing(mock_pyfunc_wrapper): 2044 plain_list = ["a", "b", "c"] 2045 nested_list = [plain_list, plain_list] 2046 list_dict = [{"a": "b"}, {"a": "c"}] 2047 2048 plain_input = {"in": json.dumps(plain_list)} 2049 nested_input = {"in": json.dumps(nested_list)} 2050 list_dict_input = {"in": json.dumps(list_dict)} 2051 2052 plain_parsed = mock_pyfunc_wrapper._parse_json_encoded_list(plain_input, "in") 2053 assert plain_parsed == {"in": plain_list} 2054 2055 nested_parsed = mock_pyfunc_wrapper._parse_json_encoded_list(nested_input, "in") 2056 assert nested_parsed == {"in": nested_list} 2057 2058 list_dict_parsed = mock_pyfunc_wrapper._parse_json_encoded_list(list_dict_input, "in") 2059 assert list_dict_parsed == {"in": list_dict} 2060 2061 with pytest.raises(MlflowException, match="Invalid key in inference payload. The "): 2062 mock_pyfunc_wrapper._parse_json_encoded_list(list_dict_input, "invalid") 2063 2064 2065 def test_pyfunc_text_to_text_input(mock_pyfunc_wrapper): 2066 text2text_input = {"context": "a", "answer": "b"} 2067 parsed_input = mock_pyfunc_wrapper._parse_text2text_input(text2text_input) 2068 assert parsed_input == "context: a answer: b" 2069 2070 text2text_input_list = [text2text_input, text2text_input] 2071 parsed_input_list = mock_pyfunc_wrapper._parse_text2text_input(text2text_input_list) 2072 assert parsed_input_list == ["context: a answer: b", "context: a answer: b"] 2073 2074 parsed_with_inputs = mock_pyfunc_wrapper._parse_text2text_input({"inputs": "a"}) 2075 assert parsed_with_inputs == ["a"] 2076 2077 parsed_str = mock_pyfunc_wrapper._parse_text2text_input("a") 2078 assert parsed_str == "a" 2079 2080 parsed_list_str = mock_pyfunc_wrapper._parse_text2text_input(["a", "b"]) 2081 assert parsed_list_str == ["a", "b"] 2082 2083 with pytest.raises(MlflowException, match="An invalid type has been supplied"): 2084 mock_pyfunc_wrapper._parse_text2text_input([1, 2, 3]) 2085 2086 with pytest.raises(MlflowException, match="An invalid type has been supplied"): 2087 mock_pyfunc_wrapper._parse_text2text_input([{"a": [{"b": "c"}]}]) 2088 2089 2090 def test_pyfunc_qa_input(mock_pyfunc_wrapper): 2091 single_input = {"question": "a", "context": "b"} 2092 parsed_single_input = mock_pyfunc_wrapper._parse_question_answer_input(single_input) 2093 assert parsed_single_input == single_input 2094 2095 multi_input = [single_input, single_input] 2096 parsed_multi_input = mock_pyfunc_wrapper._parse_question_answer_input(multi_input) 2097 assert parsed_multi_input == multi_input 2098 2099 with pytest.raises(MlflowException, match="Invalid keys were submitted. Keys must"): 2100 mock_pyfunc_wrapper._parse_question_answer_input({"q": "a", "c": "b"}) 2101 2102 with pytest.raises(MlflowException, match="An invalid type has been supplied"): 2103 mock_pyfunc_wrapper._parse_question_answer_input("a") 2104 2105 with pytest.raises(MlflowException, match="An invalid type has been supplied"): 2106 mock_pyfunc_wrapper._parse_question_answer_input(["a", "b", "c"]) 2107 2108 2109 def test_list_of_dict_to_list_of_str_parsing(mock_pyfunc_wrapper): 2110 # Test with a single list of dictionaries 2111 output_data = [{"a": "foo"}, {"a": "bar"}, {"a": "baz"}] 2112 expected_output = ["foo", "bar", "baz"] 2113 assert ( 2114 mock_pyfunc_wrapper._parse_lists_of_dict_to_list_of_str(output_data, "a") == expected_output 2115 ) 2116 2117 # Test with a nested list of dictionaries 2118 output_data = [ 2119 {"a": "foo", "b": [{"a": "bar"}]}, 2120 {"a": "baz", "b": [{"a": "qux"}]}, 2121 ] 2122 expected_output = ["foo", "bar", "baz", "qux"] 2123 assert ( 2124 mock_pyfunc_wrapper._parse_lists_of_dict_to_list_of_str(output_data, "a") == expected_output 2125 ) 2126 2127 # Test with nested list with exclusion data 2128 output_data = [ 2129 {"a": "valid", "b": [{"a": "another valid"}, {"b": "invalid"}]}, 2130 {"a": "valid 2", "b": [{"a": "another valid 2"}, {"c": "invalid"}]}, 2131 ] 2132 expected_output = ["valid", "another valid", "valid 2", "another valid 2"] 2133 assert ( 2134 mock_pyfunc_wrapper._parse_lists_of_dict_to_list_of_str(output_data, "a") == expected_output 2135 ) 2136 2137 2138 def test_parsing_tokenizer_output(mock_pyfunc_wrapper): 2139 output_data = [{"a": "b"}, {"a": "c"}, {"a": "d"}] 2140 expected_output = "b,c,d" 2141 assert mock_pyfunc_wrapper._parse_tokenizer_output(output_data, {"a"}) == expected_output 2142 2143 output_data = [output_data, output_data] 2144 expected_output = [expected_output, expected_output] 2145 assert mock_pyfunc_wrapper._parse_tokenizer_output(output_data, {"a"}) == expected_output 2146 2147 2148 def test_parse_list_of_multiple_dicts(mock_pyfunc_wrapper): 2149 output_data = [{"a": "b", "d": "f"}, {"a": "z", "d": "g"}] 2150 target_dict_key = "a" 2151 expected_output = ["b"] 2152 2153 assert ( 2154 mock_pyfunc_wrapper._parse_list_of_multiple_dicts(output_data, target_dict_key) 2155 == expected_output 2156 ) 2157 2158 output_data = [ 2159 [{"a": "c", "d": "q"}, {"a": "o", "d": "q"}, {"a": "d", "d": "q"}, {"a": "e", "d": "r"}], 2160 [{"a": "m", "d": "s"}, {"a": "e", "d": "t"}], 2161 ] 2162 target_dict_key = "a" 2163 expected_output = ["c", "m"] 2164 2165 assert ( 2166 mock_pyfunc_wrapper._parse_list_of_multiple_dicts(output_data, target_dict_key) 2167 == expected_output 2168 ) 2169 2170 2171 @pytest.mark.parametrize( 2172 ( 2173 "pipeline_input", 2174 "pipeline_output", 2175 "expected_output", 2176 "flavor_config", 2177 "include_prompt", 2178 "collapse_whitespace", 2179 ), 2180 [ 2181 ( 2182 "What answers?", 2183 [{"generated_text": "What answers?\n\nA collection of\n\nanswers"}], 2184 "A collection of\n\nanswers", 2185 {"instance_type": "InstructionTextGenerationPipeline"}, 2186 False, 2187 False, 2188 ), 2189 ( 2190 "What answers?", 2191 [{"generated_text": "What answers?\n\nA collection of\n\nanswers"}], 2192 "A collection of answers", 2193 {"instance_type": "InstructionTextGenerationPipeline"}, 2194 False, 2195 True, 2196 ), 2197 ( 2198 "Hello!", 2199 [{"generated_text": "Hello!\n\nHow are you?"}], 2200 "How are you?", 2201 {"instance_type": "InstructionTextGenerationPipeline"}, 2202 False, 2203 False, 2204 ), 2205 ( 2206 "Hello!", 2207 [{"generated_text": "Hello!\n\nA: How are you?\n\n"}], 2208 "How are you?", 2209 {"instance_type": "InstructionTextGenerationPipeline"}, 2210 False, 2211 True, 2212 ), 2213 ( 2214 "Hello!", 2215 [{"generated_text": "Hello!\n\nA: How are you?\n\n"}], 2216 "Hello! A: How are you?", 2217 {"instance_type": "InstructionTextGenerationPipeline"}, 2218 True, 2219 True, 2220 ), 2221 ( 2222 "Hello!", 2223 [{"generated_text": "Hello!\n\nA: How\nare\nyou?\n\n"}], 2224 "How\nare\nyou?\n\n", 2225 {"instance_type": "InstructionTextGenerationPipeline"}, 2226 False, 2227 False, 2228 ), 2229 ( 2230 ["Hi!", "What's up?"], 2231 [[{"generated_text": "Hi!\n\nHello there"}, {"generated_text": "Not much, and you?"}]], 2232 ["Hello there", "Not much, and you?"], 2233 {"instance_type": "InstructionTextGenerationPipeline"}, 2234 False, 2235 False, 2236 ), 2237 # Tests disabling parsing of newline characters 2238 ( 2239 ["Hi!", "What's up?"], 2240 [ 2241 [ 2242 {"generated_text": "Hi!\n\nHello there"}, 2243 {"generated_text": "What's up?\n\nNot much, and you?"}, 2244 ] 2245 ], 2246 ["Hi!\n\nHello there", "What's up?\n\nNot much, and you?"], 2247 {"instance_type": "InstructionTextGenerationPipeline"}, 2248 True, 2249 False, 2250 ), 2251 ( 2252 "Hello!", 2253 [{"generated_text": "Hello!\n\nHow are you?"}], 2254 "Hello!\n\nHow are you?", 2255 {"instance_type": "InstructionTextGenerationPipeline"}, 2256 True, 2257 False, 2258 ), 2259 # Tests a standard TextGenerationPipeline output 2260 ( 2261 ["We like to", "Open the"], 2262 [ 2263 [ 2264 {"generated_text": "We like to party"}, 2265 {"generated_text": "Open the door get on the floor everybody do the dinosaur"}, 2266 ] 2267 ], 2268 ["We like to party", "Open the door get on the floor everybody do the dinosaur"], 2269 {"instance_type": "TextGenerationPipeline"}, 2270 True, 2271 True, 2272 ), 2273 # Tests a standard TextGenerationPipeline output with setting "include_prompt" (noop) 2274 ( 2275 ["We like to", "Open the"], 2276 [ 2277 [ 2278 {"generated_text": "We like to party"}, 2279 {"generated_text": "Open the door get on the floor everybody do the dinosaur"}, 2280 ] 2281 ], 2282 ["We like to party", "Open the door get on the floor everybody do the dinosaur"], 2283 {"instance_type": "TextGenerationPipeline"}, 2284 False, 2285 False, 2286 ), 2287 # Test TextGenerationPipeline removes whitespace 2288 ( 2289 ["We like to", "Open the"], 2290 [ 2291 [ 2292 {"generated_text": " We like to party"}, 2293 { 2294 "generated_text": "Open the door get on the floor everybody " 2295 "do\nthe dinosaur" 2296 }, 2297 ] 2298 ], 2299 ["We like to party", "Open the door get on the floor everybody do the dinosaur"], 2300 {"instance_type": "TextGenerationPipeline"}, 2301 False, 2302 True, 2303 ), 2304 ], 2305 ) 2306 def test_parse_input_from_instruction_pipeline( 2307 mock_pyfunc_wrapper, 2308 pipeline_input, 2309 pipeline_output, 2310 expected_output, 2311 flavor_config, 2312 include_prompt, 2313 collapse_whitespace, 2314 ): 2315 assert ( 2316 mock_pyfunc_wrapper._strip_input_from_response_in_instruction_pipelines( 2317 pipeline_input, 2318 pipeline_output, 2319 "generated_text", 2320 flavor_config, 2321 include_prompt, 2322 collapse_whitespace, 2323 ) 2324 == expected_output 2325 ) 2326 2327 2328 @pytest.mark.parametrize( 2329 "flavor_config", 2330 [ 2331 {"instance_type": "InstructionTextGenerationPipeline"}, 2332 {"instance_type": "TextGenerationPipeline"}, 2333 ], 2334 ) 2335 def test_invalid_instruction_pipeline_parsing(mock_pyfunc_wrapper, flavor_config): 2336 prompt = "What is your favorite boba flavor?" 2337 2338 bad_output = {"generated_text": ["Strawberry Milk Cap", "Honeydew with boba"]} 2339 2340 with pytest.raises(MlflowException, match="Unable to parse the pipeline output. Expected"): 2341 mock_pyfunc_wrapper._strip_input_from_response_in_instruction_pipelines( 2342 prompt, bad_output, "generated_text", flavor_config, True 2343 ) 2344 2345 2346 @pytest.mark.skipif(RUNNING_IN_GITHUB_ACTIONS, reason=GITHUB_ACTIONS_SKIP_REASON) 2347 def test_instructional_pipeline_no_prompt_in_output(model_path): 2348 architecture = "databricks/dolly-v2-3b" 2349 dolly = transformers.pipeline(model=architecture, trust_remote_code=True) 2350 2351 mlflow.transformers.save_model( 2352 transformers_model=dolly, 2353 path=model_path, 2354 # Validate removal of prompt but inclusion of newlines by default 2355 model_config={"max_length": 100, "include_prompt": False}, 2356 input_example="Hello, Dolly!", 2357 ) 2358 2359 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 2360 2361 inference = pyfunc_loaded.predict("What is MLflow?") 2362 2363 assert not inference[0].startswith("What is MLflow?") 2364 assert "\n" in inference[0] 2365 2366 2367 @pytest.mark.skipif(RUNNING_IN_GITHUB_ACTIONS, reason=GITHUB_ACTIONS_SKIP_REASON) 2368 def test_instructional_pipeline_no_prompt_in_output_and_removal_of_newlines(model_path): 2369 architecture = "databricks/dolly-v2-3b" 2370 dolly = transformers.pipeline(model=architecture, trust_remote_code=True) 2371 2372 mlflow.transformers.save_model( 2373 transformers_model=dolly, 2374 path=model_path, 2375 # Validate removal of prompt but inclusion of newlines by default 2376 model_config={"max_length": 100, "include_prompt": False, "collapse_whitespace": True}, 2377 input_example="Hello, Dolly!", 2378 ) 2379 2380 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 2381 2382 inference = pyfunc_loaded.predict("What is MLflow?") 2383 2384 assert not inference[0].startswith("What is MLflow?") 2385 assert "\n" not in inference[0] 2386 2387 2388 @pytest.mark.skipif(RUNNING_IN_GITHUB_ACTIONS, reason=GITHUB_ACTIONS_SKIP_REASON) 2389 def test_instructional_pipeline_with_prompt_in_output(model_path): 2390 architecture = "databricks/dolly-v2-3b" 2391 dolly = transformers.pipeline(model=architecture, trust_remote_code=True) 2392 2393 mlflow.transformers.save_model( 2394 transformers_model=dolly, 2395 path=model_path, 2396 # test default propagation of `include_prompt`=True and `collapse_whitespace`=False 2397 model_config={"max_length": 100}, 2398 input_example="Hello, Dolly!", 2399 ) 2400 2401 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 2402 2403 inference = pyfunc_loaded.predict("What is MLflow?") 2404 2405 assert inference[0].startswith("What is MLflow?") 2406 assert "\n\n" in inference[0] 2407 2408 2409 def read_audio_data(format: str): 2410 datasets_path = pathlib.Path(__file__).resolve().parent.parent.joinpath("datasets") 2411 wav_file_path = datasets_path.joinpath("apollo11_launch.wav") 2412 if format == "float": 2413 audio, _ = librosa.load(wav_file_path, sr=16000) 2414 return audio 2415 elif format == "bytes": 2416 return wav_file_path.read_bytes() 2417 elif format == "file": 2418 return wav_file_path.as_posix() 2419 else: 2420 raise ValueError(f"Invalid format: {format}") 2421 2422 2423 @pytest.mark.parametrize("input_format", ["float", "bytes", "file"]) 2424 @pytest.mark.parametrize("with_input_example", [True, False]) 2425 def test_whisper_model_predict(model_path, whisper_pipeline, input_format, with_input_example): 2426 if input_format == "float" and not with_input_example: 2427 pytest.skip("If the input format is float, the default signature must be overridden.") 2428 2429 audio = read_audio_data(input_format) 2430 mlflow.transformers.save_model( 2431 transformers_model=whisper_pipeline, 2432 path=model_path, 2433 input_example=audio if with_input_example else None, 2434 save_pretrained=False, 2435 ) 2436 2437 # 1. Single prediction with native transformer pipeline 2438 loaded_pipeline = mlflow.transformers.load_model(model_path) 2439 transcription = loaded_pipeline(audio) 2440 assert transcription["text"].startswith(" 30") 2441 # strip the leading space 2442 expected_text = transcription["text"].lstrip() 2443 2444 # 2. Single prediction with Pyfunc 2445 loaded_pyfunc = mlflow.pyfunc.load_model(model_path) 2446 pyfunc_transcription = loaded_pyfunc.predict(audio)[0] 2447 assert pyfunc_transcription == expected_text 2448 2449 # 3. Batch prediction with Pyfunc. Float input format is not supported for batch prediction, 2450 # because our signature framework doesn't support a list of numpy array. 2451 if input_format != "float": 2452 batch_transcription = loaded_pyfunc.predict([audio, audio]) 2453 assert len(batch_transcription) == 2 2454 assert all(ts == expected_text for ts in batch_transcription) 2455 2456 2457 def test_whisper_model_serve_and_score(whisper_pipeline): 2458 # Request payload to the model serving endpoint contains base64 encoded audio data 2459 audio = read_audio_data("bytes") 2460 encoded_audio = base64.b64encode(audio).decode("ascii") 2461 2462 with mlflow.start_run(): 2463 model_info = mlflow.transformers.log_model( 2464 whisper_pipeline, 2465 name="whisper", 2466 save_pretrained=False, 2467 ) 2468 2469 def _assert_response(response, length=1): 2470 preds = json.loads(response.content.decode("utf-8"))["predictions"] 2471 assert len(preds) == length 2472 assert all(pred.startswith("30") for pred in preds) 2473 2474 with pyfunc_scoring_endpoint( 2475 model_info.model_uri, 2476 extra_args=["--env-manager", "local"], 2477 ) as endpoint: 2478 content_type = pyfunc_scoring_server.CONTENT_TYPE_JSON 2479 2480 # Test payload with "inputs" key 2481 inputs_dict = {"inputs": [encoded_audio]} 2482 payload = json.dumps(inputs_dict) 2483 response = endpoint.invoke(payload, content_type=content_type) 2484 _assert_response(response) 2485 2486 # Test payload with "dataframe_split" key 2487 inference_df = pd.DataFrame(pd.Series([encoded_audio], name="audio_file")) 2488 split_dict = {"dataframe_split": inference_df.to_dict(orient="split")} 2489 payload = json.dumps(split_dict) 2490 response = endpoint.invoke(payload, content_type=content_type) 2491 _assert_response(response) 2492 2493 # Test payload with "dataframe_records" key 2494 records_dict = {"dataframe_records": inference_df.to_dict(orient="records")} 2495 payload = json.dumps(records_dict) 2496 response = endpoint.invoke(payload, content_type=content_type) 2497 _assert_response(response) 2498 2499 # Test batch prediction 2500 inputs_dict = {"inputs": [encoded_audio, encoded_audio]} 2501 payload = json.dumps(inputs_dict) 2502 response = endpoint.invoke(payload, content_type=content_type) 2503 _assert_response(response, length=2) 2504 2505 # Scoring with audio file URI is not supported yet (pyfunc prediction works tho) 2506 inputs_dict = {"inputs": [read_audio_data("file")]} 2507 payload = json.dumps(inputs_dict) 2508 response = endpoint.invoke(payload, content_type=content_type) 2509 response = json.loads(response.content.decode("utf-8")) 2510 assert response["error_code"] == "INVALID_PARAMETER_VALUE" 2511 assert "Failed to process the input audio data. Either" in response["message"] 2512 2513 2514 # https://github.com/huggingface/transformers/commit/9c500015c556f9ddf6e7a7449d3f46b2e3ff8ea5 2515 # caused a regression in beam search. 2516 # https://github.com/huggingface/transformers/commit/a6b51e7341d702127a4a45f37439640840b5abf0 2517 # fixed the regression but has not been released yet as of May 30, 2025. 2518 @pytest.mark.skipif( 2519 Version("4.52.0") <= Version(transformers.__version__) < Version("4.53.0"), 2520 reason="Transformers 4.52 has a bug for beam search in whiper implementation", 2521 ) 2522 def test_whisper_model_support_timestamps(whisper_pipeline): 2523 # Request payload to the model serving endpoint contains base64 encoded audio data 2524 audio = read_audio_data("bytes") 2525 encoded_audio = base64.b64encode(audio).decode("ascii") 2526 2527 model_config = { 2528 "return_timestamps": "word", 2529 "chunk_length_s": 20, 2530 "stride_length_s": [5, 3], 2531 } 2532 2533 with mlflow.start_run(): 2534 model_info = mlflow.transformers.log_model( 2535 whisper_pipeline, 2536 name="whisper_timestamps", 2537 model_config=model_config, 2538 input_example=(audio, model_config), 2539 ) 2540 2541 # Native transformers prediction as ground truth 2542 gt = whisper_pipeline(audio, **model_config) 2543 2544 def _assert_prediction(pred): 2545 assert pred["text"] == gt["text"] 2546 assert len(pred["chunks"]) == len(gt["chunks"]) 2547 for pred_chunk, gt_chunk in zip(pred["chunks"], gt["chunks"]): 2548 assert pred_chunk["text"] == gt_chunk["text"] 2549 # Timestamps are tuples, but converted to list when serialized to JSON. 2550 assert tuple(pred_chunk["timestamp"]) == gt_chunk["timestamp"] 2551 2552 # Prediction with Pyfunc 2553 loaded_pyfunc = mlflow.pyfunc.load_model(model_info.model_uri) 2554 prediction = json.loads(loaded_pyfunc.predict(audio)[0]) 2555 _assert_prediction(prediction) 2556 2557 # Serve and score 2558 with pyfunc_scoring_endpoint( 2559 model_info.model_uri, 2560 extra_args=["--env-manager", "local"], 2561 ) as endpoint: 2562 content_type = pyfunc_scoring_server.CONTENT_TYPE_JSON 2563 payload = json.dumps({"inputs": [encoded_audio]}) 2564 response = endpoint.invoke(payload, content_type=content_type) 2565 2566 predictions = json.loads(response.content.decode("utf-8"))["predictions"] 2567 # When return_timestamps is specified, the predictions list contains json 2568 # serialized output from the pipeline. 2569 _assert_prediction(json.loads(predictions[0])) 2570 2571 # Request with inference params 2572 payload = json.dumps({ 2573 "inputs": [encoded_audio], 2574 "model_config": model_config, 2575 }) 2576 response = endpoint.invoke(payload, content_type=content_type) 2577 predictions = json.loads(response.content.decode("utf-8"))["predictions"] 2578 _assert_prediction(json.loads(predictions[0])) 2579 2580 2581 def test_whisper_model_pyfunc_with_malformed_input(whisper_pipeline, model_path): 2582 mlflow.transformers.save_model( 2583 transformers_model=whisper_pipeline, 2584 path=model_path, 2585 save_pretrained=False, 2586 ) 2587 2588 pyfunc_model = mlflow.pyfunc.load_model(model_path) 2589 2590 invalid_audio = b"This isn't a real audio file" 2591 with pytest.raises(MlflowException, match="Failed to process the input audio data. Either"): 2592 pyfunc_model.predict([invalid_audio]) 2593 2594 bad_uri_msg = "An invalid string input was provided. String" 2595 2596 with pytest.raises(MlflowException, match=bad_uri_msg): 2597 pyfunc_model.predict("An invalid path") 2598 2599 with pytest.raises(MlflowException, match=bad_uri_msg): 2600 pyfunc_model.predict("//www.invalid.net/audio.wav") 2601 2602 with pytest.raises(MlflowException, match=bad_uri_msg): 2603 pyfunc_model.predict("https:///my/audio.mp3") 2604 2605 2606 @pytest.mark.parametrize("with_input_example", [True, False]) 2607 def test_audio_classification_pipeline(audio_classification_pipeline, with_input_example): 2608 audio = read_audio_data("bytes") 2609 2610 with mlflow.start_run(): 2611 model_info = mlflow.transformers.log_model( 2612 audio_classification_pipeline, 2613 name="audio_classification", 2614 input_example=audio if with_input_example else None, 2615 save_pretrained=False, 2616 ) 2617 2618 inference_payload = json.dumps({"inputs": [base64.b64encode(audio).decode("ascii")]}) 2619 2620 response = pyfunc_serve_and_score_model( 2621 model_info.model_uri, 2622 data=inference_payload, 2623 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 2624 extra_args=["--env-manager", "local"], 2625 ) 2626 2627 values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions() 2628 assert isinstance(values, pd.DataFrame) 2629 assert len(values) > 1 2630 assert list(values.columns) == ["score", "label"] 2631 2632 2633 @pytest.mark.parametrize( 2634 "model_name", 2635 [ 2636 "tiiuae/falcon-7b", 2637 "openai-community/gpt2", 2638 "PrunaAI/runwayml-stable-diffusion-v1-5-turbo-tiny-green-smashed", 2639 ], 2640 ) 2641 def test_save_model_card_with_non_utf_characters(tmp_path, model_name): 2642 # non-ascii unicode characters 2643 test_text = ( 2644 "Emoji testing! \u2728 \U0001f600 \U0001f609 \U0001f606 " 2645 "\U0001f970 \U0001f60e \U0001f917 \U0001f9d0" 2646 ) 2647 2648 card_data: ModelCard = huggingface_hub.ModelCard.load(model_name) 2649 card_data.text = card_data.text + "\n\n" + test_text 2650 custom_data = card_data.data.to_dict() 2651 custom_data["emojis"] = test_text 2652 2653 card_data.data = huggingface_hub.CardData(**custom_data) 2654 _write_card_data(card_data, tmp_path) 2655 2656 txt = tmp_path.joinpath(_CARD_TEXT_FILE_NAME).read_text() 2657 assert txt == card_data.text 2658 data = yaml.safe_load(tmp_path.joinpath(_CARD_DATA_FILE_NAME).read_text()) 2659 assert data == card_data.data.to_dict() 2660 2661 2662 def test_vision_pipeline_pyfunc_predict_with_kwargs(small_vision_model): 2663 artifact_path = "image_classification_model" 2664 2665 parameters = { 2666 "top_k": 2, 2667 } 2668 inference_payload = json.dumps({ 2669 "inputs": [image_url], 2670 "params": parameters, 2671 }) 2672 2673 with mlflow.start_run(): 2674 model_info = mlflow.transformers.log_model( 2675 small_vision_model, 2676 name=artifact_path, 2677 signature=infer_signature( 2678 image_url, 2679 mlflow.transformers.generate_signature_output(small_vision_model, image_url), 2680 params=parameters, 2681 ), 2682 ) 2683 model_uri = model_info.model_uri 2684 transformers_loaded_model = mlflow.transformers.load_model(model_uri) 2685 expected_predictions = transformers_loaded_model.predict(image_url) 2686 2687 response = pyfunc_serve_and_score_model( 2688 model_uri, 2689 data=inference_payload, 2690 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 2691 extra_args=["--env-manager", "local"], 2692 ) 2693 2694 predictions = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions() 2695 2696 assert ( 2697 list(predictions.to_dict("records")[0].values()) 2698 == expected_predictions[: parameters["top_k"]] 2699 ) 2700 2701 2702 def test_qa_pipeline_pyfunc_predict_with_kwargs(small_qa_pipeline): 2703 artifact_path = "qa_model" 2704 data = { 2705 "question": [ 2706 "What color is it?", 2707 "What does the 'wolf' howl at?", 2708 ], 2709 "context": [ 2710 "Some people said it was green but I know that it's pink.", 2711 "The pack of 'wolves' stood on the cliff and a 'lone wolf' howled at " 2712 "the moon for hours.", 2713 ], 2714 } 2715 parameters = { 2716 "top_k": 2, 2717 "max_answer_len": 5, 2718 } 2719 inference_payload = json.dumps({ 2720 "inputs": data, 2721 "params": parameters, 2722 }) 2723 output = mlflow.transformers.generate_signature_output(small_qa_pipeline, data) 2724 signature_with_params = infer_signature( 2725 data, 2726 output, 2727 parameters, 2728 ) 2729 expected_signature = ModelSignature( 2730 Schema([ 2731 ColSpec(Array(DataType.string), name="question"), 2732 ColSpec(Array(DataType.string), name="context"), 2733 ]), 2734 Schema([ColSpec(DataType.string)]), 2735 ParamSchema([ 2736 ParamSpec("top_k", DataType.long, 2), 2737 ParamSpec("max_answer_len", DataType.long, 5), 2738 ]), 2739 ) 2740 assert signature_with_params == expected_signature 2741 2742 with mlflow.start_run(): 2743 model_info = mlflow.transformers.log_model( 2744 small_qa_pipeline, 2745 name=artifact_path, 2746 signature=signature_with_params, 2747 ) 2748 2749 response = pyfunc_serve_and_score_model( 2750 model_info.model_uri, 2751 data=inference_payload, 2752 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 2753 extra_args=["--env-manager", "local"], 2754 ) 2755 values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions() 2756 2757 assert values.to_dict(orient="records") == [ 2758 {0: "pink"}, 2759 {0: "pink."}, 2760 {0: "the moon"}, 2761 {0: "moon"}, 2762 ] 2763 2764 2765 def test_uri_directory_renaming_handling_pipeline(model_path, text_classification_pipeline): 2766 with mlflow.start_run(): 2767 mlflow.transformers.save_model( 2768 transformers_model=text_classification_pipeline, path=model_path 2769 ) 2770 2771 absolute_model_directory = os.path.join(model_path, "model") 2772 renamed_to_old_convention = os.path.join(model_path, "pipeline") 2773 os.rename(absolute_model_directory, renamed_to_old_convention) 2774 2775 # remove the 'model_binary' entries to emulate older versions of MLflow 2776 mlmodel_file = os.path.join(model_path, "MLmodel") 2777 with open(mlmodel_file) as yaml_file: 2778 mlmodel = yaml.safe_load(yaml_file) 2779 2780 mlmodel["flavors"]["python_function"].pop("model_binary", None) 2781 mlmodel["flavors"]["transformers"].pop("model_binary", None) 2782 2783 with open(mlmodel_file, "w") as yaml_file: 2784 yaml.safe_dump(mlmodel, yaml_file) 2785 2786 loaded_model = mlflow.pyfunc.load_model(model_path) 2787 2788 prediction = loaded_model.predict("test") 2789 assert isinstance(prediction, pd.DataFrame) 2790 assert isinstance(prediction["label"][0], str) 2791 2792 2793 def test_uri_directory_renaming_handling_components(model_path, text_classification_pipeline): 2794 components = { 2795 "tokenizer": text_classification_pipeline.tokenizer, 2796 "model": text_classification_pipeline.model, 2797 } 2798 2799 with mlflow.start_run(): 2800 mlflow.transformers.save_model(transformers_model=components, path=model_path) 2801 2802 absolute_model_directory = os.path.join(model_path, "model") 2803 renamed_to_old_convention = os.path.join(model_path, "pipeline") 2804 os.rename(absolute_model_directory, renamed_to_old_convention) 2805 2806 # remove the 'model_binary' entries to emulate older versions of MLflow 2807 mlmodel_file = os.path.join(model_path, "MLmodel") 2808 with open(mlmodel_file) as yaml_file: 2809 mlmodel = yaml.safe_load(yaml_file) 2810 2811 mlmodel["flavors"]["python_function"].pop("model_binary", None) 2812 mlmodel["flavors"]["transformers"].pop("model_binary", None) 2813 2814 with open(mlmodel_file, "w") as yaml_file: 2815 yaml.safe_dump(mlmodel, yaml_file) 2816 2817 loaded_model = mlflow.pyfunc.load_model(model_path) 2818 2819 prediction = loaded_model.predict("test") 2820 assert isinstance(prediction, pd.DataFrame) 2821 assert isinstance(prediction["label"][0], str) 2822 2823 2824 @skip_transformers_v5_or_later 2825 def test_pyfunc_model_log_load_with_artifacts_snapshot(): 2826 architecture = "prajjwal1/bert-tiny" 2827 tokenizer = transformers.AutoTokenizer.from_pretrained(architecture) 2828 model = transformers.BertForQuestionAnswering.from_pretrained(architecture) 2829 bert_tiny_pipeline = transformers.pipeline( 2830 task="question-answering", model=model, tokenizer=tokenizer 2831 ) 2832 2833 class QAModel(mlflow.pyfunc.PythonModel): 2834 def load_context(self, context): 2835 """ 2836 This method initializes the tokenizer and language model 2837 using the specified snapshot location. 2838 """ 2839 snapshot_location = context.artifacts["bert-tiny-model"].removeprefix("hf:/") 2840 # Initialize tokenizer and language model 2841 tokenizer = transformers.AutoTokenizer.from_pretrained(snapshot_location) 2842 model = transformers.BertForQuestionAnswering.from_pretrained(snapshot_location) 2843 self.pipeline = transformers.pipeline( 2844 task="question-answering", model=model, tokenizer=tokenizer 2845 ) 2846 2847 def predict(self, context, model_input, params=None): 2848 question = model_input["question"][0] 2849 if isinstance(question, np.ndarray): 2850 question = question.item() 2851 ctx = model_input["context"][0] 2852 if isinstance(ctx, np.ndarray): 2853 ctx = ctx.item() 2854 return self.pipeline(question=question, context=ctx) 2855 2856 data = {"question": "Who's house?", "context": "The house is owned by Run."} 2857 pyfunc_artifact_path = "question_answering_model" 2858 with mlflow.start_run(): 2859 model_info = mlflow.pyfunc.log_model( 2860 name=pyfunc_artifact_path, 2861 python_model=QAModel(), 2862 artifacts={"bert-tiny-model": "hf:/prajjwal1/bert-tiny"}, 2863 input_example=data, 2864 signature=infer_signature( 2865 data, mlflow.transformers.generate_signature_output(bert_tiny_pipeline, data) 2866 ), 2867 extra_pip_requirements=["transformers", "torch", "numpy"], 2868 ) 2869 2870 pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri) 2871 assert len(os.listdir(os.path.join(pyfunc_model_path, "artifacts"))) != 0 2872 model_config = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) 2873 2874 loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri) 2875 assert model_config.to_yaml() == loaded_pyfunc_model.metadata.to_yaml() 2876 assert loaded_pyfunc_model.predict(data)["answer"] != "" 2877 2878 # Test model serving 2879 inference_payload = json.dumps({"inputs": data}) 2880 response = pyfunc_serve_and_score_model( 2881 model_info.model_uri, 2882 data=inference_payload, 2883 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 2884 extra_args=["--env-manager", "local"], 2885 ) 2886 values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions() 2887 2888 assert values.to_dict(orient="records")[0]["answer"] != "" 2889 2890 2891 def test_pyfunc_model_log_load_with_artifacts_snapshot_errors(): 2892 class TestModel(mlflow.pyfunc.PythonModel): 2893 def predict(self, context, model_input, params=None): 2894 return model_input 2895 2896 with mlflow.start_run(): 2897 with pytest.raises( 2898 MlflowException, 2899 match=r"Failed to download snapshot from Hugging Face Hub " 2900 r"with artifact_uri: hf:/invalid-repo-id.", 2901 ): 2902 mlflow.pyfunc.log_model( 2903 name="pyfunc_artifact_path", 2904 python_model=TestModel(), 2905 artifacts={"some-model": "hf:/invalid-repo-id"}, 2906 ) 2907 2908 2909 def test_model_distributed_across_devices(): 2910 mock_model = mock.Mock() 2911 mock_model.device.type = "meta" 2912 mock_model.hf_device_map = { 2913 "layer1": mock.Mock(type="cpu"), 2914 "layer2": mock.Mock(type="cpu"), 2915 "layer3": mock.Mock(type="gpu"), 2916 "layer4": mock.Mock(type="disk"), 2917 } 2918 2919 assert _is_model_distributed_in_memory(mock_model) 2920 2921 2922 def test_model_on_single_device(): 2923 mock_model = mock.Mock() 2924 mock_model.device.type = "cpu" 2925 mock_model.hf_device_map = {} 2926 2927 assert not _is_model_distributed_in_memory(mock_model) 2928 2929 2930 @skip_transformers_v5_or_later 2931 def test_basic_model_with_accelerate_device_mapping_fails_save(tmp_path, model_path): 2932 task = "translation_en_to_de" 2933 architecture = "t5-small" 2934 model = transformers.T5ForConditionalGeneration.from_pretrained( 2935 pretrained_model_name_or_path=architecture, 2936 device_map={"shared": "cpu", "encoder": "cpu", "decoder": "disk", "lm_head": "disk"}, 2937 offload_folder=str(tmp_path / "weights"), 2938 low_cpu_mem_usage=True, 2939 ) 2940 2941 tokenizer = transformers.T5TokenizerFast.from_pretrained( 2942 pretrained_model_name_or_path=architecture, model_max_length=100 2943 ) 2944 pipeline = transformers.pipeline(task=task, model=model, tokenizer=tokenizer) 2945 2946 with pytest.raises( 2947 MlflowException, 2948 match="The model that is attempting to be saved has been loaded into memory", 2949 ): 2950 mlflow.transformers.save_model(transformers_model=pipeline, path=model_path) 2951 2952 2953 @pytest.mark.skipif( 2954 Version(transformers.__version__) > Version("4.44.2"), 2955 reason="Multi-task pipeline (t5) has a loading issue with Transformers 4.45.x. " 2956 "See https://github.com/huggingface/transformers/issues/33398 for more details.", 2957 ) 2958 def test_basic_model_with_accelerate_homogeneous_mapping_works(model_path): 2959 task = "translation_en_to_de" 2960 architecture = "t5-small" 2961 model = transformers.T5ForConditionalGeneration.from_pretrained( 2962 pretrained_model_name_or_path=architecture, 2963 device_map={"shared": "cpu", "encoder": "cpu", "decoder": "cpu", "lm_head": "cpu"}, 2964 low_cpu_mem_usage=True, 2965 ) 2966 2967 tokenizer = transformers.T5TokenizerFast.from_pretrained( 2968 pretrained_model_name_or_path=architecture, model_max_length=100 2969 ) 2970 pipeline = transformers.pipeline(task=task, model=model, tokenizer=tokenizer) 2971 2972 mlflow.transformers.save_model(transformers_model=pipeline, path=model_path) 2973 2974 loaded = mlflow.transformers.load_model(model_path) 2975 text = "Apples are delicious" 2976 assert loaded(text) == pipeline(text) 2977 2978 2979 def test_qa_model_model_size_bytes(small_qa_pipeline, tmp_path): 2980 def _calculate_expected_size(path_or_dir): 2981 # this helper function does not consider subdirectories 2982 expected_size = 0 2983 if path_or_dir.is_dir(): 2984 for path in path_or_dir.iterdir(): 2985 if not path.is_file(): 2986 continue 2987 expected_size += path.stat().st_size 2988 elif path_or_dir.is_file(): 2989 expected_size = path_or_dir.stat().st_size 2990 return expected_size 2991 2992 mlflow.transformers.save_model( 2993 transformers_model=small_qa_pipeline, 2994 path=tmp_path, 2995 ) 2996 2997 # expected size only counts for files saved before the MLmodel file is saved 2998 model_dir = tmp_path.joinpath("model") 2999 tokenizer_dir = tmp_path.joinpath("components").joinpath("tokenizer") 3000 expected_size = 0 3001 for folder in [model_dir, tokenizer_dir]: 3002 expected_size += _calculate_expected_size(folder) 3003 other_files = ["model_card.md", "model_card_data.yaml", "LICENSE.txt"] 3004 for file in other_files: 3005 path = tmp_path.joinpath(file) 3006 expected_size += _calculate_expected_size(path) 3007 3008 mlmodel = yaml.safe_load(tmp_path.joinpath("MLmodel").read_bytes()) 3009 assert mlmodel["model_size_bytes"] == expected_size 3010 3011 3012 @pytest.mark.parametrize( 3013 ("task", "input_example"), 3014 [ 3015 ("llm/v1/completions", None), 3016 ("llm/v1/chat", None), 3017 ( 3018 "llm/v1/completions", 3019 { 3020 "prompt": "How to learn Python in 3 weeks?", 3021 "max_tokens": 10, 3022 "stop": "Python", 3023 }, 3024 ), 3025 ( 3026 "llm/v1/chat", 3027 { 3028 "messages": [ 3029 {"role": "system", "content": "Hello, how are you?"}, 3030 ], 3031 "temperature": 0.5, 3032 "max_tokens": 50, 3033 }, 3034 ), 3035 ], 3036 ) 3037 def test_text_generation_save_model_with_inference_task( 3038 monkeypatch, task, input_example, text_generation_pipeline, model_path 3039 ): 3040 # Strictly raise error during requirements inference for testing purposes 3041 monkeypatch.setenv("MLFLOW_REQUIREMENTS_INFERENCE_RAISE_ERRORS", "true") 3042 3043 mlflow.transformers.save_model( 3044 transformers_model=text_generation_pipeline, 3045 path=model_path, 3046 task=task, 3047 input_example=input_example, 3048 ) 3049 3050 mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes()) 3051 flavor_config = mlmodel["flavors"]["transformers"] 3052 assert flavor_config["inference_task"] == task 3053 assert mlmodel["metadata"]["task"] == task 3054 3055 if input_example: 3056 saved_input_example = json.loads(model_path.joinpath("input_example.json").read_text()) 3057 assert saved_input_example == input_example 3058 3059 3060 def test_text_generation_save_model_with_invalid_inference_task( 3061 text_generation_pipeline, model_path 3062 ): 3063 with pytest.raises( 3064 MlflowException, match=r"The task provided is invalid.*Must be.*llm/v1/completions" 3065 ): 3066 mlflow.transformers.save_model( 3067 transformers_model=text_generation_pipeline, 3068 path=model_path, 3069 task="llm/v1/invalid", 3070 ) 3071 3072 3073 def test_text_generation_log_model_with_mismatched_task(text_generation_pipeline): 3074 with pytest.raises( 3075 MlflowException, match=r"LLM v1 task type 'llm/v1/chat' is specified in metadata, but" 3076 ): 3077 with mlflow.start_run(): 3078 mlflow.transformers.log_model( 3079 text_generation_pipeline, 3080 name="model", 3081 # Task argument and metadata task are different 3082 task=None, 3083 metadata={"task": "llm/v1/chat"}, 3084 ) 3085 3086 3087 def test_text_generation_task_completions_predict_with_max_tokens( 3088 text_generation_pipeline, model_path 3089 ): 3090 mlflow.transformers.save_model( 3091 transformers_model=text_generation_pipeline, 3092 path=model_path, 3093 task="llm/v1/completions", 3094 model_config={"max_tokens": 10}, 3095 ) 3096 3097 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 3098 3099 inference = pyfunc_loaded.predict( 3100 {"prompt": "How to learn Python in 3 weeks?", "max_tokens": 10}, 3101 ) 3102 3103 assert isinstance(inference[0], dict) 3104 assert inference[0]["model"] == "distilgpt2" 3105 assert inference[0]["object"] == "text_completion" 3106 assert ( 3107 inference[0]["choices"][0]["finish_reason"] == "length" 3108 and inference[0]["usage"]["completion_tokens"] == 10 3109 ) or ( 3110 inference[0]["choices"][0]["finish_reason"] == "stop" 3111 and inference[0]["usage"]["completion_tokens"] < 10 3112 ) 3113 3114 # Override model_config with runtime params 3115 inference = pyfunc_loaded.predict( 3116 {"prompt": "How to learn Python in 3 weeks?", "max_tokens": 5}, 3117 ) 3118 assert 6 > inference[0]["usage"]["completion_tokens"] > 0 3119 3120 3121 def test_text_generation_task_completions_predict_with_stop(text_generation_pipeline, model_path): 3122 mlflow.transformers.save_model( 3123 transformers_model=text_generation_pipeline, 3124 path=model_path, 3125 task="llm/v1/completions", 3126 metadata={"foo": "bar"}, 3127 model_config={"stop": ["Python"], "max_tokens": 50}, 3128 ) 3129 3130 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 3131 inference = pyfunc_loaded.predict( 3132 {"prompt": "How to learn Python in 3 weeks?"}, 3133 ) 3134 3135 if "Python" not in inference[0]["choices"][0]["text"]: 3136 pytest.skip( 3137 "Model did not generate text containing 'Python', " 3138 "skipping validation of stop parameter in inference" 3139 ) 3140 3141 assert ( 3142 inference[0]["choices"][0]["finish_reason"] == "stop" 3143 and inference[0]["usage"]["completion_tokens"] < 50 3144 ) or ( 3145 inference[0]["choices"][0]["finish_reason"] == "length" 3146 and inference[0]["usage"]["completion_tokens"] == 50 3147 ) 3148 3149 assert inference[0]["choices"][0]["text"].endswith("Python") 3150 3151 # Override model_config with runtime params 3152 inference = pyfunc_loaded.predict( 3153 {"prompt": "How to learn Python in 3 weeks?", "stop": ["Abracadabra"]}, 3154 ) 3155 3156 response_text = inference[0]["choices"][0]["text"] 3157 3158 # Only check for early stopping if we stop on the word "Python". 3159 # If we exhaust the token limit, there is a non-insignificant chance of 3160 # terminating on the word due to max tokens, which should not count as 3161 # a stop word abort if there are multiple instances of the word in the text. 3162 if 0 < response_text.count("Python") < 2: 3163 assert not inference[0]["choices"][0]["text"].endswith("Python") 3164 3165 3166 def test_text_generation_task_completions_serve(text_generation_pipeline): 3167 data = {"prompt": "How to learn Python in 3 weeks?"} 3168 3169 with mlflow.start_run(): 3170 model_info = mlflow.transformers.log_model( 3171 text_generation_pipeline, 3172 name="model", 3173 task="llm/v1/completions", 3174 ) 3175 3176 inference_payload = json.dumps({"inputs": data}) 3177 3178 response = pyfunc_serve_and_score_model( 3179 model_info.model_uri, 3180 data=inference_payload, 3181 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 3182 extra_args=["--env-manager", "local"], 3183 ) 3184 values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions() 3185 output_dict = values.to_dict("records")[0] 3186 assert output_dict["choices"][0]["text"] is not None 3187 assert output_dict["choices"][0]["finish_reason"] == "stop" 3188 assert output_dict["usage"]["prompt_tokens"] < 20 3189 3190 3191 def test_llm_v1_task_embeddings_predict(feature_extraction_pipeline, model_path): 3192 mlflow.transformers.save_model( 3193 transformers_model=feature_extraction_pipeline, 3194 path=model_path, 3195 input_examples=["Football", "Soccer"], 3196 task="llm/v1/embeddings", 3197 ) 3198 3199 mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes()) 3200 3201 flavor_config = mlmodel["flavors"]["transformers"] 3202 assert flavor_config["inference_task"] == "llm/v1/embeddings" 3203 assert mlmodel["metadata"]["task"] == "llm/v1/embeddings" 3204 3205 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 3206 3207 # Predict with single string input 3208 prediction = pyfunc_loaded.predict({"input": "A great day"}) 3209 assert prediction["object"] == "list" 3210 assert len(prediction["data"]) == 1 3211 assert prediction["data"][0]["object"] == "embedding" 3212 assert prediction["usage"]["prompt_tokens"] == 5 3213 assert len(prediction["data"][0]["embedding"]) == 384 3214 3215 # Predict with list of string input 3216 prediction = pyfunc_loaded.predict({"input": ["A great day", "A bad day"]}) 3217 assert prediction["object"] == "list" 3218 assert len(prediction["data"]) == 2 3219 assert prediction["data"][0]["object"] == "embedding" 3220 assert prediction["usage"]["prompt_tokens"] == 10 3221 assert len(prediction["data"][0]["embedding"]) == 384 3222 3223 # Predict with pandas dataframe input 3224 df = pd.DataFrame({"input": ["A great day", "A bad day", "A good day"]}) 3225 prediction = pyfunc_loaded.predict(df) 3226 assert prediction["object"] == "list" 3227 assert len(prediction["data"]) == 3 3228 assert prediction["data"][0]["object"] == "embedding" 3229 assert prediction["usage"]["prompt_tokens"] == 15 3230 assert len(prediction["data"][0]["embedding"]) == 384 3231 3232 3233 @pytest.mark.parametrize( 3234 "request_payload", 3235 [ 3236 {"input": "A single string"}, 3237 { 3238 "inputs": {"input": ["A list of strings"]}, 3239 }, 3240 ], 3241 ) 3242 def test_llm_v1_task_embeddings_serve(feature_extraction_pipeline, request_payload): 3243 with mlflow.start_run(): 3244 model_info = mlflow.transformers.log_model( 3245 feature_extraction_pipeline, 3246 name="model", 3247 input_examples=["Football", "Soccer"], 3248 task="llm/v1/embeddings", 3249 ) 3250 3251 response = pyfunc_serve_and_score_model( 3252 model_info.model_uri, 3253 data=json.dumps(request_payload), 3254 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 3255 extra_args=["--env-manager", "local"], 3256 ) 3257 response = json.loads(response.content.decode("utf-8")) 3258 prediction = response["predictions"] if "inputs" in request_payload else response 3259 3260 assert prediction["object"] == "list" 3261 assert len(prediction["data"]) == 1 3262 assert prediction["data"][0]["object"] == "embedding" 3263 assert len(prediction["data"][0]["embedding"]) == 384 3264 3265 3266 def test_get_task_for_model(): 3267 with mock.patch("transformers.pipelines.get_task") as mock_get_task: 3268 mock_get_task.return_value = "feature-extraction" 3269 assert _get_task_for_model("model") == "feature-extraction" 3270 3271 # Some model task is not supported by Transformers pipeline yet. Then fall back 3272 # to the default task if provided, otherwise raise an exception. 3273 mock_get_task.return_value = "unsupported-task" 3274 assert ( 3275 _get_task_for_model("model", default_task="feature-extraction") == "feature-extraction" 3276 ) 3277 3278 with pytest.raises(MlflowException, match="Cannot construct transformers pipeline"): 3279 _get_task_for_model("model") 3280 3281 # If get_task raises an exception, fall back to the default task if provided. 3282 mock_get_task.side_effect = RuntimeError("Some error") 3283 assert ( 3284 _get_task_for_model("model", default_task="feature-extraction") == "feature-extraction" 3285 ) 3286 3287 with pytest.raises(MlflowException, match="The task could not be inferred"): 3288 _get_task_for_model("model") 3289 3290 3291 @skip_transformers_v5_or_later 3292 def test_local_custom_model_save_and_load(text_generation_pipeline, model_path, tmp_path): 3293 local_repo_path = tmp_path / "local_repo" 3294 text_generation_pipeline.save_pretrained(local_repo_path) 3295 3296 locally_loaded_model = transformers.AutoModelForCausalLM.from_pretrained(local_repo_path) 3297 tokenizer = transformers.AutoTokenizer.from_pretrained( 3298 local_repo_path, chat_template=CHAT_TEMPLATE 3299 ) 3300 model_dict = {"model": locally_loaded_model, "tokenizer": tokenizer} 3301 3302 # 1. Save local custom model without specifying task -> raises MlflowException 3303 with pytest.raises(MlflowException, match=r"The task could not be inferred"): 3304 mlflow.transformers.save_model(transformers_model=model_dict, path=model_path) 3305 3306 # 2. Save local custom model with task -> saves successfully 3307 mlflow.transformers.save_model( 3308 transformers_model=model_dict, 3309 path=model_path, 3310 task="text-generation", 3311 ) 3312 3313 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 3314 3315 inference = pyfunc_loaded.predict("How to save Transformer model?") 3316 assert isinstance(inference[0], str) 3317 assert inference[0].startswith("How to save Transformer model?") 3318 3319 # 3. Save local custom model with LLM v1 chat inference task -> saves successfully 3320 # with the corresponding Transformers task 3321 shutil.rmtree(model_path) 3322 3323 mlflow.transformers.save_model( 3324 transformers_model=model_dict, 3325 path=model_path, 3326 task="llm/v1/chat", 3327 ) 3328 3329 mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes()) 3330 flavor_config = mlmodel["flavors"]["transformers"] 3331 assert flavor_config["task"] == "text-generation" 3332 assert flavor_config["inference_task"] == "llm/v1/chat" 3333 assert mlmodel["metadata"]["task"] == "llm/v1/chat" 3334 3335 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 3336 3337 inference = pyfunc_loaded.predict({ 3338 "messages": [ 3339 { 3340 "role": "user", 3341 "content": "How to save Transformer model?", 3342 } 3343 ] 3344 }) 3345 assert isinstance(inference[0], dict) 3346 assert inference[0]["choices"][0]["message"]["role"] == "assistant" 3347 3348 3349 def test_model_config_is_not_mutated_after_prediction(text2text_generation_pipeline): 3350 model_config = { 3351 "top_k": 2, 3352 "num_beams": 5, 3353 "max_length": 30, 3354 "max_new_tokens": 500, 3355 } 3356 3357 # Params will be used to override the values of model_config but should not mutate it 3358 params = { 3359 "top_k": 30, 3360 "max_length": 500, 3361 "max_new_tokens": 5, 3362 } 3363 3364 pyfunc_model = _TransformersWrapper(text2text_generation_pipeline, model_config=model_config) 3365 assert pyfunc_model.model_config["top_k"] == 2 3366 3367 prediction_output = pyfunc_model.predict( 3368 "rocket moon ship astronaut space gravity", params=params 3369 ) 3370 3371 assert pyfunc_model.model_config["top_k"] == 2 3372 assert pyfunc_model.model_config["num_beams"] == 5 3373 assert pyfunc_model.model_config["max_length"] == 30 3374 assert pyfunc_model.model_config["max_new_tokens"] == 500 3375 assert len(prediction_output[0].split(" ")) <= 5 3376 3377 3378 def test_text_generation_task_chat_predict(text_generation_pipeline, model_path): 3379 mlflow.transformers.save_model( 3380 transformers_model=text_generation_pipeline, 3381 path=model_path, 3382 task="llm/v1/chat", 3383 ) 3384 3385 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 3386 3387 inference = pyfunc_loaded.predict({ 3388 "messages": [ 3389 {"role": "system", "content": "Hello, how can I help you today?"}, 3390 {"role": "user", "content": "How to learn Python in 3 weeks?"}, 3391 ], 3392 "max_tokens": 10, 3393 }) 3394 3395 assert inference[0]["choices"][0]["message"]["role"] == "assistant" 3396 assert ( 3397 inference[0]["choices"][0]["finish_reason"] == "length" 3398 and inference[0]["usage"]["completion_tokens"] == 10 3399 ) or ( 3400 inference[0]["choices"][0]["finish_reason"] == "stop" 3401 and inference[0]["usage"]["completion_tokens"] < 10 3402 ) 3403 3404 3405 def test_text_generation_task_chat_serve(text_generation_pipeline): 3406 data = { 3407 "messages": [ 3408 {"role": "user", "content": "How to learn Python in 3 weeks?"}, 3409 ], 3410 "max_tokens": 10, 3411 } 3412 3413 with mlflow.start_run(): 3414 model_info = mlflow.transformers.log_model( 3415 text_generation_pipeline, 3416 name="model", 3417 task="llm/v1/chat", 3418 ) 3419 3420 inference_payload = json.dumps(data) 3421 3422 response = pyfunc_serve_and_score_model( 3423 model_info.model_uri, 3424 data=inference_payload, 3425 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 3426 extra_args=["--env-manager", "local"], 3427 ) 3428 3429 output_dict = json.loads(response.content)[0] 3430 assert output_dict["choices"][0]["message"] is not None 3431 assert ( 3432 output_dict["choices"][0]["finish_reason"] == "length" 3433 and output_dict["usage"]["completion_tokens"] == 10 3434 ) or ( 3435 output_dict["choices"][0]["finish_reason"] == "stop" 3436 and output_dict["usage"]["completion_tokens"] < 10 3437 ) 3438 assert output_dict["usage"]["prompt_tokens"] < 20 3439 3440 3441 HF_COMMIT_HASH_PATTERN = re.compile(r"^[a-z0-9]{40}$") 3442 3443 3444 @pytest.mark.parametrize( 3445 ("model_fixture", "input_example", "components"), 3446 [ 3447 ("text2text_generation_pipeline", "What is MLflow?", {"tokenizer"}), 3448 ("text_generation_pipeline", "What is MLflow?", {"tokenizer"}), 3449 ( 3450 "small_vision_model", 3451 image_url, 3452 {"image_processor"} if IS_NEW_FEATURE_EXTRACTION_API else {"feature_extractor"}, 3453 ), 3454 ( 3455 "component_multi_modal", 3456 {"text": "What is MLflow?", "image": image_url}, 3457 {"image_processor", "tokenizer"} 3458 if IS_NEW_FEATURE_EXTRACTION_API 3459 else {"feature_extractor", "tokenizer"}, 3460 ), 3461 ("fill_mask_pipeline", "The quick brown <mask> jumps over the lazy dog.", {"tokenizer"}), 3462 ("whisper_pipeline", lambda: read_audio_data("bytes"), {"feature_extractor", "tokenizer"}), 3463 ("feature_extraction_pipeline", "What is MLflow?", {"tokenizer"}), 3464 ], 3465 ) 3466 def test_save_and_load_pipeline_without_save_pretrained_false( 3467 model_fixture, input_example, components, model_path, request 3468 ): 3469 pipeline = request.getfixturevalue(model_fixture) 3470 model = pipeline["model"] if isinstance(pipeline, dict) else pipeline.model 3471 3472 mlflow.transformers.save_model( 3473 transformers_model=pipeline, 3474 path=model_path, 3475 save_pretrained=False, 3476 ) 3477 3478 # No weights should be saved 3479 assert not model_path.joinpath("model").exists() 3480 assert not model_path.joinpath("components").exists() 3481 3482 # Validate the contents of MLModel file 3483 mlmodel = Model.load(str(model_path.joinpath("MLmodel"))) 3484 flavor_conf = mlmodel.flavors["transformers"] 3485 assert "model_binary" not in flavor_conf 3486 assert flavor_conf["source_model_name"] == model.name_or_path 3487 assert HF_COMMIT_HASH_PATTERN.match(flavor_conf["source_model_revision"]) 3488 assert set(flavor_conf["components"]) == components 3489 for c in components: 3490 component = pipeline[c] if isinstance(pipeline, dict) else getattr(pipeline, c) 3491 assert flavor_conf[f"{c}_name"] == getattr(component, "name_or_path", model.name_or_path) 3492 assert HF_COMMIT_HASH_PATTERN.match(flavor_conf[f"{c}_revision"]) 3493 3494 # Validate pyfunc load and prediction (if pyfunc supported) 3495 if "python_function" in mlmodel.flavors: 3496 if callable(input_example): 3497 input_example = input_example() 3498 mlflow.pyfunc.load_model(model_path).predict(input_example) 3499 3500 3501 # Patch tempdir just to verify the invocation 3502 def test_persist_pretrained_model(small_qa_pipeline): 3503 with mlflow.start_run(): 3504 model_info = mlflow.transformers.log_model( 3505 small_qa_pipeline, 3506 name="model", 3507 save_pretrained=False, 3508 pip_requirements=["mlflow"], # For speed up logging 3509 ) 3510 3511 artifact_path = Path(mlflow.artifacts.download_artifacts(model_info.model_uri)) 3512 model_path = artifact_path / "model" 3513 tokenizer_path = artifact_path / "components" / "tokenizer" 3514 3515 original_config = Model.load(artifact_path).flavors["transformers"] 3516 assert "model_binary" not in original_config 3517 assert "source_model_revision" in original_config 3518 assert not model_path.exists() 3519 assert not tokenizer_path.exists() 3520 3521 with mock.patch( 3522 "mlflow.transformers.TempDir", side_effect=mlflow.utils.file_utils.TempDir 3523 ) as mock_tmpdir: 3524 mlflow.transformers.persist_pretrained_model(model_info.model_uri) 3525 mock_tmpdir.assert_called_once() 3526 3527 updated_config = Model.load(model_info.model_uri).flavors["transformers"] 3528 assert "model_binary" in updated_config 3529 assert "source_model_revision" not in updated_config 3530 assert model_path.exists() 3531 model_path_files = list(model_path.iterdir()) 3532 assert len(model_path_files) > 0 3533 assert tokenizer_path.exists() 3534 assert (tokenizer_path / "tokenizer.json").exists() 3535 3536 # Repeat persisting the model will no-op 3537 with mock.patch( 3538 "mlflow.transformers.TempDir", side_effect=mlflow.utils.file_utils.TempDir 3539 ) as mock_tmpdir: 3540 mlflow.transformers.persist_pretrained_model(model_info.model_uri) 3541 mock_tmpdir.assert_not_called() 3542 3543 3544 def test_small_qa_pipeline_copy_metadata_in_databricks( 3545 mock_is_in_databricks, small_qa_pipeline, tmp_path 3546 ): 3547 artifact_path = "transformers" 3548 with mlflow.start_run(): 3549 model_info = mlflow.transformers.log_model( 3550 small_qa_pipeline, 3551 name=artifact_path, 3552 ) 3553 artifact_path = mlflow.artifacts.download_artifacts( 3554 artifact_uri=model_info.model_uri, dst_path=tmp_path.as_posix() 3555 ) 3556 3557 # Metadata should be copied only in Databricks 3558 metadata_path = os.path.join(artifact_path, "metadata") 3559 if mock_is_in_databricks.return_value: 3560 assert set(os.listdir(metadata_path)) == set(METADATA_FILES) 3561 else: 3562 assert not os.path.exists(metadata_path) 3563 mock_is_in_databricks.assert_called_once() 3564 3565 3566 def test_peft_pipeline_copy_metadata_in_databricks(mock_is_in_databricks, peft_pipeline, tmp_path): 3567 artifact_path = "transformers" 3568 with mlflow.start_run(): 3569 model_info = mlflow.transformers.log_model( 3570 peft_pipeline, 3571 name=artifact_path, 3572 ) 3573 3574 artifact_path = mlflow.artifacts.download_artifacts( 3575 artifact_uri=model_info.model_uri, dst_path=tmp_path.as_posix() 3576 ) 3577 3578 # Metadata should be copied only in Databricks 3579 metadata_path = os.path.join(artifact_path, "metadata") 3580 if mock_is_in_databricks.return_value: 3581 assert set(os.listdir(metadata_path)) == set(METADATA_FILES) 3582 else: 3583 assert not os.path.exists(metadata_path) 3584 mock_is_in_databricks.assert_called_once() 3585 3586 3587 @pytest.mark.parametrize("device", ["cpu", "cuda", 0, -1, None]) 3588 def test_device_param_on_load_model(device, small_qa_pipeline, model_path, monkeypatch): 3589 mlflow.transformers.save_model(small_qa_pipeline, path=model_path) 3590 conf = mlflow.transformers.load_model(model_path, return_type="components", device=device) 3591 assert conf.get("device") == device 3592 3593 monkeypatch.setenv("MLFLOW_HUGGINGFACE_USE_DEVICE_MAP", "true") 3594 if device is None: 3595 conf = mlflow.transformers.load_model(model_path, return_type="components", device=device) 3596 assert conf.get("device") is None 3597 else: 3598 with pytest.raises( 3599 MlflowException, 3600 match=r"The environment variable MLFLOW_HUGGINGFACE_USE_DEVICE_MAP is set to True, " 3601 rf"but the `device` argument is provided with value {device}.", 3602 ): 3603 mlflow.transformers.load_model(model_path, return_type="components", device=device) 3604 3605 3606 @pytest.fixture 3607 def local_checkpoint_path(tmp_path): 3608 """ 3609 Fixture to create a local model checkpoint for testing fine-tuning scenario. 3610 """ 3611 model = transformers.AutoModelForCausalLM.from_pretrained("distilgpt2") 3612 3613 class DummyDataset(torch.utils.data.Dataset): 3614 def __getitem__(self, idx): 3615 pass 3616 3617 def __len__(self): 3618 return 1 3619 3620 # Create a trainer and save model, but not running the actual training 3621 training_args = transformers.TrainingArguments( 3622 output_dir=tmp_path / "result", 3623 num_train_epochs=1, 3624 per_device_train_batch_size=4, 3625 report_to="none", 3626 ) 3627 trainer = transformers.Trainer(model=model, args=training_args, train_dataset=DummyDataset()) 3628 3629 checkpoint_path = tmp_path / "checkpoint" 3630 trainer.save_model(checkpoint_path) 3631 3632 # The tokenizer should also be saved in the checkpoint 3633 tokenizer = transformers.AutoTokenizer.from_pretrained( 3634 # Chat template is required to test with llm/v1/chat task 3635 "distilgpt2", 3636 chat_template=CHAT_TEMPLATE, 3637 ) 3638 tokenizer.save_pretrained(checkpoint_path) 3639 3640 return str(checkpoint_path) 3641 3642 3643 def test_save_model_from_local_checkpoint(model_path, local_checkpoint_path): 3644 with mock.patch("mlflow.transformers._logger") as mock_logger: 3645 mlflow.transformers.save_model( 3646 transformers_model=local_checkpoint_path, 3647 task="text-generation", 3648 path=model_path, 3649 input_example=["What is MLflow?"], 3650 ) 3651 3652 logged_info = Model.load(model_path) 3653 flavor_conf = logged_info.flavors["transformers"] 3654 assert flavor_conf["source_model_name"] == local_checkpoint_path 3655 assert flavor_conf["task"] == "text-generation" 3656 if not IS_TRANSFORMERS_V5_OR_LATER: 3657 assert flavor_conf["framework"] == "pt" 3658 assert flavor_conf["instance_type"] == "TextGenerationPipeline" 3659 expected_tokenizer_type = ( 3660 "GPT2Tokenizer" if IS_TRANSFORMERS_V5_OR_LATER else "GPT2TokenizerFast" 3661 ) 3662 assert flavor_conf["tokenizer_type"] == expected_tokenizer_type 3663 3664 # Default task signature should be used 3665 assert logged_info.signature.inputs == Schema([ColSpec(DataType.string)]) 3666 assert logged_info.signature.outputs == Schema([ColSpec(DataType.string)]) 3667 3668 # Default requirements should be used 3669 info_calls = mock_logger.info.call_args_list 3670 assert any("A local checkpoint path or PEFT model" in c[0][0] for c in info_calls) 3671 with model_path.joinpath("requirements.txt").open() as f: 3672 reqs = {req.split("==")[0] for req in f.read().split("\n")} 3673 assert reqs == {"mlflow", "accelerate", "transformers", "torch", "torchvision"} 3674 3675 # Load as native pipeline 3676 loaded_pipeline = mlflow.transformers.load_model(model_path) 3677 assert isinstance(loaded_pipeline, transformers.TextGenerationPipeline) 3678 3679 query = "What is MLflow?" 3680 pred_native = loaded_pipeline(query)[0] 3681 assert pred_native["generated_text"].startswith(query) 3682 3683 # Load as pyfunc 3684 loaded_pyfunc = mlflow.pyfunc.load_model(model_path) 3685 pred_pyfunc = loaded_pyfunc.predict(query)[0] 3686 assert pred_pyfunc.startswith(query) 3687 3688 # Serve 3689 response = pyfunc_serve_and_score_model( 3690 model_path, 3691 data=json.dumps({"inputs": [query]}), 3692 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 3693 extra_args=["--env-manager", "local"], 3694 ) 3695 pred_serve = json.loads(response.content.decode("utf-8")) 3696 assert pred_serve["predictions"][0].startswith(query) 3697 3698 3699 @skip_transformers_v5_or_later 3700 def test_save_model_from_local_checkpoint_with_custom_tokenizer(model_path, local_checkpoint_path): 3701 # When a custom tokenizer is also saved in the checkpoint, MLflow should save and load it. 3702 tokenizer = transformers.AutoTokenizer.from_pretrained("distilroberta-base") 3703 tokenizer.add_special_tokens({"additional_special_tokens": ["<sushi>"]}) 3704 tokenizer.save_pretrained(local_checkpoint_path) 3705 3706 mlflow.transformers.save_model( 3707 transformers_model=local_checkpoint_path, 3708 path=model_path, 3709 task="text-generation", 3710 input_example=["What is MLflow?"], 3711 ) 3712 3713 # The custom tokenizer should be loaded 3714 loaded_pipeline = mlflow.transformers.load_model(model_path) 3715 tokenizer = loaded_pipeline.tokenizer 3716 assert tokenizer.special_tokens_map["additional_special_tokens"] == ["<sushi>"] 3717 3718 3719 def test_save_model_from_local_checkpoint_with_llm_inference_task( 3720 model_path, local_checkpoint_path 3721 ): 3722 mlflow.transformers.save_model( 3723 transformers_model=local_checkpoint_path, 3724 path=model_path, 3725 task="llm/v1/chat", 3726 input_example=["What is MLflow?"], 3727 ) 3728 3729 logged_info = Model.load(model_path) 3730 flavor_conf = logged_info.flavors["transformers"] 3731 assert flavor_conf["source_model_name"] == local_checkpoint_path 3732 assert flavor_conf["task"] == "text-generation" 3733 assert flavor_conf["inference_task"] == "llm/v1/chat" 3734 3735 # Load as pyfunc 3736 loaded_pyfunc = mlflow.pyfunc.load_model(model_path) 3737 response = loaded_pyfunc.predict({ 3738 "messages": [ 3739 {"role": "system", "content": "Hello, how can I help you today?"}, 3740 {"role": "user", "content": "What is MLflow?"}, 3741 ], 3742 }) 3743 assert response[0]["choices"][0]["message"]["role"] == "assistant" 3744 assert response[0]["choices"][0]["message"]["content"] is not None 3745 3746 3747 def test_save_model_from_local_checkpoint_invalid_arguments(model_path, local_checkpoint_path): 3748 with pytest.raises(MlflowException, match=r"The `task` argument must be specified"): 3749 mlflow.transformers.save_model( 3750 transformers_model=local_checkpoint_path, 3751 path=model_path, 3752 ) 3753 3754 with pytest.raises( 3755 MlflowException, match=r"The `save_pretrained` argument must be set to True" 3756 ): 3757 mlflow.transformers.save_model( 3758 transformers_model=local_checkpoint_path, 3759 path=model_path, 3760 task="fill-mask", 3761 save_pretrained=False, 3762 ) 3763 3764 with pytest.raises( 3765 MlflowException, 3766 match=r"The provided directory invalid path does not contain a config.json file.", 3767 ): 3768 mlflow.transformers.save_model( 3769 transformers_model="invalid path", 3770 path=model_path, 3771 task="fill-mask", 3772 ) 3773 3774 3775 @pytest.mark.parametrize( 3776 ("model_fixture", "should_skip_validation"), 3777 [ 3778 ("local_checkpoint_path", True), 3779 ("fill_mask_pipeline", False), 3780 ], 3781 ) 3782 def test_log_model_skip_validating_serving_input_for_local_checkpoint( 3783 model_fixture, 3784 should_skip_validation, 3785 tmp_path, 3786 request, 3787 ): 3788 # input to avoid expensive computation 3789 model = request.getfixturevalue(model_fixture) 3790 with mock.patch("mlflow.models.validate_serving_input") as mock_validate_input: 3791 # Ensure mlflow skips serving input validation for local checkpoint 3792 with mlflow.start_run(): 3793 model_info = mlflow.transformers.log_model( 3794 model, 3795 name="model", 3796 task="fill-mask", 3797 input_example=["How are you?"], 3798 ) 3799 3800 # Serving input should exist regardless of the skip validation 3801 mlflow_model = Model.load(model_info.model_uri) 3802 local_path = _download_artifact_from_uri(model_info.model_uri, output_path=tmp_path) 3803 serving_input = mlflow_model.get_serving_input(local_path) 3804 assert json.loads(serving_input) == {"inputs": ["How are you?"]} 3805 3806 if should_skip_validation: 3807 mock_validate_input.assert_not_called() 3808 else: 3809 mock_validate_input.assert_called_once()