/ tests / transformers / test_transformers_peft_model.py
test_transformers_peft_model.py
  1  import importlib
  2  import os
  3  import re
  4  
  5  import pytest
  6  import transformers
  7  
  8  import mlflow
  9  from mlflow.exceptions import MlflowException
 10  from mlflow.models import Model
 11  from mlflow.transformers.flavor_config import FlavorKey
 12  from mlflow.transformers.peft import get_peft_base_model, is_peft_model
 13  from mlflow.utils.logging_utils import suppress_logs
 14  
 15  SKIP_IF_PEFT_NOT_AVAILABLE = pytest.mark.skipif(
 16      importlib.util.find_spec("peft") is None,
 17      reason="PEFT is not installed",
 18  )
 19  pytestmark = SKIP_IF_PEFT_NOT_AVAILABLE
 20  
 21  
 22  def test_is_peft_model(peft_pipeline, small_qa_pipeline):
 23      assert is_peft_model(peft_pipeline.model)
 24      assert not is_peft_model(small_qa_pipeline.model)
 25  
 26  
 27  def test_get_peft_base_model(peft_pipeline):
 28      base_model = get_peft_base_model(peft_pipeline.model)
 29      assert base_model.__class__.__name__ == "BertForSequenceClassification"
 30      assert base_model.name_or_path == "Elron/bleurt-tiny-512"
 31  
 32  
 33  def test_get_peft_base_model_prompt_learning(small_qa_pipeline):
 34      from peft import PeftModel, PromptTuningConfig, TaskType
 35  
 36      peft_config = PromptTuningConfig(
 37          task_type=TaskType.QUESTION_ANS,
 38          num_virtual_tokens=10,
 39          peft_type="PROMPT_TUNING",
 40      )
 41      peft_model = PeftModel(small_qa_pipeline.model, peft_config)
 42  
 43      base_model = get_peft_base_model(peft_model)
 44      assert base_model.__class__.__name__ == "MobileBertForQuestionAnswering"
 45      assert base_model.name_or_path == "csarron/mobilebert-uncased-squad-v2"
 46  
 47  
 48  def test_save_and_load_peft_pipeline(peft_pipeline, tmp_path):
 49      import peft
 50  
 51      from tests.transformers.test_transformers_model_export import HF_COMMIT_HASH_PATTERN
 52  
 53      mlflow.transformers.save_model(
 54          transformers_model=peft_pipeline,
 55          path=tmp_path,
 56      )
 57  
 58      # For PEFT, only the adapter model should be saved
 59      assert tmp_path.joinpath("peft").exists()
 60      assert not tmp_path.joinpath("model").exists()
 61      assert not tmp_path.joinpath("components").exists()
 62  
 63      # Validate the contents of MLModel file
 64      flavor_conf = Model.load(str(tmp_path.joinpath("MLmodel"))).flavors["transformers"]
 65      assert "model_binary" not in flavor_conf
 66      assert HF_COMMIT_HASH_PATTERN.match(flavor_conf["source_model_revision"])
 67      assert flavor_conf["peft_adaptor"] == "peft"
 68  
 69      # Validate peft is recorded to requirements.txt
 70      with open(tmp_path.joinpath("requirements.txt")) as f:
 71          assert f"peft=={peft.__version__}" in f.read()
 72  
 73      loaded_pipeline = mlflow.transformers.load_model(tmp_path)
 74      assert isinstance(loaded_pipeline.model, peft.PeftModel)
 75      loaded_pipeline.predict("Hi")
 76  
 77  
 78  def test_save_and_load_peft_components(peft_pipeline, tmp_path, capsys):
 79      from peft import PeftModel
 80  
 81      mlflow.transformers.save_model(
 82          transformers_model={
 83              "model": peft_pipeline.model,
 84              "tokenizer": peft_pipeline.tokenizer,
 85          },
 86          path=tmp_path,
 87      )
 88  
 89      # PEFT pipeline construction error should not be raised
 90      peft_err_msg = (
 91          "The model 'PeftModelForSequenceClassification' is not supported for text-classification"
 92      )
 93      assert peft_err_msg not in capsys.readouterr().err
 94  
 95      loaded_pipeline = mlflow.transformers.load_model(tmp_path)
 96      assert isinstance(loaded_pipeline.model, PeftModel)
 97      loaded_pipeline.predict("Hi")
 98  
 99  
