/ tests / pyfunc / test_chat_model_validation.py
test_chat_model_validation.py
  1  import pytest
  2  
  3  from mlflow.types.llm import (
  4      ChatChoice,
  5      ChatCompletionRequest,
  6      ChatCompletionResponse,
  7      ChatMessage,
  8      TokenUsageStats,
  9  )
 10  
 11  MOCK_RESPONSE = {
 12      "id": "123",
 13      "object": "chat.completion",
 14      "created": 1677652288,
 15      "model": "MyChatModel",
 16      "choices": [
 17          {
 18              "index": 0,
 19              "message": {
 20                  "role": "assistant",
 21                  "content": "hello",
 22              },
 23              "finish_reason": "stop",
 24          },
 25          {
 26              "index": 1,
 27              "message": {
 28                  "role": "user",
 29                  "content": "world",
 30              },
 31              "finish_reason": "stop",
 32          },
 33      ],
 34      "usage": {
 35          "prompt_tokens": 10,
 36          "completion_tokens": 10,
 37          "total_tokens": 20,
 38      },
 39  }
 40  
 41  MOCK_OPENAI_CHAT_COMPLETION_RESPONSE = {
 42      "id": "chatcmpl-123",
 43      "object": "chat.completion",
 44      "created": 1702685778,
 45      "model": "gpt-4o-mini",
 46      "choices": [
 47          {
 48              "index": 0,
 49              "message": {"role": "assistant", "content": "Hello! How can I assist you today?"},
 50              "logprobs": {
 51                  "content": [
 52                      {
 53                          "token": "Hello",
 54                          "logprob": -0.31725305,
 55                          "bytes": [72, 101, 108, 108, 111],
 56                          "top_logprobs": [
 57                              {
 58                                  "token": "Hello",
 59                                  "logprob": -0.31725305,
 60                                  "bytes": [72, 101, 108, 108, 111],
 61                              },
 62                              {"token": "Hi", "logprob": -1.3190403, "bytes": [72, 105]},
 63                          ],
 64                      },
 65                      {
 66                          "token": "!",
 67                          "logprob": -0.02380986,
 68                          "bytes": None,
 69                          "top_logprobs": [
 70                              {"token": "!", "logprob": -0.02380986, "bytes": [33]},
 71                              {
 72                                  "token": " there",
 73                                  "logprob": -3.787621,
 74                                  "bytes": None,
 75                              },
 76                          ],
 77                      },
 78                  ]
 79              },
 80              "finish_reason": "stop",
 81          },
 82          {
 83              "index": 1,
 84              "message": {"role": "user", "content": "I need help with my computer."},
 85              "logprobs": None,
 86              "finish_reason": "stop",
 87          },
 88          {
 89              "index": 2,
 90              "message": {"role": "assistant", "content": "Sure! What seems to be the problem?"},
 91              "logprobs": {
 92                  "content": None,
 93              },
 94              "finish_reason": "stop",
 95          },
 96      ],
 97      "usage": {"prompt_tokens": 9, "completion_tokens": 9, "total_tokens": 18},
 98  }
 99  
