/ tests / transformers / test_transformers_model_export.py
test_transformers_model_export.py
   1  import base64
   2  import gc
   3  import importlib.util
   4  import json
   5  import math
   6  import os
   7  import pathlib
   8  import re
   9  import shutil
  10  import textwrap
  11  from pathlib import Path
  12  from unittest import mock
  13  
  14  import huggingface_hub
  15  import librosa
  16  import numpy as np
  17  import pandas as pd
  18  import pytest
  19  import torch
  20  import transformers
  21  import yaml
  22  from datasets import load_dataset
  23  from huggingface_hub import ModelCard
  24  from packaging.version import Version
  25  
  26  import mlflow
  27  import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
  28  from mlflow import pyfunc
  29  from mlflow.deployments import PredictionsResponse
  30  from mlflow.exceptions import MlflowException
  31  from mlflow.models import Model, ModelSignature, infer_signature
  32  from mlflow.models.model import METADATA_FILES
  33  from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
  34  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
  35  from mlflow.transformers import (
  36      _CARD_DATA_FILE_NAME,
  37      _CARD_TEXT_FILE_NAME,
  38      _build_pipeline_from_model_input,
  39      _fetch_model_card,
  40      _get_task_for_model,
  41      _is_model_distributed_in_memory,
  42      _should_add_pyfunc_to_model,
  43      _TransformersWrapper,
  44      _try_import_conversational_pipeline,
  45      _validate_llm_inference_task_type,
  46      _write_card_data,
  47      _write_license_information,
  48      get_default_conda_env,
  49      get_default_pip_requirements,
  50  )
  51  from mlflow.types.schema import Array, ColSpec, DataType, ParamSchema, ParamSpec, Schema
  52  from mlflow.utils.environment import _mlflow_conda_env
  53  
  54  from tests.helper_functions import (
  55      _assert_pip_requirements,
  56      _compare_conda_env_requirements,
  57      _compare_logged_code_paths,
  58      _get_deps_from_requirement_file,
  59      _mlflow_major_version_string,
  60      assert_register_model_called_with_local_model_path,
  61      flaky,
  62      pyfunc_scoring_endpoint,
  63      pyfunc_serve_and_score_model,
  64  )
  65  from tests.transformers.helper import (
  66      CHAT_TEMPLATE,
  67      IS_NEW_FEATURE_EXTRACTION_API,
  68      IS_TRANSFORMERS_V5_OR_LATER,
  69  )
  70  from tests.transformers.test_transformers_peft_model import SKIP_IF_PEFT_NOT_AVAILABLE
  71  
  72  # NB: Some pipelines under test in this suite come very close or outright exceed the
  73  # default runner containers specs of 7GB RAM. Due to this inability to run the suite without
  74  # generating a SIGTERM Error (143), some tests are marked as local only.
  75  # See: https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted- \
  76  # runners#supported-runners-and-hardware-resources for instance specs.
  77  RUNNING_IN_GITHUB_ACTIONS = os.environ.get("GITHUB_ACTIONS") == "true"
  78  GITHUB_ACTIONS_SKIP_REASON = "Test consumes too much memory"
  79  
  80  skip_transformers_v5_or_later = pytest.mark.skipif(
  81      IS_TRANSFORMERS_V5_OR_LATER,
  82      reason="Incompatible API changes in transformers 5.x",
  83  )
  84  image_url = "https://raw.githubusercontent.com/mlflow/mlflow/master/tests/datasets/cat.png"
  85  image_file_path = pathlib.Path(pathlib.Path(__file__).parent.parent, "datasets", "cat.png")
  86  # Test that can only be run locally:
  87  # - Summarization pipeline tests
  88  # - TextClassifier pipeline tests
  89  # - Text2TextGeneration pipeline tests
  90  # - Conversational pipeline tests
  91  
  92  
  93  @pytest.fixture(autouse=True)
  94  def force_gc():
  95      # This reduces the memory pressure for the usage of the larger pipeline fixtures ~500MB - 1GB
  96      gc.disable()
  97      gc.collect()
  98      gc.set_threshold(0)
  99      gc.collect()
 100      gc.enable()
 101  
 102  
 103  @pytest.fixture
 104  def model_path(tmp_path):
 105      model_path = tmp_path.joinpath("model")
 106      yield model_path
 107  
 108      # Pytest keeps the temporary directory created by `tmp_path` fixture for 3 recent test sessions
 109      # by default. This is useful for debugging during local testing, but in CI it just wastes the
 110      # disk space.
 111      if os.environ.get("GITHUB_ACTIONS") == "true":
 112          shutil.rmtree(model_path, ignore_errors=True)
 113  
 114  
 115  @pytest.fixture
 116  def transformers_custom_env(tmp_path):
 117      conda_env = tmp_path.joinpath("conda_env.yml")
 118      _mlflow_conda_env(conda_env, additional_pip_deps=["transformers"])
 119      return conda_env
 120  
 121  
 122  @pytest.fixture
 123  def mock_pyfunc_wrapper():
 124      return mlflow.transformers._TransformersWrapper("mock")
 125  
 126  
 127  @pytest.fixture
 128  @flaky()
 129  def image_for_test():
 130      dataset = load_dataset("hf-internal-testing/dummy_image_text_data")
 131      return dataset["train"]["image"][3]
 132  
 133  
 134  @pytest.mark.parametrize(
 135      ("pipeline", "expected_requirements"),
 136      [
 137          ("small_qa_pipeline", {"transformers", "torch", "torchvision"}),
 138          pytest.param(
 139              "peft_pipeline",
 140              {"peft", "transformers", "torch", "torchvision"},
 141              marks=SKIP_IF_PEFT_NOT_AVAILABLE,
 142          ),
 143      ],
 144  )
 145  def test_default_requirements(pipeline, expected_requirements, request):
 146      if "torch" in expected_requirements and importlib.util.find_spec("accelerate"):
 147          expected_requirements.add("accelerate")
 148  
 149      model = request.getfixturevalue(pipeline).model
 150      pip_requirements = get_default_pip_requirements(model)
 151      conda_requirements = get_default_conda_env(model)["dependencies"][2]["pip"]
 152  
 153      def _strip_requirements(requirements):
 154          return {req.split("==")[0] for req in requirements}
 155  
 156      assert _strip_requirements(pip_requirements) == expected_requirements
 157      assert _strip_requirements(conda_requirements) == (expected_requirements | {"mlflow"})
 158  
 159  
 160  def test_inference_task_validation(small_qa_pipeline):
 161      with pytest.raises(
 162          MlflowException, match="The task provided is invalid. 'llm/v1/invalid' is not"
 163      ):
 164          _validate_llm_inference_task_type("llm/v1/invalid", "text-generation")
 165      with pytest.raises(
 166          MlflowException, match="The task provided is invalid. 'llm/v1/completions' is not"
 167      ):
 168          _validate_llm_inference_task_type("llm/v1/completions", small_qa_pipeline)
 169      _validate_llm_inference_task_type("llm/v1/completions", "text-generation")
 170  
 171  
 172  @pytest.mark.parametrize(
 173      ("model", "result"),
 174      [
 175          ("small_qa_pipeline", True),
 176          ("small_multi_modal_pipeline", False),
 177          ("small_vision_model", True),
 178      ],
 179  )
 180  def test_pipeline_eligibility_for_pyfunc_registration(model, result, request):
 181      pipeline = request.getfixturevalue(model)
 182      assert _should_add_pyfunc_to_model(pipeline) == result
 183  
 184  
 185  def test_component_multi_modal_model_ineligible_for_pyfunc(component_multi_modal):
 186      task = transformers.pipelines.get_task(component_multi_modal["model"].name_or_path)
 187      pipeline = _build_pipeline_from_model_input(component_multi_modal, task)
 188      assert not _should_add_pyfunc_to_model(pipeline)
 189  
 190  
 191  def test_pipeline_construction_from_base_nlp_model(small_qa_pipeline):
 192      generated = _build_pipeline_from_model_input(
 193          {"model": small_qa_pipeline.model, "tokenizer": small_qa_pipeline.tokenizer},
 194          "question-answering",
 195      )
 196      assert isinstance(generated, type(small_qa_pipeline))
 197      assert isinstance(generated.tokenizer, type(small_qa_pipeline.tokenizer))
 198  
 199  
 200  def test_pipeline_construction_from_base_vision_model(small_vision_model):
 201      model = {"model": small_vision_model.model, "tokenizer": small_vision_model.tokenizer}
 202      if IS_NEW_FEATURE_EXTRACTION_API:
 203          model.update({"image_processor": small_vision_model.image_processor})
 204      else:
 205          model.update({"feature_extractor": small_vision_model.feature_extractor})
 206      generated = _build_pipeline_from_model_input(model, task="image-classification")
 207      assert isinstance(generated, type(small_vision_model))
 208      assert isinstance(generated.tokenizer, type(small_vision_model.tokenizer))
 209      if IS_NEW_FEATURE_EXTRACTION_API:
 210          assert isinstance(generated.image_processor, type(small_vision_model.image_processor))
 211      else:
 212          assert isinstance(generated.feature_extractor, transformers.MobileNetV2ImageProcessor)
 213  
 214  
 215  def test_saving_with_invalid_dict_as_model(model_path):
 216      with pytest.raises(
 217          MlflowException, match="Invalid dictionary submitted for 'transformers_model'. The "
 218      ):
 219          mlflow.transformers.save_model(transformers_model={"invalid": "key"}, path=model_path)
 220  
 221      with pytest.raises(
 222          MlflowException, match="The 'transformers_model' dictionary must have an entry"
 223      ):
 224          mlflow.transformers.save_model(
 225              transformers_model={"tokenizer": "some_tokenizer"}, path=model_path
 226          )
 227  
 228  
 229  def test_model_card_acquisition_vision_model(small_vision_model):
 230      model_provided_card = _fetch_model_card(small_vision_model.model.name_or_path)
 231      assert model_provided_card.data.to_dict()["tags"] == ["vision", "image-classification"]
 232      assert len(model_provided_card.text) > 0
 233  
 234  
 235  @pytest.mark.parametrize(
 236      ("repo_id", "license_file"),
 237      [
 238          ("google/mobilenet_v2_1.0_224", "LICENSE.txt"),  # no license declared
 239          ("csarron/mobilebert-uncased-squad-v2", "LICENSE.txt"),  # mit license
 240          ("codellama/CodeLlama-34b-hf", "LICENSE"),  # custom license
 241          ("openai/whisper-tiny", "LICENSE.txt"),  # apache license
 242          ("stabilityai/stable-code-3b", "LICENSE"),  # custom
 243          ("mistralai/Mixtral-8x7B-Instruct-v0.1", "LICENSE.txt"),  # apache
 244      ],
 245  )
 246  def test_license_acquisition(repo_id, license_file, tmp_path):
 247      card_data = _fetch_model_card(repo_id)
 248      _write_license_information(repo_id, card_data, tmp_path)
 249      license_file = list(tmp_path.glob("*LICENSE*"))
 250      assert len(license_file) == 1
 251      assert tmp_path.joinpath(license_file[0]).stat().st_size > 0
 252  
 253  
 254  def test_license_fallback(tmp_path):
 255      _write_license_information("not a real repo", None, tmp_path)
 256      assert tmp_path.joinpath("LICENSE.txt").stat().st_size > 0
 257  
 258  
 259  def test_vision_model_save_pipeline_with_defaults(small_vision_model, model_path):
 260      mlflow.transformers.save_model(transformers_model=small_vision_model, path=model_path)
 261      # validate inferred pip requirements
 262      requirements = model_path.joinpath("requirements.txt").read_text()
 263      reqs = {req.split("==")[0] for req in requirements.split("\n")}
 264      expected_requirements = {"torch", "torchvision", "transformers"}
 265      assert reqs.intersection(expected_requirements) == expected_requirements
 266      # validate inferred model card data
 267      card_data = yaml.safe_load(model_path.joinpath("model_card_data.yaml").read_bytes())
 268      assert card_data["tags"] == ["vision", "image-classification"]
 269      # verify the license file has been written
 270      license_file = model_path.joinpath("LICENSE.txt").read_text()
 271      assert len(license_file) > 0
 272      # Validate inferred model card text
 273      with model_path.joinpath("model_card.md").open() as file:
 274          card_text = file.read()
 275      assert len(card_text) > 0
 276      # Validate conda.yaml
 277      conda_env = yaml.safe_load(model_path.joinpath("conda.yaml").read_bytes())
 278      assert {req.split("==")[0] for req in conda_env["dependencies"][2]["pip"]}.intersection(
 279          expected_requirements
 280      ) == expected_requirements
 281      # Validate the MLModel file
 282      mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes())
 283      flavor_config = mlmodel["flavors"]["transformers"]
 284      assert flavor_config["instance_type"] == "ImageClassificationPipeline"
 285      assert flavor_config["pipeline_model_type"] == "MobileNetV2ForImageClassification"
 286      assert flavor_config["task"] == "image-classification"
 287      assert flavor_config["source_model_name"] == "google/mobilenet_v2_1.0_224"
 288  
 289  
 290  def test_vision_model_save_model_for_task_and_card_inference(small_vision_model, model_path):
 291      mlflow.transformers.save_model(transformers_model=small_vision_model, path=model_path)
 292      # validate inferred pip requirements
 293      requirements = model_path.joinpath("requirements.txt").read_text()
 294      reqs = {req.split("==")[0] for req in requirements.split("\n")}
 295      expected_requirements = {"torch", "torchvision", "transformers"}
 296      assert reqs.intersection(expected_requirements) == expected_requirements
 297      # validate inferred model card data
 298      card_data = yaml.safe_load(model_path.joinpath("model_card_data.yaml").read_bytes())
 299      assert card_data["tags"] == ["vision", "image-classification"]
 300      # Validate inferred model card text
 301      card_text = model_path.joinpath("model_card.md").read_text(encoding="utf-8")
 302      assert len(card_text) > 0
 303      # verify the license file has been written
 304      license_file = model_path.joinpath("LICENSE.txt").read_text()
 305      assert len(license_file) > 0
 306      # Validate the MLModel file
 307      mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes())
 308      flavor_config = mlmodel["flavors"]["transformers"]
 309      assert flavor_config["instance_type"] == "ImageClassificationPipeline"
 310      assert flavor_config["pipeline_model_type"] == "MobileNetV2ForImageClassification"
 311      assert flavor_config["task"] == "image-classification"
 312      assert flavor_config["source_model_name"] == "google/mobilenet_v2_1.0_224"
 313  
 314  
 315  def test_qa_model_save_model_for_task_and_card_inference(small_qa_pipeline, model_path):
 316      mlflow.transformers.save_model(
 317          transformers_model={
 318              "model": small_qa_pipeline.model,
 319              "tokenizer": small_qa_pipeline.tokenizer,
 320          },
 321          path=model_path,
 322      )
 323      # validate inferred pip requirements
 324      with model_path.joinpath("requirements.txt").open() as file:
 325          requirements = file.read()
 326      reqs = {req.split("==")[0] for req in requirements.split("\n")}
 327      expected_requirements = {"torch", "transformers"}
 328      assert reqs.intersection(expected_requirements) == expected_requirements
 329      # validate that the card was acquired by model reference
 330      card_data = yaml.safe_load(model_path.joinpath("model_card_data.yaml").read_bytes())
 331      assert card_data["datasets"] == ["squad_v2"]
 332      assert "tags" in card_data
 333      # verify the license file has been written
 334      license_file = model_path.joinpath("LICENSE.txt").read_text()
 335      assert len(license_file) > 0
 336      # Validate inferred model card text
 337      with model_path.joinpath("model_card.md").open() as file:
 338          card_text = file.read()
 339      assert len(card_text) > 0
 340      # validate MLmodel files
 341      mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes())
 342      flavor_config = mlmodel["flavors"]["transformers"]
 343      assert flavor_config["instance_type"] == "QuestionAnsweringPipeline"
 344      assert flavor_config["pipeline_model_type"] == "MobileBertForQuestionAnswering"
 345      assert flavor_config["task"] == "question-answering"
 346      assert flavor_config["source_model_name"] == "csarron/mobilebert-uncased-squad-v2"
 347  
 348  
 349  def test_qa_model_save_and_override_card(small_qa_pipeline, model_path):
 350      supplied_card = """
 351                      ---
 352                      language: en
 353                      license: bsd
 354                      ---
 355  
 356                      # I made a new model!
 357                      """
 358      card_info = textwrap.dedent(supplied_card)
 359      card = ModelCard(card_info)
 360      # save the model instance
 361      mlflow.transformers.save_model(
 362          transformers_model=small_qa_pipeline,
 363          path=model_path,
 364          model_card=card,
 365      )
 366      # validate that the card was acquired by model reference
 367      card_data = yaml.safe_load(model_path.joinpath("model_card_data.yaml").read_bytes())
 368      assert card_data["language"] == "en"
 369      assert card_data["license"] == "bsd"
 370      # Validate inferred model card text
 371      with model_path.joinpath("model_card.md").open() as file:
 372          card_text = file.read()
 373      # verify the license file has been written
 374      license_file = model_path.joinpath("LICENSE.txt").read_text()
 375      assert len(license_file) > 0
 376      assert card_text.startswith("\n# I made a new model!")
 377      # validate MLmodel files
 378      mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes())
 379      flavor_config = mlmodel["flavors"]["transformers"]
 380      assert flavor_config["instance_type"] == "QuestionAnsweringPipeline"
 381      assert flavor_config["pipeline_model_type"] == "MobileBertForQuestionAnswering"
 382      assert flavor_config["task"] == "question-answering"
 383      assert flavor_config["source_model_name"] == "csarron/mobilebert-uncased-squad-v2"
 384  
 385  
 386  def test_basic_save_model_and_load_text_pipeline(text_classification_pipeline, model_path):
 387      mlflow.transformers.save_model(
 388          transformers_model={
 389              "model": text_classification_pipeline.model,
 390              "tokenizer": text_classification_pipeline.tokenizer,
 391          },
 392          path=model_path,
 393      )
 394      loaded = mlflow.transformers.load_model(model_path)
 395      result = loaded("MLflow is a really neat tool!")
 396      assert result[0]["label"] == "POSITIVE"
 397      assert result[0]["score"] > 0.5
 398  
 399  
 400  @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float64])
 401  def test_basic_save_model_with_torch_dtype(text2text_generation_pipeline, model_path, dtype):
 402      mlflow.transformers.save_model(
 403          transformers_model=text2text_generation_pipeline,
 404          path=model_path,
 405          torch_dtype=dtype,
 406      )
 407  
 408      loaded = mlflow.transformers.load_model(model_path)
 409      assert loaded.model.dtype == dtype
 410  
 411      loaded = mlflow.transformers.load_model(model_path, torch_dtype=torch.float32)
 412      assert loaded.model.dtype == torch.float32
 413  
 414  
 415  def test_basic_save_model_and_load_vision_pipeline(small_vision_model, model_path, image_for_test):
 416      if IS_NEW_FEATURE_EXTRACTION_API:
 417          model = {
 418              "model": small_vision_model.model,
 419              "image_processor": small_vision_model.image_processor,
 420              "tokenizer": small_vision_model.tokenizer,
 421          }
 422      else:
 423          model = {
 424              "model": small_vision_model.model,
 425              "feature_extractor": small_vision_model.feature_extractor,
 426              "tokenizer": small_vision_model.tokenizer,
 427          }
 428      mlflow.transformers.save_model(
 429          transformers_model=model,
 430          path=model_path,
 431      )
 432      loaded = mlflow.transformers.load_model(model_path)
 433      prediction = loaded(image_for_test)
 434      assert prediction[0]["label"] == "wall clock"
 435      assert prediction[0]["score"] > 0.5
 436  
 437  
 438  @flaky()
 439  def test_multi_modal_pipeline_save_and_load(small_multi_modal_pipeline, model_path, image_for_test):
 440      mlflow.transformers.save_model(transformers_model=small_multi_modal_pipeline, path=model_path)
 441      question = "How many wall clocks are in the picture?"
 442      # Load components
 443      components = mlflow.transformers.load_model(model_path, return_type="components")
 444      if IS_NEW_FEATURE_EXTRACTION_API:
 445          expected_components = {"model", "task", "tokenizer", "image_processor"}
 446      else:
 447          expected_components = {"model", "task", "tokenizer", "feature_extractor"}
 448      assert set(components.keys()).intersection(expected_components) == expected_components
 449      constructed_pipeline = transformers.pipeline(**components)
 450      answer = constructed_pipeline(image=image_for_test, question=question)
 451      assert answer[0]["answer"] == "1"
 452      # Load pipeline
 453      pipeline = mlflow.transformers.load_model(model_path)
 454      pipeline_answer = pipeline(image=image_for_test, question=question)
 455      assert pipeline_answer[0]["answer"] == "1"
 456      # Test invalid loading mode
 457      with pytest.raises(MlflowException, match="The specified return_type mode 'magic' is"):
 458          mlflow.transformers.load_model(model_path, return_type="magic")
 459  
 460  
 461  def test_multi_modal_component_save_and_load(component_multi_modal, model_path, image_for_test):
 462      if IS_NEW_FEATURE_EXTRACTION_API:
 463          processor = component_multi_modal["image_processor"]
 464      else:
 465          processor = component_multi_modal["feature_extractor"]
 466      mlflow.transformers.save_model(
 467          transformers_model=component_multi_modal,
 468          path=model_path,
 469          processor=processor,
 470      )
 471      # Ensure that the appropriate Processor object was detected and loaded with the pipeline.
 472      loaded_components = mlflow.transformers.load_model(
 473          model_uri=model_path, return_type="components"
 474      )
 475      assert isinstance(loaded_components["model"], transformers.ViltForQuestionAnswering)
 476      assert isinstance(loaded_components["tokenizer"], transformers.BertTokenizerFast)
 477      # This is to simulate a post-processing processor that would be used externally to a Pipeline
 478      # This isn't being tested on an actual use case of such a model type due to the size of
 479      # these types of models that have this interface being ill-suited for CI testing.
 480  
 481      if IS_NEW_FEATURE_EXTRACTION_API:
 482          processor_key = "image_processor"
 483          assert isinstance(loaded_components[processor_key], transformers.ViltImageProcessor)
 484      else:
 485          processor_key = "feature_extractor"
 486          assert isinstance(loaded_components[processor_key], transformers.ViltProcessor)
 487          assert isinstance(loaded_components["processor"], transformers.ViltProcessor)
 488      if not IS_NEW_FEATURE_EXTRACTION_API:
 489          # NB: This simulated behavior is no longer valid in versions 4.27.4 and above.
 490          # With the port of functionality away from feature extractor types, the new architecture
 491          # for multi-modal models is entirely pipeline based.
 492          # Make sure that the component usage works correctly when extracted from inference loading
 493          model = loaded_components["model"]
 494          processor = loaded_components["processor"]
 495          question = "What are the cats doing?"
 496          inputs = processor(image_for_test, question, return_tensors="pt")
 497          outputs = model(**inputs)
 498          logits = outputs.logits
 499          idx = logits.argmax(-1).item()
 500          answer = model.config.id2label[idx]
 501          assert answer == "sleeping"
 502  
 503  
 504  @flaky()
 505  def test_pipeline_saved_model_with_processor_cannot_be_loaded_as_pipeline(
 506      component_multi_modal, model_path
 507  ):
 508      invalid_pipeline = transformers.pipeline(
 509          task="visual-question-answering", **component_multi_modal
 510      )
 511      if IS_NEW_FEATURE_EXTRACTION_API:
 512          processor = component_multi_modal["image_processor"]
 513      else:
 514          processor = component_multi_modal["feature_extractor"]
 515      mlflow.transformers.save_model(
 516          transformers_model=invalid_pipeline,
 517          path=model_path,
 518          processor=processor,  # If this is specified, we cannot guarantee correct inference
 519      )
 520      with pytest.raises(
 521          MlflowException, match="This model has been saved with a processor. Processor objects"
 522      ):
 523          mlflow.transformers.load_model(model_uri=model_path, return_type="pipeline")
 524  
 525  
 526  def test_component_saved_model_with_processor_cannot_be_loaded_as_pipeline(
 527      component_multi_modal, model_path
 528  ):
 529      if IS_NEW_FEATURE_EXTRACTION_API:
 530          processor = component_multi_modal["image_processor"]
 531      else:
 532          processor = component_multi_modal["feature_extractor"]
 533      mlflow.transformers.save_model(
 534          transformers_model=component_multi_modal,
 535          path=model_path,
 536          processor=processor,
 537      )
 538      with pytest.raises(
 539          MlflowException,
 540          match="This model has been saved with a processor. Processor objects are not compatible "
 541          "with Pipelines. Please load",
 542      ):
 543          mlflow.transformers.load_model(model_uri=model_path, return_type="pipeline")
 544  
 545  
 546  @pytest.mark.parametrize("should_start_run", [True, False])
 547  def test_log_and_load_transformers_pipeline(small_qa_pipeline, tmp_path, should_start_run):
 548      try:
 549          if should_start_run:
 550              mlflow.start_run()
 551          artifact_path = "transformers"
 552          conda_env = tmp_path.joinpath("conda_env.yaml")
 553          _mlflow_conda_env(conda_env, additional_pip_deps=["transformers"])
 554          model_info = mlflow.transformers.log_model(
 555              small_qa_pipeline,
 556              name=artifact_path,
 557              conda_env=str(conda_env),
 558          )
 559          reloaded_model = mlflow.transformers.load_model(
 560              model_uri=model_info.model_uri, return_type="pipeline"
 561          )
 562          assert (
 563              reloaded_model(
 564                  question="Who's house?", context="The house is owned by a man named Run."
 565              )["answer"]
 566              == "Run"
 567          )
 568          model_path = pathlib.Path(_download_artifact_from_uri(artifact_uri=model_info.model_uri))
 569          model_config = Model.load(str(model_path.joinpath("MLmodel")))
 570          assert pyfunc.FLAVOR_NAME in model_config.flavors
 571          assert pyfunc.ENV in model_config.flavors[pyfunc.FLAVOR_NAME]
 572          env_path = model_config.flavors[pyfunc.FLAVOR_NAME][pyfunc.ENV]["conda"]
 573          assert model_path.joinpath(env_path).exists()
 574      finally:
 575          mlflow.end_run()
 576  
 577  
 578  def test_load_pipeline_from_remote_uri_succeeds(
 579      text_classification_pipeline, model_path, mock_s3_bucket
 580  ):
 581      mlflow.transformers.save_model(transformers_model=text_classification_pipeline, path=model_path)
 582      artifact_root = f"s3://{mock_s3_bucket}"
 583      artifact_path = "model"
 584      artifact_repo = S3ArtifactRepository(artifact_root)
 585      artifact_repo.log_artifacts(model_path, artifact_path=artifact_path)
 586      model_uri = os.path.join(artifact_root, artifact_path)
 587      loaded = mlflow.transformers.load_model(model_uri=str(model_uri), return_type="pipeline")
 588      assert loaded("I like it when CI checks pass and are never flaky!")[0]["label"] == "POSITIVE"
 589  
 590  
 591  def test_transformers_log_model_calls_register_model(small_qa_pipeline, tmp_path):
 592      artifact_path = "transformers"
 593      register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model")
 594      with mlflow.start_run(), register_model_patch:
 595          conda_env = tmp_path.joinpath("conda_env.yaml")
 596          _mlflow_conda_env(conda_env, additional_pip_deps=["transformers", "torch", "torchvision"])
 597          model_info = mlflow.transformers.log_model(
 598              small_qa_pipeline,
 599              name=artifact_path,
 600              conda_env=str(conda_env),
 601              registered_model_name="Question-Answering Model 1",
 602          )
 603          assert_register_model_called_with_local_model_path(
 604              register_model_mock=mlflow.tracking._model_registry.fluent._register_model,
 605              model_uri=model_info.model_uri,
 606              registered_model_name="Question-Answering Model 1",
 607          )
 608  
 609  
 610  def test_transformers_log_model_with_no_registered_model_name(small_vision_model, tmp_path):
 611      if IS_NEW_FEATURE_EXTRACTION_API:
 612          model = {
 613              "model": small_vision_model.model,
 614              "image_processor": small_vision_model.image_processor,
 615              "tokenizer": small_vision_model.tokenizer,
 616          }
 617      else:
 618          model = {
 619              "model": small_vision_model.model,
 620              "feature_extractor": small_vision_model.feature_extractor,
 621              "tokenizer": small_vision_model.tokenizer,
 622          }
 623  
 624      artifact_path = "transformers"
 625      registered_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model")
 626      with mlflow.start_run(), registered_model_patch:
 627          conda_env = tmp_path.joinpath("conda_env.yaml")
 628          _mlflow_conda_env(conda_env, additional_pip_deps=["tensorflow", "transformers"])
 629          mlflow.transformers.log_model(
 630              model,
 631              name=artifact_path,
 632              conda_env=str(conda_env),
 633          )
 634          mlflow.tracking._model_registry.fluent._register_model.assert_not_called()
 635  
 636  
 637  def test_transformers_log_model_with_prompt_template_sets_return_full_text_false(
 638      text_generation_pipeline,
 639  ):
 640      artifact_path = "text_generation_with_prompt_template"
 641      prompt_template = "User: {prompt}"
 642  
 643      with mlflow.start_run():
 644          model_info = mlflow.transformers.log_model(
 645              text_generation_pipeline,
 646              name=artifact_path,
 647              prompt_template=prompt_template,
 648          )
 649  
 650      model_path = pathlib.Path(_download_artifact_from_uri(model_info.model_uri))
 651      mlmodel = Model.load(str(model_path.joinpath("MLmodel")))
 652  
 653      pyfunc_flavor = mlmodel.flavors["python_function"]
 654      config = pyfunc_flavor.get("config")
 655  
 656      assert config.get("return_full_text") is False
 657  
 658  
 659  def test_transformers_save_persists_requirements_in_mlflow_directory(
 660      small_qa_pipeline, model_path, transformers_custom_env
 661  ):
 662      mlflow.transformers.save_model(
 663          transformers_model=small_qa_pipeline,
 664          path=model_path,
 665          conda_env=str(transformers_custom_env),
 666      )
 667      saved_pip_req_path = model_path.joinpath("requirements.txt")
 668      _compare_conda_env_requirements(transformers_custom_env, saved_pip_req_path)
 669  
 670  
 671  def test_transformers_log_with_pip_requirements(small_multi_modal_pipeline, tmp_path):
 672      expected_mlflow_version = _mlflow_major_version_string()
 673  
 674      requirements_file = tmp_path.joinpath("requirements.txt")
 675      requirements_file.write_text("coolpackage")
 676      with mlflow.start_run():
 677          model_info = mlflow.transformers.log_model(
 678              small_multi_modal_pipeline, name="model", pip_requirements=str(requirements_file)
 679          )
 680          _assert_pip_requirements(
 681              model_info.model_uri, [expected_mlflow_version, "coolpackage"], strict=True
 682          )
 683      with mlflow.start_run():
 684          model_info = mlflow.transformers.log_model(
 685              small_multi_modal_pipeline,
 686              name="model",
 687              pip_requirements=[f"-r {requirements_file}", "alsocool"],
 688          )
 689          _assert_pip_requirements(
 690              model_info.model_uri,
 691              [expected_mlflow_version, "coolpackage", "alsocool"],
 692              strict=True,
 693          )
 694      with mlflow.start_run():
 695          model_info = mlflow.transformers.log_model(
 696              small_multi_modal_pipeline,
 697              name="model",
 698              pip_requirements=[f"-c {requirements_file}", "constrainedcool"],
 699          )
 700          _assert_pip_requirements(
 701              model_info.model_uri,
 702              [expected_mlflow_version, "constrainedcool", "-c constraints.txt"],
 703              ["coolpackage"],
 704              strict=True,
 705          )
 706  
 707  
 708  def test_transformers_log_with_extra_pip_requirements(small_multi_modal_pipeline, tmp_path):
 709      expected_mlflow_version = _mlflow_major_version_string()
 710      default_requirements = mlflow.transformers.get_default_pip_requirements(
 711          small_multi_modal_pipeline.model
 712      )
 713      requirements_file = tmp_path.joinpath("requirements.txt")
 714      requirements_file.write_text("coolpackage")
 715      with mlflow.start_run():
 716          model_info = mlflow.transformers.log_model(
 717              small_multi_modal_pipeline, name="model", extra_pip_requirements=str(requirements_file)
 718          )
 719          _assert_pip_requirements(
 720              model_info.model_uri,
 721              [expected_mlflow_version, *default_requirements, "coolpackage"],
 722              strict=True,
 723          )
 724      with mlflow.start_run():
 725          model_info = mlflow.transformers.log_model(
 726              small_multi_modal_pipeline,
 727              name="model",
 728              extra_pip_requirements=[f"-r {requirements_file}", "alsocool"],
 729          )
 730          _assert_pip_requirements(
 731              model_info.model_uri,
 732              [expected_mlflow_version, *default_requirements, "coolpackage", "alsocool"],
 733              strict=True,
 734          )
 735      with mlflow.start_run():
 736          model_info = mlflow.transformers.log_model(
 737              small_multi_modal_pipeline,
 738              name="model",
 739              extra_pip_requirements=[f"-c {requirements_file}", "constrainedcool"],
 740          )
 741          _assert_pip_requirements(
 742              model_info.model_uri,
 743              [
 744                  expected_mlflow_version,
 745                  *default_requirements,
 746                  "constrainedcool",
 747                  "-c constraints.txt",
 748              ],
 749              ["coolpackage"],
 750              strict=True,
 751          )
 752  
 753  
 754  def test_transformers_log_with_duplicate_extra_pip_requirements(small_multi_modal_pipeline):
 755      with pytest.raises(
 756          MlflowException, match="The specified requirements versions are incompatible"
 757      ):
 758          with mlflow.start_run():
 759              mlflow.transformers.log_model(
 760                  small_multi_modal_pipeline,
 761                  name="model",
 762                  extra_pip_requirements=["transformers==1.1.0"],
 763              )
 764  
 765  
 766  def test_transformers_pt_model_save_without_conda_env_uses_default_env_with_expected_dependencies(
 767      small_qa_pipeline, model_path
 768  ):
 769      mlflow.transformers.save_model(small_qa_pipeline, model_path)
 770      _assert_pip_requirements(
 771          model_path, mlflow.transformers.get_default_pip_requirements(small_qa_pipeline.model)
 772      )
 773      pip_requirements = _get_deps_from_requirement_file(model_path)
 774      assert "tensorflow" not in pip_requirements
 775      assert "accelerate" in pip_requirements
 776      assert "torch" in pip_requirements
 777  
 778  
 779  @pytest.mark.skipif(
 780      importlib.util.find_spec("accelerate") is not None, reason="fails when accelerate is installed"
 781  )
 782  def test_transformers_pt_model_save_dependencies_without_accelerate(
 783      text_generation_pipeline, model_path
 784  ):
 785      mlflow.transformers.save_model(text_generation_pipeline, model_path)
 786      _assert_pip_requirements(
 787          model_path, mlflow.transformers.get_default_pip_requirements(text_generation_pipeline.model)
 788      )
 789      pip_requirements = _get_deps_from_requirement_file(model_path)
 790      assert "tensorflow" not in pip_requirements
 791      assert "accelerate" not in pip_requirements
 792      assert "torch" in pip_requirements
 793  
 794  
 795  def test_transformers_pt_model_log_without_conda_env_uses_default_env_with_expected_dependencies(
 796      small_qa_pipeline,
 797  ):
 798      artifact_path = "model"
 799      with mlflow.start_run():
 800          model_info = mlflow.transformers.log_model(small_qa_pipeline, name=artifact_path)
 801      _assert_pip_requirements(
 802          model_info.model_uri,
 803          mlflow.transformers.get_default_pip_requirements(small_qa_pipeline.model),
 804      )
 805      pip_requirements = _get_deps_from_requirement_file(model_info.model_uri)
 806      assert "tensorflow" not in pip_requirements
 807      assert "torch" in pip_requirements
 808  
 809  
 810  def test_log_model_with_code_paths(small_qa_pipeline):
 811      artifact_path = "model"
 812      with (
 813          mlflow.start_run(),
 814          mock.patch("mlflow.transformers._add_code_from_conf_to_system_path") as add_mock,
 815      ):
 816          model_info = mlflow.transformers.log_model(
 817              small_qa_pipeline, name=artifact_path, code_paths=[__file__]
 818          )
 819          model_uri = model_info.model_uri
 820          _compare_logged_code_paths(__file__, model_uri, mlflow.transformers.FLAVOR_NAME)
 821          mlflow.transformers.load_model(model_uri)
 822          add_mock.assert_called()
 823  
 824  
 825  def test_non_existent_model_card_entry(small_qa_pipeline, model_path):
 826      with mock.patch("mlflow.transformers._fetch_model_card", return_value=None):
 827          mlflow.transformers.save_model(transformers_model=small_qa_pipeline, path=model_path)
 828  
 829          contents = {item.name for item in model_path.iterdir()}
 830          assert not contents.intersection({"model_card.txt", "model_card_data.yaml"})
 831  
 832  
 833  def test_huggingface_hub_not_installed(small_qa_pipeline, model_path):
 834      with mock.patch.dict("sys.modules", {"huggingface_hub": None}):
 835          result = mlflow.transformers._fetch_model_card(small_qa_pipeline.model.name_or_path)
 836  
 837          assert result is None
 838  
 839          mlflow.transformers.save_model(transformers_model=small_qa_pipeline, path=model_path)
 840  
 841          contents = {item.name for item in model_path.iterdir()}
 842          assert not contents.intersection({"model_card.txt", "model_card_data.yaml"})
 843  
 844          license_data = model_path.joinpath("LICENSE.txt").read_text()
 845          assert license_data.rstrip().endswith("mobilebert-uncased-squad-v2")
 846  
 847  
 848  @pytest.mark.skipif(
 849      _try_import_conversational_pipeline() is None,
 850      reason="Conversation model is deprecated and removed.",
 851  )
 852  def test_save_pipeline_without_defined_components(small_conversational_model, model_path):
 853      # This pipeline type explicitly does not have a configuration for an image_processor
 854      with mlflow.start_run():
 855          mlflow.transformers.save_model(
 856              transformers_model=small_conversational_model, path=model_path
 857          )
 858      pipe = mlflow.transformers.load_model(model_path)
 859      convo = transformers.Conversation("How are you today?")
 860      convo = pipe(convo)
 861      assert convo.generated_responses[-1] == "good"
 862  
 863  
 864  @flaky()
 865  def test_invalid_model_type_without_registered_name_does_not_save(model_path):
 866      invalid_pipeline = transformers.pipeline(task="text-generation", model="gpt2")
 867      del invalid_pipeline.model.name_or_path
 868  
 869      with pytest.raises(MlflowException, match="The submitted model type"):
 870          mlflow.transformers.save_model(transformers_model=invalid_pipeline, path=model_path)
 871  
 872  
 873  def test_invalid_input_to_pyfunc_signature_output_wrapper_raises(component_multi_modal):
 874      with pytest.raises(MlflowException, match="The pipeline type submitted is not a valid"):
 875          mlflow.transformers.generate_signature_output(component_multi_modal["model"], "bogus")
 876  
 877  
 878  @pytest.mark.parametrize(
 879      "inference_payload",
 880      [
 881          ({"question": "Who's house?", "context": "The house is owned by a man named Run."}),
 882          ([
 883              {
 884                  "question": "What color is it?",
 885                  "context": "Some people said it was green but I know that it's definitely blue",
 886              },
 887              {
 888                  "question": "How do the wheels go?",
 889                  "context": "The wheels on the bus go round and round. Round and round.",
 890              },
 891          ]),
 892          ([
 893              {
 894                  "question": "What color is it?",
 895                  "context": "Some people said it was green but I know that it's pink.",
 896              },
 897              {
 898                  "context": "The people on the bus go up and down. Up and down.",
 899                  "question": "How do the people go?",
 900              },
 901          ]),
 902      ],
 903  )
 904  def test_qa_pipeline_pyfunc_load_and_infer(small_qa_pipeline, model_path, inference_payload):
 905      signature = infer_signature(
 906          inference_payload,
 907          mlflow.transformers.generate_signature_output(small_qa_pipeline, inference_payload),
 908      )
 909  
 910      mlflow.transformers.save_model(
 911          transformers_model=small_qa_pipeline,
 912          path=model_path,
 913          signature=signature,
 914      )
 915      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
 916  
 917      inference = pyfunc_loaded.predict(inference_payload)
 918  
 919      assert isinstance(inference, list)
 920      assert all(isinstance(element, str) for element in inference)
 921  
 922      pd_input = (
 923          pd.DataFrame([inference_payload])
 924          if isinstance(inference_payload, dict)
 925          else pd.DataFrame(inference_payload)
 926      )
 927      pd_inference = pyfunc_loaded.predict(pd_input)
 928  
 929      assert isinstance(pd_inference, list)
 930      assert all(isinstance(element, str) for element in inference)
 931  
 932  
 933  @pytest.mark.parametrize(
 934      "inference_payload",
 935      [
 936          image_url,
 937          str(image_file_path),
 938          "base64",
 939          pytest.param(
 940              "base64_encodebytes",
 941              marks=pytest.mark.skipif(
 942                  Version(transformers.__version__) < Version("4.41"),
 943                  reason="base64 encodebytes feature not present",
 944              ),
 945          ),
 946      ],
 947  )
 948  def test_vision_pipeline_pyfunc_load_and_infer(small_vision_model, model_path, inference_payload):
 949      if inference_payload == "base64":
 950          inference_payload = base64.b64encode(image_file_path.read_bytes()).decode("utf-8")
 951      elif inference_payload == "base64_encodebytes":
 952          inference_payload = base64.encodebytes(image_file_path.read_bytes()).decode("utf-8")
 953      signature = infer_signature(
 954          inference_payload,
 955          mlflow.transformers.generate_signature_output(small_vision_model, inference_payload),
 956      )
 957      mlflow.transformers.save_model(
 958          transformers_model=small_vision_model,
 959          path=model_path,
 960          signature=signature,
 961      )
 962      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
 963      predictions = pyfunc_loaded.predict(inference_payload)
 964  
 965      transformers_loaded_model = mlflow.transformers.load_model(model_path)
 966      expected_predictions = transformers_loaded_model.predict(inference_payload)
 967      assert list(predictions.to_dict("records")[0].values()) == expected_predictions
 968  
 969  
 970  @pytest.mark.parametrize(
 971      ("data", "result"),
 972      [
 973          ("muppet keyboard type", ["A man is typing a muppet on a keyboard."]),
 974          (
 975              ["pencil draw paper", "pie apple eat"],
 976              # NB: The result of this test case, without inference config overrides is:
 977              # ["A man drawing on paper with pencil", "A man eating a pie with applies"]
 978              # The inference config override forces additional insertion of more grammatically
 979              # correct responses to validate that the inference config is being applied.
 980              ["A man draws a pencil on a paper.", "A man eats a pie of apples."],
 981          ),
 982      ],
 983  )
 984  def test_text2text_generation_pipeline_with_model_configs(
 985      text2text_generation_pipeline, tmp_path, data, result
 986  ):
 987      signature = infer_signature(
 988          data, mlflow.transformers.generate_signature_output(text2text_generation_pipeline, data)
 989      )
 990  
 991      model_config = {
 992          "top_k": 2,
 993          "num_beams": 5,
 994          "max_length": 30,
 995          "temperature": 0.62,
 996          "top_p": 0.85,
 997          "repetition_penalty": 1.15,
 998      }
 999      model_path1 = tmp_path.joinpath("model1")