100  def test_log_peft_pipeline(peft_pipeline):
101      from peft import PeftModel
102  
103      with mlflow.start_run():
104          model_info = mlflow.transformers.log_model(peft_pipeline, name="model", input_example="hi")
105  
106      loaded_pipeline = mlflow.transformers.load_model(model_info.model_uri)
107      assert isinstance(loaded_pipeline.model, PeftModel)
108      loaded_pipeline.predict("Hi")
109  
110  
111  @pytest.fixture
112  def peft_model_with_local_base(tmp_path_factory):
113      from peft import LoraConfig, TaskType, get_peft_model
114  
115      _PEFT_PIPELINE_ERROR_MSG = re.compile(r"is not supported for")
116  
117      base_model_id = "Elron/bleurt-tiny-512"
118      base_dir = tmp_path_factory.mktemp("base_model")
119  
120      base_model = transformers.AutoModelForSequenceClassification.from_pretrained(base_model_id)
121      tokenizer = transformers.AutoTokenizer.from_pretrained(base_model_id)
122  
123      base_model.save_pretrained(str(base_dir))
124      tokenizer.save_pretrained(str(base_dir))
125  
126      local_model = transformers.AutoModelForSequenceClassification.from_pretrained(str(base_dir))
127      local_tokenizer = transformers.AutoTokenizer.from_pretrained(str(base_dir))
128  
129      peft_config = LoraConfig(
130          task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
131      )
132      peft_model = get_peft_model(local_model, peft_config)
133  
134      with suppress_logs("transformers.pipelines.base", filter_regex=_PEFT_PIPELINE_ERROR_MSG):
135          pipeline = transformers.pipeline(
136              task="text-classification", model=peft_model, tokenizer=local_tokenizer
137          )
138  
139      return pipeline, str(base_dir)
140  
141  
142  def test_save_and_load_peft_with_base_model_path(peft_model_with_local_base, tmp_path):
143      from peft import PeftModel
144  
145      pipeline, base_dir = peft_model_with_local_base
146  
147      mlflow.transformers.save_model(
148          transformers_model=pipeline,
149          path=tmp_path,
150          base_model_path=base_dir,
151      )
152  
153      # PEFT adapter should be saved, components should be saved, but base model should NOT
154      assert tmp_path.joinpath("peft").exists()
155      assert not tmp_path.joinpath("model").exists()
156      assert tmp_path.joinpath("components").exists()
157  
158      # Validate flavor config
159      flavor_conf = Model.load(str(tmp_path.joinpath("MLmodel"))).flavors["transformers"]
160      assert "model_binary" not in flavor_conf
161      assert "source_model_revision" not in flavor_conf
162      assert flavor_conf[FlavorKey.MODEL_LOCAL_BASE] == os.path.abspath(base_dir)
163      assert flavor_conf[FlavorKey.PEFT] == "peft"
164  
165      loaded_pipeline = mlflow.transformers.load_model(tmp_path)
166      assert isinstance(loaded_pipeline.model, PeftModel)
167      loaded_pipeline.predict("Hi")
168  
169  
170  def test_save_peft_with_base_model_path_components(peft_model_with_local_base, tmp_path):
171      pipeline, base_dir = peft_model_with_local_base
172  
173      mlflow.transformers.save_model(
174          transformers_model=pipeline,
175          path=tmp_path,
176          base_model_path=base_dir,
177      )
178  
179      components_dir = tmp_path / "components" / "tokenizer"
180      assert components_dir.exists()
181      assert any(components_dir.iterdir())
182  
183  
184  def test_log_peft_with_base_model_path(peft_model_with_local_base):
185      from peft import PeftModel
186  
187      pipeline, base_dir = peft_model_with_local_base
188  
189      with mlflow.start_run():
190          model_info = mlflow.transformers.log_model(
191              pipeline,
192              name="model",
193              base_model_path=base_dir,
194              input_example="hi",
195          )
196  
197      loaded_pipeline = mlflow.transformers.load_model(model_info.model_uri)
198      assert isinstance(loaded_pipeline.model, PeftModel)
199      loaded_pipeline.predict("Hi")
200  
201  
202  def test_base_model_path_rejects_non_peft_model(small_qa_pipeline, tmp_path):
203      with pytest.raises(MlflowException, match="only supported for PEFT models"):
204          mlflow.transformers.save_model(
205              transformers_model=small_qa_pipeline,
206              path=tmp_path,
207              base_model_path="/some/path",
208          )
209  
210  
211  def test_base_model_path_rejects_invalid_path(peft_model_with_local_base, tmp_path):
212      pipeline, _ = peft_model_with_local_base
213  
214      with pytest.raises(MlflowException, match="does not exist"):
215          mlflow.transformers.save_model(
216              transformers_model=pipeline,
217              path=tmp_path,
218              base_model_path="/nonexistent/path/to/model",
219          )
220  
221  
222  def test_load_peft_with_base_model_path_override(peft_model_with_local_base, tmp_path):
223      from peft import PeftModel
224  
225      pipeline, base_dir = peft_model_with_local_base
226      save_dir = tmp_path / "model_output"
227  
228      # Save with a dummy path (simulating save on a different machine)
229      mlflow.transformers.save_model(
230          transformers_model=pipeline,
231          path=save_dir,
232          base_model_path=base_dir,
233      )
234  
235      # Load with an explicit override path (simulating different mount point)
236      loaded_pipeline = mlflow.transformers.load_model(save_dir, base_model_path=base_dir)
237      assert isinstance(loaded_pipeline.model, PeftModel)
238      loaded_pipeline.predict("Hi")
239  
240  
241  def test_base_model_path_rejects_non_checkpoint_dir(peft_model_with_local_base, tmp_path):
242      pipeline, _ = peft_model_with_local_base
243  
244      empty_dir = tmp_path / "empty_base"
245      empty_dir.mkdir()
246  
247      save_dir = tmp_path / "model_output"
248      with pytest.raises(MlflowException, match="config.json"):
249          mlflow.transformers.save_model(
250              transformers_model=pipeline,
251              path=save_dir,
252              base_model_path=str(empty_dir),
253          )
254  
255  
256  def test_load_base_model_path_override_rejects_non_checkpoint_dir(
257      peft_model_with_local_base, tmp_path
258  ):
259      pipeline, base_dir = peft_model_with_local_base
260      save_dir = tmp_path / "model_output"
261      mlflow.transformers.save_model(
262          transformers_model=pipeline,
263          path=save_dir,
264          base_model_path=base_dir,
265      )
266  
267      empty_dir = tmp_path / "empty_override"
268      empty_dir.mkdir()
269  
270      with pytest.raises(MlflowException, match="config.json"):
271          mlflow.transformers.load_model(save_dir, base_model_path=str(empty_dir))