/ tests / langchain / test_chat_utils.py
test_chat_utils.py
  1  from unittest.mock import MagicMock, patch
  2  
  3  import pytest
  4  from langchain_core.language_models.chat_models import SimpleChatModel
  5  from langchain_core.messages import (
  6      AIMessage,
  7      AIMessageChunk,
  8      HumanMessage,
  9      SystemMessage,
 10      ToolMessage,
 11  )
 12  from langchain_core.outputs import ChatGenerationChunk
 13  from langchain_core.outputs.chat_generation import ChatGeneration
 14  from langchain_core.outputs.generation import Generation
 15  
 16  from mlflow.exceptions import MlflowException
 17  from mlflow.langchain.utils.chat import (
 18      convert_lc_message_to_chat_message,
 19      parse_token_usage,
 20      transform_request_json_for_chat_if_necessary,
 21      try_transform_response_iter_to_chat_format,
 22      try_transform_response_to_chat_format,
 23  )
 24  from mlflow.types.chat import ChatMessage, Function
 25  from mlflow.types.chat import ToolCall as _ToolCall
 26  
 27  
 28  @pytest.mark.parametrize(
 29      ("message", "expected"),
 30      [
 31          (
 32              AIMessage(content="foo", id="123"),
 33              ChatMessage(role="assistant", content="foo", id="123"),
 34          ),
 35          (
 36              ToolMessage(content="foo", tool_call_id="123"),
 37              ChatMessage(role="tool", content="foo", tool_call_id="123"),
 38          ),
 39          (
 40              SystemMessage(content="foo"),
 41              ChatMessage(role="system", content="foo"),
 42          ),
 43          (
 44              HumanMessage(content="foo"),
 45              ChatMessage(role="user", content="foo"),
 46          ),
 47      ],
 48  )
 49  def test_convert_lc_message_to_chat_message(message, expected):
 50      assert convert_lc_message_to_chat_message(message) == expected
 51  
 52  
 53  @pytest.mark.parametrize(
 54      ("message", "expected"),
 55      [
 56          (
 57              AIMessage(
 58                  content=[
 59                      {"type": "text", "text": "Response text"},
 60                      {"type": "tool_use", "id": "123", "name": "tool"},
 61                  ],
 62                  tool_calls=[{"id": "123", "name": "tool", "args": {}, "type": "tool_call"}],
 63              ),
 64              ChatMessage(
 65                  role="assistant",
 66                  content=[{"type": "text", "text": "Response text"}],
 67                  tool_calls=[
 68                      _ToolCall(
 69                          id="123",
 70                          type="function",
 71                          function=Function(name="tool", arguments="{}"),
 72                      )
 73                  ],
 74              ),
 75          ),
 76          (
 77              AIMessage(
 78                  content="",
 79                  tool_calls=[{"id": "123", "name": "tool_name", "args": {"arg1": "val1"}}],
 80              ),
 81              ChatMessage(
 82                  role="assistant",
 83                  content=None,
 84                  tool_calls=[
 85                      _ToolCall(
 86                          id="123",
 87                          type="function",
 88                          function=Function(name="tool_name", arguments='{"arg1": "val1"}'),
 89                      )
 90                  ],
 91              ),
 92          ),
 93      ],
 94  )
 95  def test_convert_lc_message_to_chat_message_tool_calls(message, expected):
 96      assert convert_lc_message_to_chat_message(message) == expected
 97  
 98  
 99  def test_convert_lc_message_to_chat_message_audio_content():
