/ tests / transformers / test_transformers_llm_inference_utils.py
test_transformers_llm_inference_utils.py
  1  import uuid
  2  from typing import Any, NamedTuple
  3  from unittest import mock
  4  
  5  import pandas as pd
  6  import pytest
  7  import torch
  8  
  9  from mlflow.exceptions import MlflowException
 10  from mlflow.models import infer_signature
 11  from mlflow.transformers.llm_inference_utils import (
 12      _get_default_task_for_llm_inference_task,
 13      _get_finish_reason,
 14      _get_output_and_usage_from_tensor,
 15      _get_stopping_criteria,
 16      _get_token_usage,
 17      convert_messages_to_prompt,
 18      infer_signature_from_llm_inference_task,
 19      preprocess_llm_inference_input,
 20  )
 21  from mlflow.types.llm import (
 22      CHAT_MODEL_INPUT_SCHEMA,
 23      CHAT_MODEL_OUTPUT_SCHEMA,
 24      COMPLETIONS_MODEL_INPUT_SCHEMA,
 25      COMPLETIONS_MODEL_OUTPUT_SCHEMA,
 26  )
 27  
 28  
 29  def test_infer_signature_from_llm_inference_task():
 30      signature = infer_signature_from_llm_inference_task("llm/v1/completions")
 31      assert signature.inputs == COMPLETIONS_MODEL_INPUT_SCHEMA
 32      assert signature.outputs == COMPLETIONS_MODEL_OUTPUT_SCHEMA
 33  
 34      signature = infer_signature_from_llm_inference_task("llm/v1/chat")
 35      assert signature.inputs == CHAT_MODEL_INPUT_SCHEMA
 36      assert signature.outputs == CHAT_MODEL_OUTPUT_SCHEMA
 37  
 38      signature = infer_signature("hello", "world")
 39      with pytest.raises(MlflowException, match=r".*llm/v1/completions.*signature"):
 40          infer_signature_from_llm_inference_task("llm/v1/completions", signature)
 41  
 42  
 43  class DummyTokenizer:
 44      def __call__(self, text: str, **kwargs):
 45          input_ids = list(map(int, text.split(" ")))
 46          return {"input_ids": torch.tensor([input_ids])}
 47  
 48      def decode(self, tensor, **kwargs):
 49          if isinstance(tensor, torch.Tensor):
 50              tensor = tensor.tolist()
 51          return " ".join([str(x) for x in tensor])
 52  
 53      def convert_tokens_to_ids(self, tokens: list[str]):
 54          return [int(x) for x in tokens]
 55  
 56      def tokenize(self, text: str):
 57          return [x for x in text.split(" ") if x]
 58  
 59      def apply_chat_template(self, messages: list[dict[str, str]], **kwargs):
 60          return " ".join(message["content"] for message in messages)
 61  
 62  
 63  def test_apply_chat_template():
 64      data1 = [{"role": "A", "content": "one"}, {"role": "B", "content": "two"}]
 65      # Test that the function modifies the data in place for Chat task
 66      prompt = convert_messages_to_prompt(data1, DummyTokenizer())
 67      assert prompt == "one two"
 68  
 69      with pytest.raises(MlflowException, match=r"Input messages should be list of"):
 70          convert_messages_to_prompt([["one", "two"]], DummyTokenizer())
 71  
 72  
 73  class _TestCase(NamedTuple):
 74      data: Any
 75      params: Any
 76      expected_data: Any
 77      expected_params: Any
 78  
 79  
 80  @pytest.mark.parametrize(
 81      "case",
 82      [
 83          # Case 0: Data only includes prompt
 84          _TestCase(
 85              data=pd.DataFrame({"prompt": ["Hello world!"]}),
 86              params={},
 87              expected_data=["Hello world!"],
 88              expected_params={},
 89          ),
 90          # Case 1: Data includes prompt and params
 91          _TestCase(
 92              data=pd.DataFrame({
 93                  "prompt": ["Hello world!"],
 94                  "temperature": [0.7],
 95                  "max_tokens": [100],
 96                  "stop": [None],
 97              }),
 98              params={},
 99              expected_data=["Hello world!"],
100              expected_params={
101                  "temperature": 0.7,
102                  # max_tokens is replaced with max_new_tokens
103                  "max_new_tokens": 100,
104                  # do not pass `stop` to params as it is None
105              },
106          ),
107          # Case 2: Params are passed if not specified in data
108          _TestCase(
109              data=pd.DataFrame({
110                  "prompt": ["Hello world!"],
111              }),
112              params={
113                  "temperature": 0.7,
114                  "max_tokens": 100,
115                  "stop": ["foo", "bar"],
116              },
117              expected_data=["Hello world!"],
118              expected_params={
119                  "temperature": 0.7,
120                  "max_new_tokens": 100,
121                  # Stopping criteria is _StopSequenceMatchCriteria instance
122                  # "stop": ...
123              },
124          ),
125          # Case 3: Data overrides params
126          _TestCase(
127              data=pd.DataFrame({
128                  "messages": [
129                      [
130                          {"role": "user", "content": "Hello!"},
131                          {"role": "assistant", "content": "Hi!"},
132                      ]
133                  ],
134                  "temperature": [0.1],
135                  "max_tokens": [100],
136                  "stop": [["foo", "bar"]],
137              }),
138              params={
139                  "temperature": [0.2],
140                  "max_tokens": [200],
141                  "stop": ["foo", "bar", "baz"],
142              },
143              expected_data=[
144                  [
145                      {"role": "user", "content": "Hello!"},
146                      {"role": "assistant", "content": "Hi!"},
147                  ]
148              ],
149              expected_params={
150                  "temperature": 0.1,
151                  "max_new_tokens": 100,
152              },
153          ),
154          # Case 4: Batch input
155          _TestCase(
156              data=pd.DataFrame({
157                  "prompt": ["Hello!", "Hi", "Hola"],
158                  "temperature": [0.1, 0.2, 0.3],
159                  "max_tokens": [None, 200, 300],
160              }),
161              params={
162                  "temperature": 0.4,
163                  "max_tokens": 400,
164              },
165              expected_data=["Hello!", "Hi", "Hola"],
166              # The values in the first data is used, otherwise params
167              expected_params={
168                  "temperature": 0.1,
169                  "max_new_tokens": 400,
170              },
171          ),
172          # Case 5: Raw dict input
173          _TestCase(
174              data={
175                  "messages": [
176                      {"role": "user", "content": "Hello!"},
177                      {"role": "assistant", "content": "Hi!"},
178                  ],
179                  "temperature": 0.1,
180                  "max_tokens": 100,
181                  "stop": ["foo", "bar"],
182              },
183              params={},
184              expected_data=[
185                  [
186                      {"role": "user", "content": "Hello!"},
187                      {"role": "assistant", "content": "Hi!"},
188                  ]
189              ],
190              expected_params={
191                  "temperature": 0.1,
192                  "max_new_tokens": 100,
193              },
194          ),
195      ],
196  )
197  def test_preprocess_llm_inference_input(case):
198      task = "llm/v1/completions" if "prompt" in case.data else "llm/v1/chat"
199      flavor_config = {"inference_task": task, "source_model_name": "test"}
200  
201      with mock.patch(
202          "mlflow.transformers.llm_inference_utils._get_stopping_criteria"
203      ) as mock_get_stopping_criteria:
204          data, params = preprocess_llm_inference_input(case.data, case.params, flavor_config)
205  
206      # Test that OpenAI params are separated from data and replaced with Hugging Face params
207      assert data == case.expected_data
208      if "stopping_criteria" in params:
209          assert params.pop("stopping_criteria") is not None
210          mock_get_stopping_criteria.assert_called_once_with(["foo", "bar"], "test")
211      assert params == case.expected_params
212  
213  
214  def test_preprocess_llm_inference_input_raise_if_key_invalid():
215      # Missing input key
216      with pytest.raises(MlflowException, match=r"Transformer model saved with"):
217          preprocess_llm_inference_input(
218              pd.DataFrame({"invalid_key": [1, 2, 3]}),
219              flavor_config={"inference_task": "llm/v1/completions"},
220          )
221  
222      # Unmatched key (should be "messages" for chat task)
223      with pytest.raises(MlflowException, match=r"Transformer model saved with"):
224          preprocess_llm_inference_input(
225              pd.DataFrame({"prompt": ["Hi"]}), flavor_config={"inference_task": "llm/v1/chat"}
226          )
227  
228  
229  def test_stopping_criteria():
230      with mock.patch("transformers.AutoTokenizer.from_pretrained") as mock_from_pretrained:
231          mock_from_pretrained.return_value = DummyTokenizer()
232  
233          stopping_criteria = _get_stopping_criteria(stop=None, model_name=None)
234          assert stopping_criteria is None
235  
236          input_ids = torch.tensor([[1, 2, 3, 4, 5]])
237          scores = torch.ones(1, 5)
238  
239          stopping_criteria = _get_stopping_criteria(stop="5", model_name="my/model")
240          stopping_criteria_matches = [f(input_ids, scores) for f in stopping_criteria]
241          assert stopping_criteria_matches == [True, True]
242  
243          stopping_criteria = _get_stopping_criteria(stop=["100", "5"], model_name="my/model")
244          stopping_criteria_matches = [f(input_ids, scores) for f in stopping_criteria]
245          assert stopping_criteria_matches == [False, False, True, True]
246  
247  
248  def test_output_dict_for_completions():
249      prompt = "1 2 3"
250      output_tensor = [1, 2, 3, 4, 5]
251      flavor_config = {"source_model_name": "gpt2"}
252      model_config = {"max_new_tokens": 2}
253      inference_task = "llm/v1/completions"
254  
255      pipeline = mock.MagicMock()
256      pipeline.tokenizer = DummyTokenizer()
257  
258      output_dict = _get_output_and_usage_from_tensor(
259          prompt, output_tensor, pipeline, flavor_config, model_config, inference_task
260      )
261  
262      # Test UUID validity
263      uuid.UUID(output_dict["id"])
264  
265      assert output_dict["object"] == "text_completion"
266      assert output_dict["model"] == "gpt2"
267  
268      assert output_dict["choices"][0]["text"] == "4 5"
269      assert output_dict["choices"][0]["finish_reason"] == "length"
270  
271      usage = output_dict["usage"]
272      assert usage["prompt_tokens"] + usage["completion_tokens"] == usage["total_tokens"]
273  
274  
275  def test_token_usage():
276      prompt = "1 2 3"
277      output_tensor = [1, 2, 3, 4, 5]
278  
279      pipeline = mock.MagicMock()
280      pipeline.tokenizer = DummyTokenizer()
281  
282      usage = _get_token_usage(prompt, output_tensor, pipeline, {})
283      assert usage["prompt_tokens"] == 3
284      assert usage["completion_tokens"] == 2
285      assert usage["total_tokens"] == 5
286  
287  
288  def test_finish_reason():
289      assert _get_finish_reason(total_tokens=20, completion_tokens=10, model_config={}) == "stop"
290  
291      assert (
292          _get_finish_reason(
293              total_tokens=20, completion_tokens=10, model_config={"max_new_tokens": 10}
294          )
295          == "length"
296      )
297  
298      assert (
299          _get_finish_reason(total_tokens=20, completion_tokens=10, model_config={"max_length": 15})
300          == "length"
301      )
302  
303  
304  @pytest.mark.parametrize(
305      ("inference_task", "expected_task"),
306      [
307          ("llm/v1/completions", "text-generation"),
308          ("llm/v1/chat", "text-generation"),
309          (None, None),
310      ],
311  )
312  def test_default_task_for_llm_inference_task(inference_task, expected_task):
313      assert _get_default_task_for_llm_inference_task(inference_task) == expected_task