test_flavor_configs.py
1 import pytest 2 3 from mlflow.exceptions import MlflowException 4 from mlflow.transformers import _build_pipeline_from_model_input 5 from mlflow.transformers.flavor_config import ( 6 build_flavor_config, 7 update_flavor_conf_to_persist_pretrained_model, 8 ) 9 from mlflow.utils.huggingface_utils import is_valid_hf_repo_id 10 11 from tests.transformers.helper import IS_NEW_FEATURE_EXTRACTION_API, IS_TRANSFORMERS_V5_OR_LATER 12 13 14 @pytest.fixture 15 def multi_modal_pipeline(component_multi_modal): 16 task = "image-classification" 17 pipeline = _build_pipeline_from_model_input(component_multi_modal, task) 18 19 tokenizer_type = type(pipeline.tokenizer).__name__ 20 if IS_NEW_FEATURE_EXTRACTION_API: 21 processor = pipeline.image_processor 22 components = { 23 "tokenizer": tokenizer_type, 24 "image_processor": "ViltImageProcessor", 25 "processor": "ViltImageProcessor", 26 } 27 else: 28 processor = pipeline.feature_extractor 29 components = { 30 "tokenizer": tokenizer_type, 31 "feature_extractor": "ViltProcessor", 32 "processor": "ViltProcessor", 33 } 34 35 return pipeline, task, processor, components 36 37 38 def test_flavor_config_pt_save_pretrained_false(small_qa_pipeline): 39 expected = { 40 "task": "question-answering", 41 "instance_type": "QuestionAnsweringPipeline", 42 "pipeline_model_type": "MobileBertForQuestionAnswering", 43 "source_model_name": "csarron/mobilebert-uncased-squad-v2", 44 # "source_model_revision": "SOME_COMMIT_SHA", 45 "torch_dtype": "torch.float32", 46 "components": ["tokenizer"], 47 "tokenizer_type": type(small_qa_pipeline.tokenizer).__name__, 48 "tokenizer_name": "csarron/mobilebert-uncased-squad-v2", 49 # "tokenizer_revision": "SOME_COMMIT_SHA", 50 } 51 if not IS_TRANSFORMERS_V5_OR_LATER: 52 expected["framework"] = "pt" 53 conf = build_flavor_config(small_qa_pipeline, save_pretrained=False) 54 assert len(conf.pop("source_model_revision")) == 40 55 assert len(conf.pop("tokenizer_revision")) == 40 56 assert conf == expected 57 58 59 def test_flavor_config_torch_dtype_overridden_when_specified(small_qa_pipeline): 60 import torch 61 62 conf = build_flavor_config(small_qa_pipeline, torch_dtype=torch.float16, save_pretrained=False) 63 assert conf["torch_dtype"] == "torch.float16" 64 65 66 def test_flavor_config_component_multi_modal(multi_modal_pipeline): 67 pipeline, task, processor, expected_components = multi_modal_pipeline 68 69 # 1. Test with save_pretrained = True 70 conf = build_flavor_config(pipeline, processor) 71 72 assert "model_binary" in conf 73 assert conf["pipeline_model_type"] == "ViltForQuestionAnswering" 74 assert conf["source_model_name"] == "dandelin/vilt-b32-finetuned-vqa" 75 assert "source_model_revision" not in conf 76 77 assert set(conf["components"]) == set(expected_components.keys()) - {"processor"} 78 for component in expected_components: 79 assert conf[f"{component}_type"] == expected_components[component] 80 assert f"{component}_revision" not in conf 81 assert f"{component}_revision" not in conf 82 83 84 def test_flavor_config_component_multi_modal_save_pretrained_false(multi_modal_pipeline): 85 pipeline, task, processor, expected_components = multi_modal_pipeline 86 87 conf = build_flavor_config(pipeline, processor, save_pretrained=False) 88 89 assert "model_binary" not in conf 90 assert conf["pipeline_model_type"] == "ViltForQuestionAnswering" 91 assert conf["source_model_name"] == pipeline.model.name_or_path 92 assert len(conf["source_model_revision"]) == 40 93 94 assert set(conf["components"]) == set(expected_components.keys()) - {"processor"} 95 96 for component in expected_components: 97 assert conf[f"{component}_type"] == expected_components[component] 98 assert conf[f"{component}_name"] == pipeline.model.name_or_path 99 assert len(conf[f"{component}_revision"]) == 40 100 101 102 def test_is_valid_hf_repo_id(tmp_path): 103 assert is_valid_hf_repo_id(None) is False 104 assert is_valid_hf_repo_id(str(tmp_path)) is False 105 assert is_valid_hf_repo_id("invalid/repo/name") is False 106 assert is_valid_hf_repo_id("google-t5/t5-small") is True 107 108 109 _COMMON_CONF = { 110 "task": "text-classification", 111 "instance_type": "TextClassificationPipeline", 112 "pipeline_model_type": "TFMobileBertForSequenceClassification", 113 "source_model_name": "lordtt13/emo-mobilebert", 114 "framework": "tf", 115 "components": ["tokenizer"], 116 "tokenizer_type": "MobileBertTokenizerFast", 117 "transformers_version": "4.37.1", 118 } 119 _COMMIT_HASH = "26d8fcb41762ae83cc9fa03005cb63cde06ef340" 120 121 122 def test_update_flavor_conf_to_persist_pretrained_model(): 123 flavor_conf = { 124 **_COMMON_CONF, 125 "components": ["tokenizer"], 126 "source_model_revision": _COMMIT_HASH, 127 "tokenizer_name": "lordtt13/emo-mobilebert", 128 "tokenizer_revision": _COMMIT_HASH, 129 } 130 updated_flavor_conf = update_flavor_conf_to_persist_pretrained_model(flavor_conf) 131 132 assert updated_flavor_conf["model_binary"] == "model" 133 assert "source_model_revision" not in updated_flavor_conf 134 assert "tokenizer_revision" not in updated_flavor_conf 135 assert "tokenizer_name" not in updated_flavor_conf 136 137 138 def test_update_flavor_conf_to_persist_pretrained_model_multi_modal(): 139 flavor_conf = { 140 **_COMMON_CONF, 141 "components": ["tokenizer", "image_processor"], 142 "source_model_revision": _COMMIT_HASH, 143 "tokznier_revision": _COMMIT_HASH, 144 "image-processor_type": "ViltImageProcessor", 145 "image_processor_name": "dandelin/vilt-b32-finetuned-vqa", 146 "image_processor_revision": _COMMIT_HASH, 147 "processor_type": "ViltImageProcessor", 148 "processor_name": "dandelin/vilt-b32-finetuned-vqa", 149 "processor_revision": _COMMIT_HASH, 150 } 151 updated_flavor_conf = update_flavor_conf_to_persist_pretrained_model(flavor_conf) 152 153 assert updated_flavor_conf["model_binary"] == "model" 154 assert "source_model_revision" not in updated_flavor_conf 155 for component in ["tokenizer", "image_processor", "processor"]: 156 assert f"{component}_revision" not in updated_flavor_conf 157 assert f"{component}_name" not in updated_flavor_conf 158 159 160 def test_update_flavor_conf_to_persist_pretrained_model_raise_if_already_persisted(): 161 flavor_conf = { 162 **_COMMON_CONF, 163 "components": ["tokenizer"], 164 "model_binary": "model", 165 } 166 167 with pytest.raises(MlflowException, match="It appears that the pretrained model weight"): 168 update_flavor_conf_to_persist_pretrained_model(flavor_conf)