test_transformers_llm_inference_utils.py
1 import uuid 2 from typing import Any, NamedTuple 3 from unittest import mock 4 5 import pandas as pd 6 import pytest 7 import torch 8 9 from mlflow.exceptions import MlflowException 10 from mlflow.models import infer_signature 11 from mlflow.transformers.llm_inference_utils import ( 12 _get_default_task_for_llm_inference_task, 13 _get_finish_reason, 14 _get_output_and_usage_from_tensor, 15 _get_stopping_criteria, 16 _get_token_usage, 17 convert_messages_to_prompt, 18 infer_signature_from_llm_inference_task, 19 preprocess_llm_inference_input, 20 ) 21 from mlflow.types.llm import ( 22 CHAT_MODEL_INPUT_SCHEMA, 23 CHAT_MODEL_OUTPUT_SCHEMA, 24 COMPLETIONS_MODEL_INPUT_SCHEMA, 25 COMPLETIONS_MODEL_OUTPUT_SCHEMA, 26 ) 27 28 29 def test_infer_signature_from_llm_inference_task(): 30 signature = infer_signature_from_llm_inference_task("llm/v1/completions") 31 assert signature.inputs == COMPLETIONS_MODEL_INPUT_SCHEMA 32 assert signature.outputs == COMPLETIONS_MODEL_OUTPUT_SCHEMA 33 34 signature = infer_signature_from_llm_inference_task("llm/v1/chat") 35 assert signature.inputs == CHAT_MODEL_INPUT_SCHEMA 36 assert signature.outputs == CHAT_MODEL_OUTPUT_SCHEMA 37 38 signature = infer_signature("hello", "world") 39 with pytest.raises(MlflowException, match=r".*llm/v1/completions.*signature"): 40 infer_signature_from_llm_inference_task("llm/v1/completions", signature) 41 42 43 class DummyTokenizer: 44 def __call__(self, text: str, **kwargs): 45 input_ids = list(map(int, text.split(" "))) 46 return {"input_ids": torch.tensor([input_ids])} 47 48 def decode(self, tensor, **kwargs): 49 if isinstance(tensor, torch.Tensor): 50 tensor = tensor.tolist() 51 return " ".join([str(x) for x in tensor]) 52 53 def convert_tokens_to_ids(self, tokens: list[str]): 54 return [int(x) for x in tokens] 55 56 def tokenize(self, text: str): 57 return [x for x in text.split(" ") if x] 58 59 def apply_chat_template(self, messages: list[dict[str, str]], **kwargs): 60 return " ".join(message["content"] for message in messages) 61 62 63 def test_apply_chat_template(): 64 data1 = [{"role": "A", "content": "one"}, {"role": "B", "content": "two"}] 65 # Test that the function modifies the data in place for Chat task 66 prompt = convert_messages_to_prompt(data1, DummyTokenizer()) 67 assert prompt == "one two" 68 69 with pytest.raises(MlflowException, match=r"Input messages should be list of"): 70 convert_messages_to_prompt([["one", "two"]], DummyTokenizer()) 71 72 73 class _TestCase(NamedTuple): 74 data: Any 75 params: Any 76 expected_data: Any 77 expected_params: Any 78 79 80 @pytest.mark.parametrize( 81 "case", 82 [ 83 # Case 0: Data only includes prompt 84 _TestCase( 85 data=pd.DataFrame({"prompt": ["Hello world!"]}), 86 params={}, 87 expected_data=["Hello world!"], 88 expected_params={}, 89 ), 90 # Case 1: Data includes prompt and params 91 _TestCase( 92 data=pd.DataFrame({ 93 "prompt": ["Hello world!"], 94 "temperature": [0.7], 95 "max_tokens": [100], 96 "stop": [None], 97 }), 98 params={}, 99 expected_data=["Hello world!"], 100 expected_params={ 101 "temperature": 0.7, 102 # max_tokens is replaced with max_new_tokens 103 "max_new_tokens": 100, 104 # do not pass `stop` to params as it is None 105 }, 106 ), 107 # Case 2: Params are passed if not specified in data 108 _TestCase( 109 data=pd.DataFrame({ 110 "prompt": ["Hello world!"], 111 }), 112 params={ 113 "temperature": 0.7, 114 "max_tokens": 100, 115 "stop": ["foo", "bar"], 116 }, 117 expected_data=["Hello world!"], 118 expected_params={ 119 "temperature": 0.7, 120 "max_new_tokens": 100, 121 # Stopping criteria is _StopSequenceMatchCriteria instance 122 # "stop": ... 123 }, 124 ), 125 # Case 3: Data overrides params 126 _TestCase( 127 data=pd.DataFrame({ 128 "messages": [ 129 [ 130 {"role": "user", "content": "Hello!"}, 131 {"role": "assistant", "content": "Hi!"}, 132 ] 133 ], 134 "temperature": [0.1], 135 "max_tokens": [100], 136 "stop": [["foo", "bar"]], 137 }), 138 params={ 139 "temperature": [0.2], 140 "max_tokens": [200], 141 "stop": ["foo", "bar", "baz"], 142 }, 143 expected_data=[ 144 [ 145 {"role": "user", "content": "Hello!"}, 146 {"role": "assistant", "content": "Hi!"}, 147 ] 148 ], 149 expected_params={ 150 "temperature": 0.1, 151 "max_new_tokens": 100, 152 }, 153 ), 154 # Case 4: Batch input 155 _TestCase( 156 data=pd.DataFrame({ 157 "prompt": ["Hello!", "Hi", "Hola"], 158 "temperature": [0.1, 0.2, 0.3], 159 "max_tokens": [None, 200, 300], 160 }), 161 params={ 162 "temperature": 0.4, 163 "max_tokens": 400, 164 }, 165 expected_data=["Hello!", "Hi", "Hola"], 166 # The values in the first data is used, otherwise params 167 expected_params={ 168 "temperature": 0.1, 169 "max_new_tokens": 400, 170 }, 171 ), 172 # Case 5: Raw dict input 173 _TestCase( 174 data={ 175 "messages": [ 176 {"role": "user", "content": "Hello!"}, 177 {"role": "assistant", "content": "Hi!"}, 178 ], 179 "temperature": 0.1, 180 "max_tokens": 100, 181 "stop": ["foo", "bar"], 182 }, 183 params={}, 184 expected_data=[ 185 [ 186 {"role": "user", "content": "Hello!"}, 187 {"role": "assistant", "content": "Hi!"}, 188 ] 189 ], 190 expected_params={ 191 "temperature": 0.1, 192 "max_new_tokens": 100, 193 }, 194 ), 195 ], 196 ) 197 def test_preprocess_llm_inference_input(case): 198 task = "llm/v1/completions" if "prompt" in case.data else "llm/v1/chat" 199 flavor_config = {"inference_task": task, "source_model_name": "test"} 200 201 with mock.patch( 202 "mlflow.transformers.llm_inference_utils._get_stopping_criteria" 203 ) as mock_get_stopping_criteria: 204 data, params = preprocess_llm_inference_input(case.data, case.params, flavor_config) 205 206 # Test that OpenAI params are separated from data and replaced with Hugging Face params 207 assert data == case.expected_data 208 if "stopping_criteria" in params: 209 assert params.pop("stopping_criteria") is not None 210 mock_get_stopping_criteria.assert_called_once_with(["foo", "bar"], "test") 211 assert params == case.expected_params 212 213 214 def test_preprocess_llm_inference_input_raise_if_key_invalid(): 215 # Missing input key 216 with pytest.raises(MlflowException, match=r"Transformer model saved with"): 217 preprocess_llm_inference_input( 218 pd.DataFrame({"invalid_key": [1, 2, 3]}), 219 flavor_config={"inference_task": "llm/v1/completions"}, 220 ) 221 222 # Unmatched key (should be "messages" for chat task) 223 with pytest.raises(MlflowException, match=r"Transformer model saved with"): 224 preprocess_llm_inference_input( 225 pd.DataFrame({"prompt": ["Hi"]}), flavor_config={"inference_task": "llm/v1/chat"} 226 ) 227 228 229 def test_stopping_criteria(): 230 with mock.patch("transformers.AutoTokenizer.from_pretrained") as mock_from_pretrained: 231 mock_from_pretrained.return_value = DummyTokenizer() 232 233 stopping_criteria = _get_stopping_criteria(stop=None, model_name=None) 234 assert stopping_criteria is None 235 236 input_ids = torch.tensor([[1, 2, 3, 4, 5]]) 237 scores = torch.ones(1, 5) 238 239 stopping_criteria = _get_stopping_criteria(stop="5", model_name="my/model") 240 stopping_criteria_matches = [f(input_ids, scores) for f in stopping_criteria] 241 assert stopping_criteria_matches == [True, True] 242 243 stopping_criteria = _get_stopping_criteria(stop=["100", "5"], model_name="my/model") 244 stopping_criteria_matches = [f(input_ids, scores) for f in stopping_criteria] 245 assert stopping_criteria_matches == [False, False, True, True] 246 247 248 def test_output_dict_for_completions(): 249 prompt = "1 2 3" 250 output_tensor = [1, 2, 3, 4, 5] 251 flavor_config = {"source_model_name": "gpt2"} 252 model_config = {"max_new_tokens": 2} 253 inference_task = "llm/v1/completions" 254 255 pipeline = mock.MagicMock() 256 pipeline.tokenizer = DummyTokenizer() 257 258 output_dict = _get_output_and_usage_from_tensor( 259 prompt, output_tensor, pipeline, flavor_config, model_config, inference_task 260 ) 261 262 # Test UUID validity 263 uuid.UUID(output_dict["id"]) 264 265 assert output_dict["object"] == "text_completion" 266 assert output_dict["model"] == "gpt2" 267 268 assert output_dict["choices"][0]["text"] == "4 5" 269 assert output_dict["choices"][0]["finish_reason"] == "length" 270 271 usage = output_dict["usage"] 272 assert usage["prompt_tokens"] + usage["completion_tokens"] == usage["total_tokens"] 273 274 275 def test_token_usage(): 276 prompt = "1 2 3" 277 output_tensor = [1, 2, 3, 4, 5] 278 279 pipeline = mock.MagicMock() 280 pipeline.tokenizer = DummyTokenizer() 281 282 usage = _get_token_usage(prompt, output_tensor, pipeline, {}) 283 assert usage["prompt_tokens"] == 3 284 assert usage["completion_tokens"] == 2 285 assert usage["total_tokens"] == 5 286 287 288 def test_finish_reason(): 289 assert _get_finish_reason(total_tokens=20, completion_tokens=10, model_config={}) == "stop" 290 291 assert ( 292 _get_finish_reason( 293 total_tokens=20, completion_tokens=10, model_config={"max_new_tokens": 10} 294 ) 295 == "length" 296 ) 297 298 assert ( 299 _get_finish_reason(total_tokens=20, completion_tokens=10, model_config={"max_length": 15}) 300 == "length" 301 ) 302 303 304 @pytest.mark.parametrize( 305 ("inference_task", "expected_task"), 306 [ 307 ("llm/v1/completions", "text-generation"), 308 ("llm/v1/chat", "text-generation"), 309 (None, None), 310 ], 311 ) 312 def test_default_task_for_llm_inference_task(inference_task, expected_task): 313 assert _get_default_task_for_llm_inference_task(inference_task) == expected_task