test_langchain_output_parsers.py
1 import pytest 2 from langchain_core.messages.base import BaseMessage 3 from langchain_core.runnables.config import RunnableConfig 4 5 from mlflow.langchain.output_parsers import ( 6 ChatAgentOutputParser, 7 ChatCompletionOutputParser, 8 ChatCompletionsOutputParser, 9 StringResponseOutputParser, 10 ) 11 from mlflow.types.llm import ChatCompletionChunk 12 13 14 def test_chatcompletions_output_parser_parse_response(): 15 parser = ChatCompletionsOutputParser() 16 message = "The weather today is" 17 18 parsed_response = parser.parse(message) 19 assert parsed_response == { 20 "choices": [ 21 { 22 "finish_reason": "stop", 23 "index": 0, 24 "message": {"content": "The weather today is", "role": "assistant"}, 25 } 26 ], 27 "object": "chat.completion", 28 } 29 30 31 def test_chatcompletions_output_parser_is_lc_serializable(): 32 parser = StringResponseOutputParser() 33 message = "The weather today is" 34 35 parsed_response = parser.parse(message) 36 assert parsed_response == {"content": "The weather today is"} 37 38 39 def test_chatcompletion_output_parser_parse_response(): 40 parser = ChatCompletionOutputParser() 41 message = "The weather today is" 42 43 parsed_response = parser.parse(message) 44 assert isinstance(parsed_response["created"], int) 45 del parsed_response["created"] 46 47 assert parsed_response == { 48 "choices": [ 49 { 50 "finish_reason": "stop", 51 "index": 0, 52 "message": { 53 "content": "The weather today is", 54 "role": "assistant", 55 }, 56 } 57 ], 58 "object": "chat.completion", 59 } 60 61 streaming_messages = ["The ", "weather ", "today ", "is"] 62 base_messages = [BaseMessage(content=m, type="test") for m in streaming_messages] 63 streaming_chunks = parser.transform(base_messages, RunnableConfig()) 64 for i, chunk in enumerate(streaming_chunks): 65 assert isinstance(chunk["created"], int) 66 del chunk["created"] 67 assert chunk == { 68 "choices": [ 69 { 70 "delta": { 71 "content": streaming_messages[i], 72 "role": "assistant", 73 }, 74 "index": 0, 75 } 76 ], 77 "object": "chat.completion.chunk", 78 } 79 80 81 def test_chat_agent_output_parser_parse_response(): 82 parser = ChatAgentOutputParser() 83 message = "The weather today is" 84 85 parsed_response = parser.parse(message) 86 assert parsed_response["messages"][0]["id"] is not None 87 del parsed_response["messages"][0]["id"] 88 assert parsed_response == { 89 "messages": [{"content": "The weather today is", "role": "assistant"}], 90 } 91 92 streaming_messages = ["The ", "weather ", "today ", "is"] 93 base_messages = [BaseMessage(content=m, type="test", id="1") for m in streaming_messages] 94 streaming_chunks = parser.transform(base_messages, RunnableConfig()) 95 for i, chunk in enumerate(streaming_chunks): 96 assert chunk == { 97 "delta": {"content": streaming_messages[i], "role": "assistant", "id": "1"} 98 } 99 100 101 async def async_message_generator(messages): 102 for message in messages: 103 yield message 104 105 106 @pytest.mark.asyncio 107 async def test_chatcompletion_output_parser_atransform(): 108 parser = ChatCompletionOutputParser() 109 streaming_messages = ["The ", "weather ", "today ", "is"] 110 base_messages = [BaseMessage(content=m, type="test") for m in streaming_messages] 111 112 async_chunks = parser.atransform(async_message_generator(base_messages), RunnableConfig()) 113 114 chunks = [chunk async for chunk in async_chunks] 115 116 assert len(chunks) == len(streaming_messages) 117 118 for i, chunk in enumerate(chunks): 119 parsed_chunk = ChatCompletionChunk.from_dict(chunk) 120 assert parsed_chunk.choices[0].delta.content == streaming_messages[i]