test_message_utils.py
1 import pydantic 2 import pytest 3 4 from mlflow.genai.utils.message_utils import ( 5 pydantic_to_response_format, 6 serialize_messages_to_prompts, 7 ) 8 from mlflow.types.llm import ChatMessage, FunctionToolCallArguments, ToolCall 9 10 11 @pytest.mark.parametrize( 12 ("messages", "expected_user_prompt", "expected_system_prompt"), 13 [ 14 # Basic user message (object) 15 ( 16 [ChatMessage(role="user", content="Hello")], 17 "Hello", 18 None, 19 ), 20 # Basic user message (dict) 21 ( 22 [{"role": "user", "content": "Hello"}], 23 "Hello", 24 None, 25 ), 26 # System + user messages (object) 27 ( 28 [ 29 ChatMessage(role="system", content="You are helpful."), 30 ChatMessage(role="user", content="Hello"), 31 ], 32 "Hello", 33 "You are helpful.", 34 ), 35 # System + user messages (dict) 36 ( 37 [ 38 {"role": "system", "content": "You are helpful."}, 39 {"role": "user", "content": "Hello"}, 40 ], 41 "Hello", 42 "You are helpful.", 43 ), 44 # Multiple user messages (object) 45 ( 46 [ 47 ChatMessage(role="user", content="First"), 48 ChatMessage(role="user", content="Second"), 49 ], 50 "First\n\nSecond", 51 None, 52 ), 53 # Multiple user messages (dict) 54 ( 55 [ 56 {"role": "user", "content": "First"}, 57 {"role": "user", "content": "Second"}, 58 ], 59 "First\n\nSecond", 60 None, 61 ), 62 # Empty messages 63 ( 64 [], 65 "", 66 None, 67 ), 68 ], 69 ids=[ 70 "basic_user_object", 71 "basic_user_dict", 72 "system_user_object", 73 "system_user_dict", 74 "multiple_users_object", 75 "multiple_users_dict", 76 "empty_messages", 77 ], 78 ) 79 def test_serialize_messages_basic(messages, expected_user_prompt, expected_system_prompt): 80 user_prompt, system_prompt = serialize_messages_to_prompts(messages) 81 assert user_prompt == expected_user_prompt 82 assert system_prompt == expected_system_prompt 83 84 85 def test_assistant_message_with_content_object(): 86 messages = [ 87 ChatMessage(role="user", content="Hello"), 88 ChatMessage(role="assistant", content="Hi there!"), 89 ] 90 user_prompt, system_prompt = serialize_messages_to_prompts(messages) 91 assert user_prompt == "Hello\n\nAssistant: Hi there!" 92 assert system_prompt is None 93 94 95 def test_assistant_message_with_content_dict(): 96 messages = [ 97 {"role": "user", "content": "Hello"}, 98 {"role": "assistant", "content": "Hi there!"}, 99 ] 100 user_prompt, system_prompt = serialize_messages_to_prompts(messages) 101 assert user_prompt == "Hello\n\nAssistant: Hi there!" 102 assert system_prompt is None 103 104 105 def test_assistant_message_with_tool_calls(): 106 tool_call = ToolCall( 107 function=FunctionToolCallArguments(name="search", arguments='{"query": "test"}') 108 ) 109 messages = [ 110 ChatMessage(role="user", content="Search for info"), 111 ChatMessage(role="assistant", tool_calls=[tool_call]), 112 ] 113 user_prompt, system_prompt = serialize_messages_to_prompts(messages) 114 assert user_prompt == "Search for info\n\nAssistant: [Called tools]" 115 assert system_prompt is None 116 117 118 def test_assistant_message_with_tool_calls_dict(): 119 messages = [ 120 {"role": "user", "content": "Search for info"}, 121 {"role": "assistant", "content": None, "tool_calls": [{"id": "1", "function": {}}]}, 122 ] 123 user_prompt, system_prompt = serialize_messages_to_prompts(messages) 124 assert user_prompt == "Search for info\n\nAssistant: [Called tools]" 125 assert system_prompt is None 126 127 128 def test_tool_message_with_name_object(): 129 messages = [ 130 ChatMessage(role="user", content="Search"), 131 ChatMessage(role="tool", name="search_tool", content='{"results": ["a", "b"]}'), 132 ] 133 user_prompt, system_prompt = serialize_messages_to_prompts(messages) 134 assert user_prompt == 'Search\n\nTool search_tool: {"results": ["a", "b"]}' 135 assert system_prompt is None 136 137 138 def test_tool_message_with_name_dict(): 139 messages = [ 140 {"role": "user", "content": "Search"}, 141 {"role": "tool", "name": "search_tool", "content": '{"results": ["a", "b"]}'}, 142 ] 143 user_prompt, system_prompt = serialize_messages_to_prompts(messages) 144 assert user_prompt == 'Search\n\nTool search_tool: {"results": ["a", "b"]}' 145 assert system_prompt is None 146 147 148 def test_tool_message_without_name_dict(): 149 messages = [ 150 {"role": "user", "content": "Hello"}, 151 {"role": "tool", "content": "Tool result"}, 152 ] 153 user_prompt, system_prompt = serialize_messages_to_prompts(messages) 154 assert user_prompt == "Hello\n\ntool: Tool result" 155 assert system_prompt is None 156 157 158 def test_custom_role_dict(): 159 messages = [ 160 {"role": "user", "content": "Hello"}, 161 {"role": "developer", "content": "Custom message"}, 162 ] 163 user_prompt, system_prompt = serialize_messages_to_prompts(messages) 164 assert user_prompt == "Hello\n\ndeveloper: Custom message" 165 assert system_prompt is None 166 167 168 def test_full_conversation_object(): 169 tool_call = ToolCall( 170 function=FunctionToolCallArguments(name="search", arguments='{"query": "test"}') 171 ) 172 messages = [ 173 ChatMessage(role="system", content="Be helpful"), 174 ChatMessage(role="user", content="Query"), 175 ChatMessage(role="assistant", content="Response"), 176 ChatMessage(role="user", content="Search please"), 177 ChatMessage(role="assistant", tool_calls=[tool_call]), 178 ChatMessage(role="tool", name="search", content="Results"), 179 ChatMessage(role="user", content="Follow-up"), 180 ] 181 user_prompt, system_prompt = serialize_messages_to_prompts(messages) 182 expected = ( 183 "Query\n\nAssistant: Response\n\nSearch please\n\n" 184 "Assistant: [Called tools]\n\nTool search: Results\n\nFollow-up" 185 ) 186 assert user_prompt == expected 187 assert system_prompt == "Be helpful" 188 189 190 def test_full_conversation_dict(): 191 messages = [ 192 {"role": "system", "content": "Be helpful"}, 193 {"role": "user", "content": "Query"}, 194 {"role": "assistant", "content": "Response"}, 195 {"role": "user", "content": "Follow-up"}, 196 ] 197 user_prompt, system_prompt = serialize_messages_to_prompts(messages) 198 assert user_prompt == "Query\n\nAssistant: Response\n\nFollow-up" 199 assert system_prompt == "Be helpful" 200 201 202 def test_pydantic_to_response_format(): 203 class MySchema(pydantic.BaseModel): 204 name: str 205 score: int 206 207 result = pydantic_to_response_format(MySchema) 208 209 assert result["type"] == "json_schema" 210 assert result["json_schema"]["name"] == "MySchema" 211 schema = result["json_schema"]["schema"] 212 assert "name" in schema["properties"] 213 assert "score" in schema["properties"]