test_transformers_peft_model.py
1 import importlib 2 import os 3 import re 4 5 import pytest 6 import transformers 7 8 import mlflow 9 from mlflow.exceptions import MlflowException 10 from mlflow.models import Model 11 from mlflow.transformers.flavor_config import FlavorKey 12 from mlflow.transformers.peft import get_peft_base_model, is_peft_model 13 from mlflow.utils.logging_utils import suppress_logs 14 15 SKIP_IF_PEFT_NOT_AVAILABLE = pytest.mark.skipif( 16 importlib.util.find_spec("peft") is None, 17 reason="PEFT is not installed", 18 ) 19 pytestmark = SKIP_IF_PEFT_NOT_AVAILABLE 20 21 22 def test_is_peft_model(peft_pipeline, small_qa_pipeline): 23 assert is_peft_model(peft_pipeline.model) 24 assert not is_peft_model(small_qa_pipeline.model) 25 26 27 def test_get_peft_base_model(peft_pipeline): 28 base_model = get_peft_base_model(peft_pipeline.model) 29 assert base_model.__class__.__name__ == "BertForSequenceClassification" 30 assert base_model.name_or_path == "Elron/bleurt-tiny-512" 31 32 33 def test_get_peft_base_model_prompt_learning(small_qa_pipeline): 34 from peft import PeftModel, PromptTuningConfig, TaskType 35 36 peft_config = PromptTuningConfig( 37 task_type=TaskType.QUESTION_ANS, 38 num_virtual_tokens=10, 39 peft_type="PROMPT_TUNING", 40 ) 41 peft_model = PeftModel(small_qa_pipeline.model, peft_config) 42 43 base_model = get_peft_base_model(peft_model) 44 assert base_model.__class__.__name__ == "MobileBertForQuestionAnswering" 45 assert base_model.name_or_path == "csarron/mobilebert-uncased-squad-v2" 46 47 48 def test_save_and_load_peft_pipeline(peft_pipeline, tmp_path): 49 import peft 50 51 from tests.transformers.test_transformers_model_export import HF_COMMIT_HASH_PATTERN 52 53 mlflow.transformers.save_model( 54 transformers_model=peft_pipeline, 55 path=tmp_path, 56 ) 57 58 # For PEFT, only the adapter model should be saved 59 assert tmp_path.joinpath("peft").exists() 60 assert not tmp_path.joinpath("model").exists() 61 assert not tmp_path.joinpath("components").exists() 62 63 # Validate the contents of MLModel file 64 flavor_conf = Model.load(str(tmp_path.joinpath("MLmodel"))).flavors["transformers"] 65 assert "model_binary" not in flavor_conf 66 assert HF_COMMIT_HASH_PATTERN.match(flavor_conf["source_model_revision"]) 67 assert flavor_conf["peft_adaptor"] == "peft" 68 69 # Validate peft is recorded to requirements.txt 70 with open(tmp_path.joinpath("requirements.txt")) as f: 71 assert f"peft=={peft.__version__}" in f.read() 72 73 loaded_pipeline = mlflow.transformers.load_model(tmp_path) 74 assert isinstance(loaded_pipeline.model, peft.PeftModel) 75 loaded_pipeline.predict("Hi") 76 77 78 def test_save_and_load_peft_components(peft_pipeline, tmp_path, capsys): 79 from peft import PeftModel 80 81 mlflow.transformers.save_model( 82 transformers_model={ 83 "model": peft_pipeline.model, 84 "tokenizer": peft_pipeline.tokenizer, 85 }, 86 path=tmp_path, 87 ) 88 89 # PEFT pipeline construction error should not be raised 90 peft_err_msg = ( 91 "The model 'PeftModelForSequenceClassification' is not supported for text-classification" 92 ) 93 assert peft_err_msg not in capsys.readouterr().err 94 95 loaded_pipeline = mlflow.transformers.load_model(tmp_path) 96 assert isinstance(loaded_pipeline.model, PeftModel) 97 loaded_pipeline.predict("Hi") 98 99 100 def test_log_peft_pipeline(peft_pipeline): 101 from peft import PeftModel 102 103 with mlflow.start_run(): 104 model_info = mlflow.transformers.log_model(peft_pipeline, name="model", input_example="hi") 105 106 loaded_pipeline = mlflow.transformers.load_model(model_info.model_uri) 107 assert isinstance(loaded_pipeline.model, PeftModel) 108 loaded_pipeline.predict("Hi") 109 110 111 @pytest.fixture 112 def peft_model_with_local_base(tmp_path_factory): 113 from peft import LoraConfig, TaskType, get_peft_model 114 115 _PEFT_PIPELINE_ERROR_MSG = re.compile(r"is not supported for") 116 117 base_model_id = "Elron/bleurt-tiny-512" 118 base_dir = tmp_path_factory.mktemp("base_model") 119 120 base_model = transformers.AutoModelForSequenceClassification.from_pretrained(base_model_id) 121 tokenizer = transformers.AutoTokenizer.from_pretrained(base_model_id) 122 123 base_model.save_pretrained(str(base_dir)) 124 tokenizer.save_pretrained(str(base_dir)) 125 126 local_model = transformers.AutoModelForSequenceClassification.from_pretrained(str(base_dir)) 127 local_tokenizer = transformers.AutoTokenizer.from_pretrained(str(base_dir)) 128 129 peft_config = LoraConfig( 130 task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1 131 ) 132 peft_model = get_peft_model(local_model, peft_config) 133 134 with suppress_logs("transformers.pipelines.base", filter_regex=_PEFT_PIPELINE_ERROR_MSG): 135 pipeline = transformers.pipeline( 136 task="text-classification", model=peft_model, tokenizer=local_tokenizer 137 ) 138 139 return pipeline, str(base_dir) 140 141 142 def test_save_and_load_peft_with_base_model_path(peft_model_with_local_base, tmp_path): 143 from peft import PeftModel 144 145 pipeline, base_dir = peft_model_with_local_base 146 147 mlflow.transformers.save_model( 148 transformers_model=pipeline, 149 path=tmp_path, 150 base_model_path=base_dir, 151 ) 152 153 # PEFT adapter should be saved, components should be saved, but base model should NOT 154 assert tmp_path.joinpath("peft").exists() 155 assert not tmp_path.joinpath("model").exists() 156 assert tmp_path.joinpath("components").exists() 157 158 # Validate flavor config 159 flavor_conf = Model.load(str(tmp_path.joinpath("MLmodel"))).flavors["transformers"] 160 assert "model_binary" not in flavor_conf 161 assert "source_model_revision" not in flavor_conf 162 assert flavor_conf[FlavorKey.MODEL_LOCAL_BASE] == os.path.abspath(base_dir) 163 assert flavor_conf[FlavorKey.PEFT] == "peft" 164 165 loaded_pipeline = mlflow.transformers.load_model(tmp_path) 166 assert isinstance(loaded_pipeline.model, PeftModel) 167 loaded_pipeline.predict("Hi") 168 169 170 def test_save_peft_with_base_model_path_components(peft_model_with_local_base, tmp_path): 171 pipeline, base_dir = peft_model_with_local_base 172 173 mlflow.transformers.save_model( 174 transformers_model=pipeline, 175 path=tmp_path, 176 base_model_path=base_dir, 177 ) 178 179 components_dir = tmp_path / "components" / "tokenizer" 180 assert components_dir.exists() 181 assert any(components_dir.iterdir()) 182 183 184 def test_log_peft_with_base_model_path(peft_model_with_local_base): 185 from peft import PeftModel 186 187 pipeline, base_dir = peft_model_with_local_base 188 189 with mlflow.start_run(): 190 model_info = mlflow.transformers.log_model( 191 pipeline, 192 name="model", 193 base_model_path=base_dir, 194 input_example="hi", 195 ) 196 197 loaded_pipeline = mlflow.transformers.load_model(model_info.model_uri) 198 assert isinstance(loaded_pipeline.model, PeftModel) 199 loaded_pipeline.predict("Hi") 200 201 202 def test_base_model_path_rejects_non_peft_model(small_qa_pipeline, tmp_path): 203 with pytest.raises(MlflowException, match="only supported for PEFT models"): 204 mlflow.transformers.save_model( 205 transformers_model=small_qa_pipeline, 206 path=tmp_path, 207 base_model_path="/some/path", 208 ) 209 210 211 def test_base_model_path_rejects_invalid_path(peft_model_with_local_base, tmp_path): 212 pipeline, _ = peft_model_with_local_base 213 214 with pytest.raises(MlflowException, match="does not exist"): 215 mlflow.transformers.save_model( 216 transformers_model=pipeline, 217 path=tmp_path, 218 base_model_path="/nonexistent/path/to/model", 219 ) 220 221 222 def test_load_peft_with_base_model_path_override(peft_model_with_local_base, tmp_path): 223 from peft import PeftModel 224 225 pipeline, base_dir = peft_model_with_local_base 226 save_dir = tmp_path / "model_output" 227 228 # Save with a dummy path (simulating save on a different machine) 229 mlflow.transformers.save_model( 230 transformers_model=pipeline, 231 path=save_dir, 232 base_model_path=base_dir, 233 ) 234 235 # Load with an explicit override path (simulating different mount point) 236 loaded_pipeline = mlflow.transformers.load_model(save_dir, base_model_path=base_dir) 237 assert isinstance(loaded_pipeline.model, PeftModel) 238 loaded_pipeline.predict("Hi") 239 240 241 def test_base_model_path_rejects_non_checkpoint_dir(peft_model_with_local_base, tmp_path): 242 pipeline, _ = peft_model_with_local_base 243 244 empty_dir = tmp_path / "empty_base" 245 empty_dir.mkdir() 246 247 save_dir = tmp_path / "model_output" 248 with pytest.raises(MlflowException, match="config.json"): 249 mlflow.transformers.save_model( 250 transformers_model=pipeline, 251 path=save_dir, 252 base_model_path=str(empty_dir), 253 ) 254 255 256 def test_load_base_model_path_override_rejects_non_checkpoint_dir( 257 peft_model_with_local_base, tmp_path 258 ): 259 pipeline, base_dir = peft_model_with_local_base 260 save_dir = tmp_path / "model_output" 261 mlflow.transformers.save_model( 262 transformers_model=pipeline, 263 path=save_dir, 264 base_model_path=base_dir, 265 ) 266 267 empty_dir = tmp_path / "empty_override" 268 empty_dir.mkdir() 269 270 with pytest.raises(MlflowException, match="config.json"): 271 mlflow.transformers.load_model(save_dir, base_model_path=str(empty_dir))