test_truncation.py
1 import json 2 from unittest.mock import patch 3 4 import pytest 5 6 from mlflow.entities.trace_data import TraceData 7 from mlflow.entities.trace_info import TraceInfo 8 from mlflow.entities.trace_location import TraceLocation 9 from mlflow.entities.trace_state import TraceState 10 from mlflow.tracing.utils.truncation import _get_truncated_preview, set_request_response_preview 11 12 13 @pytest.fixture(autouse=True) 14 def patch_max_length(): 15 # Patch max length to 50 to make tests faster 16 with patch("mlflow.tracing.utils.truncation._get_max_length", return_value=50): 17 yield 18 19 20 @pytest.mark.parametrize( 21 ("input_str", "expected"), 22 [ 23 ("short string", "short string"), 24 ("{'a': 'b'}", "{'a': 'b'}"), 25 ("start" + "a" * 50, "start" + "a" * 42 + "..."), 26 (None, None), 27 ], 28 ids=["short string", "short json", "long string", "none"], 29 ) 30 def test_truncate_simple_string(input_str, expected): 31 assert _get_truncated_preview(input_str, role="user") == expected 32 33 34 def test_truncate_long_non_message_json(): 35 input_str = json.dumps({ 36 "a": "b" + "a" * 30, 37 "b": "c" + "a" * 30, 38 }) 39 result = _get_truncated_preview(input_str, role="user") 40 assert len(result) == 50 41 assert result.startswith('{"a": "b') 42 43 44 _TEST_MESSAGE_HISTORY = [ 45 {"role": "user", "content": "First"}, 46 {"role": "assistant", "content": "Second"}, 47 {"role": "user", "content": "Third" + "a" * 50}, 48 {"role": "assistant", "content": "Fourth"}, 49 ] 50 51 52 @pytest.mark.parametrize( 53 "input", 54 [ 55 # ChatCompletion API 56 {"messages": _TEST_MESSAGE_HISTORY}, 57 # Responses API 58 {"input": _TEST_MESSAGE_HISTORY}, 59 # Responses Agent 60 {"request": {"input": _TEST_MESSAGE_HISTORY}}, 61 ], 62 ids=["chat_completion", "responses", "responses_agent"], 63 ) 64 def test_truncate_request_messages(input): 65 input_str = json.dumps(input) 66 assert _get_truncated_preview(input_str, role="assistant") == "Fourth" 67 # Long content should be truncated 68 assert _get_truncated_preview(input_str, role="user") == "Third" + "a" * 42 + "..." 69 # If non-existing role is provided, return the last message 70 assert _get_truncated_preview(input_str, role="system") == "Fourth" 71 72 73 def test_truncate_request_choices(): 74 input_str = json.dumps({ 75 "choices": [ 76 { 77 "index": 1, 78 "message": {"role": "assistant", "content": "First" + "a" * 50}, 79 "finish_reason": "stop", 80 }, 81 ], 82 "object": "chat.completions", 83 }) 84 assert _get_truncated_preview(input_str, role="assistant").startswith("First") 85 86 87 def test_truncate_multi_content_messages(): 88 # If text content exists, use it 89 assert ( 90 _get_truncated_preview( 91 json.dumps({ 92 "messages": [{"role": "user", "content": [{"type": "text", "text": "a" * 60}]}] 93 }), 94 role="user", 95 ) 96 == "a" * 47 + "..." 97 ) 98 99 # Ignore non text content 100 assert ( 101 _get_truncated_preview( 102 json.dumps({ 103 "messages": [ 104 { 105 "role": "user", 106 "content": [ 107 {"type": "text", "text": "a" * 60}, 108 {"type": "image", "image_url": "http://example.com/image.jpg"}, 109 ], 110 }, 111 ] 112 }), 113 role="user", 114 ) 115 == "a" * 47 + "..." 116 ) 117 118 # If non-text content exists, truncate the full json as-is 119 assert _get_truncated_preview( 120 json.dumps({ 121 "messages": [ 122 { 123 "role": "user", 124 "content": [ 125 { 126 "type": "image", 127 "image_url": "http://example.com/image.jpg" + "a" * 50, 128 } 129 ], 130 }, 131 ] 132 }), 133 role="user", 134 ).startswith('{"messages":') 135 136 137 def test_truncate_responses_api_output(): 138 input_str = json.dumps({ 139 "output": [ 140 { 141 "type": "message", 142 "id": "test", 143 "role": "assistant", 144 "content": [{"type": "output_text", "text": "a" * 60}], 145 } 146 ], 147 }) 148 149 assert _get_truncated_preview(input_str, role="assistant") == "a" * 47 + "..." 150 151 152 @pytest.mark.parametrize( 153 "input_data", 154 [ 155 {"messages": 123, "long_data": "a" * 50}, 156 {"messages": []}, 157 {"input": "string"}, 158 {"output": 123}, 159 {"choices": {"0": "value"}}, 160 {"request": "string"}, 161 {"choices": [{"message": "not a dict"}]}, 162 {"choices": [{"message": {"role": "user"}}]}, 163 ], 164 ) 165 def test_truncate_invalid_messages(input_data): 166 input_str = json.dumps(input_data) 167 result = _get_truncated_preview(input_str, role="user") 168 if "long_data" in input_data: 169 assert len(result) == 50 170 assert result.startswith(input_str[:20]) 171 else: 172 assert result == input_str 173 174 175 @pytest.mark.parametrize( 176 ("request_data", "expected_content", "should_not_contain"), 177 [ 178 ( 179 {"request": {"input": [{"role": "user", "content": "Hello"}]}}, 180 "Hello", 181 "request", 182 ), 183 ( 184 {"request": {"tool_choice": None, "input": [{"role": "user", "content": "Weather?"}]}}, 185 "Weather?", 186 '"tool_choice"', 187 ), 188 ( 189 {"request": {"input": [{"role": "user", "content": "Hi"}]}}, 190 "Hi", 191 '"request"', 192 ), 193 ], 194 ids=["short_structured_json", "agent_format_with_null_fields", "responses_agent_short"], 195 ) 196 def test_truncate_structured_json_extracts_content( 197 request_data, expected_content, should_not_contain 198 ): 199 input_str = json.dumps(request_data) 200 result = _get_truncated_preview(input_str, role="user") 201 assert result == expected_content 202 assert should_not_contain not in result 203 204 205 @pytest.mark.parametrize( 206 ("content_value", "expected_in_result"), 207 [ 208 (None, '"content": null'), 209 ("", '"content": ""'), 210 (123, '"content": 123'), 211 ], 212 ids=["null_content", "empty_string_content", "numeric_content"], 213 ) 214 def test_truncate_invalid_content_falls_back_to_json(content_value, expected_in_result): 215 request_data = {"input": [{"role": "user", "content": content_value}]} 216 input_str = json.dumps(request_data) 217 result = _get_truncated_preview(input_str, role="user") 218 assert expected_in_result in result or result.endswith("...") 219 220 221 def test_set_request_response_preview_skips_none_data(): 222 trace_info = TraceInfo( 223 trace_id="tr-test", 224 trace_location=TraceLocation.from_experiment_id("0"), 225 request_time=1000, 226 state=TraceState.OK, 227 ) 228 trace_data = TraceData(spans=[], request=None, response=None) 229 set_request_response_preview(trace_info, trace_data) 230 231 assert trace_info.request_preview is None 232 assert trace_info.response_preview is None