/ tests / genai / utils / test_message_utils.py
test_message_utils.py
  1  import pydantic
  2  import pytest
  3  
  4  from mlflow.genai.utils.message_utils import (
  5      pydantic_to_response_format,
  6      serialize_messages_to_prompts,
  7  )
  8  from mlflow.types.llm import ChatMessage, FunctionToolCallArguments, ToolCall
  9  
 10  
 11  @pytest.mark.parametrize(
 12      ("messages", "expected_user_prompt", "expected_system_prompt"),
 13      [
 14          # Basic user message (object)
 15          (
 16              [ChatMessage(role="user", content="Hello")],
 17              "Hello",
 18              None,
 19          ),
 20          # Basic user message (dict)
 21          (
 22              [{"role": "user", "content": "Hello"}],
 23              "Hello",
 24              None,
 25          ),
 26          # System + user messages (object)
 27          (
 28              [
 29                  ChatMessage(role="system", content="You are helpful."),
 30                  ChatMessage(role="user", content="Hello"),
 31              ],
 32              "Hello",
 33              "You are helpful.",
 34          ),
 35          # System + user messages (dict)
 36          (
 37              [
 38                  {"role": "system", "content": "You are helpful."},
 39                  {"role": "user", "content": "Hello"},
 40              ],
 41              "Hello",
 42              "You are helpful.",
 43          ),
 44          # Multiple user messages (object)
 45          (
 46              [
 47                  ChatMessage(role="user", content="First"),
 48                  ChatMessage(role="user", content="Second"),
 49              ],
 50              "First\n\nSecond",
 51              None,
 52          ),
 53          # Multiple user messages (dict)
 54          (
 55              [
 56                  {"role": "user", "content": "First"},
 57                  {"role": "user", "content": "Second"},
 58              ],
 59              "First\n\nSecond",
 60              None,
 61          ),
 62          # Empty messages
 63          (
 64              [],
 65              "",
 66              None,
 67          ),
 68      ],
 69      ids=[
 70          "basic_user_object",
 71          "basic_user_dict",
 72          "system_user_object",
 73          "system_user_dict",
 74          "multiple_users_object",
 75          "multiple_users_dict",
 76          "empty_messages",
 77      ],
 78  )
 79  def test_serialize_messages_basic(messages, expected_user_prompt, expected_system_prompt):
 80      user_prompt, system_prompt = serialize_messages_to_prompts(messages)
 81      assert user_prompt == expected_user_prompt
 82      assert system_prompt == expected_system_prompt
 83  
 84  
 85  def test_assistant_message_with_content_object():
 86      messages = [
 87          ChatMessage(role="user", content="Hello"),
 88          ChatMessage(role="assistant", content="Hi there!"),
 89      ]
 90      user_prompt, system_prompt = serialize_messages_to_prompts(messages)
 91      assert user_prompt == "Hello\n\nAssistant: Hi there!"
 92      assert system_prompt is None
 93  
 94  
 95  def test_assistant_message_with_content_dict():
 96      messages = [
 97          {"role": "user", "content": "Hello"},
 98          {"role": "assistant", "content": "Hi there!"},
 99      ]
