/ tests / transformers / test_flavor_configs.py
test_flavor_configs.py
  1  import pytest
  2  
  3  from mlflow.exceptions import MlflowException
  4  from mlflow.transformers import _build_pipeline_from_model_input
  5  from mlflow.transformers.flavor_config import (
  6      build_flavor_config,
  7      update_flavor_conf_to_persist_pretrained_model,
  8  )
  9  from mlflow.utils.huggingface_utils import is_valid_hf_repo_id
 10  
 11  from tests.transformers.helper import IS_NEW_FEATURE_EXTRACTION_API, IS_TRANSFORMERS_V5_OR_LATER
 12  
 13  
 14  @pytest.fixture
 15  def multi_modal_pipeline(component_multi_modal):
 16      task = "image-classification"
 17      pipeline = _build_pipeline_from_model_input(component_multi_modal, task)
 18  
 19      tokenizer_type = type(pipeline.tokenizer).__name__
 20      if IS_NEW_FEATURE_EXTRACTION_API:
 21          processor = pipeline.image_processor
 22          components = {
 23              "tokenizer": tokenizer_type,
 24              "image_processor": "ViltImageProcessor",
 25              "processor": "ViltImageProcessor",
 26          }
 27      else:
 28          processor = pipeline.feature_extractor
 29          components = {
 30              "tokenizer": tokenizer_type,
 31              "feature_extractor": "ViltProcessor",
 32              "processor": "ViltProcessor",
 33          }
 34  
 35      return pipeline, task, processor, components
 36  
 37  
 38  def test_flavor_config_pt_save_pretrained_false(small_qa_pipeline):
 39      expected = {
 40          "task": "question-answering",
 41          "instance_type": "QuestionAnsweringPipeline",
 42          "pipeline_model_type": "MobileBertForQuestionAnswering",
 43          "source_model_name": "csarron/mobilebert-uncased-squad-v2",
 44          # "source_model_revision": "SOME_COMMIT_SHA",
 45          "torch_dtype": "torch.float32",
 46          "components": ["tokenizer"],
 47          "tokenizer_type": type(small_qa_pipeline.tokenizer).__name__,
 48          "tokenizer_name": "csarron/mobilebert-uncased-squad-v2",
 49          # "tokenizer_revision": "SOME_COMMIT_SHA",
 50      }
 51      if not IS_TRANSFORMERS_V5_OR_LATER:
 52          expected["framework"] = "pt"
 53      conf = build_flavor_config(small_qa_pipeline, save_pretrained=False)
 54      assert len(conf.pop("source_model_revision")) == 40
 55      assert len(conf.pop("tokenizer_revision")) == 40
 56      assert conf == expected
 57  
 58  
 59  def test_flavor_config_torch_dtype_overridden_when_specified(small_qa_pipeline):
 60      import torch
 61  
 62      conf = build_flavor_config(small_qa_pipeline, torch_dtype=torch.float16, save_pretrained=False)
 63      assert conf["torch_dtype"] == "torch.float16"
 64  
 65  
 66  def test_flavor_config_component_multi_modal(multi_modal_pipeline):
 67      pipeline, task, processor, expected_components = multi_modal_pipeline
 68  
 69      # 1. Test with save_pretrained = True
 70      conf = build_flavor_config(pipeline, processor)
 71  
 72      assert "model_binary" in conf
 73      assert conf["pipeline_model_type"] == "ViltForQuestionAnswering"
 74      assert conf["source_model_name"] == "dandelin/vilt-b32-finetuned-vqa"
 75      assert "source_model_revision" not in conf
 76  
 77      assert set(conf["components"]) == set(expected_components.keys()) - {"processor"}
 78      for component in expected_components:
 79          assert conf[f"{component}_type"] == expected_components[component]
 80          assert f"{component}_revision" not in conf
 81          assert f"{component}_revision" not in conf
 82  
 83  
 84  def test_flavor_config_component_multi_modal_save_pretrained_false(multi_modal_pipeline):
 85      pipeline, task, processor, expected_components = multi_modal_pipeline
 86  
 87      conf = build_flavor_config(pipeline, processor, save_pretrained=False)
 88  
 89      assert "model_binary" not in conf
 90      assert conf["pipeline_model_type"] == "ViltForQuestionAnswering"
 91      assert conf["source_model_name"] == pipeline.model.name_or_path
 92      assert len(conf["source_model_revision"]) == 40
 93  
 94      assert set(conf["components"]) == set(expected_components.keys()) - {"processor"}
 95  
 96      for component in expected_components:
 97          assert conf[f"{component}_type"] == expected_components[component]
 98          assert conf[f"{component}_name"] == pipeline.model.name_or_path
 99          assert len(conf[f"{component}_revision"]) == 40
