/ tests / transformers / test_transformers_prompt_templating.py
test_transformers_prompt_templating.py
  1  from unittest.mock import MagicMock
  2  
  3  import pytest
  4  import transformers
  5  import yaml
  6  from packaging.version import Version
  7  
  8  import mlflow
  9  from mlflow.exceptions import MlflowException
 10  from mlflow.models.model import MLMODEL_FILE_NAME
 11  from mlflow.transformers import _SUPPORTED_PROMPT_TEMPLATING_TASK_TYPES, _validate_prompt_template
 12  from mlflow.transformers.flavor_config import FlavorKey
 13  
 14  # session fixtures to prevent saving and loading a ~400mb model every time
 15  TEST_PROMPT_TEMPLATE = "Answer the following question like a pirate:\nQ: {prompt}\nA: "
 16  
 17  UNSUPPORTED_PIPELINES = [
 18      "audio-classification",
 19      "automatic-speech-recognition",
 20      "text-to-audio",
 21      "text-to-speech",
 22      "text-classification",
 23      "sentiment-analysis",
 24      "token-classification",
 25      "ner",
 26      "question-answering",
 27      "table-question-answering",
 28      "visual-question-answering",
 29      "vqa",
 30      "document-question-answering",
 31      "translation",
 32      "zero-shot-classification",
 33      "zero-shot-image-classification",
 34      "zero-shot-audio-classification",
 35      "conversational",
 36      "image-classification",
 37      "image-segmentation",
 38      "image-to-text",
 39      "object-detection",
 40      "zero-shot-object-detection",
 41      "depth-estimation",
 42      "video-classification",
 43      "mask-generation",
 44      "image-to-image",
 45  ]
 46  
 47  
 48  @pytest.fixture(scope="session")
 49  def small_text_generation_model():
 50      return transformers.pipeline("text-generation", model="distilgpt2")
 51  
 52  
 53  @pytest.fixture(scope="session")
 54  def saved_transformers_model_path(tmp_path_factory, small_text_generation_model):
 55      tmp_path = tmp_path_factory.mktemp("model")
 56      mlflow.transformers.save_model(
 57          transformers_model=small_text_generation_model,
 58          path=tmp_path,
 59          prompt_template=TEST_PROMPT_TEMPLATE,
 60      )
 61      return tmp_path
 62  
 63  
 64  @pytest.mark.parametrize(
 65      "template",
 66      [
 67          "{multiple} {placeholders}",
 68          "No placeholders",
 69          "Placeholder {that} isn't `prompt`",
 70          "Placeholder without a {} name",
 71          "Placeholder with {prompt} and {} empty",
 72          1001,  # not a string
 73      ],
 74  )
 75  def test_prompt_validation_throws_on_invalid_templates(template):
 76      match = (
 77          "Argument `prompt_template` must be a string with a single format arg, 'prompt'."
 78          if isinstance(template, str)
 79          else "Argument `prompt_template` must be a string"
 80      )
 81      with pytest.raises(MlflowException, match=match):
 82          _validate_prompt_template(template)
 83  
 84  
 85  @pytest.mark.parametrize(
 86      "template",
 87      [
 88          "Single placeholder {prompt}",
 89          "Text can be before {prompt} and after",
 90          # the formatter will interpret the double braces as a literal single brace
 91          "Escaped braces {{ work fine {prompt} }}",
 92      ],
 93  )
 94  def test_prompt_validation_succeeds_on_valid_templates(template):
 95      assert _validate_prompt_template(template) is None
 96  
 97  
 98  # test that prompt is saved to mlmodel file and is present in model load
 99  def test_prompt_save_and_load(saved_transformers_model_path):
