/ tests / pyfunc / test_responses_agent_validation.py
test_responses_agent_validation.py
  1  import pytest
  2  from pydantic import ValidationError
  3  
  4  from mlflow.types.responses import (
  5      ResponsesAgentRequest,
  6      ResponsesAgentResponse,
  7      ResponsesAgentStreamEvent,
  8      responses_to_cc,
  9      to_chat_completions_input,
 10  )
 11  from mlflow.types.responses_helpers import FunctionCallOutput, Message
 12  
 13  
 14  def test_responses_request_validation():
 15      with pytest.raises(ValueError, match="content.0.text"):
 16          ResponsesAgentRequest(**{
 17              "input": [
 18                  {
 19                      "type": "message",
 20                      "id": "1",
 21                      "status": "completed",
 22                      "role": "assistant",
 23                      "content": [
 24                          {
 25                              "type": "output_text",
 26                          }
 27                      ],
 28                  }
 29              ],
 30          })
 31  
 32      with pytest.raises(ValueError, match="role"):
 33          ResponsesAgentRequest(**{
 34              "input": [
 35                  {
 36                      "type": "message",
 37                      "id": "1",
 38                      "status": "completed",
 39                      "role": "asdf",
 40                      "content": [
 41                          {
 42                              "type": "output_text",
 43                              "text": "asdf",
 44                          }
 45                      ],
 46                  }
 47              ],
 48          })
 49  
 50  
 51  def test_message_content_validation():
 52      # Test that None content is rejected (by Pydantic validation)
 53      with pytest.raises(ValidationError, match="Input should be a valid"):
 54          Message(role="assistant", content=None, type="message")
 55  
 56      # Test that empty string content is allowed
 57      message_empty_str = Message(role="assistant", content="", type="message")
 58      assert message_empty_str.content == ""
 59  
 60      # Test that empty list content is allowed
 61      message_empty_list = Message(role="assistant", content=[], type="message")
 62      assert message_empty_list.content == []
 63  
 64  
 65  def test_responses_response_validation():
 66      with pytest.raises(ValueError, match="output.0.content.0.text"):
 67          ResponsesAgentResponse(**{
 68              "output": [
 69                  {
 70                      "type": "message",
 71                      "id": "1",
 72                      "status": "completed",
 73                      "role": "assistant",
 74                      "content": [
 75                          {
 76                              "type": "output_text",
 77                          }
 78                      ],
 79                  }
 80              ],
 81          })
 82  
 83  
 84  def test_responses_stream_event_validation():
 85      with pytest.raises(ValueError, match="content must not be an empty"):
 86          ResponsesAgentStreamEvent(**{
 87              "type": "response.output_item.done",
 88              "output_index": 0,
 89              "item": {
 90                  "type": "message",
 91                  "status": "in_progress",
 92                  "role": "assistant",
 93                  "content": [],
 94                  "id": "1",
 95              },
 96          })
 97  
 98      with pytest.raises(ValueError, match="Invalid status"):
 99          ResponsesAgentStreamEvent(**{
100              "type": "response.output_item.done",
101              "output_index": 0,
102              "item": {
103                  "type": "message",
104                  "status": "asdf",
105                  "role": "assistant",
106                  "content": [
107                      {
108                          "type": "output_text",
109                          "text": "asdf",
110                      }
111                  ],
112                  "id": "1",
113              },
114          })
115  
116      with pytest.raises(ValueError, match="item.content.0.annotations.0.url"):
117          ResponsesAgentStreamEvent(
118              **{
119                  "type": "response.output_item.done",
120                  "output_index": 1,
121                  "item": {
122                      "type": "message",
123                      "id": "msg_67ed73ed2c288191b0f0f445e21c66540fbd8030171e9b0c",
124                      "status": "completed",
125                      "role": "assistant",
126                      "content": [
127                          {
128                              "type": "output_text",
129                              "text": "On T",
130                              "annotations": [
131                                  {
132                                      "type": "url_citation",
133                                      "start_index": 359,
134                                      "end_index": 492,
135                                      "title": "NBA roundup:",
136                                  },
137                              ],
138                          }
139                      ],
140                  },
141              },
142          )
143      with pytest.raises(ValueError, match="delta"):
144          ResponsesAgentStreamEvent(
145              **{
146                  "type": "response.output_text.delta",
147                  "item_id": "msg_67eda402cba48191a1c35b84af04fc3c0a4363ad71e9395a",
148                  "output_index": 0,
149                  "content_index": 0,
150              },
151          )
152  
153      with pytest.raises(ValueError, match="annotation.url"):
154          ResponsesAgentStreamEvent(
155              **{
156                  "type": "response.output_text.annotation.added",
157                  "item_id": "msg_67ed73ed2c288191b0f0f445e21c66540fbd8030171e9b0c",
158                  "output_index": 1,
159                  "content_index": 0,
160                  "annotation_index": 0,
161                  "annotation": {
162                      "type": "url_citation",
163                      "start_index": 359,
164                      "end_index": 492,
165                      "title": "NBA roundup: Wolves overcome Nikola",
166                  },
167              },
168          )
169  
170  
171  @pytest.mark.parametrize(
172      "output",
173      [
174          "hello",
175          [{"type": "input_text", "text": "result"}],
176          [
177              {
178                  "type": "input_text",
179                  "text": '{"content":{"queryAttachments":[]},"status":"COMPLETED"}',
180              }
181          ],
182      ],
183  )
184  def test_function_call_output_accepts_string_and_list(output):
185      FunctionCallOutput(call_id="c", output=output)
186      ResponsesAgentStreamEvent(
187          type="response.output_item.done",
188          item={"type": "function_call_output", "call_id": "c", "output": output},
189      )
190  
191  
192  @pytest.mark.parametrize(
193      ("output", "expected"),
194      [
195          ("hello", "hello"),
196          ([{"key": "value"}], '[{"key": "value"}]'),
197          ({"a": 1}, '{"a": 1}'),
198          (12345, "12345"),
199      ],
200  )
201  def test_responses_to_cc_stringifies_function_call_output(output, expected):
202      result = responses_to_cc({"type": "function_call_output", "call_id": "c", "output": output})
203      assert result[0]["content"] == expected
204  
205  
206  def test_responses_to_cc_fallback_to_str_on_non_serializable():
207      class NonSerializable:
208          pass
209  
210      result = responses_to_cc({
211          "type": "function_call_output",
212          "call_id": "c",
213          "output": [NonSerializable()],
214      })
215      assert isinstance(result[0]["content"], str)
216  
217  
218  def test_function_call_output_round_trip():
219      raw_item = {
220          "call_id": "toolu_bdrk_017fvUyTS6oaCDYg6GVL3X7j",
221          "output": [{"type": "input_text", "text": '{"status":"COMPLETED"}'}],
222          "type": "function_call_output",
223      }
224      event = ResponsesAgentStreamEvent(type="response.output_item.done", item=raw_item)
225      response_items = [event.item]
226      dumped_items = [
227          item.model_dump() if hasattr(item, "model_dump") else item for item in response_items
228      ]
229      cc_messages = to_chat_completions_input(dumped_items)
230      assert cc_messages[0]["role"] == "tool"
231      assert isinstance(cc_messages[0]["content"], str)
232      assert "input_text" in cc_messages[0]["content"]