test_chat_utils.py
1 from unittest.mock import MagicMock, patch 2 3 import pytest 4 from langchain_core.language_models.chat_models import SimpleChatModel 5 from langchain_core.messages import ( 6 AIMessage, 7 AIMessageChunk, 8 HumanMessage, 9 SystemMessage, 10 ToolMessage, 11 ) 12 from langchain_core.outputs import ChatGenerationChunk 13 from langchain_core.outputs.chat_generation import ChatGeneration 14 from langchain_core.outputs.generation import Generation 15 16 from mlflow.exceptions import MlflowException 17 from mlflow.langchain.utils.chat import ( 18 convert_lc_message_to_chat_message, 19 parse_token_usage, 20 transform_request_json_for_chat_if_necessary, 21 try_transform_response_iter_to_chat_format, 22 try_transform_response_to_chat_format, 23 ) 24 from mlflow.types.chat import ChatMessage, Function 25 from mlflow.types.chat import ToolCall as _ToolCall 26 27 28 @pytest.mark.parametrize( 29 ("message", "expected"), 30 [ 31 ( 32 AIMessage(content="foo", id="123"), 33 ChatMessage(role="assistant", content="foo", id="123"), 34 ), 35 ( 36 ToolMessage(content="foo", tool_call_id="123"), 37 ChatMessage(role="tool", content="foo", tool_call_id="123"), 38 ), 39 ( 40 SystemMessage(content="foo"), 41 ChatMessage(role="system", content="foo"), 42 ), 43 ( 44 HumanMessage(content="foo"), 45 ChatMessage(role="user", content="foo"), 46 ), 47 ], 48 ) 49 def test_convert_lc_message_to_chat_message(message, expected): 50 assert convert_lc_message_to_chat_message(message) == expected 51 52 53 @pytest.mark.parametrize( 54 ("message", "expected"), 55 [ 56 ( 57 AIMessage( 58 content=[ 59 {"type": "text", "text": "Response text"}, 60 {"type": "tool_use", "id": "123", "name": "tool"}, 61 ], 62 tool_calls=[{"id": "123", "name": "tool", "args": {}, "type": "tool_call"}], 63 ), 64 ChatMessage( 65 role="assistant", 66 content=[{"type": "text", "text": "Response text"}], 67 tool_calls=[ 68 _ToolCall( 69 id="123", 70 type="function", 71 function=Function(name="tool", arguments="{}"), 72 ) 73 ], 74 ), 75 ), 76 ( 77 AIMessage( 78 content="", 79 tool_calls=[{"id": "123", "name": "tool_name", "args": {"arg1": "val1"}}], 80 ), 81 ChatMessage( 82 role="assistant", 83 content=None, 84 tool_calls=[ 85 _ToolCall( 86 id="123", 87 type="function", 88 function=Function(name="tool_name", arguments='{"arg1": "val1"}'), 89 ) 90 ], 91 ), 92 ), 93 ], 94 ) 95 def test_convert_lc_message_to_chat_message_tool_calls(message, expected): 96 assert convert_lc_message_to_chat_message(message) == expected 97 98 99 def test_convert_lc_message_to_chat_message_audio_content(): 100 message = HumanMessage( 101 content=[ 102 {"type": "text", "text": "What is this audio?"}, 103 { 104 "type": "audio", 105 "source_type": "base64", 106 "data": "SGVsbG8=", 107 "mime_type": "audio/wav", 108 }, 109 ] 110 ) 111 result = convert_lc_message_to_chat_message(message) 112 assert result.role == "user" 113 assert len(result.content) == 2 114 assert result.content[0].type == "text" 115 assert result.content[0].text == "What is this audio?" 116 assert result.content[1].type == "input_audio" 117 assert result.content[1].input_audio.data == "SGVsbG8=" 118 assert result.content[1].input_audio.format == "wav" 119 120 121 def test_convert_lc_message_to_chat_message_audio_mp3(): 122 message = HumanMessage( 123 content=[ 124 { 125 "type": "audio", 126 "source_type": "base64", 127 "data": "AAAA", 128 "mime_type": "audio/mp3", 129 }, 130 ] 131 ) 132 result = convert_lc_message_to_chat_message(message) 133 assert result.content[0].type == "input_audio" 134 assert result.content[0].input_audio.data == "AAAA" 135 assert result.content[0].input_audio.format == "mp3" 136 137 138 def test_convert_lc_message_to_chat_message_audio_mpeg(): 139 message = HumanMessage( 140 content=[ 141 { 142 "type": "audio", 143 "source_type": "base64", 144 "data": "AAAA", 145 "mime_type": "audio/mpeg", 146 }, 147 ] 148 ) 149 result = convert_lc_message_to_chat_message(message) 150 assert result.content[0].type == "input_audio" 151 assert result.content[0].input_audio.data == "AAAA" 152 assert result.content[0].input_audio.format == "mp3" 153 154 155 def test_convert_lc_message_to_chat_message_string_content_unchanged(): 156 message = HumanMessage(content="just text") 157 result = convert_lc_message_to_chat_message(message) 158 assert result.content == "just text" 159 160 161 def test_convert_lc_message_audio_url_source_raises(): 162 message = HumanMessage( 163 content=[ 164 { 165 "type": "audio", 166 "source_type": "url", 167 "url": "https://example.com/audio.wav", 168 "mime_type": "audio/wav", 169 }, 170 ] 171 ) 172 with pytest.raises(MlflowException, match="Only base64-encoded audio"): 173 convert_lc_message_to_chat_message(message) 174 175 176 def test_convert_lc_message_audio_no_mime_type_raises(): 177 message = HumanMessage( 178 content=[ 179 { 180 "type": "audio", 181 "source_type": "base64", 182 "data": "SGVsbG8=", 183 }, 184 ] 185 ) 186 with pytest.raises(MlflowException, match="Only base64-encoded audio"): 187 convert_lc_message_to_chat_message(message) 188 189 190 def test_convert_lc_message_audio_unsupported_format_raises(): 191 message = HumanMessage( 192 content=[ 193 { 194 "type": "audio", 195 "source_type": "base64", 196 "data": "SGVsbG8=", 197 "mime_type": "audio/ogg", 198 }, 199 ] 200 ) 201 with pytest.raises(MlflowException, match="Unsupported audio format"): 202 convert_lc_message_to_chat_message(message) 203 204 205 def test_transform_response_to_chat_format_no_conversion(): 206 response = ["list_response"] 207 assert try_transform_response_to_chat_format(response) == response 208 209 response = {"dict_response": "response"} 210 assert try_transform_response_to_chat_format(response) == response 211 212 213 def test_transform_response_to_chat_format_conversion(): 214 response = "string_response" 215 converted_response = try_transform_response_to_chat_format(response) 216 assert isinstance(converted_response, dict) 217 assert converted_response["id"] is None 218 assert converted_response["choices"][0]["message"]["content"] == response 219 220 response = AIMessage(content="ai_message_response") 221 converted_response = try_transform_response_to_chat_format(response) 222 assert isinstance(converted_response, dict) 223 assert converted_response["id"] == getattr(response, "id", None) 224 assert converted_response["choices"][0]["message"]["content"] == response.content 225 226 227 def test_transform_response_iter_to_chat_format_no_conversion(): 228 response = [{"dict_response": "response"}] 229 converted_response = list(try_transform_response_iter_to_chat_format(response)) 230 assert len(converted_response) == 1 231 assert converted_response[0] == response[0] 232 233 234 def test_transform_response_iter_to_chat_format_ai_message(): 235 response = ["string response"] 236 converted_response = list(try_transform_response_iter_to_chat_format(response)) 237 assert len(converted_response) == 1 238 assert converted_response[0]["id"] is None 239 assert converted_response[0]["choices"][0]["delta"]["content"] == response[0] 240 241 response = [ 242 AIMessage( 243 content="ai_message_response", id="123", response_metadata={"finish_reason": "done"} 244 ) 245 ] 246 converted_response = list(try_transform_response_iter_to_chat_format(response)) 247 assert len(converted_response) == 1 248 assert converted_response[0]["id"] == getattr(response[0], "id", None) 249 assert converted_response[0]["choices"][0]["delta"]["content"] == response[0].content 250 assert converted_response[0]["choices"][0]["finish_reason"] == "stop" 251 252 response = [ 253 AIMessageChunk( 254 content="ai_message_chunk_response", 255 id="123", 256 response_metadata={"finish_reason": "done"}, 257 ), 258 AIMessageChunk( 259 content="ai_message_chunk_response", 260 id="456", 261 response_metadata={"finish_reason": "stop"}, 262 ), 263 ] 264 converted_response = list(try_transform_response_iter_to_chat_format(response)) 265 assert len(converted_response) == 2 266 for i in range(2): 267 assert converted_response[i]["id"] == getattr(response[i], "id", None) 268 assert converted_response[i]["choices"][0]["delta"]["content"] == response[i].content 269 assert ( 270 converted_response[i]["choices"][0]["finish_reason"] 271 == response[i].response_metadata["finish_reason"] 272 ) 273 274 275 def test_transform_request_json_for_chat_if_necessary_conversion(): 276 model = MagicMock(spec=SimpleChatModel) 277 request_json = {"messages": [{"role": "user", "content": "some_input"}]} 278 279 with patch("mlflow.langchain.utils.chat._get_lc_model_input_fields", return_value={"messages"}): 280 transformed_request = transform_request_json_for_chat_if_necessary(request_json, model) 281 assert transformed_request == (request_json, False) 282 283 with patch( 284 "mlflow.langchain.utils.chat._get_lc_model_input_fields", 285 return_value={}, 286 ): 287 transformed_request = transform_request_json_for_chat_if_necessary(request_json, model) 288 assert transformed_request[0][0] == HumanMessage(content="some_input") 289 assert transformed_request[1] is True 290 291 request_json = [ 292 {"messages": [{"role": "system", "content": "You are a helpful assistant."}]}, 293 {"messages": [{"role": "assistant", "content": "What would you like to ask?"}]}, 294 {"messages": [{"role": "user", "content": "Who owns MLflow?"}]}, 295 ] 296 with patch( 297 "mlflow.langchain.utils.chat._get_lc_model_input_fields", 298 return_value={}, 299 ): 300 transformed_request = transform_request_json_for_chat_if_necessary(request_json, model) 301 assert transformed_request[0][0][0] == SystemMessage(content="You are a helpful assistant.") 302 assert transformed_request[0][1][0] == AIMessage(content="What would you like to ask?") 303 assert transformed_request[0][2][0] == HumanMessage(content="Who owns MLflow?") 304 assert transformed_request[1] is True 305 306 307 @pytest.mark.parametrize( 308 ("generation", "expected"), 309 [ 310 (ChatGeneration(message=AIMessage(content="foo", id="123")), None), 311 ( 312 ChatGeneration( 313 message=AIMessage( 314 content="foo", 315 id="123", 316 usage_metadata={"input_tokens": 5, "output_tokens": 10, "total_tokens": 15}, 317 ) 318 ), 319 {"input_tokens": 5, "output_tokens": 10, "total_tokens": 15}, 320 ), 321 ( 322 ChatGeneration( 323 message=AIMessageChunk( 324 content="foo", 325 id="123", 326 usage_metadata={"input_tokens": 5, "output_tokens": 10, "total_tokens": 15}, 327 ) 328 ), 329 {"input_tokens": 5, "output_tokens": 10, "total_tokens": 15}, 330 ), 331 ( 332 ChatGeneration( 333 message=AIMessage( 334 content="foo", 335 id="123", 336 response_metadata={ 337 "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15} 338 }, 339 ) 340 ), 341 {"input_tokens": 5, "output_tokens": 10, "total_tokens": 15}, 342 ), 343 # OpenAI usage_metadata with input_token_details (LangChain standardized format) 344 ( 345 ChatGeneration( 346 message=AIMessage( 347 content="foo", 348 id="123", 349 usage_metadata={ 350 "input_tokens": 50, 351 "output_tokens": 20, 352 "total_tokens": 70, 353 "input_token_details": {"cache_read": 30, "cache_creation": 0}, 354 }, 355 ) 356 ), 357 { 358 "input_tokens": 50, 359 "output_tokens": 20, 360 "total_tokens": 70, 361 "cache_read_input_tokens": 30, 362 "cache_creation_input_tokens": 0, 363 }, 364 ), 365 # OpenAI usage_metadata with both cache_read and cache_creation 366 ( 367 ChatGeneration( 368 message=AIMessage( 369 content="foo", 370 id="123", 371 usage_metadata={ 372 "input_tokens": 100, 373 "output_tokens": 50, 374 "total_tokens": 150, 375 "input_token_details": {"cache_read": 25, "cache_creation": 15}, 376 }, 377 ) 378 ), 379 { 380 "input_tokens": 100, 381 "output_tokens": 50, 382 "total_tokens": 150, 383 "cache_read_input_tokens": 25, 384 "cache_creation_input_tokens": 15, 385 }, 386 ), 387 # Raw OpenAI response_metadata with prompt_tokens_details 388 ( 389 ChatGeneration( 390 message=AIMessage( 391 content="foo", 392 id="123", 393 response_metadata={ 394 "token_usage": { 395 "prompt_tokens": 50, 396 "completion_tokens": 20, 397 "total_tokens": 70, 398 "prompt_tokens_details": {"cached_tokens": 30}, 399 } 400 }, 401 ) 402 ), 403 { 404 "input_tokens": 50, 405 "output_tokens": 20, 406 "total_tokens": 70, 407 "cache_read_input_tokens": 30, 408 }, 409 ), 410 # Gemini usage_metadata with cached_content_token_count 411 ( 412 ChatGeneration( 413 message=AIMessage( 414 content="foo", 415 id="123", 416 usage_metadata={ 417 "input_tokens": 50, 418 "output_tokens": 20, 419 "total_tokens": 70, 420 "cached_content_token_count": 30, 421 }, 422 ) 423 ), 424 { 425 "input_tokens": 50, 426 "output_tokens": 20, 427 "total_tokens": 70, 428 "cache_read_input_tokens": 30, 429 }, 430 ), 431 # Legacy completion generation object 432 (Generation(text="foo"), None), 433 ], 434 ) 435 def test_parse_token_usage(generation, expected): 436 assert parse_token_usage([generation]) == expected 437 438 439 def test_parse_token_usage_streaming_chunks(): 440 """ 441 Test that streaming chunks with cumulative token usage are handled correctly. 442 443 In streaming mode, each ChatGenerationChunk contains: 444 - Same input_tokens (repeated for each chunk) 445 - Cumulative output_tokens (increasing with each chunk) 446 447 Expected behavior: Use only the last chunk's usage (final cumulative values) 448 """ 449 # Simulate 3 streaming chunks with same input_tokens but cumulative output_tokens 450 # This matches the pattern observed in real streaming scenarios 451 chunks = [ 452 ChatGenerationChunk( 453 message=AIMessageChunk( 454 content="Agreement", 455 usage_metadata={ 456 "input_tokens": 16049, 457 "output_tokens": 2, 458 "total_tokens": 16051, 459 }, 460 ) 461 ), 462 ChatGenerationChunk( 463 message=AIMessageChunk( 464 content=" ", 465 usage_metadata={ 466 "input_tokens": 16049, 467 "output_tokens": 58, 468 "total_tokens": 16107, 469 }, 470 ) 471 ), 472 ChatGenerationChunk( 473 message=AIMessageChunk( 474 content="", 475 usage_metadata={ 476 "input_tokens": 16049, 477 "output_tokens": 115, 478 "total_tokens": 16164, 479 }, 480 ) 481 ), 482 ] 483 484 result = parse_token_usage(chunks) 485 486 # Should use only the last chunk's usage (final cumulative values) 487 assert result is not None 488 assert result["input_tokens"] == 16049 489 assert result["output_tokens"] == 115 490 assert result["total_tokens"] == 16164 491 492 493 def test_parse_token_usage_non_streaming_multiple_calls(): 494 """ 495 Test that non-streaming multiple calls still sum correctly (existing behavior). 496 497 When multiple ChatGeneration objects are present (non-streaming), they represent 498 separate LLM calls and should be summed. 499 """ 500 # Simulate 2 separate non-streaming calls with different token usage 501 generations = [ 502 ChatGeneration( 503 message=AIMessage( 504 content="Response 1", 505 usage_metadata={ 506 "input_tokens": 10, 507 "output_tokens": 20, 508 "total_tokens": 30, 509 }, 510 ) 511 ), 512 ChatGeneration( 513 message=AIMessage( 514 content="Response 2", 515 usage_metadata={ 516 "input_tokens": 15, 517 "output_tokens": 25, 518 "total_tokens": 40, 519 }, 520 ) 521 ), 522 ] 523 524 result = parse_token_usage(generations) 525 526 # Should sum all generations (existing non-streaming behavior) 527 assert result is not None 528 assert result["input_tokens"] == 25 # 10 + 15 529 assert result["output_tokens"] == 45 # 20 + 25 530 assert result["total_tokens"] == 70 # 30 + 40