100      mlmodel_path = saved_transformers_model_path / MLMODEL_FILE_NAME
101      with open(mlmodel_path) as f:
102          mlmodel_dict = yaml.safe_load(f)
103  
104      assert mlmodel_dict["metadata"][FlavorKey.PROMPT_TEMPLATE] == TEST_PROMPT_TEMPLATE
105  
106      model = mlflow.pyfunc.load_model(saved_transformers_model_path)
107      assert model._model_impl.prompt_template == TEST_PROMPT_TEMPLATE
108      assert model._model_impl.model_config["return_full_text"] is False
109  
110  
111  def test_model_save_override_return_full_text(tmp_path, small_text_generation_model):
112      mlflow.transformers.save_model(
113          transformers_model=small_text_generation_model,
114          path=tmp_path,
115          prompt_template=TEST_PROMPT_TEMPLATE,
116          model_config={"return_full_text": True},
117      )
118      model = mlflow.pyfunc.load_model(tmp_path)
119      assert model._model_impl.model_config["return_full_text"] is True
120  
121  
122  def test_saving_prompt_throws_on_unsupported_task():
123      model = transformers.pipeline("text-generation", model="distilgpt2")
124  
125      for pipeline_type in UNSUPPORTED_PIPELINES:
126          # mock the task by setting it explicitly
127          model.task = pipeline_type
128  
129          with pytest.raises(
130              MlflowException,
131              match=f"Prompt templating is not supported for the `{pipeline_type}` task type.",
132          ):
133              mlflow.transformers.save_model(
134                  transformers_model=model,
135                  path="model",
136                  prompt_template=TEST_PROMPT_TEMPLATE,
137              )
138  
139  
140  def test_prompt_formatting(saved_transformers_model_path):
141      model_impl = mlflow.pyfunc.load_model(saved_transformers_model_path)._model_impl
142  
143      # test that the formatting function throws for unsupported pipelines
144      # this is a bit of a redundant test, because the function is explicitly
145      # called only on supported pipelines.
146      for pipeline_type in UNSUPPORTED_PIPELINES:
147          model_impl.pipeline = MagicMock(task=pipeline_type, return_value="")
148          with pytest.raises(
149              MlflowException,
150              match="_format_prompt_template called on an unexpected pipeline type.",
151          ):
152              result = model_impl._format_prompt_template("test")
153  
154      # test that supported pipelines apply the prompt template
155      for pipeline_type in _SUPPORTED_PROMPT_TEMPLATING_TASK_TYPES:
156          model_impl.pipeline = MagicMock(task=pipeline_type, return_value="")
157          result = model_impl._format_prompt_template("test")
158          assert result == TEST_PROMPT_TEMPLATE.format(prompt="test")
159  
160          result_list = model_impl._format_prompt_template(["item1", "item2"])
161          assert result_list == [
162              TEST_PROMPT_TEMPLATE.format(prompt="item1"),
163              TEST_PROMPT_TEMPLATE.format(prompt="item2"),
164          ]
165  
166  
167  # test that prompt is used in pyfunc predict
168  @pytest.mark.parametrize(
169      ("task", "pipeline_fixture", "output_key"),
170      [
171          ("feature-extraction", "feature_extraction_pipeline", None),
172          ("fill-mask", "fill_mask_pipeline", "token_str"),
173          ("summarization", "summarizer_pipeline", "summary_text"),
174          ("text2text-generation", "text2text_generation_pipeline", "generated_text"),
175          ("text-generation", "text_generation_pipeline", "generated_text"),
176      ],
177  )
178  def test_prompt_used_in_predict(task, pipeline_fixture, output_key, request, tmp_path):
179      pipeline = request.getfixturevalue(pipeline_fixture)
180  
181      if task == "summarization" and Version(transformers.__version__) > Version("4.44.2"):
182          pytest.skip(
183              reason="Multi-task pipeline has a loading issue with Transformers 4.45.x. "
184              "See https://github.com/huggingface/transformers/issues/33398 for more details."
185          )
186  
187      model_path = tmp_path / "model"
188      mlflow.transformers.save_model(
189          transformers_model=pipeline,
190          path=model_path,
191          prompt_template=TEST_PROMPT_TEMPLATE,
192      )
193  
194      model = mlflow.pyfunc.load_model(model_path)
195      prompt = "What is MLflow?"
196      formatted_prompt = TEST_PROMPT_TEMPLATE.format(prompt=prompt)
197      mock_response = "MLflow be a tool fer machine lernin'"
198      mock_return = [[{output_key: formatted_prompt + mock_response}]]
199  
200      model._model_impl.pipeline = MagicMock(
201          spec=model._model_impl.pipeline, task=task, return_value=mock_return
202      )
203  
204      model.predict(prompt)
205  
206      # check that the underlying pipeline was called with the formatted prompt template
207      if task == "text-generation":
208          model._model_impl.pipeline.assert_called_once_with(
209              [formatted_prompt], return_full_text=False
210          )
211      else:
212          model._model_impl.pipeline.assert_called_once_with([formatted_prompt])
213  
214  
215  def test_prompt_and_llm_inference_task(tmp_path, request):
216      pipeline = request.getfixturevalue("text_generation_pipeline")
217  
218      model_path = tmp_path / "model"
219      mlflow.transformers.save_model(
220          transformers_model=pipeline,
221          path=model_path,
222          prompt_template=TEST_PROMPT_TEMPLATE,
223          task="llm/v1/completions",
224      )
225  
226      model = mlflow.pyfunc.load_model(model_path)
227  
228      prompt = "What is MLflow?"
229      formatted_prompt = TEST_PROMPT_TEMPLATE.format(prompt=prompt)
230      mock_return = [[{"generated_token_ids": [1, 2, 3]}]]
231  
232      model._model_impl.pipeline = MagicMock(
233          spec=model._model_impl.pipeline, task="text-generation", return_value=mock_return
234      )
235  
236      model.predict({"prompt": prompt})
237  
238      model._model_impl.pipeline.assert_called_once_with(
239          [formatted_prompt], return_full_text=None, return_tensors=True
240      )