100      user_prompt, system_prompt = serialize_messages_to_prompts(messages)
101      assert user_prompt == "Hello\n\nAssistant: Hi there!"
102      assert system_prompt is None
103  
104  
105  def test_assistant_message_with_tool_calls():
106      tool_call = ToolCall(
107          function=FunctionToolCallArguments(name="search", arguments='{"query": "test"}')
108      )
109      messages = [
110          ChatMessage(role="user", content="Search for info"),
111          ChatMessage(role="assistant", tool_calls=[tool_call]),
112      ]
113      user_prompt, system_prompt = serialize_messages_to_prompts(messages)
114      assert user_prompt == "Search for info\n\nAssistant: [Called tools]"
115      assert system_prompt is None
116  
117  
118  def test_assistant_message_with_tool_calls_dict():
119      messages = [
120          {"role": "user", "content": "Search for info"},
121          {"role": "assistant", "content": None, "tool_calls": [{"id": "1", "function": {}}]},
122      ]
123      user_prompt, system_prompt = serialize_messages_to_prompts(messages)
124      assert user_prompt == "Search for info\n\nAssistant: [Called tools]"
125      assert system_prompt is None
126  
127  
128  def test_tool_message_with_name_object():
129      messages = [
130          ChatMessage(role="user", content="Search"),
131          ChatMessage(role="tool", name="search_tool", content='{"results": ["a", "b"]}'),
132      ]
133      user_prompt, system_prompt = serialize_messages_to_prompts(messages)
134      assert user_prompt == 'Search\n\nTool search_tool: {"results": ["a", "b"]}'
135      assert system_prompt is None
136  
137  
138  def test_tool_message_with_name_dict():
139      messages = [
140          {"role": "user", "content": "Search"},
141          {"role": "tool", "name": "search_tool", "content": '{"results": ["a", "b"]}'},
142      ]
143      user_prompt, system_prompt = serialize_messages_to_prompts(messages)
144      assert user_prompt == 'Search\n\nTool search_tool: {"results": ["a", "b"]}'
145      assert system_prompt is None
146  
147  
148  def test_tool_message_without_name_dict():
149      messages = [
150          {"role": "user", "content": "Hello"},
151          {"role": "tool", "content": "Tool result"},
152      ]
153      user_prompt, system_prompt = serialize_messages_to_prompts(messages)
154      assert user_prompt == "Hello\n\ntool: Tool result"
155      assert system_prompt is None
156  
157  
158  def test_custom_role_dict():
159      messages = [
160          {"role": "user", "content": "Hello"},
161          {"role": "developer", "content": "Custom message"},
162      ]
163      user_prompt, system_prompt = serialize_messages_to_prompts(messages)
164      assert user_prompt == "Hello\n\ndeveloper: Custom message"
165      assert system_prompt is None
166  
167  
168  def test_full_conversation_object():
169      tool_call = ToolCall(
170          function=FunctionToolCallArguments(name="search", arguments='{"query": "test"}')
171      )
172      messages = [
173          ChatMessage(role="system", content="Be helpful"),
174          ChatMessage(role="user", content="Query"),
175          ChatMessage(role="assistant", content="Response"),
176          ChatMessage(role="user", content="Search please"),
177          ChatMessage(role="assistant", tool_calls=[tool_call]),
178          ChatMessage(role="tool", name="search", content="Results"),
179          ChatMessage(role="user", content="Follow-up"),
180      ]
181      user_prompt, system_prompt = serialize_messages_to_prompts(messages)
182      expected = (
183          "Query\n\nAssistant: Response\n\nSearch please\n\n"
184          "Assistant: [Called tools]\n\nTool search: Results\n\nFollow-up"
185      )
186      assert user_prompt == expected
187      assert system_prompt == "Be helpful"
188  
189  
190  def test_full_conversation_dict():
191      messages = [
192          {"role": "system", "content": "Be helpful"},
193          {"role": "user", "content": "Query"},
194          {"role": "assistant", "content": "Response"},
195          {"role": "user", "content": "Follow-up"},
196      ]
197      user_prompt, system_prompt = serialize_messages_to_prompts(messages)
198      assert user_prompt == "Query\n\nAssistant: Response\n\nFollow-up"
199      assert system_prompt == "Be helpful"
200  
201  
202  def test_pydantic_to_response_format():
203      class MySchema(pydantic.BaseModel):
204          name: str
205          score: int
206  
207      result = pydantic_to_response_format(MySchema)
208  
209      assert result["type"] == "json_schema"
210      assert result["json_schema"]["name"] == "MySchema"
211      schema = result["json_schema"]["schema"]
212      assert "name" in schema["properties"]
213      assert "score" in schema["properties"]