100  
101  MOCK_OPENAI_CHAT_REFUSAL_RESPONSE = {
102      "id": "chatcmpl-123",
103      "object": "chat.completion",
104      "created": 1721596428,
105      "model": "gpt-4o-mini",
106      "choices": [
107          {
108              "index": 0,
109              "message": {
110                  "role": "assistant",
111                  "refusal": "I'm sorry, I cannot assist with that request.",
112              },
113              "logprobs": None,
114              "finish_reason": "stop",
115          }
116      ],
117      "usage": {"prompt_tokens": 81, "completion_tokens": 11, "total_tokens": 92},
118  }
119  
120  
121  @pytest.mark.parametrize(
122      ("data", "error", "match"),
123      [
124          ({"content": "hello"}, TypeError, "required positional argument"),  # missing required field
125          (
126              {"role": "user", "content": "hello", "name": 1},
127              ValueError,
128              "`name` must be of type str",
129          ),  # field of wrong type
130          (
131              {"role": "user", "refusal": "I can't answer that.", "content": "hi"},
132              ValueError,
133              "Both `content` and `refusal` cannot be set",
134          ),  # conflicting schema
135          (
136              {"role": "user", "name": "name"},
137              ValueError,
138              "`content` is required",
139          ),  # missing one-of required field
140      ],
141  )
142  def test_chat_message_throws_on_invalid_data(data, error, match):
143      with pytest.raises(error, match=match):
144          ChatMessage.from_dict(data)
145  
146  
147  @pytest.mark.parametrize(
148      "data",
149      [
150          {"role": "user", "content": "hello"},
151          {"role": "user", "content": "hello", "name": "world"},
152      ],
153  )
154  def test_chat_message_succeeds_on_valid_data(data):
155      assert ChatMessage.from_dict(data).to_dict() == data
156  
157  
158  @pytest.mark.parametrize(
159      ("data", "match"),
160      [
161          ({"messages": "not a list"}, "`messages` must be a list"),
162          (
163              {"messages": ["not a dict"]},
164              "Items in `messages` must all have the same type: ChatMessage or dict",
165          ),
166          (
167              {
168                  "messages": [
169                      {"role": "user", "content": "not all the same"},
170                      ChatMessage.from_dict({"role": "user", "content": "hello"}),
171                  ]
172              },
173              "Items in `messages` must all have the same type: ChatMessage or dict",
174          ),
175      ],
176  )
177  def test_list_validation_throws_on_invalid_lists(data, match):
178      with pytest.raises(ValueError, match=match):
179          ChatCompletionRequest.from_dict(data)
180  
181  
182  @pytest.mark.parametrize(
183      "sample_output",
184      [MOCK_RESPONSE, MOCK_OPENAI_CHAT_COMPLETION_RESPONSE, MOCK_OPENAI_CHAT_REFUSAL_RESPONSE],
185  )
186  def test_dataclass_constructs_nested_types_from_dict(sample_output):
187      response = ChatCompletionResponse.from_dict(sample_output)
188      assert isinstance(response.usage, TokenUsageStats)
189      assert isinstance(response.choices[0], ChatChoice)
190      assert isinstance(response.choices[0].message, ChatMessage)
191  
192  
193  @pytest.mark.parametrize(
194      "sample_output",
195      [MOCK_RESPONSE, MOCK_OPENAI_CHAT_COMPLETION_RESPONSE, MOCK_OPENAI_CHAT_REFUSAL_RESPONSE],
196  )
197  def test_to_dict_converts_nested_dataclasses(sample_output):
198      response = ChatCompletionResponse.from_dict(sample_output).to_dict()
199      assert isinstance(response["choices"][0], dict)
200      assert isinstance(response["usage"], dict)
201      assert isinstance(response["choices"][0]["message"], dict)
202  
203  
204  def test_to_dict_excludes_nones():
205      response = ChatCompletionResponse.from_dict(MOCK_RESPONSE).to_dict()
206      assert "name" not in response["choices"][0]["message"]
207  
208  
209  def test_chat_response_defaults():
210      tokens = TokenUsageStats()
211      message = ChatMessage("user", "Hello")
212      choice = ChatChoice(message)
213      response = ChatCompletionResponse([choice], tokens)
214  
215      assert response.usage.prompt_tokens is None
216      assert response.usage.completion_tokens is None
217      assert response.usage.total_tokens is None
218      assert response.model is None
219      assert response.id is None
220      assert response.choices[0].finish_reason == "stop"
221  
222  
223  @pytest.mark.parametrize(
224      ("custom_inputs", "match"),
225      [
226          (1, r"Expected `custom_inputs` to be a dictionary, received `int`"),
227          ({1: "example"}, r"received key of type `int` \(key: 1\)"),
228      ],
229  )
230  def test_chat_request_custom_inputs_must_be_valid_map(custom_inputs, match):
231      message = ChatMessage("user", "Hello")
232      with pytest.raises(ValueError, match=match):
233          ChatCompletionRequest(messages=[message], custom_inputs=custom_inputs)
234  
235  
236  @pytest.mark.parametrize(
237      ("cls", "data", "match"),
238      [
239          (
240              ChatChoice,
241              {"index": 0, "message": 123},
242              "Expected `message` to be either an instance of `ChatMessage` or a dict",
243          ),
244          (
245              ChatCompletionResponse,
246              {"choices": [], "usage": 123},
247              "Expected `usage` to be either an instance of `TokenUsageStats` or a dict",
248          ),
249      ],
250  )
251  def test_convert_dataclass_throws_on_invalid_data(cls, data, match):
252      with pytest.raises(ValueError, match=match):
253          cls.from_dict(data)
254  
255  
256  @pytest.mark.parametrize(
257      ("cls", "data"),
258      [
259          (ChatMessage, {"role": "user", "content": "hello", "extra": "field"}),
260          (
261              TokenUsageStats,
262              {
263                  "completion_tokens": 10,
264                  "prompt_tokens": 57,
265                  "total_tokens": 67,
266                  # this field is not in the TokenUsageStats schema
267                  "completion_tokens_details": {"reasoning_tokens": 0},
268              },
269          ),
270      ],
271  )
272  def test_from_dict_ignores_extra_fields(cls, data):
273      assert isinstance(cls.from_dict(data), cls)