/ tests / types / test_genai_types.py
test_genai_types.py
 1  import pytest
 2  from pydantic import ValidationError
 3  
 4  from mlflow.types.chat import ChatCompletionResponse
 5  
 6  
 7  def test_instantiation_chat_completion():
 8      response_structure = {
 9          "id": "1",
10          "object": "1",
11          "created": 1,
12          "model": "model",
13          "choices": [
14              {
15                  "index": 0,
16                  "message": {"role": "user", "content": "hi"},
17                  "finish_reason": None,
18              },
19              {
20                  "index": 1,
21                  "message": {"role": "user", "content": "there"},
22                  "finish_reason": "STOP",
23              },
24          ],
25          "usage": {"prompt_tokens": 12, "completion_tokens": 22, "total_tokens": 34},
26      }
27  
28      response = ChatCompletionResponse(**response_structure)
29  
30      assert response.id == "1"
31      assert response.object == "1"
32      assert response.created == 1
33      assert response.model == "model"
34      assert len(response.choices) == 2
35      assert response.choices[0].index == 0
36      assert response.choices[0].message.content == "hi"
37      assert response.choices[1].finish_reason == "STOP"
38      assert response.usage.prompt_tokens == 12
39      assert response.usage.completion_tokens == 22
40      assert response.usage.total_tokens == 34
41  
42  
43  def test_invalid_chat_completion():
44      invalid_response_structure = {
45          "id": "1",
46          "model": "model",
47          "choices": [
48              {
49                  "index": 0,
50                  "message": {"role": "user", "content": "hi"},
51              }
52          ],
53          "usage": {"prompt_tokens": 12, "completion_tokens": 22, "total_tokens": 34},
54      }
55  
56      with pytest.raises(ValidationError, match="1 validation error for ChatCompletionResponse"):
57          ChatCompletionResponse(**invalid_response_structure)