/ tests / langchain / test_langchain_output_parsers.py
test_langchain_output_parsers.py
  1  import pytest
  2  from langchain_core.messages.base import BaseMessage
  3  from langchain_core.runnables.config import RunnableConfig
  4  
  5  from mlflow.langchain.output_parsers import (
  6      ChatAgentOutputParser,
  7      ChatCompletionOutputParser,
  8      ChatCompletionsOutputParser,
  9      StringResponseOutputParser,
 10  )
 11  from mlflow.types.llm import ChatCompletionChunk
 12  
 13  
 14  def test_chatcompletions_output_parser_parse_response():
 15      parser = ChatCompletionsOutputParser()
 16      message = "The weather today is"
 17  
 18      parsed_response = parser.parse(message)
 19      assert parsed_response == {
 20          "choices": [
 21              {
 22                  "finish_reason": "stop",
 23                  "index": 0,
 24                  "message": {"content": "The weather today is", "role": "assistant"},
 25              }
 26          ],
 27          "object": "chat.completion",
 28      }
 29  
 30  
 31  def test_chatcompletions_output_parser_is_lc_serializable():
 32      parser = StringResponseOutputParser()
 33      message = "The weather today is"
 34  
 35      parsed_response = parser.parse(message)
 36      assert parsed_response == {"content": "The weather today is"}
 37  
 38  
 39  def test_chatcompletion_output_parser_parse_response():
 40      parser = ChatCompletionOutputParser()
 41      message = "The weather today is"
 42  
 43      parsed_response = parser.parse(message)
 44      assert isinstance(parsed_response["created"], int)
 45      del parsed_response["created"]
 46  
 47      assert parsed_response == {
 48          "choices": [
 49              {
 50                  "finish_reason": "stop",
 51                  "index": 0,
 52                  "message": {
 53                      "content": "The weather today is",
 54                      "role": "assistant",
 55                  },
 56              }
 57          ],
 58          "object": "chat.completion",
 59      }
 60  
 61      streaming_messages = ["The ", "weather ", "today ", "is"]
 62      base_messages = [BaseMessage(content=m, type="test") for m in streaming_messages]
 63      streaming_chunks = parser.transform(base_messages, RunnableConfig())
 64      for i, chunk in enumerate(streaming_chunks):
 65          assert isinstance(chunk["created"], int)
 66          del chunk["created"]
 67          assert chunk == {
 68              "choices": [
 69                  {
 70                      "delta": {
 71                          "content": streaming_messages[i],
 72                          "role": "assistant",
 73                      },
 74                      "index": 0,
 75                  }
 76              ],
 77              "object": "chat.completion.chunk",
 78          }
 79  
 80  
 81  def test_chat_agent_output_parser_parse_response():
 82      parser = ChatAgentOutputParser()
 83      message = "The weather today is"
 84  
 85      parsed_response = parser.parse(message)
 86      assert parsed_response["messages"][0]["id"] is not None
 87      del parsed_response["messages"][0]["id"]
 88      assert parsed_response == {
 89          "messages": [{"content": "The weather today is", "role": "assistant"}],
 90      }
 91  
 92      streaming_messages = ["The ", "weather ", "today ", "is"]
 93      base_messages = [BaseMessage(content=m, type="test", id="1") for m in streaming_messages]
 94      streaming_chunks = parser.transform(base_messages, RunnableConfig())
 95      for i, chunk in enumerate(streaming_chunks):
 96          assert chunk == {
 97              "delta": {"content": streaming_messages[i], "role": "assistant", "id": "1"}
 98          }
 99  
100  
101  async def async_message_generator(messages):
102      for message in messages:
103          yield message
104  
105  
106  @pytest.mark.asyncio
107  async def test_chatcompletion_output_parser_atransform():
108      parser = ChatCompletionOutputParser()
109      streaming_messages = ["The ", "weather ", "today ", "is"]
110      base_messages = [BaseMessage(content=m, type="test") for m in streaming_messages]
111  
112      async_chunks = parser.atransform(async_message_generator(base_messages), RunnableConfig())
113  
114      chunks = [chunk async for chunk in async_chunks]
115  
116      assert len(chunks) == len(streaming_messages)
117  
118      for i, chunk in enumerate(chunks):
119          parsed_chunk = ChatCompletionChunk.from_dict(chunk)
120          assert parsed_chunk.choices[0].delta.content == streaming_messages[i]