/ tests / tracing / utils / test_truncation.py
test_truncation.py
  1  import json
  2  from unittest.mock import patch
  3  
  4  import pytest
  5  
  6  from mlflow.entities.trace_data import TraceData
  7  from mlflow.entities.trace_info import TraceInfo
  8  from mlflow.entities.trace_location import TraceLocation
  9  from mlflow.entities.trace_state import TraceState
 10  from mlflow.tracing.utils.truncation import _get_truncated_preview, set_request_response_preview
 11  
 12  
 13  @pytest.fixture(autouse=True)
 14  def patch_max_length():
 15      # Patch max length to 50 to make tests faster
 16      with patch("mlflow.tracing.utils.truncation._get_max_length", return_value=50):
 17          yield
 18  
 19  
 20  @pytest.mark.parametrize(
 21      ("input_str", "expected"),
 22      [
 23          ("short string", "short string"),
 24          ("{'a': 'b'}", "{'a': 'b'}"),
 25          ("start" + "a" * 50, "start" + "a" * 42 + "..."),
 26          (None, None),
 27      ],
 28      ids=["short string", "short json", "long string", "none"],
 29  )
 30  def test_truncate_simple_string(input_str, expected):
 31      assert _get_truncated_preview(input_str, role="user") == expected
 32  
 33  
 34  def test_truncate_long_non_message_json():
 35      input_str = json.dumps({
 36          "a": "b" + "a" * 30,
 37          "b": "c" + "a" * 30,
 38      })
 39      result = _get_truncated_preview(input_str, role="user")
 40      assert len(result) == 50
 41      assert result.startswith('{"a": "b')
 42  
 43  
 44  _TEST_MESSAGE_HISTORY = [
 45      {"role": "user", "content": "First"},
 46      {"role": "assistant", "content": "Second"},
 47      {"role": "user", "content": "Third" + "a" * 50},
 48      {"role": "assistant", "content": "Fourth"},
 49  ]
 50  
 51  
 52  @pytest.mark.parametrize(
 53      "input",
 54      [
 55          # ChatCompletion API
 56          {"messages": _TEST_MESSAGE_HISTORY},
 57          # Responses API
 58          {"input": _TEST_MESSAGE_HISTORY},
 59          # Responses Agent
 60          {"request": {"input": _TEST_MESSAGE_HISTORY}},
 61      ],
 62      ids=["chat_completion", "responses", "responses_agent"],
 63  )
 64  def test_truncate_request_messages(input):
 65      input_str = json.dumps(input)
 66      assert _get_truncated_preview(input_str, role="assistant") == "Fourth"
 67      # Long content should be truncated
 68      assert _get_truncated_preview(input_str, role="user") == "Third" + "a" * 42 + "..."
 69      # If non-existing role is provided, return the last message
 70      assert _get_truncated_preview(input_str, role="system") == "Fourth"
 71  
 72  
 73  def test_truncate_request_choices():
 74      input_str = json.dumps({
 75          "choices": [
 76              {
 77                  "index": 1,
 78                  "message": {"role": "assistant", "content": "First" + "a" * 50},
 79                  "finish_reason": "stop",
 80              },
 81          ],
 82          "object": "chat.completions",
 83      })
 84      assert _get_truncated_preview(input_str, role="assistant").startswith("First")
 85  
 86  
 87  def test_truncate_multi_content_messages():
 88      # If text content exists, use it
 89      assert (
 90          _get_truncated_preview(
 91              json.dumps({
 92                  "messages": [{"role": "user", "content": [{"type": "text", "text": "a" * 60}]}]
 93              }),
 94              role="user",
 95          )
 96          == "a" * 47 + "..."
 97      )
 98  
 99      # Ignore non text content
