test_trace.py
1 import importlib.util 2 import json 3 import re 4 from dataclasses import dataclass 5 from datetime import datetime 6 from typing import Any 7 from unittest import mock 8 9 import pytest 10 from pydantic import BaseModel 11 12 import mlflow 13 import mlflow.tracking.context.default_context 14 from mlflow.entities import ( 15 AssessmentSource, 16 Feedback, 17 SpanType, 18 Trace, 19 TraceData, 20 TraceInfo, 21 TraceLocation, 22 ) 23 from mlflow.entities.assessment import Expectation 24 from mlflow.entities.trace_state import TraceState 25 from mlflow.environment_variables import MLFLOW_TRACKING_USERNAME 26 from mlflow.exceptions import MlflowException 27 from mlflow.tracing.constant import TRACE_SCHEMA_VERSION_KEY 28 from mlflow.tracing.utils import TraceJSONEncoder 29 from mlflow.utils.mlflow_tags import MLFLOW_ARTIFACT_LOCATION 30 from mlflow.utils.proto_json_utils import ( 31 milliseconds_to_proto_timestamp, 32 ) 33 34 from tests.tracing.helper import ( 35 V2_TRACE_DICT, 36 create_test_trace_info, 37 create_test_trace_info_with_uc_table, 38 ) 39 40 41 def _test_model(datetime=datetime.now()): 42 class TestModel: 43 @mlflow.trace() 44 def predict(self, x, y): 45 z = x + y 46 z = self.add_one(z) 47 return z # noqa: RET504 48 49 @mlflow.trace( 50 span_type=SpanType.LLM, 51 name="add_one_with_custom_name", 52 attributes={ 53 "delta": 1, 54 "metadata": {"foo": "bar"}, 55 # Test for non-json-serializable input 56 "datetime": datetime, 57 }, 58 ) 59 def add_one(self, z): 60 return z + 1 61 62 return TestModel() 63 64 65 def test_json_deserialization(monkeypatch): 66 monkeypatch.setattr(mlflow.tracking.context.default_context, "_get_source_name", lambda: "test") 67 monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "bob") 68 datetime_now = datetime.now() 69 70 model = _test_model(datetime_now) 71 model.predict(2, 5) 72 73 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 74 trace_json = trace.to_json() 75 76 trace_json_as_dict = json.loads(trace_json) 77 assert trace_json_as_dict == { 78 "info": { 79 "trace_id": trace.info.request_id, 80 "trace_location": { 81 "mlflow_experiment": { 82 "experiment_id": "0", 83 }, 84 "type": "MLFLOW_EXPERIMENT", 85 }, 86 "request_time": milliseconds_to_proto_timestamp(trace.info.timestamp_ms), 87 "execution_duration_ms": trace.info.execution_time_ms, 88 "state": "OK", 89 "request_preview": '{"x": 2, "y": 5}', 90 "response_preview": "8", 91 "trace_metadata": { 92 "mlflow.traceInputs": '{"x": 2, "y": 5}', 93 "mlflow.traceOutputs": "8", 94 "mlflow.source.name": mock.ANY, 95 "mlflow.source.type": "LOCAL", 96 "mlflow.source.git.branch": mock.ANY, 97 "mlflow.source.git.commit": mock.ANY, 98 "mlflow.source.git.repoURL": mock.ANY, 99 "mlflow.user": mock.ANY, 100 "mlflow.trace.sizeBytes": mock.ANY, 101 "mlflow.trace.sizeStats": mock.ANY, 102 "mlflow.trace_schema.version": "3", 103 "mlflow.trace.infoFinalized": "true", 104 }, 105 "tags": { 106 "mlflow.traceName": "predict", 107 "mlflow.artifactLocation": trace.info.tags[MLFLOW_ARTIFACT_LOCATION], 108 "mlflow.trace.spansLocation": mock.ANY, 109 }, 110 }, 111 "data": { 112 "spans": [ 113 { 114 "name": "predict", 115 "trace_id": mock.ANY, 116 "span_id": mock.ANY, 117 "parent_span_id": None, 118 "start_time_unix_nano": trace.data.spans[0].start_time_ns, 119 "end_time_unix_nano": trace.data.spans[0].end_time_ns, 120 "events": [], 121 "status": { 122 "code": "STATUS_CODE_OK", 123 "message": "", 124 }, 125 "attributes": { 126 "mlflow.traceRequestId": json.dumps(trace.info.request_id), 127 "mlflow.spanType": '"UNKNOWN"', 128 "mlflow.spanFunctionName": '"predict"', 129 "mlflow.spanInputs": '{"x": 2, "y": 5}', 130 "mlflow.spanOutputs": "8", 131 }, 132 }, 133 { 134 "name": "add_one_with_custom_name", 135 "trace_id": mock.ANY, 136 "span_id": mock.ANY, 137 "parent_span_id": mock.ANY, 138 "start_time_unix_nano": trace.data.spans[1].start_time_ns, 139 "end_time_unix_nano": trace.data.spans[1].end_time_ns, 140 "events": [], 141 "status": { 142 "code": "STATUS_CODE_OK", 143 "message": "", 144 }, 145 "attributes": { 146 "mlflow.traceRequestId": json.dumps(trace.info.request_id), 147 "mlflow.spanType": '"LLM"', 148 "mlflow.spanFunctionName": '"add_one"', 149 "mlflow.spanInputs": '{"z": 7}', 150 "mlflow.spanOutputs": "8", 151 "delta": "1", 152 "datetime": json.dumps(str(datetime_now)), 153 "metadata": '{"foo": "bar"}', 154 }, 155 }, 156 ], 157 }, 158 } 159 160 161 @pytest.mark.skipif( 162 importlib.util.find_spec("pydantic") is None, reason="Pydantic is not installed" 163 ) 164 def test_trace_serialize_pydantic_model(): 165 class MyModel(BaseModel): 166 x: int 167 y: str 168 169 data = MyModel(x=1, y="foo") 170 data_json = json.dumps(data, cls=TraceJSONEncoder) 171 assert data_json == '{"x": 1, "y": "foo"}' 172 assert json.loads(data_json) == {"x": 1, "y": "foo"} 173 174 175 def test_trace_serialize_dataclass(): 176 @dataclass 177 class Config: 178 model: str 179 temperature: float 180 tags: list[str] 181 182 config = Config(model="gpt-4o", temperature=0.5, tags=["a", "b"]) 183 result = json.loads(json.dumps(config, cls=TraceJSONEncoder)) 184 assert result == {"model": "gpt-4o", "temperature": 0.5, "tags": ["a", "b"]} 185 186 187 def test_trace_serialize_dataclass_with_non_copyable_field(): 188 """Dataclasses whose fields cannot be deepcopied (e.g. contain asyncio internals) 189 must serialize without raising an exception. 190 """ 191 192 class _NonCopyable: 193 def __deepcopy__(self, memo): 194 raise RuntimeError("deepcopy not supported") 195 196 @dataclass 197 class RunConfig: 198 name: str 199 client: _NonCopyable 200 201 config = RunConfig(name="test-run", client=_NonCopyable()) 202 # Should not raise; non-serializable client falls back to str representation 203 result = json.loads(json.dumps(config, cls=TraceJSONEncoder)) 204 assert result["name"] == "test-run" 205 assert "client" in result 206 207 208 @pytest.mark.skipif( 209 importlib.util.find_spec("langchain") is None, reason="langchain is not installed" 210 ) 211 def test_trace_serialize_langchain_base_message(): 212 from langchain_core.messages import BaseMessage 213 214 message = BaseMessage( 215 content=[ 216 { 217 "role": "system", 218 "content": "Hello, World!", 219 }, 220 { 221 "role": "user", 222 "content": "Hi!", 223 }, 224 ], 225 type="chat", 226 ) 227 228 message_json = json.dumps(message, cls=TraceJSONEncoder) 229 # LangChain message model contains a few more default fields actually. But we 230 # only check if the following subset of the expected dictionary is present in 231 # the loaded JSON rather than exact equality, because the LangChain BaseModel 232 # has been changing frequently and the additional default fields may differ 233 # across versions installed on developers' machines. 234 expected_dict_subset = { 235 "content": [ 236 { 237 "role": "system", 238 "content": "Hello, World!", 239 }, 240 { 241 "role": "user", 242 "content": "Hi!", 243 }, 244 ], 245 "type": "chat", 246 } 247 loaded = json.loads(message_json) 248 assert expected_dict_subset.items() <= loaded.items() 249 250 251 def test_trace_to_from_dict_and_json(): 252 model = _test_model() 253 model.predict(2, 5) 254 255 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 256 257 spans = trace.search_spans(span_type=SpanType.LLM) 258 assert len(spans) == 1 259 260 spans = trace.search_spans(name="predict") 261 assert len(spans) == 1 262 263 trace_dict = trace.to_dict() 264 trace_from_dict = Trace.from_dict(trace_dict) 265 trace_json = trace.to_json() 266 trace_from_json = Trace.from_json(trace_json) 267 for loaded_trace in [trace_from_dict, trace_from_json]: 268 assert trace.info == loaded_trace.info 269 assert trace.data.request == loaded_trace.data.request 270 assert trace.data.response == loaded_trace.data.response 271 assert len(trace.data.spans) == len(loaded_trace.data.spans) 272 for i in range(len(trace.data.spans)): 273 for attr in [ 274 "name", 275 "request_id", 276 "span_id", 277 "start_time_ns", 278 "end_time_ns", 279 "parent_id", 280 "status", 281 "inputs", 282 "outputs", 283 "_trace_id", 284 "attributes", 285 "events", 286 ]: 287 assert getattr(trace.data.spans[i], attr) == getattr( 288 loaded_trace.data.spans[i], attr 289 ) 290 291 292 def test_trace_pandas_dataframe_columns(): 293 t = Trace( 294 info=create_test_trace_info("a"), 295 data=TraceData(), 296 ) 297 assert Trace.pandas_dataframe_columns() == list(t.to_pandas_dataframe_row()) 298 299 t = Trace( 300 info=create_test_trace_info_with_uc_table("a", "catalog", "schema"), 301 data=TraceData(), 302 ) 303 assert Trace.pandas_dataframe_columns() == list(t.to_pandas_dataframe_row()) 304 305 306 @pytest.mark.parametrize( 307 ("span_type", "name", "expected"), 308 [ 309 (None, None, ["run", "add_one", "add_one", "add_two", "multiply_by_two"]), 310 (SpanType.CHAIN, None, ["run"]), 311 (None, "add_two", ["add_two"]), 312 (None, re.compile(r"add.*"), ["add_one", "add_one", "add_two"]), 313 (None, re.compile(r"^add"), ["add_one", "add_one", "add_two"]), 314 (None, re.compile(r"_two$"), ["add_two", "multiply_by_two"]), 315 (None, re.compile(r".*ONE", re.IGNORECASE), ["add_one", "add_one"]), 316 (SpanType.TOOL, "multiply_by_two", ["multiply_by_two"]), 317 (SpanType.AGENT, None, []), 318 (None, "non_existent", []), 319 ], 320 ) 321 def test_search_spans(span_type, name, expected): 322 @mlflow.trace(span_type=SpanType.CHAIN) 323 def run(x: int) -> int: 324 x = add_one(x) 325 x = add_one(x) 326 x = add_two(x) 327 return multiply_by_two(x) 328 329 @mlflow.trace(span_type=SpanType.TOOL) 330 def add_one(x: int) -> int: 331 return x + 1 332 333 @mlflow.trace(span_type=SpanType.TOOL) 334 def add_two(x: int) -> int: 335 return x + 2 336 337 @mlflow.trace(span_type=SpanType.TOOL) 338 def multiply_by_two(x: int) -> int: 339 return x * 2 340 341 run(2) 342 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 343 344 spans = trace.search_spans(span_type=span_type, name=name) 345 346 assert [span.name for span in spans] == expected 347 348 349 def test_search_spans_raise_for_invalid_param_type(): 350 @mlflow.trace(span_type=SpanType.CHAIN) 351 def run(x: int) -> int: 352 return x + 1 353 354 run(2) 355 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 356 357 with pytest.raises(MlflowException, match="Invalid type for 'span_type'"): 358 trace.search_spans(span_type=123) 359 360 with pytest.raises(MlflowException, match="Invalid type for 'name'"): 361 trace.search_spans(name=123) 362 363 364 def test_from_v2_dict(): 365 trace = Trace.from_dict(V2_TRACE_DICT) 366 assert trace.info.request_id == "58f4e27101304034b15c512b603bf1b2" 367 assert trace.info.request_time == 100 368 assert trace.info.execution_duration == 200 369 assert len(trace.data.spans) == 2 370 371 # Verify that schema version was updated from "2" to current version during V2 to V3 conversion 372 assert trace.info.trace_metadata[TRACE_SCHEMA_VERSION_KEY] == "2" 373 374 # Verify that other metadata was preserved 375 assert trace.info.trace_metadata["mlflow.traceInputs"] == '{"x": 2, "y": 5}' 376 assert trace.info.trace_metadata["mlflow.traceOutputs"] == "8" 377 378 379 def test_request_response_smart_truncation(): 380 @mlflow.trace 381 def f(messages: list[dict[str, Any]]) -> dict[str, Any]: 382 return {"choices": [{"message": {"role": "assistant", "content": "Hi!" * 1000}}]} 383 384 # NB: Since MLflow OSS backend still uses v2 tracing schema, the most accurate way to 385 # check if the preview is truncated properly is to mock the upload_trace_data call. 386 with mock.patch( 387 "mlflow.tracing.export.mlflow_v3.TracingClient.start_trace" 388 ) as mock_start_trace: 389 f([{"role": "user", "content": "Hello!" * 1000}]) 390 391 trace_info = mock_start_trace.call_args[0][0] 392 assert len(trace_info.request_preview) == 1000 393 assert trace_info.request_preview.startswith("Hello!") 394 assert len(trace_info.response_preview) == 1000 395 assert trace_info.response_preview.startswith("Hi!") 396 397 398 def test_request_response_smart_truncation_non_chat_format(): 399 # Non-chat request/response will be naively truncated 400 @mlflow.trace 401 def f(question: str) -> list[str]: 402 return ["a" * 5000, "b" * 5000, "c" * 5000] 403 404 with mock.patch( 405 "mlflow.tracing.export.mlflow_v3.TracingClient.start_trace" 406 ) as mock_start_trace: 407 f("start" + "a" * 1000) 408 409 trace_info = mock_start_trace.call_args[0][0] 410 assert len(trace_info.request_preview) == 1000 411 assert trace_info.request_preview.startswith('{"question": "startaaa') 412 assert len(trace_info.response_preview) == 1000 413 assert trace_info.response_preview.startswith('["aaaaa') 414 415 416 def test_request_response_custom_truncation(): 417 @mlflow.trace 418 def f(messages: list[dict[str, Any]]) -> dict[str, Any]: 419 mlflow.update_current_trace( 420 request_preview="custom request preview", 421 response_preview="custom response preview", 422 ) 423 return {"choices": [{"message": {"role": "assistant", "content": "Hi!" * 10000}}]} 424 425 with mock.patch( 426 "mlflow.tracing.export.mlflow_v3.TracingClient.start_trace" 427 ) as mock_start_trace: 428 f([{"role": "user", "content": "Hello!" * 10000}]) 429 430 trace_info = mock_start_trace.call_args[0][0] 431 assert trace_info.request_preview == "custom request preview" 432 assert trace_info.response_preview == "custom response preview" 433 434 435 def test_search_assessments(): 436 assessments = [ 437 Feedback( 438 trace_id="trace_id", 439 name="relevance", 440 value=False, 441 source=AssessmentSource(source_type="HUMAN", source_id="user_1"), 442 rationale="The judge is wrong", 443 span_id=None, 444 overrides="2", 445 ), 446 Feedback( 447 trace_id="trace_id", 448 name="relevance", 449 value=True, 450 source=AssessmentSource(source_type="LLM_JUDGE", source_id="databricks"), 451 span_id=None, 452 valid=False, 453 ), 454 Feedback( 455 trace_id="trace_id", 456 name="relevance", 457 value=True, 458 source=AssessmentSource(source_type="LLM_JUDGE", source_id="databricks"), 459 span_id="123", 460 ), 461 Expectation( 462 trace_id="trace_id", 463 name="guidelines", 464 value="The response should be concise and to the point.", 465 source=AssessmentSource(source_type="LLM_JUDGE", source_id="databricks"), 466 span_id="123", 467 ), 468 ] 469 trace_info = TraceInfo( 470 trace_id="trace_id", 471 client_request_id="client_request_id", 472 trace_location=TraceLocation.from_experiment_id("123"), 473 request_preview="request", 474 response_preview="response", 475 request_time=1234567890, 476 execution_duration=100, 477 assessments=assessments, 478 state=TraceState.OK, 479 ) 480 trace = Trace( 481 info=trace_info, 482 data=TraceData( 483 spans=[], 484 ), 485 ) 486 487 assert trace.search_assessments() == [assessments[0], assessments[2], assessments[3]] 488 assert trace.search_assessments(all=True) == assessments 489 assert trace.search_assessments("relevance") == [assessments[0], assessments[2]] 490 assert trace.search_assessments("relevance", all=True) == assessments[:3] 491 assert trace.search_assessments(span_id="123") == [assessments[2], assessments[3]] 492 assert trace.search_assessments(span_id="123", name="relevance") == [assessments[2]] 493 assert trace.search_assessments(type="expectation") == [assessments[3]] 494 495 496 def test_trace_to_and_from_proto(): 497 @mlflow.trace 498 def invoke(x): 499 return x + 1 500 501 @mlflow.trace 502 def test(x): 503 return invoke(x) 504 505 test(1) 506 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 507 proto_trace = trace.to_proto() 508 assert proto_trace.trace_info.trace_id == trace.info.request_id 509 assert proto_trace.trace_info.trace_location == trace.info.trace_location.to_proto() 510 assert len(proto_trace.spans) == 2 511 assert proto_trace.spans[0].name == "test" 512 assert proto_trace.spans[1].name == "invoke" 513 514 trace_from_proto = Trace.from_proto(proto_trace) 515 assert trace_from_proto.to_dict() == trace.to_dict() 516 517 518 def test_trace_from_dict_load_old_trace(): 519 trace_dict = { 520 "info": { 521 "trace_id": "tr-ee17184669c265ffdcf9299b36f6dccc", 522 "trace_location": { 523 "type": "MLFLOW_EXPERIMENT", 524 "mlflow_experiment": {"experiment_id": "0"}, 525 }, 526 "request_time": "2025-10-22T04:14:54.524Z", 527 "state": "OK", 528 "trace_metadata": { 529 "mlflow.trace_schema.version": "3", 530 "mlflow.traceInputs": '"abc"', 531 "mlflow.source.type": "LOCAL", 532 "mlflow.source.git.branch": "branch-3.4", 533 "mlflow.source.name": "a.py", 534 "mlflow.source.git.commit": "78d075062b120597050bf2b3839a426feea5ea4c", 535 "mlflow.user": "serena.ruan", 536 "mlflow.traceOutputs": '"def"', 537 "mlflow.source.git.repoURL": "git@github.com:mlflow/mlflow.git", 538 "mlflow.trace.sizeBytes": "1226", 539 }, 540 "tags": { 541 "mlflow.artifactLocation": "mlflow-artifacts:/0/traces", 542 "mlflow.traceName": "test", 543 }, 544 "request_preview": '"abc"', 545 "response_preview": '"def"', 546 "execution_duration_ms": 60, 547 }, 548 "data": { 549 "spans": [ 550 { 551 "trace_id": "7hcYRmnCZf/c+SmbNvbczA==", 552 "span_id": "3ElmHER9IVU=", 553 "trace_state": "", 554 "parent_span_id": "", 555 "name": "test", 556 "start_time_unix_nano": 1761106494524157000, 557 "end_time_unix_nano": 1761106494584860000, 558 "attributes": { 559 "mlflow.spanOutputs": '"def"', 560 "mlflow.spanType": '"UNKNOWN"', 561 "mlflow.spanInputs": '"abc"', 562 "mlflow.traceRequestId": '"tr-ee17184669c265ffdcf9299b36f6dccc"', 563 "test": '"test"', 564 }, 565 "status": {"message": "", "code": "STATUS_CODE_OK"}, 566 } 567 ] 568 }, 569 } 570 trace = Trace.from_dict(trace_dict) 571 assert trace.info.trace_id == "tr-ee17184669c265ffdcf9299b36f6dccc" 572 assert trace.info.request_time == 1761106494524 573 assert trace.info.execution_duration == 60 574 assert trace.info.trace_location == TraceLocation.from_experiment_id("0") 575 assert len(trace.data.spans) == 1 576 assert trace.data.spans[0].name == "test" 577 assert trace.data.spans[0].inputs == "abc" 578 assert trace.data.spans[0].outputs == "def" 579 assert trace.data.spans[0].start_time_ns == 1761106494524157000 580 assert trace.data.spans[0].end_time_ns == 1761106494584860000