/ tests / entities / test_trace.py
test_trace.py
  1  import importlib.util
  2  import json
  3  import re
  4  from dataclasses import dataclass
  5  from datetime import datetime
  6  from typing import Any
  7  from unittest import mock
  8  
  9  import pytest
 10  from pydantic import BaseModel
 11  
 12  import mlflow
 13  import mlflow.tracking.context.default_context
 14  from mlflow.entities import (
 15      AssessmentSource,
 16      Feedback,
 17      SpanType,
 18      Trace,
 19      TraceData,
 20      TraceInfo,
 21      TraceLocation,
 22  )
 23  from mlflow.entities.assessment import Expectation
 24  from mlflow.entities.trace_state import TraceState
 25  from mlflow.environment_variables import MLFLOW_TRACKING_USERNAME
 26  from mlflow.exceptions import MlflowException
 27  from mlflow.tracing.constant import TRACE_SCHEMA_VERSION_KEY
 28  from mlflow.tracing.utils import TraceJSONEncoder
 29  from mlflow.utils.mlflow_tags import MLFLOW_ARTIFACT_LOCATION
 30  from mlflow.utils.proto_json_utils import (
 31      milliseconds_to_proto_timestamp,
 32  )
 33  
 34  from tests.tracing.helper import (
 35      V2_TRACE_DICT,
 36      create_test_trace_info,
 37      create_test_trace_info_with_uc_table,
 38  )
 39  
 40  
 41  def _test_model(datetime=datetime.now()):
 42      class TestModel:
 43          @mlflow.trace()
 44          def predict(self, x, y):
 45              z = x + y
 46              z = self.add_one(z)
 47              return z  # noqa: RET504
 48  
 49          @mlflow.trace(
 50              span_type=SpanType.LLM,
 51              name="add_one_with_custom_name",
 52              attributes={
 53                  "delta": 1,
 54                  "metadata": {"foo": "bar"},
 55                  # Test for non-json-serializable input
 56                  "datetime": datetime,
 57              },
 58          )
 59          def add_one(self, z):
 60              return z + 1
 61  
 62      return TestModel()
 63  
 64  
 65  def test_json_deserialization(monkeypatch):
 66      monkeypatch.setattr(mlflow.tracking.context.default_context, "_get_source_name", lambda: "test")
 67      monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "bob")
 68      datetime_now = datetime.now()
 69  
 70      model = _test_model(datetime_now)
 71      model.predict(2, 5)
 72  
 73      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
 74      trace_json = trace.to_json()
 75  
 76      trace_json_as_dict = json.loads(trace_json)
 77      assert trace_json_as_dict == {
 78          "info": {
 79              "trace_id": trace.info.request_id,
 80              "trace_location": {
 81                  "mlflow_experiment": {
 82                      "experiment_id": "0",
 83                  },
 84                  "type": "MLFLOW_EXPERIMENT",
 85              },
 86              "request_time": milliseconds_to_proto_timestamp(trace.info.timestamp_ms),
 87              "execution_duration_ms": trace.info.execution_time_ms,
 88              "state": "OK",
 89              "request_preview": '{"x": 2, "y": 5}',
 90              "response_preview": "8",
 91              "trace_metadata": {
 92                  "mlflow.traceInputs": '{"x": 2, "y": 5}',
 93                  "mlflow.traceOutputs": "8",
 94                  "mlflow.source.name": mock.ANY,
 95                  "mlflow.source.type": "LOCAL",
 96                  "mlflow.source.git.branch": mock.ANY,
 97                  "mlflow.source.git.commit": mock.ANY,
 98                  "mlflow.source.git.repoURL": mock.ANY,
 99                  "mlflow.user": mock.ANY,
100                  "mlflow.trace.sizeBytes": mock.ANY,
101                  "mlflow.trace.sizeStats": mock.ANY,
102                  "mlflow.trace_schema.version": "3",
103                  "mlflow.trace.infoFinalized": "true",
104              },
105              "tags": {
106                  "mlflow.traceName": "predict",
107                  "mlflow.artifactLocation": trace.info.tags[MLFLOW_ARTIFACT_LOCATION],
108                  "mlflow.trace.spansLocation": mock.ANY,
109              },
110          },
111          "data": {
112              "spans": [
113                  {
114                      "name": "predict",
115                      "trace_id": mock.ANY,
116                      "span_id": mock.ANY,
117                      "parent_span_id": None,
118                      "start_time_unix_nano": trace.data.spans[0].start_time_ns,
119                      "end_time_unix_nano": trace.data.spans[0].end_time_ns,
120                      "events": [],
121                      "status": {
122                          "code": "STATUS_CODE_OK",
123                          "message": "",
124                      },
125                      "attributes": {
126                          "mlflow.traceRequestId": json.dumps(trace.info.request_id),
127                          "mlflow.spanType": '"UNKNOWN"',
128                          "mlflow.spanFunctionName": '"predict"',
129                          "mlflow.spanInputs": '{"x": 2, "y": 5}',
130                          "mlflow.spanOutputs": "8",
131                      },
132                  },
133                  {
134                      "name": "add_one_with_custom_name",
135                      "trace_id": mock.ANY,
136                      "span_id": mock.ANY,
137                      "parent_span_id": mock.ANY,
138                      "start_time_unix_nano": trace.data.spans[1].start_time_ns,
139                      "end_time_unix_nano": trace.data.spans[1].end_time_ns,
140                      "events": [],
141                      "status": {
142                          "code": "STATUS_CODE_OK",
143                          "message": "",
144                      },
145                      "attributes": {
146                          "mlflow.traceRequestId": json.dumps(trace.info.request_id),
147                          "mlflow.spanType": '"LLM"',
148                          "mlflow.spanFunctionName": '"add_one"',
149                          "mlflow.spanInputs": '{"z": 7}',
150                          "mlflow.spanOutputs": "8",
151                          "delta": "1",
152                          "datetime": json.dumps(str(datetime_now)),
153                          "metadata": '{"foo": "bar"}',
154                      },
155                  },
156              ],
157          },
158      }
159  
160  
161  @pytest.mark.skipif(
162      importlib.util.find_spec("pydantic") is None, reason="Pydantic is not installed"
163  )
164  def test_trace_serialize_pydantic_model():
165      class MyModel(BaseModel):
166          x: int
167          y: str
168  
169      data = MyModel(x=1, y="foo")
170      data_json = json.dumps(data, cls=TraceJSONEncoder)
171      assert data_json == '{"x": 1, "y": "foo"}'
172      assert json.loads(data_json) == {"x": 1, "y": "foo"}
173  
174  
175  def test_trace_serialize_dataclass():
176      @dataclass
177      class Config:
178          model: str
179          temperature: float
180          tags: list[str]
181  
182      config = Config(model="gpt-4o", temperature=0.5, tags=["a", "b"])
183      result = json.loads(json.dumps(config, cls=TraceJSONEncoder))
184      assert result == {"model": "gpt-4o", "temperature": 0.5, "tags": ["a", "b"]}
185  
186  
187  def test_trace_serialize_dataclass_with_non_copyable_field():
188      """Dataclasses whose fields cannot be deepcopied (e.g. contain asyncio internals)
189      must serialize without raising an exception.
190      """
191  
192      class _NonCopyable:
193          def __deepcopy__(self, memo):
194              raise RuntimeError("deepcopy not supported")
195  
196      @dataclass
197      class RunConfig:
198          name: str
199          client: _NonCopyable
200  
201      config = RunConfig(name="test-run", client=_NonCopyable())
202      # Should not raise; non-serializable client falls back to str representation
203      result = json.loads(json.dumps(config, cls=TraceJSONEncoder))
204      assert result["name"] == "test-run"
205      assert "client" in result
206  
207  
208  @pytest.mark.skipif(
209      importlib.util.find_spec("langchain") is None, reason="langchain is not installed"
210  )
211  def test_trace_serialize_langchain_base_message():
212      from langchain_core.messages import BaseMessage
213  
214      message = BaseMessage(
215          content=[
216              {
217                  "role": "system",
218                  "content": "Hello, World!",
219              },
220              {
221                  "role": "user",
222                  "content": "Hi!",
223              },
224          ],
225          type="chat",
226      )
227  
228      message_json = json.dumps(message, cls=TraceJSONEncoder)
229      # LangChain message model contains a few more default fields actually. But we
230      # only check if the following subset of the expected dictionary is present in
231      # the loaded JSON rather than exact equality, because the LangChain BaseModel
232      # has been changing frequently and the additional default fields may differ
233      # across versions installed on developers' machines.
234      expected_dict_subset = {
235          "content": [
236              {
237                  "role": "system",
238                  "content": "Hello, World!",
239              },
240              {
241                  "role": "user",
242                  "content": "Hi!",
243              },
244          ],
245          "type": "chat",
246      }
247      loaded = json.loads(message_json)
248      assert expected_dict_subset.items() <= loaded.items()
249  
250  
251  def test_trace_to_from_dict_and_json():
252      model = _test_model()
253      model.predict(2, 5)
254  
255      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
256  
257      spans = trace.search_spans(span_type=SpanType.LLM)
258      assert len(spans) == 1
259  
260      spans = trace.search_spans(name="predict")
261      assert len(spans) == 1
262  
263      trace_dict = trace.to_dict()
264      trace_from_dict = Trace.from_dict(trace_dict)
265      trace_json = trace.to_json()
266      trace_from_json = Trace.from_json(trace_json)
267      for loaded_trace in [trace_from_dict, trace_from_json]:
268          assert trace.info == loaded_trace.info
269          assert trace.data.request == loaded_trace.data.request
270          assert trace.data.response == loaded_trace.data.response
271          assert len(trace.data.spans) == len(loaded_trace.data.spans)
272          for i in range(len(trace.data.spans)):
273              for attr in [
274                  "name",
275                  "request_id",
276                  "span_id",
277                  "start_time_ns",
278                  "end_time_ns",
279                  "parent_id",
280                  "status",
281                  "inputs",
282                  "outputs",
283                  "_trace_id",
284                  "attributes",
285                  "events",
286              ]:
287                  assert getattr(trace.data.spans[i], attr) == getattr(
288                      loaded_trace.data.spans[i], attr
289                  )
290  
291  
292  def test_trace_pandas_dataframe_columns():
293      t = Trace(
294          info=create_test_trace_info("a"),
295          data=TraceData(),
296      )
297      assert Trace.pandas_dataframe_columns() == list(t.to_pandas_dataframe_row())
298  
299      t = Trace(
300          info=create_test_trace_info_with_uc_table("a", "catalog", "schema"),
301          data=TraceData(),
302      )
303      assert Trace.pandas_dataframe_columns() == list(t.to_pandas_dataframe_row())
304  
305  
306  @pytest.mark.parametrize(
307      ("span_type", "name", "expected"),
308      [
309          (None, None, ["run", "add_one", "add_one", "add_two", "multiply_by_two"]),
310          (SpanType.CHAIN, None, ["run"]),
311          (None, "add_two", ["add_two"]),
312          (None, re.compile(r"add.*"), ["add_one", "add_one", "add_two"]),
313          (None, re.compile(r"^add"), ["add_one", "add_one", "add_two"]),
314          (None, re.compile(r"_two$"), ["add_two", "multiply_by_two"]),
315          (None, re.compile(r".*ONE", re.IGNORECASE), ["add_one", "add_one"]),
316          (SpanType.TOOL, "multiply_by_two", ["multiply_by_two"]),
317          (SpanType.AGENT, None, []),
318          (None, "non_existent", []),
319      ],
320  )
321  def test_search_spans(span_type, name, expected):
322      @mlflow.trace(span_type=SpanType.CHAIN)
323      def run(x: int) -> int:
324          x = add_one(x)
325          x = add_one(x)
326          x = add_two(x)
327          return multiply_by_two(x)
328  
329      @mlflow.trace(span_type=SpanType.TOOL)
330      def add_one(x: int) -> int:
331          return x + 1
332  
333      @mlflow.trace(span_type=SpanType.TOOL)
334      def add_two(x: int) -> int:
335          return x + 2
336  
337      @mlflow.trace(span_type=SpanType.TOOL)
338      def multiply_by_two(x: int) -> int:
339          return x * 2
340  
341      run(2)
342      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
343  
344      spans = trace.search_spans(span_type=span_type, name=name)
345  
346      assert [span.name for span in spans] == expected
347  
348  
349  def test_search_spans_raise_for_invalid_param_type():
350      @mlflow.trace(span_type=SpanType.CHAIN)
351      def run(x: int) -> int:
352          return x + 1
353  
354      run(2)
355      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
356  
357      with pytest.raises(MlflowException, match="Invalid type for 'span_type'"):
358          trace.search_spans(span_type=123)
359  
360      with pytest.raises(MlflowException, match="Invalid type for 'name'"):
361          trace.search_spans(name=123)
362  
363  
364  def test_from_v2_dict():
365      trace = Trace.from_dict(V2_TRACE_DICT)
366      assert trace.info.request_id == "58f4e27101304034b15c512b603bf1b2"
367      assert trace.info.request_time == 100
368      assert trace.info.execution_duration == 200
369      assert len(trace.data.spans) == 2
370  
371      # Verify that schema version was updated from "2" to current version during V2 to V3 conversion
372      assert trace.info.trace_metadata[TRACE_SCHEMA_VERSION_KEY] == "2"
373  
374      # Verify that other metadata was preserved
375      assert trace.info.trace_metadata["mlflow.traceInputs"] == '{"x": 2, "y": 5}'
376      assert trace.info.trace_metadata["mlflow.traceOutputs"] == "8"
377  
378  
379  def test_request_response_smart_truncation():
380      @mlflow.trace
381      def f(messages: list[dict[str, Any]]) -> dict[str, Any]:
382          return {"choices": [{"message": {"role": "assistant", "content": "Hi!" * 1000}}]}
383  
384      # NB: Since MLflow OSS backend still uses v2 tracing schema, the most accurate way to
385      # check if the preview is truncated properly is to mock the upload_trace_data call.
386      with mock.patch(
387          "mlflow.tracing.export.mlflow_v3.TracingClient.start_trace"
388      ) as mock_start_trace:
389          f([{"role": "user", "content": "Hello!" * 1000}])
390  
391      trace_info = mock_start_trace.call_args[0][0]
392      assert len(trace_info.request_preview) == 1000
393      assert trace_info.request_preview.startswith("Hello!")
394      assert len(trace_info.response_preview) == 1000
395      assert trace_info.response_preview.startswith("Hi!")
396  
397  
398  def test_request_response_smart_truncation_non_chat_format():
399      # Non-chat request/response will be naively truncated
400      @mlflow.trace
401      def f(question: str) -> list[str]:
402          return ["a" * 5000, "b" * 5000, "c" * 5000]
403  
404      with mock.patch(
405          "mlflow.tracing.export.mlflow_v3.TracingClient.start_trace"
406      ) as mock_start_trace:
407          f("start" + "a" * 1000)
408  
409      trace_info = mock_start_trace.call_args[0][0]
410      assert len(trace_info.request_preview) == 1000
411      assert trace_info.request_preview.startswith('{"question": "startaaa')
412      assert len(trace_info.response_preview) == 1000
413      assert trace_info.response_preview.startswith('["aaaaa')
414  
415  
416  def test_request_response_custom_truncation():
417      @mlflow.trace
418      def f(messages: list[dict[str, Any]]) -> dict[str, Any]:
419          mlflow.update_current_trace(
420              request_preview="custom request preview",
421              response_preview="custom response preview",
422          )
423          return {"choices": [{"message": {"role": "assistant", "content": "Hi!" * 10000}}]}
424  
425      with mock.patch(
426          "mlflow.tracing.export.mlflow_v3.TracingClient.start_trace"
427      ) as mock_start_trace:
428          f([{"role": "user", "content": "Hello!" * 10000}])
429  
430      trace_info = mock_start_trace.call_args[0][0]
431      assert trace_info.request_preview == "custom request preview"
432      assert trace_info.response_preview == "custom response preview"
433  
434  
435  def test_search_assessments():
436      assessments = [
437          Feedback(
438              trace_id="trace_id",
439              name="relevance",
440              value=False,
441              source=AssessmentSource(source_type="HUMAN", source_id="user_1"),
442              rationale="The judge is wrong",
443              span_id=None,
444              overrides="2",
445          ),
446          Feedback(
447              trace_id="trace_id",
448              name="relevance",
449              value=True,
450              source=AssessmentSource(source_type="LLM_JUDGE", source_id="databricks"),
451              span_id=None,
452              valid=False,
453          ),
454          Feedback(
455              trace_id="trace_id",
456              name="relevance",
457              value=True,
458              source=AssessmentSource(source_type="LLM_JUDGE", source_id="databricks"),
459              span_id="123",
460          ),
461          Expectation(
462              trace_id="trace_id",
463              name="guidelines",
464              value="The response should be concise and to the point.",
465              source=AssessmentSource(source_type="LLM_JUDGE", source_id="databricks"),
466              span_id="123",
467          ),
468      ]
469      trace_info = TraceInfo(
470          trace_id="trace_id",
471          client_request_id="client_request_id",
472          trace_location=TraceLocation.from_experiment_id("123"),
473          request_preview="request",
474          response_preview="response",
475          request_time=1234567890,
476          execution_duration=100,
477          assessments=assessments,
478          state=TraceState.OK,
479      )
480      trace = Trace(
481          info=trace_info,
482          data=TraceData(
483              spans=[],
484          ),
485      )
486  
487      assert trace.search_assessments() == [assessments[0], assessments[2], assessments[3]]
488      assert trace.search_assessments(all=True) == assessments
489      assert trace.search_assessments("relevance") == [assessments[0], assessments[2]]
490      assert trace.search_assessments("relevance", all=True) == assessments[:3]
491      assert trace.search_assessments(span_id="123") == [assessments[2], assessments[3]]
492      assert trace.search_assessments(span_id="123", name="relevance") == [assessments[2]]
493      assert trace.search_assessments(type="expectation") == [assessments[3]]
494  
495  
496  def test_trace_to_and_from_proto():
497      @mlflow.trace
498      def invoke(x):
499          return x + 1
500  
501      @mlflow.trace
502      def test(x):
503          return invoke(x)
504  
505      test(1)
506      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
507      proto_trace = trace.to_proto()
508      assert proto_trace.trace_info.trace_id == trace.info.request_id
509      assert proto_trace.trace_info.trace_location == trace.info.trace_location.to_proto()
510      assert len(proto_trace.spans) == 2
511      assert proto_trace.spans[0].name == "test"
512      assert proto_trace.spans[1].name == "invoke"
513  
514      trace_from_proto = Trace.from_proto(proto_trace)
515      assert trace_from_proto.to_dict() == trace.to_dict()
516  
517  
518  def test_trace_from_dict_load_old_trace():
519      trace_dict = {
520          "info": {
521              "trace_id": "tr-ee17184669c265ffdcf9299b36f6dccc",
522              "trace_location": {
523                  "type": "MLFLOW_EXPERIMENT",
524                  "mlflow_experiment": {"experiment_id": "0"},
525              },
526              "request_time": "2025-10-22T04:14:54.524Z",
527              "state": "OK",
528              "trace_metadata": {
529                  "mlflow.trace_schema.version": "3",
530                  "mlflow.traceInputs": '"abc"',
531                  "mlflow.source.type": "LOCAL",
532                  "mlflow.source.git.branch": "branch-3.4",
533                  "mlflow.source.name": "a.py",
534                  "mlflow.source.git.commit": "78d075062b120597050bf2b3839a426feea5ea4c",
535                  "mlflow.user": "serena.ruan",
536                  "mlflow.traceOutputs": '"def"',
537                  "mlflow.source.git.repoURL": "git@github.com:mlflow/mlflow.git",
538                  "mlflow.trace.sizeBytes": "1226",
539              },
540              "tags": {
541                  "mlflow.artifactLocation": "mlflow-artifacts:/0/traces",
542                  "mlflow.traceName": "test",
543              },
544              "request_preview": '"abc"',
545              "response_preview": '"def"',
546              "execution_duration_ms": 60,
547          },
548          "data": {
549              "spans": [
550                  {
551                      "trace_id": "7hcYRmnCZf/c+SmbNvbczA==",
552                      "span_id": "3ElmHER9IVU=",
553                      "trace_state": "",
554                      "parent_span_id": "",
555                      "name": "test",
556                      "start_time_unix_nano": 1761106494524157000,
557                      "end_time_unix_nano": 1761106494584860000,
558                      "attributes": {
559                          "mlflow.spanOutputs": '"def"',
560                          "mlflow.spanType": '"UNKNOWN"',
561                          "mlflow.spanInputs": '"abc"',
562                          "mlflow.traceRequestId": '"tr-ee17184669c265ffdcf9299b36f6dccc"',
563                          "test": '"test"',
564                      },
565                      "status": {"message": "", "code": "STATUS_CODE_OK"},
566                  }
567              ]
568          },
569      }
570      trace = Trace.from_dict(trace_dict)
571      assert trace.info.trace_id == "tr-ee17184669c265ffdcf9299b36f6dccc"
572      assert trace.info.request_time == 1761106494524
573      assert trace.info.execution_duration == 60
574      assert trace.info.trace_location == TraceLocation.from_experiment_id("0")
575      assert len(trace.data.spans) == 1
576      assert trace.data.spans[0].name == "test"
577      assert trace.data.spans[0].inputs == "abc"
578      assert trace.data.spans[0].outputs == "def"
579      assert trace.data.spans[0].start_time_ns == 1761106494524157000
580      assert trace.data.spans[0].end_time_ns == 1761106494584860000