test_transformers_prompt_templating.py
1 from unittest.mock import MagicMock 2 3 import pytest 4 import transformers 5 import yaml 6 from packaging.version import Version 7 8 import mlflow 9 from mlflow.exceptions import MlflowException 10 from mlflow.models.model import MLMODEL_FILE_NAME 11 from mlflow.transformers import _SUPPORTED_PROMPT_TEMPLATING_TASK_TYPES, _validate_prompt_template 12 from mlflow.transformers.flavor_config import FlavorKey 13 14 # session fixtures to prevent saving and loading a ~400mb model every time 15 TEST_PROMPT_TEMPLATE = "Answer the following question like a pirate:\nQ: {prompt}\nA: " 16 17 UNSUPPORTED_PIPELINES = [ 18 "audio-classification", 19 "automatic-speech-recognition", 20 "text-to-audio", 21 "text-to-speech", 22 "text-classification", 23 "sentiment-analysis", 24 "token-classification", 25 "ner", 26 "question-answering", 27 "table-question-answering", 28 "visual-question-answering", 29 "vqa", 30 "document-question-answering", 31 "translation", 32 "zero-shot-classification", 33 "zero-shot-image-classification", 34 "zero-shot-audio-classification", 35 "conversational", 36 "image-classification", 37 "image-segmentation", 38 "image-to-text", 39 "object-detection", 40 "zero-shot-object-detection", 41 "depth-estimation", 42 "video-classification", 43 "mask-generation", 44 "image-to-image", 45 ] 46 47 48 @pytest.fixture(scope="session") 49 def small_text_generation_model(): 50 return transformers.pipeline("text-generation", model="distilgpt2") 51 52 53 @pytest.fixture(scope="session") 54 def saved_transformers_model_path(tmp_path_factory, small_text_generation_model): 55 tmp_path = tmp_path_factory.mktemp("model") 56 mlflow.transformers.save_model( 57 transformers_model=small_text_generation_model, 58 path=tmp_path, 59 prompt_template=TEST_PROMPT_TEMPLATE, 60 ) 61 return tmp_path 62 63 64 @pytest.mark.parametrize( 65 "template", 66 [ 67 "{multiple} {placeholders}", 68 "No placeholders", 69 "Placeholder {that} isn't `prompt`", 70 "Placeholder without a {} name", 71 "Placeholder with {prompt} and {} empty", 72 1001, # not a string 73 ], 74 ) 75 def test_prompt_validation_throws_on_invalid_templates(template): 76 match = ( 77 "Argument `prompt_template` must be a string with a single format arg, 'prompt'." 78 if isinstance(template, str) 79 else "Argument `prompt_template` must be a string" 80 ) 81 with pytest.raises(MlflowException, match=match): 82 _validate_prompt_template(template) 83 84 85 @pytest.mark.parametrize( 86 "template", 87 [ 88 "Single placeholder {prompt}", 89 "Text can be before {prompt} and after", 90 # the formatter will interpret the double braces as a literal single brace 91 "Escaped braces {{ work fine {prompt} }}", 92 ], 93 ) 94 def test_prompt_validation_succeeds_on_valid_templates(template): 95 assert _validate_prompt_template(template) is None 96 97 98 # test that prompt is saved to mlmodel file and is present in model load 99 def test_prompt_save_and_load(saved_transformers_model_path): 100 mlmodel_path = saved_transformers_model_path / MLMODEL_FILE_NAME 101 with open(mlmodel_path) as f: 102 mlmodel_dict = yaml.safe_load(f) 103 104 assert mlmodel_dict["metadata"][FlavorKey.PROMPT_TEMPLATE] == TEST_PROMPT_TEMPLATE 105 106 model = mlflow.pyfunc.load_model(saved_transformers_model_path) 107 assert model._model_impl.prompt_template == TEST_PROMPT_TEMPLATE 108 assert model._model_impl.model_config["return_full_text"] is False 109 110 111 def test_model_save_override_return_full_text(tmp_path, small_text_generation_model): 112 mlflow.transformers.save_model( 113 transformers_model=small_text_generation_model, 114 path=tmp_path, 115 prompt_template=TEST_PROMPT_TEMPLATE, 116 model_config={"return_full_text": True}, 117 ) 118 model = mlflow.pyfunc.load_model(tmp_path) 119 assert model._model_impl.model_config["return_full_text"] is True 120 121 122 def test_saving_prompt_throws_on_unsupported_task(): 123 model = transformers.pipeline("text-generation", model="distilgpt2") 124 125 for pipeline_type in UNSUPPORTED_PIPELINES: 126 # mock the task by setting it explicitly 127 model.task = pipeline_type 128 129 with pytest.raises( 130 MlflowException, 131 match=f"Prompt templating is not supported for the `{pipeline_type}` task type.", 132 ): 133 mlflow.transformers.save_model( 134 transformers_model=model, 135 path="model", 136 prompt_template=TEST_PROMPT_TEMPLATE, 137 ) 138 139 140 def test_prompt_formatting(saved_transformers_model_path): 141 model_impl = mlflow.pyfunc.load_model(saved_transformers_model_path)._model_impl 142 143 # test that the formatting function throws for unsupported pipelines 144 # this is a bit of a redundant test, because the function is explicitly 145 # called only on supported pipelines. 146 for pipeline_type in UNSUPPORTED_PIPELINES: 147 model_impl.pipeline = MagicMock(task=pipeline_type, return_value="") 148 with pytest.raises( 149 MlflowException, 150 match="_format_prompt_template called on an unexpected pipeline type.", 151 ): 152 result = model_impl._format_prompt_template("test") 153 154 # test that supported pipelines apply the prompt template 155 for pipeline_type in _SUPPORTED_PROMPT_TEMPLATING_TASK_TYPES: 156 model_impl.pipeline = MagicMock(task=pipeline_type, return_value="") 157 result = model_impl._format_prompt_template("test") 158 assert result == TEST_PROMPT_TEMPLATE.format(prompt="test") 159 160 result_list = model_impl._format_prompt_template(["item1", "item2"]) 161 assert result_list == [ 162 TEST_PROMPT_TEMPLATE.format(prompt="item1"), 163 TEST_PROMPT_TEMPLATE.format(prompt="item2"), 164 ] 165 166 167 # test that prompt is used in pyfunc predict 168 @pytest.mark.parametrize( 169 ("task", "pipeline_fixture", "output_key"), 170 [ 171 ("feature-extraction", "feature_extraction_pipeline", None), 172 ("fill-mask", "fill_mask_pipeline", "token_str"), 173 ("summarization", "summarizer_pipeline", "summary_text"), 174 ("text2text-generation", "text2text_generation_pipeline", "generated_text"), 175 ("text-generation", "text_generation_pipeline", "generated_text"), 176 ], 177 ) 178 def test_prompt_used_in_predict(task, pipeline_fixture, output_key, request, tmp_path): 179 pipeline = request.getfixturevalue(pipeline_fixture) 180 181 if task == "summarization" and Version(transformers.__version__) > Version("4.44.2"): 182 pytest.skip( 183 reason="Multi-task pipeline has a loading issue with Transformers 4.45.x. " 184 "See https://github.com/huggingface/transformers/issues/33398 for more details." 185 ) 186 187 model_path = tmp_path / "model" 188 mlflow.transformers.save_model( 189 transformers_model=pipeline, 190 path=model_path, 191 prompt_template=TEST_PROMPT_TEMPLATE, 192 ) 193 194 model = mlflow.pyfunc.load_model(model_path) 195 prompt = "What is MLflow?" 196 formatted_prompt = TEST_PROMPT_TEMPLATE.format(prompt=prompt) 197 mock_response = "MLflow be a tool fer machine lernin'" 198 mock_return = [[{output_key: formatted_prompt + mock_response}]] 199 200 model._model_impl.pipeline = MagicMock( 201 spec=model._model_impl.pipeline, task=task, return_value=mock_return 202 ) 203 204 model.predict(prompt) 205 206 # check that the underlying pipeline was called with the formatted prompt template 207 if task == "text-generation": 208 model._model_impl.pipeline.assert_called_once_with( 209 [formatted_prompt], return_full_text=False 210 ) 211 else: 212 model._model_impl.pipeline.assert_called_once_with([formatted_prompt]) 213 214 215 def test_prompt_and_llm_inference_task(tmp_path, request): 216 pipeline = request.getfixturevalue("text_generation_pipeline") 217 218 model_path = tmp_path / "model" 219 mlflow.transformers.save_model( 220 transformers_model=pipeline, 221 path=model_path, 222 prompt_template=TEST_PROMPT_TEMPLATE, 223 task="llm/v1/completions", 224 ) 225 226 model = mlflow.pyfunc.load_model(model_path) 227 228 prompt = "What is MLflow?" 229 formatted_prompt = TEST_PROMPT_TEMPLATE.format(prompt=prompt) 230 mock_return = [[{"generated_token_ids": [1, 2, 3]}]] 231 232 model._model_impl.pipeline = MagicMock( 233 spec=model._model_impl.pipeline, task="text-generation", return_value=mock_return 234 ) 235 236 model.predict({"prompt": prompt}) 237 238 model._model_impl.pipeline.assert_called_once_with( 239 [formatted_prompt], return_full_text=None, return_tensors=True 240 )