100      assert (
101          _get_truncated_preview(
102              json.dumps({
103                  "messages": [
104                      {
105                          "role": "user",
106                          "content": [
107                              {"type": "text", "text": "a" * 60},
108                              {"type": "image", "image_url": "http://example.com/image.jpg"},
109                          ],
110                      },
111                  ]
112              }),
113              role="user",
114          )
115          == "a" * 47 + "..."
116      )
117  
118      # If non-text content exists, truncate the full json as-is
119      assert _get_truncated_preview(
120          json.dumps({
121              "messages": [
122                  {
123                      "role": "user",
124                      "content": [
125                          {
126                              "type": "image",
127                              "image_url": "http://example.com/image.jpg" + "a" * 50,
128                          }
129                      ],
130                  },
131              ]
132          }),
133          role="user",
134      ).startswith('{"messages":')
135  
136  
137  def test_truncate_responses_api_output():
138      input_str = json.dumps({
139          "output": [
140              {
141                  "type": "message",
142                  "id": "test",
143                  "role": "assistant",
144                  "content": [{"type": "output_text", "text": "a" * 60}],
145              }
146          ],
147      })
148  
149      assert _get_truncated_preview(input_str, role="assistant") == "a" * 47 + "..."
150  
151  
152  @pytest.mark.parametrize(
153      "input_data",
154      [
155          {"messages": 123, "long_data": "a" * 50},
156          {"messages": []},
157          {"input": "string"},
158          {"output": 123},
159          {"choices": {"0": "value"}},
160          {"request": "string"},
161          {"choices": [{"message": "not a dict"}]},
162          {"choices": [{"message": {"role": "user"}}]},
163      ],
164  )
165  def test_truncate_invalid_messages(input_data):
166      input_str = json.dumps(input_data)
167      result = _get_truncated_preview(input_str, role="user")
168      if "long_data" in input_data:
169          assert len(result) == 50
170          assert result.startswith(input_str[:20])
171      else:
172          assert result == input_str
173  
174  
175  @pytest.mark.parametrize(
176      ("request_data", "expected_content", "should_not_contain"),
177      [
178          (
179              {"request": {"input": [{"role": "user", "content": "Hello"}]}},
180              "Hello",
181              "request",
182          ),
183          (
184              {"request": {"tool_choice": None, "input": [{"role": "user", "content": "Weather?"}]}},
185              "Weather?",
186              '"tool_choice"',
187          ),
188          (
189              {"request": {"input": [{"role": "user", "content": "Hi"}]}},
190              "Hi",
191              '"request"',
192          ),
193      ],
194      ids=["short_structured_json", "agent_format_with_null_fields", "responses_agent_short"],
195  )
196  def test_truncate_structured_json_extracts_content(
197      request_data, expected_content, should_not_contain
198  ):
199      input_str = json.dumps(request_data)
200      result = _get_truncated_preview(input_str, role="user")
201      assert result == expected_content
202      assert should_not_contain not in result
203  
204  
205  @pytest.mark.parametrize(
206      ("content_value", "expected_in_result"),
207      [
208          (None, '"content": null'),
209          ("", '"content": ""'),
210          (123, '"content": 123'),
211      ],
212      ids=["null_content", "empty_string_content", "numeric_content"],
213  )
214  def test_truncate_invalid_content_falls_back_to_json(content_value, expected_in_result):
215      request_data = {"input": [{"role": "user", "content": content_value}]}
216      input_str = json.dumps(request_data)
217      result = _get_truncated_preview(input_str, role="user")
218      assert expected_in_result in result or result.endswith("...")
219  
220  
221  def test_set_request_response_preview_skips_none_data():
222      trace_info = TraceInfo(
223          trace_id="tr-test",
224          trace_location=TraceLocation.from_experiment_id("0"),
225          request_time=1000,
226          state=TraceState.OK,
227      )
228      trace_data = TraceData(spans=[], request=None, response=None)
229      set_request_response_preview(trace_info, trace_data)
230  
231      assert trace_info.request_preview is None
232      assert trace_info.response_preview is None