100      message = HumanMessage(
101          content=[
102              {"type": "text", "text": "What is this audio?"},
103              {
104                  "type": "audio",
105                  "source_type": "base64",
106                  "data": "SGVsbG8=",
107                  "mime_type": "audio/wav",
108              },
109          ]
110      )
111      result = convert_lc_message_to_chat_message(message)
112      assert result.role == "user"
113      assert len(result.content) == 2
114      assert result.content[0].type == "text"
115      assert result.content[0].text == "What is this audio?"
116      assert result.content[1].type == "input_audio"
117      assert result.content[1].input_audio.data == "SGVsbG8="
118      assert result.content[1].input_audio.format == "wav"
119  
120  
121  def test_convert_lc_message_to_chat_message_audio_mp3():
122      message = HumanMessage(
123          content=[
124              {
125                  "type": "audio",
126                  "source_type": "base64",
127                  "data": "AAAA",
128                  "mime_type": "audio/mp3",
129              },
130          ]
131      )
132      result = convert_lc_message_to_chat_message(message)
133      assert result.content[0].type == "input_audio"
134      assert result.content[0].input_audio.data == "AAAA"
135      assert result.content[0].input_audio.format == "mp3"
136  
137  
138  def test_convert_lc_message_to_chat_message_audio_mpeg():
139      message = HumanMessage(
140          content=[
141              {
142                  "type": "audio",
143                  "source_type": "base64",
144                  "data": "AAAA",
145                  "mime_type": "audio/mpeg",
146              },
147          ]
148      )
149      result = convert_lc_message_to_chat_message(message)
150      assert result.content[0].type == "input_audio"
151      assert result.content[0].input_audio.data == "AAAA"
152      assert result.content[0].input_audio.format == "mp3"
153  
154  
155  def test_convert_lc_message_to_chat_message_string_content_unchanged():
156      message = HumanMessage(content="just text")
157      result = convert_lc_message_to_chat_message(message)
158      assert result.content == "just text"
159  
160  
161  def test_convert_lc_message_audio_url_source_raises():
162      message = HumanMessage(
163          content=[
164              {
165                  "type": "audio",
166                  "source_type": "url",
167                  "url": "https://example.com/audio.wav",
168                  "mime_type": "audio/wav",
169              },
170          ]
171      )
172      with pytest.raises(MlflowException, match="Only base64-encoded audio"):
173          convert_lc_message_to_chat_message(message)
174  
175  
176  def test_convert_lc_message_audio_no_mime_type_raises():
177      message = HumanMessage(
178          content=[
179              {
180                  "type": "audio",
181                  "source_type": "base64",
182                  "data": "SGVsbG8=",
183              },
184          ]
185      )
186      with pytest.raises(MlflowException, match="Only base64-encoded audio"):
187          convert_lc_message_to_chat_message(message)
188  
189  
190  def test_convert_lc_message_audio_unsupported_format_raises():
191      message = HumanMessage(
192          content=[
193              {
194                  "type": "audio",
195                  "source_type": "base64",
196                  "data": "SGVsbG8=",
197                  "mime_type": "audio/ogg",
198              },
199          ]
200      )
201      with pytest.raises(MlflowException, match="Unsupported audio format"):
202          convert_lc_message_to_chat_message(message)
203  
204  
205  def test_transform_response_to_chat_format_no_conversion():
206      response = ["list_response"]
207      assert try_transform_response_to_chat_format(response) == response
208  
209      response = {"dict_response": "response"}
210      assert try_transform_response_to_chat_format(response) == response
211  
212  
213  def test_transform_response_to_chat_format_conversion():
214      response = "string_response"
215      converted_response = try_transform_response_to_chat_format(response)
216      assert isinstance(converted_response, dict)
217      assert converted_response["id"] is None
218      assert converted_response["choices"][0]["message"]["content"] == response
219  
220      response = AIMessage(content="ai_message_response")
221      converted_response = try_transform_response_to_chat_format(response)
222      assert isinstance(converted_response, dict)
223      assert converted_response["id"] == getattr(response, "id", None)
224      assert converted_response["choices"][0]["message"]["content"] == response.content
225  
226  
227  def test_transform_response_iter_to_chat_format_no_conversion():
228      response = [{"dict_response": "response"}]
229      converted_response = list(try_transform_response_iter_to_chat_format(response))
230      assert len(converted_response) == 1
231      assert converted_response[0] == response[0]
232  
233  
234  def test_transform_response_iter_to_chat_format_ai_message():
235      response = ["string response"]
236      converted_response = list(try_transform_response_iter_to_chat_format(response))
237      assert len(converted_response) == 1
238      assert converted_response[0]["id"] is None
239      assert converted_response[0]["choices"][0]["delta"]["content"] == response[0]
240  
241      response = [
242          AIMessage(
243              content="ai_message_response", id="123", response_metadata={"finish_reason": "done"}
244          )
245      ]
246      converted_response = list(try_transform_response_iter_to_chat_format(response))
247      assert len(converted_response) == 1
248      assert converted_response[0]["id"] == getattr(response[0], "id", None)
249      assert converted_response[0]["choices"][0]["delta"]["content"] == response[0].content
250      assert converted_response[0]["choices"][0]["finish_reason"] == "stop"
251  
252      response = [
253          AIMessageChunk(
254              content="ai_message_chunk_response",
255              id="123",
256              response_metadata={"finish_reason": "done"},
257          ),
258          AIMessageChunk(
259              content="ai_message_chunk_response",
260              id="456",
261              response_metadata={"finish_reason": "stop"},
262          ),
263      ]
264      converted_response = list(try_transform_response_iter_to_chat_format(response))
265      assert len(converted_response) == 2
266      for i in range(2):
267          assert converted_response[i]["id"] == getattr(response[i], "id", None)
268          assert converted_response[i]["choices"][0]["delta"]["content"] == response[i].content
269          assert (
270              converted_response[i]["choices"][0]["finish_reason"]
271              == response[i].response_metadata["finish_reason"]
272          )
273  
274  
275  def test_transform_request_json_for_chat_if_necessary_conversion():
276      model = MagicMock(spec=SimpleChatModel)
277      request_json = {"messages": [{"role": "user", "content": "some_input"}]}
278  
279      with patch("mlflow.langchain.utils.chat._get_lc_model_input_fields", return_value={"messages"}):
280          transformed_request = transform_request_json_for_chat_if_necessary(request_json, model)
281          assert transformed_request == (request_json, False)
282  
283      with patch(
284          "mlflow.langchain.utils.chat._get_lc_model_input_fields",
285          return_value={},
286      ):
287          transformed_request = transform_request_json_for_chat_if_necessary(request_json, model)
288          assert transformed_request[0][0] == HumanMessage(content="some_input")
289          assert transformed_request[1] is True
290  
291      request_json = [
292          {"messages": [{"role": "system", "content": "You are a helpful assistant."}]},
293          {"messages": [{"role": "assistant", "content": "What would you like to ask?"}]},
294          {"messages": [{"role": "user", "content": "Who owns MLflow?"}]},
295      ]
296      with patch(
297          "mlflow.langchain.utils.chat._get_lc_model_input_fields",
298          return_value={},
299      ):
300          transformed_request = transform_request_json_for_chat_if_necessary(request_json, model)
301          assert transformed_request[0][0][0] == SystemMessage(content="You are a helpful assistant.")
302          assert transformed_request[0][1][0] == AIMessage(content="What would you like to ask?")
303          assert transformed_request[0][2][0] == HumanMessage(content="Who owns MLflow?")
304          assert transformed_request[1] is True
305  
306  
307  @pytest.mark.parametrize(
308      ("generation", "expected"),
309      [
310          (ChatGeneration(message=AIMessage(content="foo", id="123")), None),
311          (
312              ChatGeneration(
313                  message=AIMessage(
314                      content="foo",
315                      id="123",
316                      usage_metadata={"input_tokens": 5, "output_tokens": 10, "total_tokens": 15},
317                  )
318              ),
319              {"input_tokens": 5, "output_tokens": 10, "total_tokens": 15},
320          ),
321          (
322              ChatGeneration(
323                  message=AIMessageChunk(
324                      content="foo",
325                      id="123",
326                      usage_metadata={"input_tokens": 5, "output_tokens": 10, "total_tokens": 15},
327                  )
328              ),
329              {"input_tokens": 5, "output_tokens": 10, "total_tokens": 15},
330          ),
331          (
332              ChatGeneration(
333                  message=AIMessage(
334                      content="foo",
335                      id="123",
336                      response_metadata={
337                          "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
338                      },
339                  )
340              ),
341              {"input_tokens": 5, "output_tokens": 10, "total_tokens": 15},
342          ),
343          # OpenAI usage_metadata with input_token_details (LangChain standardized format)
344          (
345              ChatGeneration(
346                  message=AIMessage(
347                      content="foo",
348                      id="123",
349                      usage_metadata={
350                          "input_tokens": 50,
351                          "output_tokens": 20,
352                          "total_tokens": 70,
353                          "input_token_details": {"cache_read": 30, "cache_creation": 0},
354                      },
355                  )
356              ),
357              {
358                  "input_tokens": 50,
359                  "output_tokens": 20,
360                  "total_tokens": 70,
361                  "cache_read_input_tokens": 30,
362                  "cache_creation_input_tokens": 0,
363              },
364          ),
365          # OpenAI usage_metadata with both cache_read and cache_creation
366          (
367              ChatGeneration(
368                  message=AIMessage(
369                      content="foo",
370                      id="123",
371                      usage_metadata={
372                          "input_tokens": 100,
373                          "output_tokens": 50,
374                          "total_tokens": 150,
375                          "input_token_details": {"cache_read": 25, "cache_creation": 15},
376                      },
377                  )
378              ),
379              {
380                  "input_tokens": 100,
381                  "output_tokens": 50,
382                  "total_tokens": 150,
383                  "cache_read_input_tokens": 25,
384                  "cache_creation_input_tokens": 15,
385              },
386          ),
387          # Raw OpenAI response_metadata with prompt_tokens_details
388          (
389              ChatGeneration(
390                  message=AIMessage(
391                      content="foo",
392                      id="123",
393                      response_metadata={
394                          "token_usage": {
395                              "prompt_tokens": 50,
396                              "completion_tokens": 20,
397                              "total_tokens": 70,
398                              "prompt_tokens_details": {"cached_tokens": 30},
399                          }
400                      },
401                  )
402              ),
403              {
404                  "input_tokens": 50,
405                  "output_tokens": 20,
406                  "total_tokens": 70,
407                  "cache_read_input_tokens": 30,
408              },
409          ),
410          # Gemini usage_metadata with cached_content_token_count
411          (
412              ChatGeneration(
413                  message=AIMessage(
414                      content="foo",
415                      id="123",
416                      usage_metadata={
417                          "input_tokens": 50,
418                          "output_tokens": 20,
419                          "total_tokens": 70,
420                          "cached_content_token_count": 30,
421                      },
422                  )
423              ),
424              {
425                  "input_tokens": 50,
426                  "output_tokens": 20,
427                  "total_tokens": 70,
428                  "cache_read_input_tokens": 30,
429              },
430          ),
431          # Legacy completion generation object
432          (Generation(text="foo"), None),
433      ],
434  )
435  def test_parse_token_usage(generation, expected):
436      assert parse_token_usage([generation]) == expected
437  
438  
439  def test_parse_token_usage_streaming_chunks():
440      """
441      Test that streaming chunks with cumulative token usage are handled correctly.
442  
443      In streaming mode, each ChatGenerationChunk contains:
444      - Same input_tokens (repeated for each chunk)
445      - Cumulative output_tokens (increasing with each chunk)
446  
447      Expected behavior: Use only the last chunk's usage (final cumulative values)
448      """
449      # Simulate 3 streaming chunks with same input_tokens but cumulative output_tokens
450      # This matches the pattern observed in real streaming scenarios
451      chunks = [
452          ChatGenerationChunk(
453              message=AIMessageChunk(
454                  content="Agreement",
455                  usage_metadata={
456                      "input_tokens": 16049,
457                      "output_tokens": 2,
458                      "total_tokens": 16051,
459                  },
460              )
461          ),
462          ChatGenerationChunk(
463              message=AIMessageChunk(
464                  content=" ",
465                  usage_metadata={
466                      "input_tokens": 16049,
467                      "output_tokens": 58,
468                      "total_tokens": 16107,
469                  },
470              )
471          ),
472          ChatGenerationChunk(
473              message=AIMessageChunk(
474                  content="",
475                  usage_metadata={
476                      "input_tokens": 16049,
477                      "output_tokens": 115,
478                      "total_tokens": 16164,
479                  },
480              )
481          ),
482      ]
483  
484      result = parse_token_usage(chunks)
485  
486      # Should use only the last chunk's usage (final cumulative values)
487      assert result is not None
488      assert result["input_tokens"] == 16049
489      assert result["output_tokens"] == 115
490      assert result["total_tokens"] == 16164
491  
492  
493  def test_parse_token_usage_non_streaming_multiple_calls():
494      """
495      Test that non-streaming multiple calls still sum correctly (existing behavior).
496  
497      When multiple ChatGeneration objects are present (non-streaming), they represent
498      separate LLM calls and should be summed.
499      """
500      # Simulate 2 separate non-streaming calls with different token usage
501      generations = [
502          ChatGeneration(
503              message=AIMessage(
504                  content="Response 1",
505                  usage_metadata={
506                      "input_tokens": 10,
507                      "output_tokens": 20,
508                      "total_tokens": 30,
509                  },
510              )
511          ),
512          ChatGeneration(
513              message=AIMessage(
514                  content="Response 2",
515                  usage_metadata={
516                      "input_tokens": 15,
517                      "output_tokens": 25,
518                      "total_tokens": 40,
519                  },
520              )
521          ),
522      ]
523  
524      result = parse_token_usage(generations)
525  
526      # Should sum all generations (existing non-streaming behavior)
527      assert result is not None
528      assert result["input_tokens"] == 25  # 10 + 15
529      assert result["output_tokens"] == 45  # 20 + 25
530      assert result["total_tokens"] == 70  # 30 + 40