test_chat_model_validation.py
1 import pytest 2 3 from mlflow.types.llm import ( 4 ChatChoice, 5 ChatCompletionRequest, 6 ChatCompletionResponse, 7 ChatMessage, 8 TokenUsageStats, 9 ) 10 11 MOCK_RESPONSE = { 12 "id": "123", 13 "object": "chat.completion", 14 "created": 1677652288, 15 "model": "MyChatModel", 16 "choices": [ 17 { 18 "index": 0, 19 "message": { 20 "role": "assistant", 21 "content": "hello", 22 }, 23 "finish_reason": "stop", 24 }, 25 { 26 "index": 1, 27 "message": { 28 "role": "user", 29 "content": "world", 30 }, 31 "finish_reason": "stop", 32 }, 33 ], 34 "usage": { 35 "prompt_tokens": 10, 36 "completion_tokens": 10, 37 "total_tokens": 20, 38 }, 39 } 40 41 MOCK_OPENAI_CHAT_COMPLETION_RESPONSE = { 42 "id": "chatcmpl-123", 43 "object": "chat.completion", 44 "created": 1702685778, 45 "model": "gpt-4o-mini", 46 "choices": [ 47 { 48 "index": 0, 49 "message": {"role": "assistant", "content": "Hello! How can I assist you today?"}, 50 "logprobs": { 51 "content": [ 52 { 53 "token": "Hello", 54 "logprob": -0.31725305, 55 "bytes": [72, 101, 108, 108, 111], 56 "top_logprobs": [ 57 { 58 "token": "Hello", 59 "logprob": -0.31725305, 60 "bytes": [72, 101, 108, 108, 111], 61 }, 62 {"token": "Hi", "logprob": -1.3190403, "bytes": [72, 105]}, 63 ], 64 }, 65 { 66 "token": "!", 67 "logprob": -0.02380986, 68 "bytes": None, 69 "top_logprobs": [ 70 {"token": "!", "logprob": -0.02380986, "bytes": [33]}, 71 { 72 "token": " there", 73 "logprob": -3.787621, 74 "bytes": None, 75 }, 76 ], 77 }, 78 ] 79 }, 80 "finish_reason": "stop", 81 }, 82 { 83 "index": 1, 84 "message": {"role": "user", "content": "I need help with my computer."}, 85 "logprobs": None, 86 "finish_reason": "stop", 87 }, 88 { 89 "index": 2, 90 "message": {"role": "assistant", "content": "Sure! What seems to be the problem?"}, 91 "logprobs": { 92 "content": None, 93 }, 94 "finish_reason": "stop", 95 }, 96 ], 97 "usage": {"prompt_tokens": 9, "completion_tokens": 9, "total_tokens": 18}, 98 } 99 100 101 MOCK_OPENAI_CHAT_REFUSAL_RESPONSE = { 102 "id": "chatcmpl-123", 103 "object": "chat.completion", 104 "created": 1721596428, 105 "model": "gpt-4o-mini", 106 "choices": [ 107 { 108 "index": 0, 109 "message": { 110 "role": "assistant", 111 "refusal": "I'm sorry, I cannot assist with that request.", 112 }, 113 "logprobs": None, 114 "finish_reason": "stop", 115 } 116 ], 117 "usage": {"prompt_tokens": 81, "completion_tokens": 11, "total_tokens": 92}, 118 } 119 120 121 @pytest.mark.parametrize( 122 ("data", "error", "match"), 123 [ 124 ({"content": "hello"}, TypeError, "required positional argument"), # missing required field 125 ( 126 {"role": "user", "content": "hello", "name": 1}, 127 ValueError, 128 "`name` must be of type str", 129 ), # field of wrong type 130 ( 131 {"role": "user", "refusal": "I can't answer that.", "content": "hi"}, 132 ValueError, 133 "Both `content` and `refusal` cannot be set", 134 ), # conflicting schema 135 ( 136 {"role": "user", "name": "name"}, 137 ValueError, 138 "`content` is required", 139 ), # missing one-of required field 140 ], 141 ) 142 def test_chat_message_throws_on_invalid_data(data, error, match): 143 with pytest.raises(error, match=match): 144 ChatMessage.from_dict(data) 145 146 147 @pytest.mark.parametrize( 148 "data", 149 [ 150 {"role": "user", "content": "hello"}, 151 {"role": "user", "content": "hello", "name": "world"}, 152 ], 153 ) 154 def test_chat_message_succeeds_on_valid_data(data): 155 assert ChatMessage.from_dict(data).to_dict() == data 156 157 158 @pytest.mark.parametrize( 159 ("data", "match"), 160 [ 161 ({"messages": "not a list"}, "`messages` must be a list"), 162 ( 163 {"messages": ["not a dict"]}, 164 "Items in `messages` must all have the same type: ChatMessage or dict", 165 ), 166 ( 167 { 168 "messages": [ 169 {"role": "user", "content": "not all the same"}, 170 ChatMessage.from_dict({"role": "user", "content": "hello"}), 171 ] 172 }, 173 "Items in `messages` must all have the same type: ChatMessage or dict", 174 ), 175 ], 176 ) 177 def test_list_validation_throws_on_invalid_lists(data, match): 178 with pytest.raises(ValueError, match=match): 179 ChatCompletionRequest.from_dict(data) 180 181 182 @pytest.mark.parametrize( 183 "sample_output", 184 [MOCK_RESPONSE, MOCK_OPENAI_CHAT_COMPLETION_RESPONSE, MOCK_OPENAI_CHAT_REFUSAL_RESPONSE], 185 ) 186 def test_dataclass_constructs_nested_types_from_dict(sample_output): 187 response = ChatCompletionResponse.from_dict(sample_output) 188 assert isinstance(response.usage, TokenUsageStats) 189 assert isinstance(response.choices[0], ChatChoice) 190 assert isinstance(response.choices[0].message, ChatMessage) 191 192 193 @pytest.mark.parametrize( 194 "sample_output", 195 [MOCK_RESPONSE, MOCK_OPENAI_CHAT_COMPLETION_RESPONSE, MOCK_OPENAI_CHAT_REFUSAL_RESPONSE], 196 ) 197 def test_to_dict_converts_nested_dataclasses(sample_output): 198 response = ChatCompletionResponse.from_dict(sample_output).to_dict() 199 assert isinstance(response["choices"][0], dict) 200 assert isinstance(response["usage"], dict) 201 assert isinstance(response["choices"][0]["message"], dict) 202 203 204 def test_to_dict_excludes_nones(): 205 response = ChatCompletionResponse.from_dict(MOCK_RESPONSE).to_dict() 206 assert "name" not in response["choices"][0]["message"] 207 208 209 def test_chat_response_defaults(): 210 tokens = TokenUsageStats() 211 message = ChatMessage("user", "Hello") 212 choice = ChatChoice(message) 213 response = ChatCompletionResponse([choice], tokens) 214 215 assert response.usage.prompt_tokens is None 216 assert response.usage.completion_tokens is None 217 assert response.usage.total_tokens is None 218 assert response.model is None 219 assert response.id is None 220 assert response.choices[0].finish_reason == "stop" 221 222 223 @pytest.mark.parametrize( 224 ("custom_inputs", "match"), 225 [ 226 (1, r"Expected `custom_inputs` to be a dictionary, received `int`"), 227 ({1: "example"}, r"received key of type `int` \(key: 1\)"), 228 ], 229 ) 230 def test_chat_request_custom_inputs_must_be_valid_map(custom_inputs, match): 231 message = ChatMessage("user", "Hello") 232 with pytest.raises(ValueError, match=match): 233 ChatCompletionRequest(messages=[message], custom_inputs=custom_inputs) 234 235 236 @pytest.mark.parametrize( 237 ("cls", "data", "match"), 238 [ 239 ( 240 ChatChoice, 241 {"index": 0, "message": 123}, 242 "Expected `message` to be either an instance of `ChatMessage` or a dict", 243 ), 244 ( 245 ChatCompletionResponse, 246 {"choices": [], "usage": 123}, 247 "Expected `usage` to be either an instance of `TokenUsageStats` or a dict", 248 ), 249 ], 250 ) 251 def test_convert_dataclass_throws_on_invalid_data(cls, data, match): 252 with pytest.raises(ValueError, match=match): 253 cls.from_dict(data) 254 255 256 @pytest.mark.parametrize( 257 ("cls", "data"), 258 [ 259 (ChatMessage, {"role": "user", "content": "hello", "extra": "field"}), 260 ( 261 TokenUsageStats, 262 { 263 "completion_tokens": 10, 264 "prompt_tokens": 57, 265 "total_tokens": 67, 266 # this field is not in the TokenUsageStats schema 267 "completion_tokens_details": {"reasoning_tokens": 0}, 268 }, 269 ), 270 ], 271 ) 272 def test_from_dict_ignores_extra_fields(cls, data): 273 assert isinstance(cls.from_dict(data), cls)