100  
101  
102  def test_is_valid_hf_repo_id(tmp_path):
103      assert is_valid_hf_repo_id(None) is False
104      assert is_valid_hf_repo_id(str(tmp_path)) is False
105      assert is_valid_hf_repo_id("invalid/repo/name") is False
106      assert is_valid_hf_repo_id("google-t5/t5-small") is True
107  
108  
109  _COMMON_CONF = {
110      "task": "text-classification",
111      "instance_type": "TextClassificationPipeline",
112      "pipeline_model_type": "TFMobileBertForSequenceClassification",
113      "source_model_name": "lordtt13/emo-mobilebert",
114      "framework": "tf",
115      "components": ["tokenizer"],
116      "tokenizer_type": "MobileBertTokenizerFast",
117      "transformers_version": "4.37.1",
118  }
119  _COMMIT_HASH = "26d8fcb41762ae83cc9fa03005cb63cde06ef340"
120  
121  
122  def test_update_flavor_conf_to_persist_pretrained_model():
123      flavor_conf = {
124          **_COMMON_CONF,
125          "components": ["tokenizer"],
126          "source_model_revision": _COMMIT_HASH,
127          "tokenizer_name": "lordtt13/emo-mobilebert",
128          "tokenizer_revision": _COMMIT_HASH,
129      }
130      updated_flavor_conf = update_flavor_conf_to_persist_pretrained_model(flavor_conf)
131  
132      assert updated_flavor_conf["model_binary"] == "model"
133      assert "source_model_revision" not in updated_flavor_conf
134      assert "tokenizer_revision" not in updated_flavor_conf
135      assert "tokenizer_name" not in updated_flavor_conf
136  
137  
138  def test_update_flavor_conf_to_persist_pretrained_model_multi_modal():
139      flavor_conf = {
140          **_COMMON_CONF,
141          "components": ["tokenizer", "image_processor"],
142          "source_model_revision": _COMMIT_HASH,
143          "tokznier_revision": _COMMIT_HASH,
144          "image-processor_type": "ViltImageProcessor",
145          "image_processor_name": "dandelin/vilt-b32-finetuned-vqa",
146          "image_processor_revision": _COMMIT_HASH,
147          "processor_type": "ViltImageProcessor",
148          "processor_name": "dandelin/vilt-b32-finetuned-vqa",
149          "processor_revision": _COMMIT_HASH,
150      }
151      updated_flavor_conf = update_flavor_conf_to_persist_pretrained_model(flavor_conf)
152  
153      assert updated_flavor_conf["model_binary"] == "model"
154      assert "source_model_revision" not in updated_flavor_conf
155      for component in ["tokenizer", "image_processor", "processor"]:
156          assert f"{component}_revision" not in updated_flavor_conf
157          assert f"{component}_name" not in updated_flavor_conf
158  
159  
160  def test_update_flavor_conf_to_persist_pretrained_model_raise_if_already_persisted():
161      flavor_conf = {
162          **_COMMON_CONF,
163          "components": ["tokenizer"],
164          "model_binary": "model",
165      }
166  
167      with pytest.raises(MlflowException, match="It appears that the pretrained model weight"):
168          update_flavor_conf_to_persist_pretrained_model(flavor_conf)