1000      mlflow.transformers.save_model(
1001          text2text_generation_pipeline,
1002          path=model_path1,
1003          model_config=model_config,
1004          signature=signature,
1005      )
1006      pyfunc_loaded = mlflow.pyfunc.load_model(model_path1)
1007  
1008      inference = pyfunc_loaded.predict(data)
1009  
1010      assert inference == result
1011  
1012      pd_input = pd.DataFrame([data]) if isinstance(data, str) else pd.DataFrame(data)
1013      pd_inference = pyfunc_loaded.predict(pd_input)
1014      assert pd_inference == result
1015  
1016      model_path2 = tmp_path.joinpath("model2")
1017      signature_with_params = infer_signature(
1018          data,
1019          mlflow.transformers.generate_signature_output(text2text_generation_pipeline, data),
1020          model_config,
1021      )
1022      mlflow.transformers.save_model(
1023          text2text_generation_pipeline,
1024          path=model_path2,
1025          signature=signature_with_params,
1026      )
1027      pyfunc_loaded = mlflow.pyfunc.load_model(model_path2)
1028  
1029      dict_inference = pyfunc_loaded.predict(
1030          data,
1031          params=model_config,
1032      )
1033  
1034      assert dict_inference == inference
1035  
1036  
1037  def test_text2text_generation_pipeline_with_model_config_and_params(
1038      text2text_generation_pipeline, model_path
1039  ):
1040      data = "muppet keyboard type"
1041      model_config = {
1042          "top_k": 2,
1043          "num_beams": 5,
1044          "top_p": 0.85,
1045          "repetition_penalty": 1.15,
1046          "do_sample": True,
1047      }
1048      parameters = {"top_k": 3, "max_new_tokens": 30}
1049      generated_output = mlflow.transformers.generate_signature_output(
1050          text2text_generation_pipeline, data
1051      )
1052      signature = infer_signature(
1053          data,
1054          generated_output,
1055          parameters,
1056      )
1057  
1058      mlflow.transformers.save_model(
1059          text2text_generation_pipeline,
1060          path=model_path,
1061          model_config=model_config,
1062          signature=signature,
1063      )
1064      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
1065  
1066      # model_config and default params are all applied
1067      res = pyfunc_loaded.predict(data)
1068      applied_params = model_config.copy()
1069      applied_params.update(parameters)
1070      res2 = pyfunc_loaded.predict(data, applied_params)
1071      assert res == res2
1072  
1073      assert res != pyfunc_loaded.predict(data, {"max_new_tokens": 3})
1074  
1075      # Extra params are ignored
1076      assert res == pyfunc_loaded.predict(data, {"extra_param": "extra_value"})
1077  
1078  
1079  def test_text2text_generation_pipeline_with_params_success(
1080      text2text_generation_pipeline, model_path
1081  ):
1082      data = "muppet keyboard type"
1083      parameters = {"top_k": 2, "num_beams": 5, "do_sample": True}
1084      generated_output = mlflow.transformers.generate_signature_output(
1085          text2text_generation_pipeline, data
1086      )
1087      signature = infer_signature(
1088          data,
1089          generated_output,
1090          parameters,
1091      )
1092  
1093      mlflow.transformers.save_model(
1094          text2text_generation_pipeline,
1095          path=model_path,
1096          signature=signature,
1097      )
1098      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
1099  
1100      # parameters saved with ModelSignature is applied by default
1101      res = pyfunc_loaded.predict(data)
1102      res2 = pyfunc_loaded.predict(data, parameters)
1103      assert res == res2
1104  
1105  
1106  def test_text2text_generation_pipeline_with_params_with_errors(
1107      text2text_generation_pipeline, model_path
1108  ):
1109      data = "muppet keyboard type"
1110      parameters = {"top_k": 2, "num_beams": 5, "invalid_param": "invalid_param", "do_sample": True}
1111      generated_output = mlflow.transformers.generate_signature_output(
1112          text2text_generation_pipeline, data
1113      )
1114  
1115      mlflow.transformers.save_model(
1116          text2text_generation_pipeline,
1117          path=model_path,
1118          signature=infer_signature(
1119              data,
1120              generated_output,
1121              parameters,
1122          ),
1123      )
1124      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
1125      with pytest.raises(
1126          MlflowException,
1127          match=r"The params provided to the `predict` method are "
1128          r"not valid for pipeline Text2TextGenerationPipeline.",
1129      ):
1130          pyfunc_loaded.predict(data, parameters)
1131  
1132      # Type validation of params failure
1133      with pytest.raises(MlflowException, match=r"Invalid parameters found"):
1134          pyfunc_loaded.predict(data, {"top_k": "2"})
1135  
1136  
1137  def test_text2text_generation_pipeline_with_inferred_schema(text2text_generation_pipeline):
1138      with mlflow.start_run():
1139          model_info = mlflow.transformers.log_model(text2text_generation_pipeline, name="my_model")
1140      pyfunc_loaded = mlflow.pyfunc.load_model(model_info.model_uri)
1141  
1142      assert pyfunc_loaded.predict("muppet board nails hammer")[0].startswith("A hammer")
1143  
1144  
1145  @pytest.mark.parametrize(
1146      "invalid_data",
1147      [
1148          ({"answer": "something", "context": ["nothing", "that", "makes", "sense"]}),
1149          ([{"answer": ["42"], "context": "life"}, {"unmatched": "keys", "cause": "failure"}]),
1150      ],
1151  )
1152  def test_invalid_input_to_text2text_pipeline(text2text_generation_pipeline, invalid_data):
1153      # Adding this validation test due to the fact that we're constructing the input to the
1154      # Pipeline. The Pipeline requires a format of a pseudo-dict-like string. An example of
1155      # a valid input string: "answer: green. context: grass is primarily green in color."
1156      # We generate this string from a dict or generate a list of these strings from a list of
1157      # dictionaries.
1158      with pytest.raises(
1159          MlflowException, match=r"An invalid type has been supplied: .+\. Please supply"
1160      ):
1161          infer_signature(
1162              invalid_data,
1163              mlflow.transformers.generate_signature_output(
1164                  text2text_generation_pipeline, invalid_data
1165              ),
1166          )
1167  
1168  
1169  @pytest.mark.parametrize(
1170      "data", ["Generative models are", (["Generative models are", "Computers are"])]
1171  )
1172  def test_text_generation_pipeline(text_generation_pipeline, model_path, data):
1173      signature = infer_signature(
1174          data, mlflow.transformers.generate_signature_output(text_generation_pipeline, data)
1175      )
1176  
1177      model_config = {
1178          "prefix": "software",
1179          "top_k": 2,
1180          "num_beams": 5,
1181          "max_length": 30,
1182          "temperature": 0.62,
1183          "top_p": 0.85,
1184          "repetition_penalty": 1.15,
1185      }
1186      mlflow.transformers.save_model(
1187          text_generation_pipeline,
1188          path=model_path,
1189          model_config=model_config,
1190          signature=signature,
1191      )
1192      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
1193  
1194      inference = pyfunc_loaded.predict(data)
1195  
1196      if isinstance(data, list):
1197          assert inference[0].startswith(data[0])
1198          assert inference[1].startswith(data[1])
1199      else:
1200          assert inference[0].startswith(data)
1201  
1202      pd_input = pd.DataFrame([data], index=[0]) if isinstance(data, str) else pd.DataFrame(data)
1203      pd_inference = pyfunc_loaded.predict(pd_input)
1204  
1205      if isinstance(data, list):
1206          assert pd_inference[0].startswith(data[0])
1207          assert pd_inference[1].startswith(data[1])
1208      else:
1209          assert pd_inference[0].startswith(data)
1210  
1211  
1212  @pytest.mark.parametrize(
1213      "invalid_data",
1214      [
1215          ({"my_input": "something to predict"}),
1216          ([{"bogus_input": "invalid"}, "not_valid"]),
1217          (["tell me a story", {"of": "a properly configured pipeline input"}]),
1218      ],
1219  )
1220  def test_invalid_input_to_text_generation_pipeline(text_generation_pipeline, invalid_data):
1221      if isinstance(invalid_data, list):
1222          match = "If supplying a list, all values must be of string type"
1223      else:
1224          match = "The input data is of an incorrect type"
1225      with pytest.raises(MlflowException, match=match):
1226          infer_signature(
1227              invalid_data,
1228              mlflow.transformers.generate_signature_output(text_generation_pipeline, invalid_data),
1229          )
1230  
1231  
1232  @pytest.mark.parametrize(
1233      ("inference_payload", "result"),
1234      [
1235          ("Riding a <mask> on the beach is fun!", ["bike"]),
1236          (["If I had <mask>, I would fly to the top of a mountain"], ["wings"]),
1237          (
1238              ["I use stacks of <mask> to buy things", "I <mask> the whole bowl of cherries"],
1239              ["cash", "ate"],
1240          ),
1241      ],
1242  )
1243  def test_fill_mask_pipeline(fill_mask_pipeline, model_path, inference_payload, result):
1244      signature = infer_signature(
1245          inference_payload,
1246          mlflow.transformers.generate_signature_output(fill_mask_pipeline, inference_payload),
1247      )
1248  
1249      mlflow.transformers.save_model(fill_mask_pipeline, path=model_path, signature=signature)
1250      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
1251  
1252      inference = pyfunc_loaded.predict(inference_payload)
1253      assert inference == result
1254  
1255      if len(inference_payload) > 1 and isinstance(inference_payload, list):
1256          pd_input = pd.DataFrame([{"inputs": v} for v in inference_payload])
1257      elif isinstance(inference_payload, list) and len(inference_payload) == 1:
1258          pd_input = pd.DataFrame([{"inputs": v} for v in inference_payload], index=[0])
1259      else:
1260          pd_input = pd.DataFrame({"inputs": inference_payload}, index=[0])
1261  
1262      pd_inference = pyfunc_loaded.predict(pd_input)
1263      assert pd_inference == result
1264  
1265  
1266  def test_fill_mask_pipeline_with_multiple_masks(fill_mask_pipeline, model_path):
1267      data = ["I <mask> the whole <mask> of <mask>", "I <mask> the whole <mask> of <mask>"]
1268  
1269      mlflow.transformers.save_model(fill_mask_pipeline, path=model_path)
1270      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
1271  
1272      inference = pyfunc_loaded.predict(data)
1273      assert len(inference) == 2
1274      assert all(len(value) == 3 for value in inference)
1275  
1276  
1277  @pytest.mark.parametrize(
1278      "invalid_data",
1279      [
1280          ({"a": "b"}),
1281          ([{"a": "b"}, [{"a": "c"}]]),
1282      ],
1283  )
1284  def test_invalid_input_to_fill_mask_pipeline(fill_mask_pipeline, invalid_data):
1285      if isinstance(invalid_data, list):
1286          match = "Invalid data submission. Ensure all"
1287      else:
1288          match = "The input data is of an incorrect type"
1289      with pytest.raises(MlflowException, match=match):
1290          infer_signature(
1291              invalid_data,
1292              mlflow.transformers.generate_signature_output(fill_mask_pipeline, invalid_data),
1293          )
1294  
1295  
1296  @pytest.mark.parametrize(
1297      "data",
1298      [
1299          {
1300              "sequences": "I love the latest update to this IDE!",
1301              "candidate_labels": ["happy", "sad"],
1302          },
1303          {
1304              "sequences": ["My dog loves to eat spaghetti", "My dog hates going to the vet"],
1305              "candidate_labels": ["happy", "sad"],
1306              "hypothesis_template": "This example talks about how the dog is {}",
1307          },
1308      ],
1309  )
1310  def test_zero_shot_classification_pipeline(zero_shot_pipeline, model_path, data):
1311      # NB: The list submission for this pipeline type can accept json-encoded lists or lists within
1312      # the values of the dictionary.
1313      signature = infer_signature(
1314          data, mlflow.transformers.generate_signature_output(zero_shot_pipeline, data)
1315      )
1316  
1317      mlflow.transformers.save_model(zero_shot_pipeline, model_path, signature=signature)
1318  
1319      loaded_pyfunc = mlflow.pyfunc.load_model(model_path)
1320      inference = loaded_pyfunc.predict(data)
1321  
1322      assert isinstance(inference, pd.DataFrame)
1323      if isinstance(data["sequences"], str):
1324          assert len(inference) == len(data["candidate_labels"])
1325      else:
1326          assert len(inference) == len(data["sequences"]) * len(data["candidate_labels"])
1327  
1328  
1329  @pytest.mark.parametrize(
1330      "query",
1331      [
1332          "What should we order more of?",
1333          [
1334              "What is our highest sales?",
1335              "What should we order more of?",
1336          ],
1337      ],
1338  )
1339  def test_table_question_answering_pipeline(table_question_answering_pipeline, model_path, query):
1340      table = {
1341          "Fruit": ["Apples", "Bananas", "Oranges", "Watermelon", "Blueberries"],
1342          "Sales": ["1230945.55", "86453.12", "11459.23", "8341.23", "2325.88"],
1343          "Inventory": ["910", "4589", "11200", "80", "3459"],
1344      }
1345      json_table = json.dumps(table)
1346      data = {"query": query, "table": json_table}
1347      signature = infer_signature(
1348          data, mlflow.transformers.generate_signature_output(table_question_answering_pipeline, data)
1349      )
1350  
1351      mlflow.transformers.save_model(
1352          table_question_answering_pipeline, model_path, signature=signature
1353      )
1354      loaded = mlflow.pyfunc.load_model(model_path)
1355  
1356      inference = loaded.predict(data)
1357      assert len(inference) == 1 if isinstance(query, str) else len(query)
1358  
1359      pd_input = pd.DataFrame([data])
1360      pd_inference = loaded.predict(pd_input)
1361      assert pd_inference is not None
1362  
1363  
1364  def test_custom_code_pipeline(custom_code_pipeline, model_path):
1365      data = "hello"
1366  
1367      signature = infer_signature(
1368          data, mlflow.transformers.generate_signature_output(custom_code_pipeline, data)
1369      )
1370  
1371      mlflow.transformers.save_model(
1372          custom_code_pipeline,
1373          path=model_path,
1374          signature=signature,
1375      )
1376  
1377      # just test that it doesn't blow up when performing inference
1378      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
1379      pyfunc_pred = pyfunc_loaded.predict(data)
1380      assert isinstance(pyfunc_pred[0][0], float)
1381  
1382      transformers_loaded = mlflow.transformers.load_model(model_path)
1383      transformers_pred = transformers_loaded(data)
1384      assert pyfunc_pred[0][0] == transformers_pred[0][0][0]
1385  
1386  
1387  def test_custom_components_pipeline(custom_components_pipeline, model_path):
1388      data = "hello"
1389  
1390      signature = infer_signature(
1391          data, mlflow.transformers.generate_signature_output(custom_components_pipeline, data)
1392      )
1393  
1394      components = {
1395          "model": custom_components_pipeline.model,
1396          "tokenizer": custom_components_pipeline.tokenizer,
1397          "feature_extractor": custom_components_pipeline.feature_extractor,
1398      }
1399  
1400      mlflow.transformers.save_model(
1401          transformers_model=components, path=model_path, signature=signature
1402      )
1403  
1404      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
1405      pyfunc_pred = pyfunc_loaded.predict(data)
1406      assert isinstance(pyfunc_pred[0][0], float)
1407  
1408      transformers_loaded = mlflow.transformers.load_model(model_path)
1409      transformers_pred = transformers_loaded(data)
1410      assert pyfunc_pred[0][0] == transformers_pred[0][0][0]
1411  
1412      # assert that all the reloaded components exist
1413      # and have the same class name as pre-save
1414      for name, component in components.items():
1415          assert component.__class__.__name__ == getattr(transformers_loaded, name).__class__.__name__
1416  
1417  
1418  @pytest.mark.parametrize(
1419      ("data", "result"),
1420      [
1421          ("I've got a lovely bunch of coconuts!", ["Ich habe eine schöne Haufe von Kokos!"]),
1422          (
1423              [
1424                  "I am the very model of a modern major general",
1425                  "Once upon a time, there was a little turtle",
1426              ],
1427              [
1428                  "Ich bin das Modell eines modernen Generals.",
1429                  "Einmal gab es eine kleine Schildkröte.",
1430              ],
1431          ),
1432      ],
1433  )
1434  def test_translation_pipeline(translation_pipeline, model_path, data, result):
1435      signature = infer_signature(
1436          data, mlflow.transformers.generate_signature_output(translation_pipeline, data)
1437      )
1438  
1439      mlflow.transformers.save_model(translation_pipeline, path=model_path, signature=signature)
1440      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
1441      inference = pyfunc_loaded.predict(data)
1442      assert inference == result
1443  
1444      if len(data) > 1 and isinstance(data, list):
1445          pd_input = pd.DataFrame([{"inputs": v} for v in data])
1446      elif isinstance(data, list) and len(data) == 1:
1447          pd_input = pd.DataFrame([{"inputs": v} for v in data], index=[0])
1448      else:
1449          pd_input = pd.DataFrame({"inputs": data}, index=[0])
1450  
1451      pd_inference = pyfunc_loaded.predict(pd_input)
1452      assert pd_inference == result
1453  
1454  
1455  @pytest.mark.parametrize(
1456      "data",
1457      [
1458          "There once was a boy",
1459          ["Dolly isn't just a sheep anymore"],
1460          ["Baking cookies is quite easy", "Writing unittests is good for"],
1461      ],
1462  )
1463  def test_summarization_pipeline(summarizer_pipeline, model_path, data):
1464      model_config = {
1465          "top_k": 2,
1466          "num_beams": 5,
1467          "max_length": 90,
1468          "temperature": 0.62,
1469          "top_p": 0.85,
1470          "repetition_penalty": 1.15,
1471      }
1472      signature = infer_signature(
1473          data, mlflow.transformers.generate_signature_output(summarizer_pipeline, data)
1474      )
1475  
1476      mlflow.transformers.save_model(
1477          summarizer_pipeline, path=model_path, model_config=model_config, signature=signature
1478      )
1479      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
1480  
1481      inference = pyfunc_loaded.predict(data)
1482      if isinstance(data, list) and len(data) > 1:
1483          for i, entry in enumerate(data):
1484              assert inference[i].strip().startswith(entry)
1485      elif isinstance(data, list) and len(data) == 1:
1486          assert inference[0].strip().startswith(data[0])
1487      else:
1488          assert inference[0].strip().startswith(data)
1489  
1490      if len(data) > 1 and isinstance(data, list):
1491          pd_input = pd.DataFrame([{"inputs": v} for v in data])
1492      elif isinstance(data, list) and len(data) == 1:
1493          pd_input = pd.DataFrame([{"inputs": v} for v in data], index=[0])
1494      else:
1495          pd_input = pd.DataFrame({"inputs": data}, index=[0])
1496  
1497      pd_inference = pyfunc_loaded.predict(pd_input)
1498      if isinstance(data, list) and len(data) > 1:
1499          for i, entry in enumerate(data):
1500              assert pd_inference[i].strip().startswith(entry)
1501      elif isinstance(data, list) and len(data) == 1:
1502          assert pd_inference[0].strip().startswith(data[0])
1503      else:
1504          assert pd_inference[0].strip().startswith(data)
1505  
1506  
1507  @pytest.mark.parametrize(
1508      "data",
1509      [
1510          "I'm telling you that Han shot first!",
1511          [
1512              "I think this sushi might have gone off",
1513              "That gym smells like feet, hot garbage, and sadness",
1514              "I love that we have a moon",
1515          ],
1516          [{"text": "test1", "text_pair": "test2"}],
1517          [{"text": "test1", "text_pair": "pair1"}, {"text": "test2", "text_pair": "pair2"}],
1518      ],
1519  )
1520  def test_classifier_pipeline(text_classification_pipeline, model_path, data):
1521      signature = infer_signature(
1522          data, mlflow.transformers.generate_signature_output(text_classification_pipeline, data)
1523      )
1524      mlflow.transformers.save_model(
1525          text_classification_pipeline, path=model_path, signature=signature
1526      )
1527  
1528      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
1529      inference = pyfunc_loaded.predict(data)
1530  
1531      # verify that native transformers outputs match the pyfunc return values
1532      native_inference = text_classification_pipeline(data)
1533      inference_dict = inference.to_dict()
1534  
1535      if isinstance(data, str):
1536          assert len(inference) == 1
1537          assert inference_dict["label"][0] == native_inference[0]["label"]
1538          assert inference_dict["score"][0] == native_inference[0]["score"]
1539      else:
1540          assert len(inference) == len(data)
1541          for key in ["score", "label"]:
1542              for value in range(0, len(data)):
1543                  if key == "label":
1544                      assert inference_dict[key][value] == native_inference[value][key]
1545                  else:
1546                      assert math.isclose(
1547                          native_inference[value][key], inference_dict[key][value], rel_tol=1e-6
1548                      )
1549  
1550  
1551  @pytest.mark.parametrize(
1552      ("data", "result"),
1553      [
1554          (
1555              "I have a dog and his name is Willy!",
1556              ["PRON,VERB,DET,NOUN,CCONJ,PRON,NOUN,AUX,PROPN,PUNCT"],
1557          ),
1558          (["I like turtles"], ["PRON,VERB,NOUN"]),
1559          (
1560              ["We are the knights who say nee!", "Houston, we may have a problem."],
1561              [
1562                  "PRON,AUX,DET,PROPN,PRON,VERB,INTJ,PUNCT",
1563                  "PROPN,PUNCT,PRON,AUX,VERB,DET,NOUN,PUNCT",
1564              ],
1565          ),
1566      ],
1567  )
1568  @pytest.mark.parametrize("pipeline_name", ["ner_pipeline", "ner_pipeline_aggregation"])
1569  def test_ner_pipeline(pipeline_name, model_path, data, result, request):
1570      pipeline = request.getfixturevalue(pipeline_name)
1571  
1572      signature = infer_signature(data, mlflow.transformers.generate_signature_output(pipeline, data))
1573  
1574      mlflow.transformers.save_model(pipeline, model_path, signature=signature)
1575      loaded_pyfunc = mlflow.pyfunc.load_model(model_path)
1576      inference = loaded_pyfunc.predict(data)
1577  
1578      assert inference == result
1579  
1580      if len(data) > 1 and isinstance(data, list):
1581          pd_input = pd.DataFrame([{"inputs": v} for v in data])
1582      elif isinstance(data, list) and len(data) == 1:
1583          pd_input = pd.DataFrame([{"inputs": v} for v in data], index=[0])
1584      else:
1585          pd_input = pd.DataFrame({"inputs": data}, index=[0])
1586      pd_inference = loaded_pyfunc.predict(pd_input)
1587      assert pd_inference == result
1588  
1589  
1590  @pytest.mark.skipif(
1591      _try_import_conversational_pipeline() is None,
1592      reason="Conversation model is deprecated and removed.",
1593  )
1594  def test_conversational_pipeline(conversational_pipeline, model_path):
1595      assert mlflow.transformers._is_conversational_pipeline(conversational_pipeline)
1596  
1597      signature = infer_signature(
1598          "Hi there!",
1599          mlflow.transformers.generate_signature_output(conversational_pipeline, "Hi there!"),
1600      )
1601  
1602      mlflow.transformers.save_model(conversational_pipeline, model_path, signature=signature)
1603      loaded_pyfunc = mlflow.pyfunc.load_model(model_path)
1604  
1605      first_response = loaded_pyfunc.predict("What is the best way to get to Antarctica?")
1606  
1607      assert first_response == "The best way would be to go to space."
1608  
1609      second_response = loaded_pyfunc.predict("What kind of boat should I use?")
1610  
1611      assert second_response == "The best way to get to space would be to reach out and touch it."
1612  
1613      # Test that a new loaded instance has no context.
1614      loaded_again_pyfunc = mlflow.pyfunc.load_model(model_path)
1615      third_response = loaded_again_pyfunc.predict("What kind of boat should I use?")
1616  
1617      assert third_response == "The one with the guns."
1618  
1619      fourth_response = loaded_again_pyfunc.predict("Can I use it to go to the moon?")
1620  
1621      assert fourth_response == "Sure."
1622  
1623  
1624  def test_qa_pipeline_pyfunc_predict(small_qa_pipeline):
1625      artifact_path = "qa_model"
1626      with mlflow.start_run():
1627          model_info = mlflow.transformers.log_model(
1628              small_qa_pipeline,
1629              name=artifact_path,
1630          )
1631  
1632      inference_payload = json.dumps({
1633          "inputs": {
1634              "question": [
1635                  "What color is it?",
1636                  "How do the people go?",
1637                  "What does the 'wolf' howl at?",
1638              ],
1639              "context": [
1640                  "Some people said it was green but I know that it's pink.",
1641                  "The people on the bus go up and down. Up and down.",
1642                  "The pack of 'wolves' stood on the cliff and a 'lone wolf' howled at "
1643                  "the moon for hours.",
1644              ],
1645          }
1646      })
1647      response = pyfunc_serve_and_score_model(
1648          model_info.model_uri,
1649          data=inference_payload,
1650          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1651          extra_args=["--env-manager", "local"],
1652      )
1653      values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
1654  
1655      assert values.to_dict(orient="records") == [{0: "pink"}, {0: "up and down"}, {0: "the moon"}]
1656  
1657      inference_payload = json.dumps({
1658          "inputs": {
1659              "question": "Who's house?",
1660              "context": "The house is owned by a man named Run.",
1661          }
1662      })
1663  
1664      response = pyfunc_serve_and_score_model(
1665          model_info.model_uri,
1666          data=inference_payload,
1667          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1668          extra_args=["--env-manager", "local"],
1669      )
1670      values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
1671  
1672      assert values.to_dict(orient="records") == [{0: "Run"}]
1673  
1674  
1675  @pytest.mark.parametrize(
1676      ("input_image", "result"),
1677      [
1678          (str(image_file_path), False),
1679          (image_url, False),
1680          ("base64", True),
1681          ("random string", False),
1682      ],
1683  )
1684  def test_vision_is_base64_image(input_image, result):
1685      if input_image == "base64":
1686          input_image = base64.b64encode(image_file_path.read_bytes()).decode("utf-8")
1687      assert _TransformersWrapper.is_base64_image(input_image) == result
1688  
1689  
1690  @pytest.mark.parametrize(
1691      "inference_payload",
1692      [
1693          [str(image_file_path)],
1694          [image_url],
1695          "base64",
1696          pytest.param(
1697              "base64_encodebytes",
1698              marks=pytest.mark.skipif(
1699                  Version(transformers.__version__) < Version("4.41"),
1700                  reason="base64 encodebytes feature not present",
1701              ),
1702          ),
1703      ],
1704  )
1705  def test_vision_pipeline_pyfunc_predict(small_vision_model, inference_payload):
1706      if inference_payload == "base64":
1707          inference_payload = [
1708              base64.b64encode(image_file_path.read_bytes()).decode("utf-8"),
1709          ]
1710      elif inference_payload == "base64_encodebytes":
1711          inference_payload = [
1712              base64.encodebytes(image_file_path.read_bytes()).decode("utf-8"),
1713          ]
1714      artifact_path = "image_classification_model"
1715  
1716      # Log the image classification model
1717      with mlflow.start_run():
1718          model_info = mlflow.transformers.log_model(
1719              small_vision_model,
1720              name=artifact_path,
1721          )
1722      pyfunc_inference_payload = json.dumps({"inputs": inference_payload})
1723      response = pyfunc_serve_and_score_model(
1724          model_info.model_uri,
1725          data=pyfunc_inference_payload,
1726          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1727          extra_args=["--env-manager", "local"],
1728      )
1729  
1730      predictions = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
1731  
1732      transformers_loaded_model = mlflow.transformers.load_model(model_info.model_uri)
1733      expected_predictions = transformers_loaded_model.predict(inference_payload)
1734  
1735      assert [list(pred.values()) for pred in predictions.to_dict("records")] == expected_predictions
1736  
1737  
1738  def test_classifier_pipeline_pyfunc_predict(text_classification_pipeline):
1739      artifact_path = "text_classifier_model"
1740      data = [
1741          "I think this sushi might have gone off",
1742          "That gym smells like feet, hot garbage, and sadness",
1743          "I love that we have a moon",
1744          "I 'love' debugging subprocesses",
1745          'Quote "in" the string',
1746      ]
1747      signature = infer_signature(data)
1748      with mlflow.start_run():
1749          model_info = mlflow.transformers.log_model(
1750              text_classification_pipeline,
1751              name=artifact_path,
1752              signature=signature,
1753          )
1754  
1755      response = pyfunc_serve_and_score_model(
1756          model_info.model_uri,
1757          data=json.dumps({"inputs": data}),
1758          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1759          extra_args=["--env-manager", "local"],
1760      )
1761      values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
1762  
1763      assert len(values.to_dict()) == 2
1764      assert len(values.to_dict()["score"]) == 5
1765  
1766      # test simple string input
1767      inference_payload = json.dumps({"inputs": ["testing"]})
1768  
1769      response = pyfunc_serve_and_score_model(
1770          model_info.model_uri,
1771          data=inference_payload,
1772          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1773          extra_args=["--env-manager", "local"],
1774      )
1775      values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
1776  
1777      assert len(values.to_dict()) == 2
1778      assert len(values.to_dict()["score"]) == 1
1779  
1780      # Test the alternate TextClassificationPipeline input structure where text_pair is used
1781      # and ensure that model serving and direct native inference match
1782      inference_data = [
1783          {"text": "test1", "text_pair": "pair1"},
1784          {"text": "test2", "text_pair": "pair2"},
1785          {"text": "test 'quote", "text_pair": "pair 'quote'"},
1786      ]
1787      signature = infer_signature(inference_data)
1788      with mlflow.start_run():
1789          model_info = mlflow.transformers.log_model(
1790              text_classification_pipeline,
1791              name=artifact_path,
1792              signature=signature,
1793          )
1794  
1795      inference_payload = json.dumps({"inputs": inference_data})
1796      response = pyfunc_serve_and_score_model(
1797          model_info.model_uri,
1798          data=inference_payload,
1799          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1800          extra_args=["--env-manager", "local"],
1801      )
1802      values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
1803      values_dict = values.to_dict()
1804      native_predict = text_classification_pipeline(inference_data)
1805  
1806      # validate that the pyfunc served model registers text_pair in the same manner as native
1807      for key in ["score", "label"]:
1808          for value in [0, 1]:
1809              if key == "label":
1810                  assert values_dict[key][value] == native_predict[value][key]
1811              else:
1812                  assert math.isclose(
1813                      values_dict[key][value], native_predict[value][key], rel_tol=1e-6
1814                  )
1815  
1816  
1817  def test_zero_shot_pipeline_pyfunc_predict(zero_shot_pipeline):
1818      artifact_path = "zero_shot_classifier_model"
1819      with mlflow.start_run():
1820          model_info = mlflow.transformers.log_model(
1821              zero_shot_pipeline,
1822              name=artifact_path,
1823          )
1824          model_uri = model_info.model_uri
1825  
1826      inference_payload = json.dumps({
1827          "inputs": {
1828              "sequences": "My dog loves running through troughs of spaghetti with his mouth open",
1829              "candidate_labels": ["happy", "sad"],
1830              "hypothesis_template": "This example talks about how the dog is {}",
1831          }
1832      })
1833  
1834      response = pyfunc_serve_and_score_model(
1835          model_uri,
1836          data=inference_payload,
1837          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1838          extra_args=["--env-manager", "local"],
1839      )
1840      values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
1841  
1842      assert len(values.to_dict()) == 3
1843      assert len(values.to_dict()["labels"]) == 2
1844  
1845      inference_payload = json.dumps({
1846          "inputs": {
1847              "sequences": [
1848                  "My dog loves to eat spaghetti",
1849                  "My dog hates going to the vet",
1850                  "My 'hamster' loves to play with my 'friendly' dog",
1851              ],
1852              "candidate_labels": '["happy", "sad"]',
1853              "hypothesis_template": "This example talks about how the dog is {}",
1854          }
1855      })
1856      response = pyfunc_serve_and_score_model(
1857          model_uri,
1858          data=inference_payload,
1859          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1860          extra_args=["--env-manager", "local"],
1861      )
1862      values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
1863  
1864      assert len(values.to_dict()) == 3
1865      assert len(values.to_dict()["labels"]) == 6
1866  
1867  
1868  def test_table_question_answering_pyfunc_predict(table_question_answering_pipeline):
1869      artifact_path = "table_qa_model"
1870      with mlflow.start_run():
1871          model_info = mlflow.transformers.log_model(
1872              table_question_answering_pipeline,
1873              name=artifact_path,
1874          )
1875  
1876      table = {
1877          "Fruit": ["Apples", "Bananas", "Oranges", "Watermelon 'small'", "Blueberries"],
1878          "Sales": ["1230945.55", "86453.12", "11459.23", "8341.23", "2325.88"],
1879          "Inventory": ["910", "4589", "11200", "80", "3459"],
1880      }
1881  
1882      inference_payload = json.dumps({
1883          "inputs": {
1884              "query": "What should we order more of?",
1885              "table": table,
1886          }
1887      })
1888  
1889      response = pyfunc_serve_and_score_model(
1890          model_info.model_uri,
1891          data=inference_payload,
1892          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1893          extra_args=["--env-manager", "local"],
1894      )
1895      values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
1896  
1897      assert len(values.to_dict(orient="records")) == 1
1898  
1899      inference_payload = json.dumps({
1900          "inputs": {
1901              "query": [
1902                  "What is our highest sales?",
1903                  "What should we order more of?",
1904                  "Which 'fruit' has the 'highest' 'sales'?",
1905              ],
1906              "table": table,
1907          }
1908      })
1909      response = pyfunc_serve_and_score_model(
1910          model_info.model_uri,
1911          data=inference_payload,
1912          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1913          extra_args=["--env-manager", "local"],
1914      )
1915      values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
1916  
1917      assert len(values.to_dict(orient="records")) == 3
1918  
1919  
1920  def test_feature_extraction_pipeline(feature_extraction_pipeline):
1921      sentences = ["hi", "hello"]
1922      signature = infer_signature(
1923          sentences,
1924          mlflow.transformers.generate_signature_output(feature_extraction_pipeline, sentences),
1925      )
1926  
1927      artifact_path = "feature_extraction_pipeline"
1928      with mlflow.start_run():
1929          model_info = mlflow.transformers.log_model(
1930              feature_extraction_pipeline,
1931              name=artifact_path,
1932              signature=signature,
1933              input_example=["A sentence", "Another sentence"],
1934          )
1935  
1936      # Load as native
1937      loaded_pipeline = mlflow.transformers.load_model(model_info.model_uri)
1938  
1939      inference_single = "Testing"
1940      inference_mult = ["Testing something", "Testing something else"]
1941  
1942      pred = loaded_pipeline(inference_single)
1943      assert len(pred[0][0]) > 10
1944      assert isinstance(pred[0][0][0], float)
1945  
1946      pred_multiple = loaded_pipeline(inference_mult)
1947      assert len(pred_multiple[0][0]) > 2
1948      assert isinstance(pred_multiple[0][0][0][0], float)
1949  
1950      loaded_pyfunc = mlflow.pyfunc.load_model(model_info.model_uri)
1951  
1952      pyfunc_pred = loaded_pyfunc.predict(inference_single)
1953  
1954      assert isinstance(pyfunc_pred, np.ndarray)
1955  
1956      assert np.array_equal(np.array(pred[0]), pyfunc_pred)
1957  
1958      pyfunc_pred_multiple = loaded_pyfunc.predict(inference_mult)
1959  
1960      assert np.array_equal(np.array(pred_multiple[0][0]), pyfunc_pred_multiple)
1961  
1962  
1963  def test_feature_extraction_pipeline_pyfunc_predict(feature_extraction_pipeline):
1964      artifact_path = "feature_extraction"
1965      with mlflow.start_run():
1966          model_info = mlflow.transformers.log_model(
1967              feature_extraction_pipeline,
1968              name=artifact_path,
1969          )
1970  
1971      inference_payload = json.dumps({"inputs": ["sentence one", "sentence two"]})
1972  
1973      response = pyfunc_serve_and_score_model(
1974          model_info.model_uri,
1975          data=inference_payload,
1976          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1977          extra_args=["--env-manager", "local"],
1978      )
1979      values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
1980  
1981      assert len(values.columns) == 384
1982      assert len(values) == 4
1983  
1984      inference_payload = json.dumps({"inputs": "sentence three"})
1985  
1986      response = pyfunc_serve_and_score_model(
1987          model_info.model_uri,
1988          data=inference_payload,
1989          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
1990          extra_args=["--env-manager", "local"],
1991      )
1992      assert response.status_code == 200
1993      prediction = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
1994      assert len(prediction.columns) == 384
1995      assert len(prediction) == 4
1996  
1997  
1998  def test_loading_unsupported_pipeline_type_as_pyfunc(small_multi_modal_pipeline, model_path):
1999      mlflow.transformers.save_model(small_multi_modal_pipeline, model_path)
2000      with pytest.raises(MlflowException, match='Model does not have the "python_function" flavor'):
2001          mlflow.pyfunc.load_model(model_path)
2002  
2003  
2004  def test_pyfunc_input_validations(mock_pyfunc_wrapper):
2005      def ensure_raises(data, match):
2006          with pytest.raises(MlflowException, match=match):
2007              mock_pyfunc_wrapper._validate_str_or_list_str(data)
2008  
2009      match1 = "The input data is of an incorrect type"
2010      match2 = "If supplying a list, all values must"
2011      ensure_raises({"a": "b"}, match1)
2012      ensure_raises(("a", "b"), match1)
2013      ensure_raises({"a", "b"}, match1)
2014      ensure_raises(True, match1)
2015      ensure_raises(12, match1)
2016      ensure_raises([1, 2, 3], match2)
2017      ensure_raises([{"a", "b"}], match2)
2018      ensure_raises([["a", "b", "c'"]], match2)
2019      ensure_raises([{"a": "b"}, {"a": "c"}], match2)
2020      ensure_raises([[1], [2]], match2)
2021  
2022  
2023  def test_pyfunc_json_encoded_dict_parsing(mock_pyfunc_wrapper):
2024      plain_dict = {"a": "b", "b": "c"}
2025      list_dict = [plain_dict, plain_dict]
2026  
2027      plain_input = {"in": json.dumps(plain_dict)}
2028      list_input = {"in": json.dumps(list_dict)}
2029  
2030      plain_parsed = mock_pyfunc_wrapper._parse_json_encoded_dict_payload_to_dict(plain_input, "in")
2031      assert plain_parsed == {"in": plain_dict}
2032  
2033      list_parsed = mock_pyfunc_wrapper._parse_json_encoded_dict_payload_to_dict(list_input, "in")
2034      assert list_parsed == {"in": list_dict}
2035  
2036      invalid_parsed = mock_pyfunc_wrapper._parse_json_encoded_dict_payload_to_dict(
2037          plain_input, "invalid"
2038      )
2039      assert invalid_parsed != {"in": plain_dict}
2040      assert invalid_parsed == plain_input
2041  
2042  
2043  def test_pyfunc_json_encoded_list_parsing(mock_pyfunc_wrapper):
2044      plain_list = ["a", "b", "c"]
2045      nested_list = [plain_list, plain_list]
2046      list_dict = [{"a": "b"}, {"a": "c"}]
2047  
2048      plain_input = {"in": json.dumps(plain_list)}
2049      nested_input = {"in": json.dumps(nested_list)}
2050      list_dict_input = {"in": json.dumps(list_dict)}
2051  
2052      plain_parsed = mock_pyfunc_wrapper._parse_json_encoded_list(plain_input, "in")
2053      assert plain_parsed == {"in": plain_list}
2054  
2055      nested_parsed = mock_pyfunc_wrapper._parse_json_encoded_list(nested_input, "in")
2056      assert nested_parsed == {"in": nested_list}
2057  
2058      list_dict_parsed = mock_pyfunc_wrapper._parse_json_encoded_list(list_dict_input, "in")
2059      assert list_dict_parsed == {"in": list_dict}
2060  
2061      with pytest.raises(MlflowException, match="Invalid key in inference payload. The "):
2062          mock_pyfunc_wrapper._parse_json_encoded_list(list_dict_input, "invalid")
2063  
2064  
2065  def test_pyfunc_text_to_text_input(mock_pyfunc_wrapper):
2066      text2text_input = {"context": "a", "answer": "b"}
2067      parsed_input = mock_pyfunc_wrapper._parse_text2text_input(text2text_input)
2068      assert parsed_input == "context: a answer: b"
2069  
2070      text2text_input_list = [text2text_input, text2text_input]
2071      parsed_input_list = mock_pyfunc_wrapper._parse_text2text_input(text2text_input_list)
2072      assert parsed_input_list == ["context: a answer: b", "context: a answer: b"]
2073  
2074      parsed_with_inputs = mock_pyfunc_wrapper._parse_text2text_input({"inputs": "a"})
2075      assert parsed_with_inputs == ["a"]
2076  
2077      parsed_str = mock_pyfunc_wrapper._parse_text2text_input("a")
2078      assert parsed_str == "a"
2079  
2080      parsed_list_str = mock_pyfunc_wrapper._parse_text2text_input(["a", "b"])
2081      assert parsed_list_str == ["a", "b"]
2082  
2083      with pytest.raises(MlflowException, match="An invalid type has been supplied"):
2084          mock_pyfunc_wrapper._parse_text2text_input([1, 2, 3])
2085  
2086      with pytest.raises(MlflowException, match="An invalid type has been supplied"):
2087          mock_pyfunc_wrapper._parse_text2text_input([{"a": [{"b": "c"}]}])
2088  
2089  
2090  def test_pyfunc_qa_input(mock_pyfunc_wrapper):
2091      single_input = {"question": "a", "context": "b"}
2092      parsed_single_input = mock_pyfunc_wrapper._parse_question_answer_input(single_input)
2093      assert parsed_single_input == single_input
2094  
2095      multi_input = [single_input, single_input]
2096      parsed_multi_input = mock_pyfunc_wrapper._parse_question_answer_input(multi_input)
2097      assert parsed_multi_input == multi_input
2098  
2099      with pytest.raises(MlflowException, match="Invalid keys were submitted. Keys must"):
2100          mock_pyfunc_wrapper._parse_question_answer_input({"q": "a", "c": "b"})
2101  
2102      with pytest.raises(MlflowException, match="An invalid type has been supplied"):
2103          mock_pyfunc_wrapper._parse_question_answer_input("a")
2104  
2105      with pytest.raises(MlflowException, match="An invalid type has been supplied"):
2106          mock_pyfunc_wrapper._parse_question_answer_input(["a", "b", "c"])
2107  
2108  
2109  def test_list_of_dict_to_list_of_str_parsing(mock_pyfunc_wrapper):
2110      # Test with a single list of dictionaries
2111      output_data = [{"a": "foo"}, {"a": "bar"}, {"a": "baz"}]
2112      expected_output = ["foo", "bar", "baz"]
2113      assert (
2114          mock_pyfunc_wrapper._parse_lists_of_dict_to_list_of_str(output_data, "a") == expected_output
2115      )
2116  
2117      # Test with a nested list of dictionaries
2118      output_data = [
2119          {"a": "foo", "b": [{"a": "bar"}]},
2120          {"a": "baz", "b": [{"a": "qux"}]},
2121      ]
2122      expected_output = ["foo", "bar", "baz", "qux"]
2123      assert (
2124          mock_pyfunc_wrapper._parse_lists_of_dict_to_list_of_str(output_data, "a") == expected_output
2125      )
2126  
2127      # Test with nested list with exclusion data
2128      output_data = [
2129          {"a": "valid", "b": [{"a": "another valid"}, {"b": "invalid"}]},
2130          {"a": "valid 2", "b": [{"a": "another valid 2"}, {"c": "invalid"}]},
2131      ]
2132      expected_output = ["valid", "another valid", "valid 2", "another valid 2"]
2133      assert (
2134          mock_pyfunc_wrapper._parse_lists_of_dict_to_list_of_str(output_data, "a") == expected_output
2135      )
2136  
2137  
2138  def test_parsing_tokenizer_output(mock_pyfunc_wrapper):
2139      output_data = [{"a": "b"}, {"a": "c"}, {"a": "d"}]
2140      expected_output = "b,c,d"
2141      assert mock_pyfunc_wrapper._parse_tokenizer_output(output_data, {"a"}) == expected_output
2142  
2143      output_data = [output_data, output_data]
2144      expected_output = [expected_output, expected_output]
2145      assert mock_pyfunc_wrapper._parse_tokenizer_output(output_data, {"a"}) == expected_output
2146  
2147  
2148  def test_parse_list_of_multiple_dicts(mock_pyfunc_wrapper):
2149      output_data = [{"a": "b", "d": "f"}, {"a": "z", "d": "g"}]
2150      target_dict_key = "a"
2151      expected_output = ["b"]
2152  
2153      assert (
2154          mock_pyfunc_wrapper._parse_list_of_multiple_dicts(output_data, target_dict_key)
2155          == expected_output
2156      )
2157  
2158      output_data = [
2159          [{"a": "c", "d": "q"}, {"a": "o", "d": "q"}, {"a": "d", "d": "q"}, {"a": "e", "d": "r"}],
2160          [{"a": "m", "d": "s"}, {"a": "e", "d": "t"}],
2161      ]
2162      target_dict_key = "a"
2163      expected_output = ["c", "m"]
2164  
2165      assert (
2166          mock_pyfunc_wrapper._parse_list_of_multiple_dicts(output_data, target_dict_key)
2167          == expected_output
2168      )
2169  
2170  
2171  @pytest.mark.parametrize(
2172      (
2173          "pipeline_input",
2174          "pipeline_output",
2175          "expected_output",
2176          "flavor_config",
2177          "include_prompt",
2178          "collapse_whitespace",
2179      ),
2180      [
2181          (
2182              "What answers?",
2183              [{"generated_text": "What answers?\n\nA collection of\n\nanswers"}],
2184              "A collection of\n\nanswers",
2185              {"instance_type": "InstructionTextGenerationPipeline"},
2186              False,
2187              False,
2188          ),
2189          (
2190              "What answers?",
2191              [{"generated_text": "What answers?\n\nA collection of\n\nanswers"}],
2192              "A collection of answers",
2193              {"instance_type": "InstructionTextGenerationPipeline"},
2194              False,
2195              True,
2196          ),
2197          (
2198              "Hello!",
2199              [{"generated_text": "Hello!\n\nHow are you?"}],
2200              "How are you?",
2201              {"instance_type": "InstructionTextGenerationPipeline"},
2202              False,
2203              False,
2204          ),
2205          (
2206              "Hello!",
2207              [{"generated_text": "Hello!\n\nA: How are you?\n\n"}],
2208              "How are you?",
2209              {"instance_type": "InstructionTextGenerationPipeline"},
2210              False,
2211              True,
2212          ),
2213          (
2214              "Hello!",
2215              [{"generated_text": "Hello!\n\nA: How are you?\n\n"}],
2216              "Hello! A: How are you?",
2217              {"instance_type": "InstructionTextGenerationPipeline"},
2218              True,
2219              True,
2220          ),
2221          (
2222              "Hello!",
2223              [{"generated_text": "Hello!\n\nA: How\nare\nyou?\n\n"}],
2224              "How\nare\nyou?\n\n",
2225              {"instance_type": "InstructionTextGenerationPipeline"},
2226              False,
2227              False,
2228          ),
2229          (
2230              ["Hi!", "What's up?"],
2231              [[{"generated_text": "Hi!\n\nHello there"}, {"generated_text": "Not much, and you?"}]],
2232              ["Hello there", "Not much, and you?"],
2233              {"instance_type": "InstructionTextGenerationPipeline"},
2234              False,
2235              False,
2236          ),
2237          # Tests disabling parsing of newline characters
2238          (
2239              ["Hi!", "What's up?"],
2240              [
2241                  [
2242                      {"generated_text": "Hi!\n\nHello there"},
2243                      {"generated_text": "What's up?\n\nNot much, and you?"},
2244                  ]
2245              ],
2246              ["Hi!\n\nHello there", "What's up?\n\nNot much, and you?"],
2247              {"instance_type": "InstructionTextGenerationPipeline"},
2248              True,
2249              False,
2250          ),
2251          (
2252              "Hello!",
2253              [{"generated_text": "Hello!\n\nHow are you?"}],
2254              "Hello!\n\nHow are you?",
2255              {"instance_type": "InstructionTextGenerationPipeline"},
2256              True,
2257              False,
2258          ),
2259          # Tests a standard TextGenerationPipeline output
2260          (
2261              ["We like to", "Open the"],
2262              [
2263                  [
2264                      {"generated_text": "We like to party"},
2265                      {"generated_text": "Open the door get on the floor everybody do the dinosaur"},
2266                  ]
2267              ],
2268              ["We like to party", "Open the door get on the floor everybody do the dinosaur"],
2269              {"instance_type": "TextGenerationPipeline"},
2270              True,
2271              True,
2272          ),
2273          # Tests a standard TextGenerationPipeline output with setting "include_prompt" (noop)
2274          (
2275              ["We like to", "Open the"],
2276              [
2277                  [
2278                      {"generated_text": "We like to party"},
2279                      {"generated_text": "Open the door get on the floor everybody do the dinosaur"},
2280                  ]
2281              ],
2282              ["We like to party", "Open the door get on the floor everybody do the dinosaur"],
2283              {"instance_type": "TextGenerationPipeline"},
2284              False,
2285              False,
2286          ),
2287          # Test TextGenerationPipeline removes whitespace
2288          (
2289              ["We like to", "Open the"],
2290              [
2291                  [
2292                      {"generated_text": "  We like   to    party"},
2293                      {
2294                          "generated_text": "Open the   door get on the floor   everybody    "
2295                          "do\nthe dinosaur"
2296                      },
2297                  ]
2298              ],
2299              ["We like to party", "Open the door get on the floor everybody do the dinosaur"],
2300              {"instance_type": "TextGenerationPipeline"},
2301              False,
2302              True,
2303          ),
2304      ],
2305  )
2306  def test_parse_input_from_instruction_pipeline(
2307      mock_pyfunc_wrapper,
2308      pipeline_input,
2309      pipeline_output,
2310      expected_output,
2311      flavor_config,
2312      include_prompt,
2313      collapse_whitespace,
2314  ):
2315      assert (
2316          mock_pyfunc_wrapper._strip_input_from_response_in_instruction_pipelines(
2317              pipeline_input,
2318              pipeline_output,
2319              "generated_text",
2320              flavor_config,
2321              include_prompt,
2322              collapse_whitespace,
2323          )
2324          == expected_output
2325      )
2326  
2327  
2328  @pytest.mark.parametrize(
2329      "flavor_config",
2330      [
2331          {"instance_type": "InstructionTextGenerationPipeline"},
2332          {"instance_type": "TextGenerationPipeline"},
2333      ],
2334  )
2335  def test_invalid_instruction_pipeline_parsing(mock_pyfunc_wrapper, flavor_config):
2336      prompt = "What is your favorite boba flavor?"
2337  
2338      bad_output = {"generated_text": ["Strawberry Milk Cap", "Honeydew with boba"]}
2339  
2340      with pytest.raises(MlflowException, match="Unable to parse the pipeline output. Expected"):
2341          mock_pyfunc_wrapper._strip_input_from_response_in_instruction_pipelines(
2342              prompt, bad_output, "generated_text", flavor_config, True
2343          )
2344  
2345  
2346  @pytest.mark.skipif(RUNNING_IN_GITHUB_ACTIONS, reason=GITHUB_ACTIONS_SKIP_REASON)
2347  def test_instructional_pipeline_no_prompt_in_output(model_path):
2348      architecture = "databricks/dolly-v2-3b"
2349      dolly = transformers.pipeline(model=architecture, trust_remote_code=True)
2350  
2351      mlflow.transformers.save_model(
2352          transformers_model=dolly,
2353          path=model_path,
2354          # Validate removal of prompt but inclusion of newlines by default
2355          model_config={"max_length": 100, "include_prompt": False},
2356          input_example="Hello, Dolly!",
2357      )
2358  
2359      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
2360  
2361      inference = pyfunc_loaded.predict("What is MLflow?")
2362  
2363      assert not inference[0].startswith("What is MLflow?")
2364      assert "\n" in inference[0]
2365  
2366  
2367  @pytest.mark.skipif(RUNNING_IN_GITHUB_ACTIONS, reason=GITHUB_ACTIONS_SKIP_REASON)
2368  def test_instructional_pipeline_no_prompt_in_output_and_removal_of_newlines(model_path):
2369      architecture = "databricks/dolly-v2-3b"
2370      dolly = transformers.pipeline(model=architecture, trust_remote_code=True)
2371  
2372      mlflow.transformers.save_model(
2373          transformers_model=dolly,
2374          path=model_path,
2375          # Validate removal of prompt but inclusion of newlines by default
2376          model_config={"max_length": 100, "include_prompt": False, "collapse_whitespace": True},
2377          input_example="Hello, Dolly!",
2378      )
2379  
2380      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
2381  
2382      inference = pyfunc_loaded.predict("What is MLflow?")
2383  
2384      assert not inference[0].startswith("What is MLflow?")
2385      assert "\n" not in inference[0]
2386  
2387  
2388  @pytest.mark.skipif(RUNNING_IN_GITHUB_ACTIONS, reason=GITHUB_ACTIONS_SKIP_REASON)
2389  def test_instructional_pipeline_with_prompt_in_output(model_path):
2390      architecture = "databricks/dolly-v2-3b"
2391      dolly = transformers.pipeline(model=architecture, trust_remote_code=True)
2392  
2393      mlflow.transformers.save_model(
2394          transformers_model=dolly,
2395          path=model_path,
2396          # test default propagation of `include_prompt`=True and `collapse_whitespace`=False
2397          model_config={"max_length": 100},
2398          input_example="Hello, Dolly!",
2399      )
2400  
2401      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
2402  
2403      inference = pyfunc_loaded.predict("What is MLflow?")
2404  
2405      assert inference[0].startswith("What is MLflow?")
2406      assert "\n\n" in inference[0]
2407  
2408  
2409  def read_audio_data(format: str):
2410      datasets_path = pathlib.Path(__file__).resolve().parent.parent.joinpath("datasets")
2411      wav_file_path = datasets_path.joinpath("apollo11_launch.wav")
2412      if format == "float":
2413          audio, _ = librosa.load(wav_file_path, sr=16000)
2414          return audio
2415      elif format == "bytes":
2416          return wav_file_path.read_bytes()
2417      elif format == "file":
2418          return wav_file_path.as_posix()
2419      else:
2420          raise ValueError(f"Invalid format: {format}")
2421  
2422  
2423  @pytest.mark.parametrize("input_format", ["float", "bytes", "file"])
2424  @pytest.mark.parametrize("with_input_example", [True, False])
2425  def test_whisper_model_predict(model_path, whisper_pipeline, input_format, with_input_example):
2426      if input_format == "float" and not with_input_example:
2427          pytest.skip("If the input format is float, the default signature must be overridden.")
2428  
2429      audio = read_audio_data(input_format)
2430      mlflow.transformers.save_model(
2431          transformers_model=whisper_pipeline,
2432          path=model_path,
2433          input_example=audio if with_input_example else None,
2434          save_pretrained=False,
2435      )
2436  
2437      # 1. Single prediction with native transformer pipeline
2438      loaded_pipeline = mlflow.transformers.load_model(model_path)
2439      transcription = loaded_pipeline(audio)
2440      assert transcription["text"].startswith(" 30")
2441      # strip the leading space
2442      expected_text = transcription["text"].lstrip()
2443  
2444      # 2. Single prediction with Pyfunc
2445      loaded_pyfunc = mlflow.pyfunc.load_model(model_path)
2446      pyfunc_transcription = loaded_pyfunc.predict(audio)[0]
2447      assert pyfunc_transcription == expected_text
2448  
2449      # 3. Batch prediction with Pyfunc. Float input format is not supported for batch prediction,
2450      # because our signature framework doesn't support a list of numpy array.
2451      if input_format != "float":
2452          batch_transcription = loaded_pyfunc.predict([audio, audio])
2453          assert len(batch_transcription) == 2
2454          assert all(ts == expected_text for ts in batch_transcription)
2455  
2456  
2457  def test_whisper_model_serve_and_score(whisper_pipeline):
2458      # Request payload to the model serving endpoint contains base64 encoded audio data
2459      audio = read_audio_data("bytes")
2460      encoded_audio = base64.b64encode(audio).decode("ascii")
2461  
2462      with mlflow.start_run():
2463          model_info = mlflow.transformers.log_model(
2464              whisper_pipeline,
2465              name="whisper",
2466              save_pretrained=False,
2467          )
2468  
2469      def _assert_response(response, length=1):
2470          preds = json.loads(response.content.decode("utf-8"))["predictions"]
2471          assert len(preds) == length
2472          assert all(pred.startswith("30") for pred in preds)
2473  
2474      with pyfunc_scoring_endpoint(
2475          model_info.model_uri,
2476          extra_args=["--env-manager", "local"],
2477      ) as endpoint:
2478          content_type = pyfunc_scoring_server.CONTENT_TYPE_JSON
2479  
2480          # Test payload with "inputs" key
2481          inputs_dict = {"inputs": [encoded_audio]}
2482          payload = json.dumps(inputs_dict)
2483          response = endpoint.invoke(payload, content_type=content_type)
2484          _assert_response(response)
2485  
2486          # Test payload with "dataframe_split" key
2487          inference_df = pd.DataFrame(pd.Series([encoded_audio], name="audio_file"))
2488          split_dict = {"dataframe_split": inference_df.to_dict(orient="split")}
2489          payload = json.dumps(split_dict)
2490          response = endpoint.invoke(payload, content_type=content_type)
2491          _assert_response(response)
2492  
2493          # Test payload with "dataframe_records" key
2494          records_dict = {"dataframe_records": inference_df.to_dict(orient="records")}
2495          payload = json.dumps(records_dict)
2496          response = endpoint.invoke(payload, content_type=content_type)
2497          _assert_response(response)
2498  
2499          # Test batch prediction
2500          inputs_dict = {"inputs": [encoded_audio, encoded_audio]}
2501          payload = json.dumps(inputs_dict)
2502          response = endpoint.invoke(payload, content_type=content_type)
2503          _assert_response(response, length=2)
2504  
2505          # Scoring with audio file URI is not supported yet (pyfunc prediction works tho)
2506          inputs_dict = {"inputs": [read_audio_data("file")]}
2507          payload = json.dumps(inputs_dict)
2508          response = endpoint.invoke(payload, content_type=content_type)
2509          response = json.loads(response.content.decode("utf-8"))
2510          assert response["error_code"] == "INVALID_PARAMETER_VALUE"
2511          assert "Failed to process the input audio data. Either" in response["message"]
2512  
2513  
2514  # https://github.com/huggingface/transformers/commit/9c500015c556f9ddf6e7a7449d3f46b2e3ff8ea5
2515  # caused a regression in beam search.
2516  # https://github.com/huggingface/transformers/commit/a6b51e7341d702127a4a45f37439640840b5abf0
2517  # fixed the regression but has not been released yet as of May 30, 2025.
2518  @pytest.mark.skipif(
2519      Version("4.52.0") <= Version(transformers.__version__) < Version("4.53.0"),
2520      reason="Transformers 4.52 has a bug for beam search in whiper implementation",
2521  )
2522  def test_whisper_model_support_timestamps(whisper_pipeline):
2523      # Request payload to the model serving endpoint contains base64 encoded audio data
2524      audio = read_audio_data("bytes")
2525      encoded_audio = base64.b64encode(audio).decode("ascii")
2526  
2527      model_config = {
2528          "return_timestamps": "word",
2529          "chunk_length_s": 20,
2530          "stride_length_s": [5, 3],
2531      }
2532  
2533      with mlflow.start_run():
2534          model_info = mlflow.transformers.log_model(
2535              whisper_pipeline,
2536              name="whisper_timestamps",
2537              model_config=model_config,
2538              input_example=(audio, model_config),
2539          )
2540  
2541      # Native transformers prediction as ground truth
2542      gt = whisper_pipeline(audio, **model_config)
2543  
2544      def _assert_prediction(pred):
2545          assert pred["text"] == gt["text"]
2546          assert len(pred["chunks"]) == len(gt["chunks"])
2547          for pred_chunk, gt_chunk in zip(pred["chunks"], gt["chunks"]):
2548              assert pred_chunk["text"] == gt_chunk["text"]
2549              # Timestamps are tuples, but converted to list when serialized to JSON.
2550              assert tuple(pred_chunk["timestamp"]) == gt_chunk["timestamp"]
2551  
2552      # Prediction with Pyfunc
2553      loaded_pyfunc = mlflow.pyfunc.load_model(model_info.model_uri)
2554      prediction = json.loads(loaded_pyfunc.predict(audio)[0])
2555      _assert_prediction(prediction)
2556  
2557      # Serve and score
2558      with pyfunc_scoring_endpoint(
2559          model_info.model_uri,
2560          extra_args=["--env-manager", "local"],
2561      ) as endpoint:
2562          content_type = pyfunc_scoring_server.CONTENT_TYPE_JSON
2563          payload = json.dumps({"inputs": [encoded_audio]})
2564          response = endpoint.invoke(payload, content_type=content_type)
2565  
2566          predictions = json.loads(response.content.decode("utf-8"))["predictions"]
2567          # When return_timestamps is specified, the predictions list contains json
2568          # serialized output from the pipeline.
2569          _assert_prediction(json.loads(predictions[0]))
2570  
2571          # Request with inference params
2572          payload = json.dumps({
2573              "inputs": [encoded_audio],
2574              "model_config": model_config,
2575          })
2576          response = endpoint.invoke(payload, content_type=content_type)
2577          predictions = json.loads(response.content.decode("utf-8"))["predictions"]
2578          _assert_prediction(json.loads(predictions[0]))
2579  
2580  
2581  def test_whisper_model_pyfunc_with_malformed_input(whisper_pipeline, model_path):
2582      mlflow.transformers.save_model(
2583          transformers_model=whisper_pipeline,
2584          path=model_path,
2585          save_pretrained=False,
2586      )
2587  
2588      pyfunc_model = mlflow.pyfunc.load_model(model_path)
2589  
2590      invalid_audio = b"This isn't a real audio file"
2591      with pytest.raises(MlflowException, match="Failed to process the input audio data. Either"):
2592          pyfunc_model.predict([invalid_audio])
2593  
2594      bad_uri_msg = "An invalid string input was provided. String"
2595  
2596      with pytest.raises(MlflowException, match=bad_uri_msg):
2597          pyfunc_model.predict("An invalid path")
2598  
2599      with pytest.raises(MlflowException, match=bad_uri_msg):
2600          pyfunc_model.predict("//www.invalid.net/audio.wav")
2601  
2602      with pytest.raises(MlflowException, match=bad_uri_msg):
2603          pyfunc_model.predict("https:///my/audio.mp3")
2604  
2605  
2606  @pytest.mark.parametrize("with_input_example", [True, False])
2607  def test_audio_classification_pipeline(audio_classification_pipeline, with_input_example):
2608      audio = read_audio_data("bytes")
2609  
2610      with mlflow.start_run():
2611          model_info = mlflow.transformers.log_model(
2612              audio_classification_pipeline,
2613              name="audio_classification",
2614              input_example=audio if with_input_example else None,
2615              save_pretrained=False,
2616          )
2617  
2618      inference_payload = json.dumps({"inputs": [base64.b64encode(audio).decode("ascii")]})
2619  
2620      response = pyfunc_serve_and_score_model(
2621          model_info.model_uri,
2622          data=inference_payload,
2623          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
2624          extra_args=["--env-manager", "local"],
2625      )
2626  
2627      values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
2628      assert isinstance(values, pd.DataFrame)
2629      assert len(values) > 1
2630      assert list(values.columns) == ["score", "label"]
2631  
2632  
2633  @pytest.mark.parametrize(
2634      "model_name",
2635      [
2636          "tiiuae/falcon-7b",
2637          "openai-community/gpt2",
2638          "PrunaAI/runwayml-stable-diffusion-v1-5-turbo-tiny-green-smashed",
2639      ],
2640  )
2641  def test_save_model_card_with_non_utf_characters(tmp_path, model_name):
2642      # non-ascii unicode characters
2643      test_text = (
2644          "Emoji testing! \u2728 \U0001f600 \U0001f609 \U0001f606 "
2645          "\U0001f970 \U0001f60e \U0001f917 \U0001f9d0"
2646      )
2647  
2648      card_data: ModelCard = huggingface_hub.ModelCard.load(model_name)
2649      card_data.text = card_data.text + "\n\n" + test_text
2650      custom_data = card_data.data.to_dict()
2651      custom_data["emojis"] = test_text
2652  
2653      card_data.data = huggingface_hub.CardData(**custom_data)
2654      _write_card_data(card_data, tmp_path)
2655  
2656      txt = tmp_path.joinpath(_CARD_TEXT_FILE_NAME).read_text()
2657      assert txt == card_data.text
2658      data = yaml.safe_load(tmp_path.joinpath(_CARD_DATA_FILE_NAME).read_text())
2659      assert data == card_data.data.to_dict()
2660  
2661  
2662  def test_vision_pipeline_pyfunc_predict_with_kwargs(small_vision_model):
2663      artifact_path = "image_classification_model"
2664  
2665      parameters = {
2666          "top_k": 2,
2667      }
2668      inference_payload = json.dumps({
2669          "inputs": [image_url],
2670          "params": parameters,
2671      })
2672  
2673      with mlflow.start_run():
2674          model_info = mlflow.transformers.log_model(
2675              small_vision_model,
2676              name=artifact_path,
2677              signature=infer_signature(
2678                  image_url,
2679                  mlflow.transformers.generate_signature_output(small_vision_model, image_url),
2680                  params=parameters,
2681              ),
2682          )
2683          model_uri = model_info.model_uri
2684      transformers_loaded_model = mlflow.transformers.load_model(model_uri)
2685      expected_predictions = transformers_loaded_model.predict(image_url)
2686  
2687      response = pyfunc_serve_and_score_model(
2688          model_uri,
2689          data=inference_payload,
2690          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
2691          extra_args=["--env-manager", "local"],
2692      )
2693  
2694      predictions = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
2695  
2696      assert (
2697          list(predictions.to_dict("records")[0].values())
2698          == expected_predictions[: parameters["top_k"]]
2699      )
2700  
2701  
2702  def test_qa_pipeline_pyfunc_predict_with_kwargs(small_qa_pipeline):
2703      artifact_path = "qa_model"
2704      data = {
2705          "question": [
2706              "What color is it?",
2707              "What does the 'wolf' howl at?",
2708          ],
2709          "context": [
2710              "Some people said it was green but I know that it's pink.",
2711              "The pack of 'wolves' stood on the cliff and a 'lone wolf' howled at "
2712              "the moon for hours.",
2713          ],
2714      }
2715      parameters = {
2716          "top_k": 2,
2717          "max_answer_len": 5,
2718      }
2719      inference_payload = json.dumps({
2720          "inputs": data,
2721          "params": parameters,
2722      })
2723      output = mlflow.transformers.generate_signature_output(small_qa_pipeline, data)
2724      signature_with_params = infer_signature(
2725          data,
2726          output,
2727          parameters,
2728      )
2729      expected_signature = ModelSignature(
2730          Schema([
2731              ColSpec(Array(DataType.string), name="question"),
2732              ColSpec(Array(DataType.string), name="context"),
2733          ]),
2734          Schema([ColSpec(DataType.string)]),
2735          ParamSchema([
2736              ParamSpec("top_k", DataType.long, 2),
2737              ParamSpec("max_answer_len", DataType.long, 5),
2738          ]),
2739      )
2740      assert signature_with_params == expected_signature
2741  
2742      with mlflow.start_run():
2743          model_info = mlflow.transformers.log_model(
2744              small_qa_pipeline,
2745              name=artifact_path,
2746              signature=signature_with_params,
2747          )
2748  
2749      response = pyfunc_serve_and_score_model(
2750          model_info.model_uri,
2751          data=inference_payload,
2752          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
2753          extra_args=["--env-manager", "local"],
2754      )
2755      values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
2756  
2757      assert values.to_dict(orient="records") == [
2758          {0: "pink"},
2759          {0: "pink."},
2760          {0: "the moon"},
2761          {0: "moon"},
2762      ]
2763  
2764  
2765  def test_uri_directory_renaming_handling_pipeline(model_path, text_classification_pipeline):
2766      with mlflow.start_run():
2767          mlflow.transformers.save_model(
2768              transformers_model=text_classification_pipeline, path=model_path
2769          )
2770  
2771      absolute_model_directory = os.path.join(model_path, "model")
2772      renamed_to_old_convention = os.path.join(model_path, "pipeline")
2773      os.rename(absolute_model_directory, renamed_to_old_convention)
2774  
2775      # remove the 'model_binary' entries to emulate older versions of MLflow
2776      mlmodel_file = os.path.join(model_path, "MLmodel")
2777      with open(mlmodel_file) as yaml_file:
2778          mlmodel = yaml.safe_load(yaml_file)
2779  
2780      mlmodel["flavors"]["python_function"].pop("model_binary", None)
2781      mlmodel["flavors"]["transformers"].pop("model_binary", None)
2782  
2783      with open(mlmodel_file, "w") as yaml_file:
2784          yaml.safe_dump(mlmodel, yaml_file)
2785  
2786      loaded_model = mlflow.pyfunc.load_model(model_path)
2787  
2788      prediction = loaded_model.predict("test")
2789      assert isinstance(prediction, pd.DataFrame)
2790      assert isinstance(prediction["label"][0], str)
2791  
2792  
2793  def test_uri_directory_renaming_handling_components(model_path, text_classification_pipeline):
2794      components = {
2795          "tokenizer": text_classification_pipeline.tokenizer,
2796          "model": text_classification_pipeline.model,
2797      }
2798  
2799      with mlflow.start_run():
2800          mlflow.transformers.save_model(transformers_model=components, path=model_path)
2801  
2802      absolute_model_directory = os.path.join(model_path, "model")
2803      renamed_to_old_convention = os.path.join(model_path, "pipeline")
2804      os.rename(absolute_model_directory, renamed_to_old_convention)
2805  
2806      # remove the 'model_binary' entries to emulate older versions of MLflow
2807      mlmodel_file = os.path.join(model_path, "MLmodel")
2808      with open(mlmodel_file) as yaml_file:
2809          mlmodel = yaml.safe_load(yaml_file)
2810  
2811      mlmodel["flavors"]["python_function"].pop("model_binary", None)
2812      mlmodel["flavors"]["transformers"].pop("model_binary", None)
2813  
2814      with open(mlmodel_file, "w") as yaml_file:
2815          yaml.safe_dump(mlmodel, yaml_file)
2816  
2817      loaded_model = mlflow.pyfunc.load_model(model_path)
2818  
2819      prediction = loaded_model.predict("test")
2820      assert isinstance(prediction, pd.DataFrame)
2821      assert isinstance(prediction["label"][0], str)
2822  
2823  
2824  @skip_transformers_v5_or_later
2825  def test_pyfunc_model_log_load_with_artifacts_snapshot():
2826      architecture = "prajjwal1/bert-tiny"
2827      tokenizer = transformers.AutoTokenizer.from_pretrained(architecture)
2828      model = transformers.BertForQuestionAnswering.from_pretrained(architecture)
2829      bert_tiny_pipeline = transformers.pipeline(
2830          task="question-answering", model=model, tokenizer=tokenizer
2831      )
2832  
2833      class QAModel(mlflow.pyfunc.PythonModel):
2834          def load_context(self, context):
2835              """
2836              This method initializes the tokenizer and language model
2837              using the specified snapshot location.
2838              """
2839              snapshot_location = context.artifacts["bert-tiny-model"].removeprefix("hf:/")
2840              # Initialize tokenizer and language model
2841              tokenizer = transformers.AutoTokenizer.from_pretrained(snapshot_location)
2842              model = transformers.BertForQuestionAnswering.from_pretrained(snapshot_location)
2843              self.pipeline = transformers.pipeline(
2844                  task="question-answering", model=model, tokenizer=tokenizer
2845              )
2846  
2847          def predict(self, context, model_input, params=None):
2848              question = model_input["question"][0]
2849              if isinstance(question, np.ndarray):
2850                  question = question.item()
2851              ctx = model_input["context"][0]
2852              if isinstance(ctx, np.ndarray):
2853                  ctx = ctx.item()
2854              return self.pipeline(question=question, context=ctx)
2855  
2856      data = {"question": "Who's house?", "context": "The house is owned by Run."}
2857      pyfunc_artifact_path = "question_answering_model"
2858      with mlflow.start_run():
2859          model_info = mlflow.pyfunc.log_model(
2860              name=pyfunc_artifact_path,
2861              python_model=QAModel(),
2862              artifacts={"bert-tiny-model": "hf:/prajjwal1/bert-tiny"},
2863              input_example=data,
2864              signature=infer_signature(
2865                  data, mlflow.transformers.generate_signature_output(bert_tiny_pipeline, data)
2866              ),
2867              extra_pip_requirements=["transformers", "torch", "numpy"],
2868          )
2869  
2870          pyfunc_model_path = _download_artifact_from_uri(model_info.model_uri)
2871          assert len(os.listdir(os.path.join(pyfunc_model_path, "artifacts"))) != 0
2872          model_config = Model.load(os.path.join(pyfunc_model_path, "MLmodel"))
2873  
2874      loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
2875      assert model_config.to_yaml() == loaded_pyfunc_model.metadata.to_yaml()
2876      assert loaded_pyfunc_model.predict(data)["answer"] != ""
2877  
2878      # Test model serving
2879      inference_payload = json.dumps({"inputs": data})
2880      response = pyfunc_serve_and_score_model(
2881          model_info.model_uri,
2882          data=inference_payload,
2883          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
2884          extra_args=["--env-manager", "local"],
2885      )
2886      values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
2887  
2888      assert values.to_dict(orient="records")[0]["answer"] != ""
2889  
2890  
2891  def test_pyfunc_model_log_load_with_artifacts_snapshot_errors():
2892      class TestModel(mlflow.pyfunc.PythonModel):
2893          def predict(self, context, model_input, params=None):
2894              return model_input
2895  
2896      with mlflow.start_run():
2897          with pytest.raises(
2898              MlflowException,
2899              match=r"Failed to download snapshot from Hugging Face Hub "
2900              r"with artifact_uri: hf:/invalid-repo-id.",
2901          ):
2902              mlflow.pyfunc.log_model(
2903                  name="pyfunc_artifact_path",
2904                  python_model=TestModel(),
2905                  artifacts={"some-model": "hf:/invalid-repo-id"},
2906              )
2907  
2908  
2909  def test_model_distributed_across_devices():
2910      mock_model = mock.Mock()
2911      mock_model.device.type = "meta"
2912      mock_model.hf_device_map = {
2913          "layer1": mock.Mock(type="cpu"),
2914          "layer2": mock.Mock(type="cpu"),
2915          "layer3": mock.Mock(type="gpu"),
2916          "layer4": mock.Mock(type="disk"),
2917      }
2918  
2919      assert _is_model_distributed_in_memory(mock_model)
2920  
2921  
2922  def test_model_on_single_device():
2923      mock_model = mock.Mock()
2924      mock_model.device.type = "cpu"
2925      mock_model.hf_device_map = {}
2926  
2927      assert not _is_model_distributed_in_memory(mock_model)
2928  
2929  
2930  @skip_transformers_v5_or_later
2931  def test_basic_model_with_accelerate_device_mapping_fails_save(tmp_path, model_path):
2932      task = "translation_en_to_de"
2933      architecture = "t5-small"
2934      model = transformers.T5ForConditionalGeneration.from_pretrained(
2935          pretrained_model_name_or_path=architecture,
2936          device_map={"shared": "cpu", "encoder": "cpu", "decoder": "disk", "lm_head": "disk"},
2937          offload_folder=str(tmp_path / "weights"),
2938          low_cpu_mem_usage=True,
2939      )
2940  
2941      tokenizer = transformers.T5TokenizerFast.from_pretrained(
2942          pretrained_model_name_or_path=architecture, model_max_length=100
2943      )
2944      pipeline = transformers.pipeline(task=task, model=model, tokenizer=tokenizer)
2945  
2946      with pytest.raises(
2947          MlflowException,
2948          match="The model that is attempting to be saved has been loaded into memory",
2949      ):
2950          mlflow.transformers.save_model(transformers_model=pipeline, path=model_path)
2951  
2952  
2953  @pytest.mark.skipif(
2954      Version(transformers.__version__) > Version("4.44.2"),
2955      reason="Multi-task pipeline (t5) has a loading issue with Transformers 4.45.x. "
2956      "See https://github.com/huggingface/transformers/issues/33398 for more details.",
2957  )
2958  def test_basic_model_with_accelerate_homogeneous_mapping_works(model_path):
2959      task = "translation_en_to_de"
2960      architecture = "t5-small"
2961      model = transformers.T5ForConditionalGeneration.from_pretrained(
2962          pretrained_model_name_or_path=architecture,
2963          device_map={"shared": "cpu", "encoder": "cpu", "decoder": "cpu", "lm_head": "cpu"},
2964          low_cpu_mem_usage=True,
2965      )
2966  
2967      tokenizer = transformers.T5TokenizerFast.from_pretrained(
2968          pretrained_model_name_or_path=architecture, model_max_length=100
2969      )
2970      pipeline = transformers.pipeline(task=task, model=model, tokenizer=tokenizer)
2971  
2972      mlflow.transformers.save_model(transformers_model=pipeline, path=model_path)
2973  
2974      loaded = mlflow.transformers.load_model(model_path)
2975      text = "Apples are delicious"
2976      assert loaded(text) == pipeline(text)
2977  
2978  
2979  def test_qa_model_model_size_bytes(small_qa_pipeline, tmp_path):
2980      def _calculate_expected_size(path_or_dir):
2981          # this helper function does not consider subdirectories
2982          expected_size = 0
2983          if path_or_dir.is_dir():
2984              for path in path_or_dir.iterdir():
2985                  if not path.is_file():
2986                      continue
2987                  expected_size += path.stat().st_size
2988          elif path_or_dir.is_file():
2989              expected_size = path_or_dir.stat().st_size
2990          return expected_size
2991  
2992      mlflow.transformers.save_model(
2993          transformers_model=small_qa_pipeline,
2994          path=tmp_path,
2995      )
2996  
2997      # expected size only counts for files saved before the MLmodel file is saved
2998      model_dir = tmp_path.joinpath("model")
2999      tokenizer_dir = tmp_path.joinpath("components").joinpath("tokenizer")
3000      expected_size = 0
3001      for folder in [model_dir, tokenizer_dir]:
3002          expected_size += _calculate_expected_size(folder)
3003      other_files = ["model_card.md", "model_card_data.yaml", "LICENSE.txt"]
3004      for file in other_files:
3005          path = tmp_path.joinpath(file)
3006          expected_size += _calculate_expected_size(path)
3007  
3008      mlmodel = yaml.safe_load(tmp_path.joinpath("MLmodel").read_bytes())
3009      assert mlmodel["model_size_bytes"] == expected_size
3010  
3011  
3012  @pytest.mark.parametrize(
3013      ("task", "input_example"),
3014      [
3015          ("llm/v1/completions", None),
3016          ("llm/v1/chat", None),
3017          (
3018              "llm/v1/completions",
3019              {
3020                  "prompt": "How to learn Python in 3 weeks?",
3021                  "max_tokens": 10,
3022                  "stop": "Python",
3023              },
3024          ),
3025          (
3026              "llm/v1/chat",
3027              {
3028                  "messages": [
3029                      {"role": "system", "content": "Hello, how are you?"},
3030                  ],
3031                  "temperature": 0.5,
3032                  "max_tokens": 50,
3033              },
3034          ),
3035      ],
3036  )
3037  def test_text_generation_save_model_with_inference_task(
3038      monkeypatch, task, input_example, text_generation_pipeline, model_path
3039  ):
3040      # Strictly raise error during requirements inference for testing purposes
3041      monkeypatch.setenv("MLFLOW_REQUIREMENTS_INFERENCE_RAISE_ERRORS", "true")
3042  
3043      mlflow.transformers.save_model(
3044          transformers_model=text_generation_pipeline,
3045          path=model_path,
3046          task=task,
3047          input_example=input_example,
3048      )
3049  
3050      mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes())
3051      flavor_config = mlmodel["flavors"]["transformers"]
3052      assert flavor_config["inference_task"] == task
3053      assert mlmodel["metadata"]["task"] == task
3054  
3055      if input_example:
3056          saved_input_example = json.loads(model_path.joinpath("input_example.json").read_text())
3057          assert saved_input_example == input_example
3058  
3059  
3060  def test_text_generation_save_model_with_invalid_inference_task(
3061      text_generation_pipeline, model_path
3062  ):
3063      with pytest.raises(
3064          MlflowException, match=r"The task provided is invalid.*Must be.*llm/v1/completions"
3065      ):
3066          mlflow.transformers.save_model(
3067              transformers_model=text_generation_pipeline,
3068              path=model_path,
3069              task="llm/v1/invalid",
3070          )
3071  
3072  
3073  def test_text_generation_log_model_with_mismatched_task(text_generation_pipeline):
3074      with pytest.raises(
3075          MlflowException, match=r"LLM v1 task type 'llm/v1/chat' is specified in metadata, but"
3076      ):
3077          with mlflow.start_run():
3078              mlflow.transformers.log_model(
3079                  text_generation_pipeline,
3080                  name="model",
3081                  # Task argument and metadata task are different
3082                  task=None,
3083                  metadata={"task": "llm/v1/chat"},
3084              )
3085  
3086  
3087  def test_text_generation_task_completions_predict_with_max_tokens(
3088      text_generation_pipeline, model_path
3089  ):
3090      mlflow.transformers.save_model(
3091          transformers_model=text_generation_pipeline,
3092          path=model_path,
3093          task="llm/v1/completions",
3094          model_config={"max_tokens": 10},
3095      )
3096  
3097      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
3098  
3099      inference = pyfunc_loaded.predict(
3100          {"prompt": "How to learn Python in 3 weeks?", "max_tokens": 10},
3101      )
3102  
3103      assert isinstance(inference[0], dict)
3104      assert inference[0]["model"] == "distilgpt2"
3105      assert inference[0]["object"] == "text_completion"
3106      assert (
3107          inference[0]["choices"][0]["finish_reason"] == "length"
3108          and inference[0]["usage"]["completion_tokens"] == 10
3109      ) or (
3110          inference[0]["choices"][0]["finish_reason"] == "stop"
3111          and inference[0]["usage"]["completion_tokens"] < 10
3112      )
3113  
3114      # Override model_config with runtime params
3115      inference = pyfunc_loaded.predict(
3116          {"prompt": "How to learn Python in 3 weeks?", "max_tokens": 5},
3117      )
3118      assert 6 > inference[0]["usage"]["completion_tokens"] > 0
3119  
3120  
3121  def test_text_generation_task_completions_predict_with_stop(text_generation_pipeline, model_path):
3122      mlflow.transformers.save_model(
3123          transformers_model=text_generation_pipeline,
3124          path=model_path,
3125          task="llm/v1/completions",
3126          metadata={"foo": "bar"},
3127          model_config={"stop": ["Python"], "max_tokens": 50},
3128      )
3129  
3130      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
3131      inference = pyfunc_loaded.predict(
3132          {"prompt": "How to learn Python in 3 weeks?"},
3133      )
3134  
3135      if "Python" not in inference[0]["choices"][0]["text"]:
3136          pytest.skip(
3137              "Model did not generate text containing 'Python', "
3138              "skipping validation of stop parameter in inference"
3139          )
3140  
3141      assert (
3142          inference[0]["choices"][0]["finish_reason"] == "stop"
3143          and inference[0]["usage"]["completion_tokens"] < 50
3144      ) or (
3145          inference[0]["choices"][0]["finish_reason"] == "length"
3146          and inference[0]["usage"]["completion_tokens"] == 50
3147      )
3148  
3149      assert inference[0]["choices"][0]["text"].endswith("Python")
3150  
3151      # Override model_config with runtime params
3152      inference = pyfunc_loaded.predict(
3153          {"prompt": "How to learn Python in 3 weeks?", "stop": ["Abracadabra"]},
3154      )
3155  
3156      response_text = inference[0]["choices"][0]["text"]
3157  
3158      # Only check for early stopping if we stop on the word "Python".
3159      # If we exhaust the token limit, there is a non-insignificant chance of
3160      # terminating on the word due to max tokens, which should not count as
3161      # a stop word abort if there are multiple instances of the word in the text.
3162      if 0 < response_text.count("Python") < 2:
3163          assert not inference[0]["choices"][0]["text"].endswith("Python")
3164  
3165  
3166  def test_text_generation_task_completions_serve(text_generation_pipeline):
3167      data = {"prompt": "How to learn Python in 3 weeks?"}
3168  
3169      with mlflow.start_run():
3170          model_info = mlflow.transformers.log_model(
3171              text_generation_pipeline,
3172              name="model",
3173              task="llm/v1/completions",
3174          )
3175  
3176      inference_payload = json.dumps({"inputs": data})
3177  
3178      response = pyfunc_serve_and_score_model(
3179          model_info.model_uri,
3180          data=inference_payload,
3181          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
3182          extra_args=["--env-manager", "local"],
3183      )
3184      values = PredictionsResponse.from_json(response.content.decode("utf-8")).get_predictions()
3185      output_dict = values.to_dict("records")[0]
3186      assert output_dict["choices"][0]["text"] is not None
3187      assert output_dict["choices"][0]["finish_reason"] == "stop"
3188      assert output_dict["usage"]["prompt_tokens"] < 20
3189  
3190  
3191  def test_llm_v1_task_embeddings_predict(feature_extraction_pipeline, model_path):
3192      mlflow.transformers.save_model(
3193          transformers_model=feature_extraction_pipeline,
3194          path=model_path,
3195          input_examples=["Football", "Soccer"],
3196          task="llm/v1/embeddings",
3197      )
3198  
3199      mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes())
3200  
3201      flavor_config = mlmodel["flavors"]["transformers"]
3202      assert flavor_config["inference_task"] == "llm/v1/embeddings"
3203      assert mlmodel["metadata"]["task"] == "llm/v1/embeddings"
3204  
3205      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
3206  
3207      # Predict with single string input
3208      prediction = pyfunc_loaded.predict({"input": "A great day"})
3209      assert prediction["object"] == "list"
3210      assert len(prediction["data"]) == 1
3211      assert prediction["data"][0]["object"] == "embedding"
3212      assert prediction["usage"]["prompt_tokens"] == 5
3213      assert len(prediction["data"][0]["embedding"]) == 384
3214  
3215      # Predict with list of string input
3216      prediction = pyfunc_loaded.predict({"input": ["A great day", "A bad day"]})
3217      assert prediction["object"] == "list"
3218      assert len(prediction["data"]) == 2
3219      assert prediction["data"][0]["object"] == "embedding"
3220      assert prediction["usage"]["prompt_tokens"] == 10
3221      assert len(prediction["data"][0]["embedding"]) == 384
3222  
3223      # Predict with pandas dataframe input
3224      df = pd.DataFrame({"input": ["A great day", "A bad day", "A good day"]})
3225      prediction = pyfunc_loaded.predict(df)
3226      assert prediction["object"] == "list"
3227      assert len(prediction["data"]) == 3
3228      assert prediction["data"][0]["object"] == "embedding"
3229      assert prediction["usage"]["prompt_tokens"] == 15
3230      assert len(prediction["data"][0]["embedding"]) == 384
3231  
3232  
3233  @pytest.mark.parametrize(
3234      "request_payload",
3235      [
3236          {"input": "A single string"},
3237          {
3238              "inputs": {"input": ["A list of strings"]},
3239          },
3240      ],
3241  )
3242  def test_llm_v1_task_embeddings_serve(feature_extraction_pipeline, request_payload):
3243      with mlflow.start_run():
3244          model_info = mlflow.transformers.log_model(
3245              feature_extraction_pipeline,
3246              name="model",
3247              input_examples=["Football", "Soccer"],
3248              task="llm/v1/embeddings",
3249          )
3250  
3251      response = pyfunc_serve_and_score_model(
3252          model_info.model_uri,
3253          data=json.dumps(request_payload),
3254          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
3255          extra_args=["--env-manager", "local"],
3256      )
3257      response = json.loads(response.content.decode("utf-8"))
3258      prediction = response["predictions"] if "inputs" in request_payload else response
3259  
3260      assert prediction["object"] == "list"
3261      assert len(prediction["data"]) == 1
3262      assert prediction["data"][0]["object"] == "embedding"
3263      assert len(prediction["data"][0]["embedding"]) == 384
3264  
3265  
3266  def test_get_task_for_model():
3267      with mock.patch("transformers.pipelines.get_task") as mock_get_task:
3268          mock_get_task.return_value = "feature-extraction"
3269          assert _get_task_for_model("model") == "feature-extraction"
3270  
3271          # Some model task is not supported by Transformers pipeline yet. Then fall back
3272          # to the default task if provided, otherwise raise an exception.
3273          mock_get_task.return_value = "unsupported-task"
3274          assert (
3275              _get_task_for_model("model", default_task="feature-extraction") == "feature-extraction"
3276          )
3277  
3278          with pytest.raises(MlflowException, match="Cannot construct transformers pipeline"):
3279              _get_task_for_model("model")
3280  
3281          # If get_task raises an exception, fall back to the default task if provided.
3282          mock_get_task.side_effect = RuntimeError("Some error")
3283          assert (
3284              _get_task_for_model("model", default_task="feature-extraction") == "feature-extraction"
3285          )
3286  
3287          with pytest.raises(MlflowException, match="The task could not be inferred"):
3288              _get_task_for_model("model")
3289  
3290  
3291  @skip_transformers_v5_or_later
3292  def test_local_custom_model_save_and_load(text_generation_pipeline, model_path, tmp_path):
3293      local_repo_path = tmp_path / "local_repo"
3294      text_generation_pipeline.save_pretrained(local_repo_path)
3295  
3296      locally_loaded_model = transformers.AutoModelForCausalLM.from_pretrained(local_repo_path)
3297      tokenizer = transformers.AutoTokenizer.from_pretrained(
3298          local_repo_path, chat_template=CHAT_TEMPLATE
3299      )
3300      model_dict = {"model": locally_loaded_model, "tokenizer": tokenizer}
3301  
3302      # 1. Save local custom model without specifying task -> raises MlflowException
3303      with pytest.raises(MlflowException, match=r"The task could not be inferred"):
3304          mlflow.transformers.save_model(transformers_model=model_dict, path=model_path)
3305  
3306      # 2. Save local custom model with task -> saves successfully
3307      mlflow.transformers.save_model(
3308          transformers_model=model_dict,
3309          path=model_path,
3310          task="text-generation",
3311      )
3312  
3313      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
3314  
3315      inference = pyfunc_loaded.predict("How to save Transformer model?")
3316      assert isinstance(inference[0], str)
3317      assert inference[0].startswith("How to save Transformer model?")
3318  
3319      # 3. Save local custom model with LLM v1 chat inference task -> saves successfully
3320      #    with the corresponding Transformers task
3321      shutil.rmtree(model_path)
3322  
3323      mlflow.transformers.save_model(
3324          transformers_model=model_dict,
3325          path=model_path,
3326          task="llm/v1/chat",
3327      )
3328  
3329      mlmodel = yaml.safe_load(model_path.joinpath("MLmodel").read_bytes())
3330      flavor_config = mlmodel["flavors"]["transformers"]
3331      assert flavor_config["task"] == "text-generation"
3332      assert flavor_config["inference_task"] == "llm/v1/chat"
3333      assert mlmodel["metadata"]["task"] == "llm/v1/chat"
3334  
3335      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
3336  
3337      inference = pyfunc_loaded.predict({
3338          "messages": [
3339              {
3340                  "role": "user",
3341                  "content": "How to save Transformer model?",
3342              }
3343          ]
3344      })
3345      assert isinstance(inference[0], dict)
3346      assert inference[0]["choices"][0]["message"]["role"] == "assistant"
3347  
3348  
3349  def test_model_config_is_not_mutated_after_prediction(text2text_generation_pipeline):
3350      model_config = {
3351          "top_k": 2,
3352          "num_beams": 5,
3353          "max_length": 30,
3354          "max_new_tokens": 500,
3355      }
3356  
3357      # Params will be used to override the values of model_config but should not mutate it
3358      params = {
3359          "top_k": 30,
3360          "max_length": 500,
3361          "max_new_tokens": 5,
3362      }
3363  
3364      pyfunc_model = _TransformersWrapper(text2text_generation_pipeline, model_config=model_config)
3365      assert pyfunc_model.model_config["top_k"] == 2
3366  
3367      prediction_output = pyfunc_model.predict(
3368          "rocket moon ship astronaut space gravity", params=params
3369      )
3370  
3371      assert pyfunc_model.model_config["top_k"] == 2
3372      assert pyfunc_model.model_config["num_beams"] == 5
3373      assert pyfunc_model.model_config["max_length"] == 30
3374      assert pyfunc_model.model_config["max_new_tokens"] == 500
3375      assert len(prediction_output[0].split(" ")) <= 5
3376  
3377  
3378  def test_text_generation_task_chat_predict(text_generation_pipeline, model_path):
3379      mlflow.transformers.save_model(
3380          transformers_model=text_generation_pipeline,
3381          path=model_path,
3382          task="llm/v1/chat",
3383      )
3384  
3385      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
3386  
3387      inference = pyfunc_loaded.predict({
3388          "messages": [
3389              {"role": "system", "content": "Hello, how can I help you today?"},
3390              {"role": "user", "content": "How to learn Python in 3 weeks?"},
3391          ],
3392          "max_tokens": 10,
3393      })
3394  
3395      assert inference[0]["choices"][0]["message"]["role"] == "assistant"
3396      assert (
3397          inference[0]["choices"][0]["finish_reason"] == "length"
3398          and inference[0]["usage"]["completion_tokens"] == 10
3399      ) or (
3400          inference[0]["choices"][0]["finish_reason"] == "stop"
3401          and inference[0]["usage"]["completion_tokens"] < 10
3402      )
3403  
3404  
3405  def test_text_generation_task_chat_serve(text_generation_pipeline):
3406      data = {
3407          "messages": [
3408              {"role": "user", "content": "How to learn Python in 3 weeks?"},
3409          ],
3410          "max_tokens": 10,
3411      }
3412  
3413      with mlflow.start_run():
3414          model_info = mlflow.transformers.log_model(
3415              text_generation_pipeline,
3416              name="model",
3417              task="llm/v1/chat",
3418          )
3419  
3420      inference_payload = json.dumps(data)
3421  
3422      response = pyfunc_serve_and_score_model(
3423          model_info.model_uri,
3424          data=inference_payload,
3425          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
3426          extra_args=["--env-manager", "local"],
3427      )
3428  
3429      output_dict = json.loads(response.content)[0]
3430      assert output_dict["choices"][0]["message"] is not None
3431      assert (
3432          output_dict["choices"][0]["finish_reason"] == "length"
3433          and output_dict["usage"]["completion_tokens"] == 10
3434      ) or (
3435          output_dict["choices"][0]["finish_reason"] == "stop"
3436          and output_dict["usage"]["completion_tokens"] < 10
3437      )
3438      assert output_dict["usage"]["prompt_tokens"] < 20
3439  
3440  
3441  HF_COMMIT_HASH_PATTERN = re.compile(r"^[a-z0-9]{40}$")
3442  
3443  
3444  @pytest.mark.parametrize(
3445      ("model_fixture", "input_example", "components"),
3446      [
3447          ("text2text_generation_pipeline", "What is MLflow?", {"tokenizer"}),
3448          ("text_generation_pipeline", "What is MLflow?", {"tokenizer"}),
3449          (
3450              "small_vision_model",
3451              image_url,
3452              {"image_processor"} if IS_NEW_FEATURE_EXTRACTION_API else {"feature_extractor"},
3453          ),
3454          (
3455              "component_multi_modal",
3456              {"text": "What is MLflow?", "image": image_url},
3457              {"image_processor", "tokenizer"}
3458              if IS_NEW_FEATURE_EXTRACTION_API
3459              else {"feature_extractor", "tokenizer"},
3460          ),
3461          ("fill_mask_pipeline", "The quick brown <mask> jumps over the lazy dog.", {"tokenizer"}),
3462          ("whisper_pipeline", lambda: read_audio_data("bytes"), {"feature_extractor", "tokenizer"}),
3463          ("feature_extraction_pipeline", "What is MLflow?", {"tokenizer"}),
3464      ],
3465  )
3466  def test_save_and_load_pipeline_without_save_pretrained_false(
3467      model_fixture, input_example, components, model_path, request
3468  ):
3469      pipeline = request.getfixturevalue(model_fixture)
3470      model = pipeline["model"] if isinstance(pipeline, dict) else pipeline.model
3471  
3472      mlflow.transformers.save_model(
3473          transformers_model=pipeline,
3474          path=model_path,
3475          save_pretrained=False,
3476      )
3477  
3478      # No weights should be saved
3479      assert not model_path.joinpath("model").exists()
3480      assert not model_path.joinpath("components").exists()
3481  
3482      # Validate the contents of MLModel file
3483      mlmodel = Model.load(str(model_path.joinpath("MLmodel")))
3484      flavor_conf = mlmodel.flavors["transformers"]
3485      assert "model_binary" not in flavor_conf
3486      assert flavor_conf["source_model_name"] == model.name_or_path
3487      assert HF_COMMIT_HASH_PATTERN.match(flavor_conf["source_model_revision"])
3488      assert set(flavor_conf["components"]) == components
3489      for c in components:
3490          component = pipeline[c] if isinstance(pipeline, dict) else getattr(pipeline, c)
3491          assert flavor_conf[f"{c}_name"] == getattr(component, "name_or_path", model.name_or_path)
3492          assert HF_COMMIT_HASH_PATTERN.match(flavor_conf[f"{c}_revision"])
3493  
3494      # Validate pyfunc load and prediction (if pyfunc supported)
3495      if "python_function" in mlmodel.flavors:
3496          if callable(input_example):
3497              input_example = input_example()
3498          mlflow.pyfunc.load_model(model_path).predict(input_example)
3499  
3500  
3501  # Patch tempdir just to verify the invocation
3502  def test_persist_pretrained_model(small_qa_pipeline):
3503      with mlflow.start_run():
3504          model_info = mlflow.transformers.log_model(
3505              small_qa_pipeline,
3506              name="model",
3507              save_pretrained=False,
3508              pip_requirements=["mlflow"],  # For speed up logging
3509          )
3510  
3511      artifact_path = Path(mlflow.artifacts.download_artifacts(model_info.model_uri))
3512      model_path = artifact_path / "model"
3513      tokenizer_path = artifact_path / "components" / "tokenizer"
3514  
3515      original_config = Model.load(artifact_path).flavors["transformers"]
3516      assert "model_binary" not in original_config
3517      assert "source_model_revision" in original_config
3518      assert not model_path.exists()
3519      assert not tokenizer_path.exists()
3520  
3521      with mock.patch(
3522          "mlflow.transformers.TempDir", side_effect=mlflow.utils.file_utils.TempDir
3523      ) as mock_tmpdir:
3524          mlflow.transformers.persist_pretrained_model(model_info.model_uri)
3525          mock_tmpdir.assert_called_once()
3526  
3527      updated_config = Model.load(model_info.model_uri).flavors["transformers"]
3528      assert "model_binary" in updated_config
3529      assert "source_model_revision" not in updated_config
3530      assert model_path.exists()
3531      model_path_files = list(model_path.iterdir())
3532      assert len(model_path_files) > 0
3533      assert tokenizer_path.exists()
3534      assert (tokenizer_path / "tokenizer.json").exists()
3535  
3536      # Repeat persisting the model will no-op
3537      with mock.patch(
3538          "mlflow.transformers.TempDir", side_effect=mlflow.utils.file_utils.TempDir
3539      ) as mock_tmpdir:
3540          mlflow.transformers.persist_pretrained_model(model_info.model_uri)
3541          mock_tmpdir.assert_not_called()
3542  
3543  
3544  def test_small_qa_pipeline_copy_metadata_in_databricks(
3545      mock_is_in_databricks, small_qa_pipeline, tmp_path
3546  ):
3547      artifact_path = "transformers"
3548      with mlflow.start_run():
3549          model_info = mlflow.transformers.log_model(
3550              small_qa_pipeline,
3551              name=artifact_path,
3552          )
3553      artifact_path = mlflow.artifacts.download_artifacts(
3554          artifact_uri=model_info.model_uri, dst_path=tmp_path.as_posix()
3555      )
3556  
3557      # Metadata should be copied only in Databricks
3558      metadata_path = os.path.join(artifact_path, "metadata")
3559      if mock_is_in_databricks.return_value:
3560          assert set(os.listdir(metadata_path)) == set(METADATA_FILES)
3561      else:
3562          assert not os.path.exists(metadata_path)
3563      mock_is_in_databricks.assert_called_once()
3564  
3565  
3566  def test_peft_pipeline_copy_metadata_in_databricks(mock_is_in_databricks, peft_pipeline, tmp_path):
3567      artifact_path = "transformers"
3568      with mlflow.start_run():
3569          model_info = mlflow.transformers.log_model(
3570              peft_pipeline,
3571              name=artifact_path,
3572          )
3573  
3574      artifact_path = mlflow.artifacts.download_artifacts(
3575          artifact_uri=model_info.model_uri, dst_path=tmp_path.as_posix()
3576      )
3577  
3578      # Metadata should be copied only in Databricks
3579      metadata_path = os.path.join(artifact_path, "metadata")
3580      if mock_is_in_databricks.return_value:
3581          assert set(os.listdir(metadata_path)) == set(METADATA_FILES)
3582      else:
3583          assert not os.path.exists(metadata_path)
3584      mock_is_in_databricks.assert_called_once()
3585  
3586  
3587  @pytest.mark.parametrize("device", ["cpu", "cuda", 0, -1, None])
3588  def test_device_param_on_load_model(device, small_qa_pipeline, model_path, monkeypatch):
3589      mlflow.transformers.save_model(small_qa_pipeline, path=model_path)
3590      conf = mlflow.transformers.load_model(model_path, return_type="components", device=device)
3591      assert conf.get("device") == device
3592  
3593      monkeypatch.setenv("MLFLOW_HUGGINGFACE_USE_DEVICE_MAP", "true")
3594      if device is None:
3595          conf = mlflow.transformers.load_model(model_path, return_type="components", device=device)
3596          assert conf.get("device") is None
3597      else:
3598          with pytest.raises(
3599              MlflowException,
3600              match=r"The environment variable MLFLOW_HUGGINGFACE_USE_DEVICE_MAP is set to True, "
3601              rf"but the `device` argument is provided with value {device}.",
3602          ):
3603              mlflow.transformers.load_model(model_path, return_type="components", device=device)
3604  
3605  
3606  @pytest.fixture
3607  def local_checkpoint_path(tmp_path):
3608      """
3609      Fixture to create a local model checkpoint for testing fine-tuning scenario.
3610      """
3611      model = transformers.AutoModelForCausalLM.from_pretrained("distilgpt2")
3612  
3613      class DummyDataset(torch.utils.data.Dataset):
3614          def __getitem__(self, idx):
3615              pass
3616  
3617          def __len__(self):
3618              return 1
3619  
3620      # Create a trainer and save model, but not running the actual training
3621      training_args = transformers.TrainingArguments(
3622          output_dir=tmp_path / "result",
3623          num_train_epochs=1,
3624          per_device_train_batch_size=4,
3625          report_to="none",
3626      )
3627      trainer = transformers.Trainer(model=model, args=training_args, train_dataset=DummyDataset())
3628  
3629      checkpoint_path = tmp_path / "checkpoint"
3630      trainer.save_model(checkpoint_path)
3631  
3632      # The tokenizer should also be saved in the checkpoint
3633      tokenizer = transformers.AutoTokenizer.from_pretrained(
3634          # Chat template is required to test with llm/v1/chat task
3635          "distilgpt2",
3636          chat_template=CHAT_TEMPLATE,
3637      )
3638      tokenizer.save_pretrained(checkpoint_path)
3639  
3640      return str(checkpoint_path)
3641  
3642  
3643  def test_save_model_from_local_checkpoint(model_path, local_checkpoint_path):
3644      with mock.patch("mlflow.transformers._logger") as mock_logger:
3645          mlflow.transformers.save_model(
3646              transformers_model=local_checkpoint_path,
3647              task="text-generation",
3648              path=model_path,
3649              input_example=["What is MLflow?"],
3650          )
3651  
3652      logged_info = Model.load(model_path)
3653      flavor_conf = logged_info.flavors["transformers"]
3654      assert flavor_conf["source_model_name"] == local_checkpoint_path
3655      assert flavor_conf["task"] == "text-generation"
3656      if not IS_TRANSFORMERS_V5_OR_LATER:
3657          assert flavor_conf["framework"] == "pt"
3658      assert flavor_conf["instance_type"] == "TextGenerationPipeline"
3659      expected_tokenizer_type = (
3660          "GPT2Tokenizer" if IS_TRANSFORMERS_V5_OR_LATER else "GPT2TokenizerFast"
3661      )
3662      assert flavor_conf["tokenizer_type"] == expected_tokenizer_type
3663  
3664      # Default task signature should be used
3665      assert logged_info.signature.inputs == Schema([ColSpec(DataType.string)])
3666      assert logged_info.signature.outputs == Schema([ColSpec(DataType.string)])
3667  
3668      # Default requirements should be used
3669      info_calls = mock_logger.info.call_args_list
3670      assert any("A local checkpoint path or PEFT model" in c[0][0] for c in info_calls)
3671      with model_path.joinpath("requirements.txt").open() as f:
3672          reqs = {req.split("==")[0] for req in f.read().split("\n")}
3673      assert reqs == {"mlflow", "accelerate", "transformers", "torch", "torchvision"}
3674  
3675      # Load as native pipeline
3676      loaded_pipeline = mlflow.transformers.load_model(model_path)
3677      assert isinstance(loaded_pipeline, transformers.TextGenerationPipeline)
3678  
3679      query = "What is MLflow?"
3680      pred_native = loaded_pipeline(query)[0]
3681      assert pred_native["generated_text"].startswith(query)
3682  
3683      # Load as pyfunc
3684      loaded_pyfunc = mlflow.pyfunc.load_model(model_path)
3685      pred_pyfunc = loaded_pyfunc.predict(query)[0]
3686      assert pred_pyfunc.startswith(query)
3687  
3688      # Serve
3689      response = pyfunc_serve_and_score_model(
3690          model_path,
3691          data=json.dumps({"inputs": [query]}),
3692          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
3693          extra_args=["--env-manager", "local"],
3694      )
3695      pred_serve = json.loads(response.content.decode("utf-8"))
3696      assert pred_serve["predictions"][0].startswith(query)
3697  
3698  
3699  @skip_transformers_v5_or_later
3700  def test_save_model_from_local_checkpoint_with_custom_tokenizer(model_path, local_checkpoint_path):
3701      # When a custom tokenizer is also saved in the checkpoint, MLflow should save and load it.
3702      tokenizer = transformers.AutoTokenizer.from_pretrained("distilroberta-base")
3703      tokenizer.add_special_tokens({"additional_special_tokens": ["<sushi>"]})
3704      tokenizer.save_pretrained(local_checkpoint_path)
3705  
3706      mlflow.transformers.save_model(
3707          transformers_model=local_checkpoint_path,
3708          path=model_path,
3709          task="text-generation",
3710          input_example=["What is MLflow?"],
3711      )
3712  
3713      # The custom tokenizer should be loaded
3714      loaded_pipeline = mlflow.transformers.load_model(model_path)
3715      tokenizer = loaded_pipeline.tokenizer
3716      assert tokenizer.special_tokens_map["additional_special_tokens"] == ["<sushi>"]
3717  
3718  
3719  def test_save_model_from_local_checkpoint_with_llm_inference_task(
3720      model_path, local_checkpoint_path
3721  ):
3722      mlflow.transformers.save_model(
3723          transformers_model=local_checkpoint_path,
3724          path=model_path,
3725          task="llm/v1/chat",
3726          input_example=["What is MLflow?"],
3727      )
3728  
3729      logged_info = Model.load(model_path)
3730      flavor_conf = logged_info.flavors["transformers"]
3731      assert flavor_conf["source_model_name"] == local_checkpoint_path
3732      assert flavor_conf["task"] == "text-generation"
3733      assert flavor_conf["inference_task"] == "llm/v1/chat"
3734  
3735      # Load as pyfunc
3736      loaded_pyfunc = mlflow.pyfunc.load_model(model_path)
3737      response = loaded_pyfunc.predict({
3738          "messages": [
3739              {"role": "system", "content": "Hello, how can I help you today?"},
3740              {"role": "user", "content": "What is MLflow?"},
3741          ],
3742      })
3743      assert response[0]["choices"][0]["message"]["role"] == "assistant"
3744      assert response[0]["choices"][0]["message"]["content"] is not None
3745  
3746  
3747  def test_save_model_from_local_checkpoint_invalid_arguments(model_path, local_checkpoint_path):
3748      with pytest.raises(MlflowException, match=r"The `task` argument must be specified"):
3749          mlflow.transformers.save_model(
3750              transformers_model=local_checkpoint_path,
3751              path=model_path,
3752          )
3753  
3754      with pytest.raises(
3755          MlflowException, match=r"The `save_pretrained` argument must be set to True"
3756      ):
3757          mlflow.transformers.save_model(
3758              transformers_model=local_checkpoint_path,
3759              path=model_path,
3760              task="fill-mask",
3761              save_pretrained=False,
3762          )
3763  
3764      with pytest.raises(
3765          MlflowException,
3766          match=r"The provided directory invalid path does not contain a config.json file.",
3767      ):
3768          mlflow.transformers.save_model(
3769              transformers_model="invalid path",
3770              path=model_path,
3771              task="fill-mask",
3772          )
3773  
3774  
3775  @pytest.mark.parametrize(
3776      ("model_fixture", "should_skip_validation"),
3777      [
3778          ("local_checkpoint_path", True),
3779          ("fill_mask_pipeline", False),
3780      ],
3781  )
3782  def test_log_model_skip_validating_serving_input_for_local_checkpoint(
3783      model_fixture,
3784      should_skip_validation,
3785      tmp_path,
3786      request,
3787  ):
3788      # input to avoid expensive computation
3789      model = request.getfixturevalue(model_fixture)
3790      with mock.patch("mlflow.models.validate_serving_input") as mock_validate_input:
3791          # Ensure mlflow skips serving input validation for local checkpoint
3792          with mlflow.start_run():
3793              model_info = mlflow.transformers.log_model(
3794                  model,
3795                  name="model",
3796                  task="fill-mask",
3797                  input_example=["How are you?"],
3798              )
3799  
3800      # Serving input should exist regardless of the skip validation
3801      mlflow_model = Model.load(model_info.model_uri)
3802      local_path = _download_artifact_from_uri(model_info.model_uri, output_path=tmp_path)
3803      serving_input = mlflow_model.get_serving_input(local_path)
3804      assert json.loads(serving_input) == {"inputs": ["How are you?"]}
3805  
3806      if should_skip_validation:
3807          mock_validate_input.assert_not_called()
3808      else:
3809          mock_validate_input.assert_called_once()