/ tests / tracing / test_fluent.py
test_fluent.py
   1  import asyncio
   2  import json
   3  import os
   4  import subprocess
   5  import sys
   6  import threading
   7  import time
   8  import uuid
   9  from concurrent.futures import ThreadPoolExecutor, as_completed
  10  from dataclasses import asdict
  11  from datetime import datetime
  12  from unittest import mock
  13  
  14  import pytest
  15  from opentelemetry.sdk.trace.export import SpanExporter
  16  
  17  import mlflow
  18  from mlflow.entities import (
  19      SpanEvent,
  20      SpanStatusCode,
  21      SpanType,
  22      Trace,
  23      TraceData,
  24      TraceInfo,
  25  )
  26  from mlflow.entities.trace_location import TraceLocation, UCSchemaLocation
  27  from mlflow.entities.trace_state import TraceState
  28  from mlflow.environment_variables import MLFLOW_TRACE_SAMPLING_RATIO, MLFLOW_TRACKING_USERNAME
  29  from mlflow.exceptions import MlflowException
  30  from mlflow.store.entities.paged_list import PagedList
  31  from mlflow.store.tracking import SEARCH_TRACES_DEFAULT_MAX_RESULTS
  32  from mlflow.tracing.client import TracingClient
  33  from mlflow.tracing.constant import (
  34      TRACE_SCHEMA_VERSION_KEY,
  35      SpanAttributeKey,
  36      TraceMetadataKey,
  37      TraceTagKey,
  38  )
  39  from mlflow.tracing.destination import MlflowExperiment
  40  from mlflow.tracing.export.inference_table import pop_trace
  41  from mlflow.tracing.fluent import start_span_no_context
  42  from mlflow.tracing.provider import (
  43      _MLFLOW_TRACE_USER_DESTINATION,
  44      _get_tracer,
  45      safe_set_span_in_context,
  46      set_destination,
  47  )
  48  from mlflow.tracking.fluent import _get_experiment_id
  49  from mlflow.version import IS_TRACING_SDK_ONLY
  50  
  51  from tests.tracing.helper import (
  52      create_test_trace_info,
  53      get_traces,
  54      purge_traces,
  55      skip_when_testing_trace_sdk,
  56  )
  57  
  58  
  59  class DefaultTestModel:
  60      @mlflow.trace()
  61      def predict(self, x, y):
  62          z = x + y
  63          z = self.add_one(z)
  64          z = mlflow.trace(self.square)(z)
  65          return z  # noqa: RET504
  66  
  67      @mlflow.trace(span_type=SpanType.LLM, name="add_one_with_custom_name", attributes={"delta": 1})
  68      def add_one(self, z):
  69          return z + 1
  70  
  71      def square(self, t):
  72          res = t**2
  73          time.sleep(0.1)
  74          return res
  75  
  76  
  77  class DefaultAsyncTestModel:
  78      @mlflow.trace()
  79      async def predict(self, x, y):
  80          z = x + y
  81          z = await self.add_one(z)
  82          z = await mlflow.trace(self.square)(z)
  83          return z  # noqa: RET504
  84  
  85      @mlflow.trace(span_type=SpanType.LLM, name="add_one_with_custom_name", attributes={"delta": 1})
  86      async def add_one(self, z):
  87          return z + 1
  88  
  89      async def square(self, t):
  90          res = t**2
  91          time.sleep(0.1)
  92          return res
  93  
  94  
  95  class StreamTestModel:
  96      @mlflow.trace(output_reducer=lambda x: sum(x))
  97      def predict_stream(self, x, y):
  98          z = x + y
  99          for i in range(z):
 100              yield i
 101  
 102          # Generator with a normal func
 103          for i in range(z):
 104              yield self.square(i)
 105  
 106          # Nested generator
 107          yield from self.generate_numbers(z)
 108  
 109      @mlflow.trace
 110      def square(self, t):
 111          time.sleep(0.1)
 112          return t**2
 113  
 114      # No output_reducer -> record the list of outputs
 115      @mlflow.trace
 116      def generate_numbers(self, z):
 117          for i in range(z):
 118              yield i
 119  
 120  
 121  class AsyncStreamTestModel:
 122      @mlflow.trace(output_reducer=lambda x: sum(x))
 123      async def predict_stream(self, x, y):
 124          z = x + y
 125          for i in range(z):
 126              yield i
 127  
 128          # Generator with a normal func
 129          for i in range(z):
 130              yield await self.square(i)
 131  
 132          # Nested generator
 133          async for number in self.generate_numbers(z):
 134              yield number
 135  
 136      @mlflow.trace
 137      async def square(self, t):
 138          await asyncio.sleep(0.1)
 139          return t**2
 140  
 141      @mlflow.trace
 142      async def generate_numbers(self, z):
 143          for i in range(z):
 144              yield i
 145  
 146  
 147  class ErroringTestModel:
 148      @mlflow.trace()
 149      def predict(self, x, y):
 150          return self.some_operation_raise_error(x, y)
 151  
 152      @mlflow.trace()
 153      def some_operation_raise_error(self, x, y):
 154          raise ValueError("Some error")
 155  
 156  
 157  class ErroringAsyncTestModel:
 158      @mlflow.trace()
 159      async def predict(self, x, y):
 160          return await self.some_operation_raise_error(x, y)
 161  
 162      @mlflow.trace()
 163      async def some_operation_raise_error(self, x, y):
 164          raise ValueError("Some error")
 165  
 166  
 167  class ErroringStreamTestModel:
 168      @mlflow.trace
 169      def predict_stream(self, x):
 170          for i in range(x):
 171              if i > 0:
 172                  # Ensure distinct start_time_ns on Windows for deterministic span ordering
 173                  time.sleep(0.001)
 174              yield self.some_operation_raise_error(i)
 175  
 176      @mlflow.trace
 177      def some_operation_raise_error(self, i):
 178          if i >= 1:
 179              raise ValueError("Some error")
 180          return i
 181  
 182  
 183  @pytest.fixture
 184  def mock_client():
 185      client = mock.MagicMock()
 186      with mock.patch("mlflow.tracing.fluent.TracingClient", return_value=client):
 187          yield client
 188  
 189  
 190  @pytest.fixture
 191  def mock_otel_trace_start_time():
 192      # mock the start time of a trace, ensuring the root span has
 193      # a smaller start time than child spans.
 194      with mock.patch("opentelemetry.sdk.trace.time_ns", return_value=0):
 195          yield
 196  
 197  
 198  @pytest.mark.parametrize("with_active_run", [True, False])
 199  @pytest.mark.parametrize("wrap_sync_func", [True, False])
 200  def test_trace(wrap_sync_func, with_active_run, async_logging_enabled):
 201      model = DefaultTestModel() if wrap_sync_func else DefaultAsyncTestModel()
 202  
 203      if with_active_run:
 204          if IS_TRACING_SDK_ONLY:
 205              pytest.skip("Skipping test because mlflow or mlflow-skinny is not installed.")
 206  
 207          with mlflow.start_run() as run:
 208              model.predict(2, 5) if wrap_sync_func else asyncio.run(model.predict(2, 5))
 209              run_id = run.info.run_id
 210      else:
 211          model.predict(2, 5) if wrap_sync_func else asyncio.run(model.predict(2, 5))
 212  
 213      if async_logging_enabled:
 214          mlflow.flush_trace_async_logging(terminate=True)
 215  
 216      traces = get_traces()
 217      assert len(traces) == 1
 218      trace = traces[0]
 219      assert trace.info.trace_id is not None
 220      assert trace.info.experiment_id == _get_experiment_id()
 221      assert trace.info.execution_time_ms >= 0.1 * 1e3  # at least 0.1 sec
 222      assert trace.info.state == TraceState.OK
 223      assert trace.info.request_metadata[TraceMetadataKey.INPUTS] == '{"x": 2, "y": 5}'
 224      assert trace.info.request_metadata[TraceMetadataKey.OUTPUTS] == "64"
 225      if with_active_run:
 226          assert trace.info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run_id
 227  
 228      assert trace.data.request == '{"x": 2, "y": 5}'
 229      assert trace.data.response == "64"
 230      assert len(trace.data.spans) == 3
 231  
 232      span_name_to_span = {span.name: span for span in trace.data.spans}
 233      root_span = span_name_to_span["predict"]
 234      # TODO: Trace info timestamp is not accurate because it is not adjusted to exclude the latency
 235      # assert root_span.start_time_ns // 1e6 == trace.info.timestamp_ms
 236      assert root_span.parent_id is None
 237      assert root_span.attributes == {
 238          "mlflow.traceRequestId": trace.info.trace_id,
 239          "mlflow.spanFunctionName": "predict",
 240          "mlflow.spanType": "UNKNOWN",
 241          "mlflow.spanInputs": {"x": 2, "y": 5},
 242          "mlflow.spanOutputs": 64,
 243      }
 244  
 245      child_span_1 = span_name_to_span["add_one_with_custom_name"]
 246      assert child_span_1.parent_id == root_span.span_id
 247      assert child_span_1.attributes == {
 248          "delta": 1,
 249          "mlflow.traceRequestId": trace.info.trace_id,
 250          "mlflow.spanFunctionName": "add_one",
 251          "mlflow.spanType": "LLM",
 252          "mlflow.spanInputs": {"z": 7},
 253          "mlflow.spanOutputs": 8,
 254      }
 255  
 256      child_span_2 = span_name_to_span["square"]
 257      assert child_span_2.parent_id == root_span.span_id
 258      assert child_span_2.start_time_ns <= child_span_2.end_time_ns - 0.1 * 1e6
 259      assert child_span_2.attributes == {
 260          "mlflow.traceRequestId": trace.info.trace_id,
 261          "mlflow.spanFunctionName": "square",
 262          "mlflow.spanType": "UNKNOWN",
 263          "mlflow.spanInputs": {"t": 8},
 264          "mlflow.spanOutputs": 64,
 265      }
 266  
 267  
 268  @pytest.mark.parametrize("wrap_sync_func", [True, False])
 269  def test_trace_stream(wrap_sync_func):
 270      model = StreamTestModel() if wrap_sync_func else AsyncStreamTestModel()
 271  
 272      stream = model.predict_stream(1, 2)
 273  
 274      # Trace should not be logged until the generator is consumed
 275      assert get_traces() == []
 276      # The span should not be set to active
 277      # because the generator is not yet consumed
 278      assert mlflow.get_current_active_span() is None
 279  
 280      chunks = []
 281      if wrap_sync_func:
 282          for chunk in stream:
 283              chunks.append(chunk)
 284              # The `predict` span should not be active here.
 285              assert mlflow.get_current_active_span() is None
 286      else:
 287  
 288          async def consume_stream():
 289              async for chunk in stream:
 290                  chunks.append(chunk)
 291                  assert mlflow.get_current_active_span() is None
 292  
 293          asyncio.run(consume_stream())
 294  
 295      traces = get_traces()
 296      assert len(traces) == 1
 297      trace = traces[0]
 298      assert trace.info.trace_id is not None
 299      assert trace.info.experiment_id == _get_experiment_id()
 300      assert trace.info.execution_time_ms >= 0.1 * 1e3  # at least 0.1 sec
 301      assert trace.info.status == SpanStatusCode.OK
 302      metadata = trace.info.request_metadata
 303      assert metadata[TraceMetadataKey.INPUTS] == '{"x": 1, "y": 2}'
 304      assert metadata[TraceMetadataKey.OUTPUTS] == "11"  # sum of the outputs
 305  
 306      assert len(trace.data.spans) == 5  # 1 root span + 3 square + 1 generate_numbers
 307  
 308      root_span = trace.data.spans[0]
 309      assert root_span.name == "predict_stream"
 310      assert root_span.inputs == {"x": 1, "y": 2}
 311      assert root_span.outputs == 11
 312      assert len(root_span.events) == 9
 313      assert root_span.events[0].name == "mlflow.chunk.item.0"
 314      assert root_span.events[0].attributes == {"mlflow.chunk.value": "0"}
 315      assert root_span.events[8].name == "mlflow.chunk.item.8"
 316  
 317      # Spans for the chid 'square' function
 318      for i in range(3):
 319          assert trace.data.spans[i + 1].name == "square"
 320          assert trace.data.spans[i + 1].inputs == {"t": i}
 321          assert trace.data.spans[i + 1].outputs == i**2
 322          assert trace.data.spans[i + 1].parent_id == root_span.span_id
 323  
 324      # Span for the 'generate_numbers' function
 325      assert trace.data.spans[4].name == "generate_numbers"
 326      assert trace.data.spans[4].inputs == {"z": 3}
 327      assert trace.data.spans[4].outputs == [0, 1, 2]  # list of outputs
 328      assert len(trace.data.spans[4].events) == 3
 329  
 330  
 331  def test_trace_with_databricks_tracking_uri(databricks_tracking_uri, monkeypatch):
 332      monkeypatch.setenv("MLFLOW_EXPERIMENT_NAME", "test")
 333      monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "bob")
 334      monkeypatch.setattr(mlflow.tracking.context.default_context, "_get_source_name", lambda: "test")
 335  
 336      model = DefaultTestModel()
 337  
 338      mock_trace_info = mock.MagicMock()
 339      mock_trace_info.trace_id = "123"
 340      mock_trace_info.trace_location = mock.MagicMock()
 341      mock_trace_info.trace_location.uc_schema = None
 342  
 343      with (
 344          mock.patch(
 345              "mlflow.tracing.client.TracingClient._upload_trace_data"
 346          ) as mock_upload_trace_data,
 347          mock.patch("mlflow.tracing.client._get_store") as mock_get_store,
 348      ):
 349          mock_get_store().start_trace.return_value = mock_trace_info
 350          model.predict(2, 5)
 351          mlflow.flush_trace_async_logging(terminate=True)
 352  
 353      mock_get_store().start_trace.assert_called_once()
 354      mock_upload_trace_data.assert_called_once()
 355  
 356  
 357  # NB: async logging should be no-op for model serving,
 358  # but we test it here to make sure it doesn't break
 359  @skip_when_testing_trace_sdk
 360  def test_trace_in_databricks_model_serving(
 361      mock_databricks_serving_with_tracing_env, async_logging_enabled
 362  ):
 363      # Dummy flask app for prediction
 364      import flask
 365  
 366      from mlflow.pyfunc.context import Context, set_prediction_context
 367  
 368      app = flask.Flask(__name__)
 369  
 370      @app.route("/invocations", methods=["POST"])
 371      def predict():
 372          data = json.loads(flask.request.data.decode("utf-8"))
 373          request_id = flask.request.headers.get("X-Request-ID")
 374  
 375          with set_prediction_context(Context(request_id=request_id)):
 376              prediction = TestModel().predict(**data)
 377  
 378          trace = pop_trace(request_id=request_id)
 379  
 380          result = json.dumps(
 381              {
 382                  "prediction": prediction,
 383                  "trace": trace,
 384              },
 385              default=str,
 386          )
 387          return flask.Response(response=result, status=200, mimetype="application/json")
 388  
 389      class TestModel:
 390          @mlflow.trace()
 391          def predict(self, x, y):
 392              z = x + y
 393              z = self.add_one(z)
 394              with mlflow.start_span(name="square") as span:
 395                  z = self.square(z)
 396                  span.add_event(SpanEvent("event", 0, attributes={"foo": "bar"}))
 397              return z
 398  
 399          @mlflow.trace(span_type=SpanType.LLM, name="custom", attributes={"delta": 1})
 400          def add_one(self, z):
 401              return z + 1
 402  
 403          def square(self, t):
 404              return t**2
 405  
 406      # Mimic scoring request
 407      databricks_request_id = "request-12345"
 408      response = app.test_client().post(
 409          "/invocations",
 410          headers={"X-Request-ID": databricks_request_id},
 411          data=json.dumps({"x": 2, "y": 5}),
 412      )
 413  
 414      assert response.status_code == 200
 415      assert response.json["prediction"] == 64
 416  
 417      trace_dict = response.json["trace"]
 418      trace = Trace.from_dict(trace_dict)
 419      assert trace.info.trace_id.startswith("tr-")
 420      assert trace.info.client_request_id == databricks_request_id
 421      assert trace.info.request_metadata[TRACE_SCHEMA_VERSION_KEY] == "3"
 422      assert len(trace.data.spans) == 3
 423  
 424      span_name_to_span = {span.name: span for span in trace.data.spans}
 425      root_span = span_name_to_span["predict"]
 426      assert isinstance(root_span._trace_id, str)
 427      assert isinstance(root_span.span_id, str)
 428      assert isinstance(root_span.start_time_ns, int)
 429      assert isinstance(root_span.end_time_ns, int)
 430      assert root_span.status.status_code.value == "OK"
 431      assert root_span.status.description == ""
 432      assert root_span.attributes == {
 433          "mlflow.traceRequestId": trace.info.trace_id,
 434          "mlflow.spanType": SpanType.UNKNOWN,
 435          "mlflow.spanFunctionName": "predict",
 436          "mlflow.spanInputs": {"x": 2, "y": 5},
 437          "mlflow.spanOutputs": 64,
 438      }
 439      assert root_span.events == []
 440  
 441      child_span_1 = span_name_to_span["custom"]
 442      assert child_span_1.parent_id == root_span.span_id
 443      assert child_span_1.attributes == {
 444          "delta": 1,
 445          "mlflow.traceRequestId": trace.info.trace_id,
 446          "mlflow.spanType": SpanType.LLM,
 447          "mlflow.spanFunctionName": "add_one",
 448          "mlflow.spanInputs": {"z": 7},
 449          "mlflow.spanOutputs": 8,
 450      }
 451      assert child_span_1.events == []
 452  
 453      child_span_2 = span_name_to_span["square"]
 454      assert child_span_2.parent_id == root_span.span_id
 455      assert child_span_2.attributes == {
 456          "mlflow.traceRequestId": trace.info.trace_id,
 457          "mlflow.spanType": SpanType.UNKNOWN,
 458      }
 459      assert asdict(child_span_2.events[0]) == {
 460          "name": "event",
 461          "timestamp": 0,
 462          "attributes": {"foo": "bar"},
 463      }
 464  
 465      # The trace should be removed from the buffer after being retrieved
 466      assert pop_trace(request_id=databricks_request_id) is None
 467  
 468      # In model serving, the traces should not be stored in the fluent API buffer
 469      traces = get_traces()
 470      assert len(traces) == 0
 471  
 472  
 473  @skip_when_testing_trace_sdk
 474  def test_trace_in_model_evaluation(monkeypatch, async_logging_enabled):
 475      from mlflow.pyfunc.context import Context, set_prediction_context
 476  
 477      monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "bob")
 478      monkeypatch.setattr(mlflow.tracking.context.default_context, "_get_source_name", lambda: "test")
 479  
 480      class TestModel:
 481          @mlflow.trace()
 482          def predict(self, x, y):
 483              return x + y
 484  
 485      model = TestModel()
 486  
 487      # mock _upload_trace_data to avoid generating trace data file
 488      with mlflow.start_run() as run:
 489          run_id = run.info.run_id
 490          request_id_1 = "tr-eval-123"
 491          with set_prediction_context(Context(request_id=request_id_1, is_evaluate=True)):
 492              model.predict(1, 2)
 493  
 494          request_id_2 = "tr-eval-456"
 495          with set_prediction_context(Context(request_id=request_id_2, is_evaluate=True)):
 496              model.predict(3, 4)
 497  
 498      if async_logging_enabled:
 499          mlflow.flush_trace_async_logging(terminate=True)
 500  
 501      trace = mlflow.get_trace(request_id_1)
 502      assert trace.info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run_id
 503      assert trace.info.tags[TraceTagKey.EVAL_REQUEST_ID] == request_id_1
 504  
 505      trace = mlflow.get_trace(request_id_2)
 506      assert trace.info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run_id
 507      assert trace.info.tags[TraceTagKey.EVAL_REQUEST_ID] == request_id_2
 508  
 509  
 510  @pytest.mark.parametrize("sync", [True, False])
 511  def test_trace_handle_exception_during_prediction(sync):
 512      # This test is to make sure that the exception raised by the main prediction
 513      # logic is raised properly and the trace is still logged.
 514      model = ErroringTestModel() if sync else ErroringAsyncTestModel()
 515  
 516      with pytest.raises(ValueError, match=r"Some error"):
 517          model.predict(2, 5) if sync else asyncio.run(model.predict(2, 5))
 518  
 519      # Trace should be logged even if the function fails, with status code ERROR
 520      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
 521      assert trace.info.trace_id is not None
 522      assert trace.info.state == TraceState.ERROR
 523      assert trace.info.request_metadata[TraceMetadataKey.INPUTS] == '{"x": 2, "y": 5}'
 524      assert trace.info.request_metadata[TraceMetadataKey.OUTPUTS] == ""
 525  
 526      assert trace.data.request == '{"x": 2, "y": 5}'
 527      assert trace.data.response is None
 528      assert len(trace.data.spans) == 2
 529  
 530  
 531  def test_trace_handle_exception_during_streaming():
 532      model = ErroringStreamTestModel()
 533  
 534      stream = model.predict_stream(2)
 535  
 536      chunks = []
 537      with pytest.raises(ValueError, match=r"Some error"):  # noqa: PT012
 538          for chunk in stream:
 539              chunks.append(chunk)
 540  
 541      # The test model raises an error after the first chunk
 542      assert len(chunks) == 1
 543  
 544      traces = get_traces()
 545      assert len(traces) == 1
 546      trace = traces[0]
 547      assert trace.info.state == TraceState.ERROR
 548      assert trace.info.request_metadata[TraceMetadataKey.INPUTS] == '{"x": 2}'
 549  
 550      # The test model is expected to produce three spans
 551      # 1. Root span (error - inherited from the child)
 552      # 2. First chunk span (OK)
 553      # 3. Second chunk span (error)
 554      spans = trace.data.spans
 555      assert len(spans) == 3
 556      assert spans[0].name == "predict_stream"
 557      assert spans[0].status.status_code == SpanStatusCode.ERROR
 558      assert spans[1].name == "some_operation_raise_error"
 559      assert spans[1].status.status_code == SpanStatusCode.OK
 560      assert spans[2].name == "some_operation_raise_error"
 561      assert spans[2].status.status_code == SpanStatusCode.ERROR
 562  
 563      # One chunk event + one exception event
 564      assert len(spans[0].events) == 2
 565      assert spans[0].events[0].name == "mlflow.chunk.item.0"
 566      assert spans[0].events[1].name == "exception"
 567  
 568  
 569  @pytest.mark.parametrize(
 570      "model",
 571      [
 572          DefaultTestModel(),
 573          DefaultAsyncTestModel(),
 574          StreamTestModel(),
 575          AsyncStreamTestModel(),
 576      ],
 577  )
 578  def test_trace_ignore_exception(monkeypatch, model):
 579      # This test is to make sure that the main prediction logic is not affected
 580      # by the exception raised by the tracing logic.
 581      def _call_model_and_assert_output(model):
 582          if isinstance(model, DefaultTestModel):
 583              output = model.predict(2, 5)
 584              assert output == 64
 585          elif isinstance(model, DefaultAsyncTestModel):
 586              output = asyncio.run(model.predict(2, 5))
 587              assert output == 64
 588          elif isinstance(model, StreamTestModel):
 589              stream = model.predict_stream(2, 5)
 590              assert len(list(stream)) == 21
 591          elif isinstance(model, AsyncStreamTestModel):
 592              astream = model.predict_stream(2, 5)
 593  
 594              async def _consume_stream():
 595                  return [chunk async for chunk in astream]
 596  
 597              stream = asyncio.run(_consume_stream())
 598              assert len(list(stream)) == 21
 599          else:
 600              raise ValueError("Unknown model type")
 601  
 602      # Exception during starting span: trace should not be logged.
 603      with mock.patch("mlflow.tracing.provider._get_tracer", side_effect=ValueError("Some error")):
 604          _call_model_and_assert_output(model)
 605  
 606      assert get_traces() == []
 607  
 608      # Exception during ending span: trace should not be logged.
 609      tracer = _get_tracer(__name__)
 610  
 611      def _always_fail(*args, **kwargs):
 612          raise ValueError("Some error")
 613  
 614      monkeypatch.setattr(tracer.span_processor, "on_end", _always_fail)
 615      _call_model_and_assert_output(model)
 616      assert len(get_traces()) == 0
 617  
 618  
 619  def test_trace_skip_resolving_unrelated_tags_to_traces():
 620      with mock.patch("mlflow.tracking.context.registry.DatabricksRepoRunContext") as mock_context:
 621          mock_context.in_context.return_value = ["unrelated tags"]
 622  
 623          model = DefaultTestModel()
 624          model.predict(2, 5)
 625  
 626      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
 627      assert "unrelated tags" not in trace.info.tags
 628  
 629  
 630  # Tracing SDK doesn't have `create_experiment` support
 631  @skip_when_testing_trace_sdk
 632  def test_trace_with_experiment_id():
 633      exp_1 = mlflow.create_experiment("exp_1")
 634      exp_2 = mlflow.set_experiment("exp_2").experiment_id  # active experiment
 635  
 636      @mlflow.trace(trace_destination=MlflowExperiment(exp_1))
 637      def predict_1():
 638          with mlflow.start_span(name="child_span"):
 639              return
 640  
 641      @mlflow.trace()
 642      def predict_2():
 643          pass
 644  
 645      predict_1()
 646      traces = get_traces(experiment_id=exp_1)
 647      assert len(traces) == 1
 648      assert traces[0].info.experiment_id == exp_1
 649      assert len(traces[0].data.spans) == 2
 650      assert get_traces(experiment_id=exp_2) == []
 651  
 652      predict_2()
 653      traces = get_traces(experiment_id=exp_2)
 654      assert len(traces) == 1
 655      assert traces[0].info.experiment_id == exp_2
 656  
 657  
 658  # Tracing SDK doesn't have `create_experiment` support
 659  @skip_when_testing_trace_sdk
 660  def test_trace_with_experiment_id_issue_warning_when_not_root_span():
 661      exp_1 = mlflow.create_experiment("exp_1")
 662  
 663      @mlflow.trace(trace_destination=MlflowExperiment(exp_1))
 664      def predict_1():
 665          return predict_2()
 666  
 667      @mlflow.trace(trace_destination=MlflowExperiment(exp_1))
 668      def predict_2():
 669          return
 670  
 671      with mock.patch("mlflow.tracing.provider._logger") as mock_logger:
 672          predict_1()
 673  
 674      assert mock_logger.warning.call_count == 1
 675      assert mock_logger.warning.call_args[0][0] == (
 676          "The `experiment_id` parameter can only be used for root spans, but the span "
 677          "`predict_2` is not a root span. The specified value `1` will be ignored."
 678      )
 679  
 680  
 681  def test_start_span_context_manager(async_logging_enabled):
 682      datetime_now = datetime.now()
 683  
 684      class TestModel:
 685          def predict(self, x, y):
 686              with mlflow.start_span(name="root_span") as root_span:
 687                  root_span.set_inputs({"x": x, "y": y})
 688                  z = x + y
 689  
 690                  with mlflow.start_span(name="child_span", span_type=SpanType.LLM) as child_span:
 691                      child_span.set_inputs(z)
 692                      z = z + 2
 693                      child_span.set_outputs(z)
 694                      child_span.set_attributes({"delta": 2, "time": datetime_now})
 695  
 696                  # Ensure deterministic span order on Windows by forcing different start_time_ns
 697                  time.sleep(0.001)
 698                  res = self.square(z)
 699                  root_span.set_outputs(res)
 700              return res
 701  
 702          def square(self, t):
 703              with mlflow.start_span(name="child_span") as span:
 704                  span.set_inputs({"t": t})
 705                  res = t**2
 706                  time.sleep(0.1)
 707                  span.set_outputs(res)
 708                  return res
 709  
 710      model = TestModel()
 711      model.predict(1, 2)
 712  
 713      if async_logging_enabled:
 714          mlflow.flush_trace_async_logging(terminate=True)
 715  
 716      traces = get_traces()
 717      assert len(traces) == 1
 718      trace = traces[0]
 719      assert trace.info.trace_id is not None
 720      assert trace.info.experiment_id == _get_experiment_id()
 721      assert trace.info.execution_time_ms >= 0.1 * 1e3  # at least 0.1 sec
 722      assert trace.info.state == TraceState.OK
 723      assert trace.info.request_metadata[TraceMetadataKey.INPUTS] == '{"x": 1, "y": 2}'
 724      assert trace.info.request_metadata[TraceMetadataKey.OUTPUTS] == "25"
 725  
 726      assert trace.data.request == '{"x": 1, "y": 2}'
 727      assert trace.data.response == "25"
 728      assert len(trace.data.spans) == 3
 729  
 730      root_span = trace.data.spans[0]
 731      assert root_span.name == "root_span"
 732      assert root_span.parent_id is None
 733      assert root_span.attributes == {
 734          "mlflow.traceRequestId": trace.info.trace_id,
 735          "mlflow.spanType": "UNKNOWN",
 736          "mlflow.spanInputs": {"x": 1, "y": 2},
 737          "mlflow.spanOutputs": 25,
 738      }
 739  
 740      child_span_1 = trace.data.spans[1]
 741      assert child_span_1.name == "child_span"
 742      assert child_span_1.parent_id == root_span.span_id
 743      assert child_span_1.attributes == {
 744          "delta": 2,
 745          "time": str(datetime_now),
 746          "mlflow.traceRequestId": trace.info.trace_id,
 747          "mlflow.spanType": "LLM",
 748          "mlflow.spanInputs": 3,
 749          "mlflow.spanOutputs": 5,
 750      }
 751  
 752      child_span_2 = trace.data.spans[2]
 753      assert child_span_2.name == "child_span"
 754      assert child_span_2.parent_id == root_span.span_id
 755      assert child_span_2.attributes == {
 756          "mlflow.traceRequestId": trace.info.trace_id,
 757          "mlflow.spanType": "UNKNOWN",
 758          "mlflow.spanInputs": {"t": 5},
 759          "mlflow.spanOutputs": 25,
 760      }
 761      assert child_span_2.start_time_ns <= child_span_2.end_time_ns - 0.1 * 1e6
 762  
 763  
 764  def test_start_span_context_manager_with_imperative_apis(async_logging_enabled):
 765      # This test is to make sure that the spans created with fluent APIs and imperative APIs
 766      # (via MLflow client) are correctly linked together. This usage is not recommended but
 767      # should be supported for the advanced use cases like using LangChain callbacks as a
 768      # part of broader tracing.
 769      class TestModel:
 770          def predict(self, x, y):
 771              with mlflow.start_span(name="root_span") as root_span:
 772                  root_span.set_inputs({"x": x, "y": y})
 773                  z = x + y
 774  
 775                  child_span = start_span_no_context(
 776                      name="child_span_1",
 777                      span_type=SpanType.LLM,
 778                      parent_span=root_span,
 779                  )
 780                  child_span.set_inputs(z)
 781  
 782                  z = z + 2
 783                  time.sleep(0.1)
 784  
 785                  child_span.set_outputs(z)
 786                  child_span.set_attributes({"delta": 2})
 787                  child_span.end()
 788  
 789                  root_span.set_outputs(z)
 790              return z
 791  
 792      model = TestModel()
 793      model.predict(1, 2)
 794  
 795      if async_logging_enabled:
 796          mlflow.flush_trace_async_logging(terminate=True)
 797  
 798      traces = get_traces()
 799      assert len(traces) == 1
 800      trace = traces[0]
 801      assert trace.info.trace_id is not None
 802      assert trace.info.experiment_id == _get_experiment_id()
 803      assert trace.info.execution_time_ms >= 0.1 * 1e3  # at least 0.1 sec
 804      assert trace.info.state == TraceState.OK
 805      assert trace.info.request_metadata[TraceMetadataKey.INPUTS] == '{"x": 1, "y": 2}'
 806      assert trace.info.request_metadata[TraceMetadataKey.OUTPUTS] == "5"
 807  
 808      assert trace.data.request == '{"x": 1, "y": 2}'
 809      assert trace.data.response == "5"
 810      assert len(trace.data.spans) == 2
 811  
 812      span_name_to_span = {span.name: span for span in trace.data.spans}
 813      root_span = span_name_to_span["root_span"]
 814      assert root_span.parent_id is None
 815      assert root_span.attributes == {
 816          "mlflow.traceRequestId": trace.info.trace_id,
 817          "mlflow.spanType": "UNKNOWN",
 818          "mlflow.spanInputs": {"x": 1, "y": 2},
 819          "mlflow.spanOutputs": 5,
 820      }
 821  
 822      child_span_1 = span_name_to_span["child_span_1"]
 823      assert child_span_1.parent_id == root_span.span_id
 824      assert child_span_1.attributes == {
 825          "delta": 2,
 826          "mlflow.traceRequestId": trace.info.trace_id,
 827          "mlflow.spanType": "LLM",
 828          "mlflow.spanInputs": 3,
 829          "mlflow.spanOutputs": 5,
 830      }
 831  
 832  
 833  def test_mlflow_trace_isolated_from_other_otel_processors():
 834      # Set up non-MLFlow tracer
 835      import opentelemetry.sdk.trace as trace_sdk
 836      from opentelemetry import trace
 837  
 838      class MockOtelExporter(trace_sdk.export.SpanExporter):
 839          def __init__(self):
 840              self.exported_spans = []
 841  
 842          def export(self, spans):
 843              self.exported_spans.extend(spans)
 844  
 845      other_exporter = MockOtelExporter()
 846      provider = trace_sdk.TracerProvider()
 847      processor = trace_sdk.export.SimpleSpanProcessor(other_exporter)
 848      provider.add_span_processor(processor)
 849      trace.set_tracer_provider(provider)
 850  
 851      # Create MLflow trace
 852      with mlflow.start_span(name="mlflow_span"):
 853          pass
 854  
 855      # Create non-MLflow trace
 856      tracer = trace.get_tracer(__name__)
 857      with tracer.start_as_current_span("non_mlflow_span"):
 858          pass
 859  
 860      # MLflow only processes spans created with MLflow APIs
 861      assert len(get_traces()) == 1
 862      assert (
 863          mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True).data.spans[0].name
 864          == "mlflow_span"
 865      )
 866  
 867      # Other spans are processed by the other processor
 868      assert len(other_exporter.exported_spans) == 1
 869      assert other_exporter.exported_spans[0].name == "non_mlflow_span"
 870  
 871  
 872  def test_get_trace():
 873      with mock.patch("mlflow.tracing.display.get_display_handler") as mock_get_display_handler:
 874          model = DefaultTestModel()
 875          model.predict(2, 5)
 876  
 877          trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
 878          trace_id = trace.info.trace_id
 879          mock_get_display_handler.reset_mock()
 880  
 881          # Fetch trace from in-memory buffer
 882          trace_in_memory = mlflow.get_trace(trace_id)
 883          assert trace.info.trace_id == trace_in_memory.info.trace_id
 884          mock_get_display_handler.assert_not_called()
 885  
 886          # Fetch trace from backend
 887          trace_from_backend = mlflow.get_trace(trace.info.trace_id)
 888          assert trace.info.trace_id == trace_from_backend.info.trace_id
 889          mock_get_display_handler.assert_not_called()
 890  
 891      # If not found, return None with warning
 892      with mock.patch("mlflow.tracing.fluent._logger") as mock_logger:
 893          assert mlflow.get_trace("not_found") is None
 894          mock_logger.warning.assert_called_once()
 895  
 896  
 897  def test_test_search_traces_empty(mock_client):
 898      mock_client.search_traces.return_value = PagedList([], token=None)
 899  
 900      traces = mlflow.search_traces()
 901      assert len(traces) == 0
 902  
 903      if not IS_TRACING_SDK_ONLY:
 904          default_columns = Trace.pandas_dataframe_columns()
 905          assert traces.columns.tolist() == default_columns
 906  
 907          traces = mlflow.search_traces(extract_fields=["foo.inputs.bar"])
 908          assert traces.columns.tolist() == [*default_columns, "foo.inputs.bar"]
 909  
 910          mock_client.search_traces.assert_called()
 911  
 912  
 913  @pytest.mark.parametrize("return_type", ["pandas", "list"])
 914  def test_search_traces(return_type, mock_client):
 915      if return_type == "pandas" and IS_TRACING_SDK_ONLY:
 916          pytest.skip("Skipping test because mlflow or mlflow-skinny is not installed.")
 917  
 918      mock_client.search_traces.return_value = PagedList(
 919          [
 920              Trace(
 921                  info=create_test_trace_info(f"tr-{i}"),
 922                  data=TraceData([]),
 923              )
 924              for i in range(10)
 925          ],
 926          token=None,
 927      )
 928  
 929      traces = mlflow.search_traces(
 930          locations=["1"],
 931          filter_string="name = 'foo'",
 932          max_results=10,
 933          order_by=["timestamp DESC"],
 934          return_type=return_type,
 935      )
 936  
 937      if return_type == "pandas":
 938          import pandas as pd
 939  
 940          assert isinstance(traces, pd.DataFrame)
 941      else:
 942          assert isinstance(traces, list)
 943          assert all(isinstance(trace, Trace) for trace in traces)
 944  
 945      assert len(traces) == 10
 946      mock_client.search_traces.assert_called_once_with(
 947          experiment_ids=None,
 948          run_id=None,
 949          filter_string="name = 'foo'",
 950          max_results=10,
 951          order_by=["timestamp DESC"],
 952          page_token=None,
 953          model_id=None,
 954          include_spans=True,
 955          locations=["1"],
 956      )
 957  
 958  
 959  def test_search_traces_invalid_return_types(mock_client):
 960      with pytest.raises(MlflowException, match=r"Invalid return type"):
 961          mlflow.search_traces(return_type="invalid")
 962  
 963      with pytest.raises(MlflowException, match=r"The `extract_fields`"):
 964          mlflow.search_traces(extract_fields=["foo.inputs.bar"], return_type="list")
 965  
 966  
 967  def test_search_traces_validates_experiment_ids_type():
 968      with pytest.raises(MlflowException, match=r"locations must be a list"):
 969          mlflow.search_traces(locations=4)
 970  
 971      with pytest.raises(MlflowException, match=r"locations must be a list"):
 972          mlflow.search_traces(locations="4")
 973  
 974  
 975  def test_search_traces_with_pagination(mock_client):
 976      traces = [
 977          Trace(
 978              info=create_test_trace_info(f"tr-{i}"),
 979              data=TraceData([]),
 980          )
 981          for i in range(30)
 982      ]
 983  
 984      mock_client.search_traces.side_effect = [
 985          PagedList(traces[:10], token="token-1"),
 986          PagedList(traces[10:20], token="token-2"),
 987          PagedList(traces[20:], token=None),
 988      ]
 989  
 990      traces = mlflow.search_traces(locations=["1"])
 991  
 992      assert len(traces) == 30
 993      common_args = {
 994          "experiment_ids": None,
 995          "run_id": None,
 996          "max_results": SEARCH_TRACES_DEFAULT_MAX_RESULTS,
 997          "filter_string": None,
 998          "order_by": None,
 999          "include_spans": True,
1000          "model_id": None,
1001          "locations": ["1"],
1002      }
1003      mock_client.search_traces.assert_has_calls([
1004          mock.call(**common_args, page_token=None),
1005          mock.call(**common_args, page_token="token-1"),
1006          mock.call(**common_args, page_token="token-2"),
1007      ])
1008  
1009  
1010  def test_search_traces_with_default_experiment_id(mock_client):
1011      mock_client.search_traces.return_value = PagedList([], token=None)
1012      with mock.patch("mlflow.tracking.fluent._get_experiment_id", return_value="123"):
1013          mlflow.search_traces()
1014  
1015      mock_client.search_traces.assert_called_once_with(
1016          experiment_ids=None,
1017          run_id=None,
1018          filter_string=None,
1019          max_results=SEARCH_TRACES_DEFAULT_MAX_RESULTS,
1020          order_by=None,
1021          page_token=None,
1022          model_id=None,
1023          include_spans=True,
1024          locations=["123"],
1025      )
1026  
1027  
1028  @skip_when_testing_trace_sdk
1029  def test_search_traces_yields_expected_dataframe_contents(monkeypatch):
1030      model = DefaultTestModel()
1031      expected_traces = []
1032      for _ in range(10):
1033          model.predict(2, 5)
1034          time.sleep(0.1)
1035  
1036          trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
1037          expected_traces.append(trace)
1038  
1039      df = mlflow.search_traces(max_results=10, order_by=["timestamp ASC"], flush=True)
1040      assert df.columns.tolist() == [
1041          "trace_id",
1042          "trace",
1043          "client_request_id",
1044          "state",
1045          "request_time",
1046          "execution_duration",
1047          "request",
1048          "response",
1049          "trace_metadata",
1050          "tags",
1051          "spans",
1052          "assessments",
1053      ]
1054      for idx, trace in enumerate(expected_traces):
1055          assert df.iloc[idx].trace_id == trace.info.trace_id
1056          assert Trace.from_json(df.iloc[idx].trace).info.trace_id == trace.info.trace_id
1057          assert df.iloc[idx].client_request_id == trace.info.client_request_id
1058          assert df.iloc[idx].state == trace.info.state
1059          assert df.iloc[idx].request_time == trace.info.request_time
1060          assert df.iloc[idx].execution_duration == pytest.approx(
1061              trace.info.execution_duration, abs=1
1062          )
1063          assert df.iloc[idx].request == json.loads(trace.data.request)
1064          assert df.iloc[idx].response == json.loads(trace.data.response)
1065          assert df.iloc[idx].trace_metadata == trace.info.trace_metadata
1066          assert df.iloc[idx].spans == [s.to_dict() for s in trace.data.spans]
1067          assert df.iloc[idx].tags == trace.info.tags
1068          assert df.iloc[idx].assessments == trace.info.assessments
1069  
1070  
1071  @skip_when_testing_trace_sdk
1072  def test_search_traces_handles_missing_response_tags_and_metadata(mock_client):
1073      mock_client.search_traces.return_value = PagedList(
1074          [
1075              Trace(
1076                  info=TraceInfo(
1077                      trace_id="5",
1078                      trace_location=TraceLocation.from_experiment_id("test"),
1079                      request_time=1,
1080                      execution_duration=2,
1081                      state=TraceState.OK,
1082                  ),
1083                  data=TraceData(spans=[]),
1084              )
1085          ],
1086          token=None,
1087      )
1088  
1089      df = mlflow.search_traces()
1090      assert df["response"].isnull().all()
1091      assert df["tags"].tolist() == [{}]
1092      assert df["trace_metadata"].tolist() == [{}]
1093  
1094  
1095  @skip_when_testing_trace_sdk
1096  def test_search_traces_extracts_fields_as_expected():
1097      model = DefaultTestModel()
1098      model.predict(2, 5)
1099  
1100      df = mlflow.search_traces(
1101          extract_fields=["predict.inputs.x", "predict.outputs", "add_one_with_custom_name.inputs.z"],
1102          flush=True,
1103      )
1104      assert df["predict.inputs.x"].tolist() == [2]
1105      assert df["predict.outputs"].tolist() == [64]
1106      assert df["add_one_with_custom_name.inputs.z"].tolist() == [7]
1107  
1108  
1109  # no spans have the input or output with name,
1110  # some span has an input but we're looking for output,
1111  @skip_when_testing_trace_sdk
1112  def test_search_traces_with_input_and_no_output():
1113      with mlflow.start_span(name="with_input_and_no_output") as span:
1114          span.set_inputs({"a": 1})
1115  
1116      df = mlflow.search_traces(
1117          extract_fields=["with_input_and_no_output.inputs.a", "with_input_and_no_output.outputs"],
1118          flush=True,
1119      )
1120      assert df["with_input_and_no_output.inputs.a"].tolist() == [1]
1121      assert df["with_input_and_no_output.outputs"].isnull().all()
1122  
1123  
1124  @skip_when_testing_trace_sdk
1125  def test_search_traces_with_non_dict_span_inputs_outputs():
1126      with mlflow.start_span(name="non_dict_span") as span:
1127          span.set_inputs(["a", "b"])
1128          span.set_outputs([1, 2, 3])
1129  
1130      df = mlflow.search_traces(
1131          extract_fields=["non_dict_span.inputs", "non_dict_span.outputs", "non_dict_span.inputs.x"],
1132          flush=True,
1133      )
1134      assert df["non_dict_span.inputs"].tolist() == [["a", "b"]]
1135      assert df["non_dict_span.outputs"].tolist() == [[1, 2, 3]]
1136      assert df["non_dict_span.inputs.x"].isnull().all()
1137  
1138  
1139  @skip_when_testing_trace_sdk
1140  def test_search_traces_extract_fields_preserves_standard_columns():
1141      with mlflow.start_span(name="test_span") as span:
1142          span.set_inputs({"x": 1})
1143          span.set_outputs({"y": 2})
1144  
1145      df = mlflow.search_traces(extract_fields=["test_span.inputs.x"], flush=True)
1146  
1147      # Verify standard columns still exist
1148      assert "trace_id" in df.columns
1149      assert "spans" in df.columns
1150      assert "tags" in df.columns
1151      assert "request" in df.columns
1152      assert "response" in df.columns
1153  
1154      # Verify extract field was added
1155      assert "test_span.inputs.x" in df.columns
1156      assert df["test_span.inputs.x"].tolist() == [1]
1157  
1158  
1159  @skip_when_testing_trace_sdk
1160  def test_search_traces_with_multiple_spans_with_same_name():
1161      class TestModel:
1162          @mlflow.trace(name="duplicate_name")
1163          def predict(self, x, y):
1164              z = x + y
1165              z = self.add_one(z)
1166              z = mlflow.trace(self.square)(z)
1167              return z  # noqa: RET504
1168  
1169          @mlflow.trace(span_type=SpanType.LLM, name="duplicate_name", attributes={"delta": 1})
1170          def add_one(self, z):
1171              return z + 1
1172  
1173          def square(self, t):
1174              res = t**2
1175              time.sleep(0.1)
1176              return res
1177  
1178      model = TestModel()
1179      model.predict(2, 5)
1180  
1181      df = mlflow.search_traces(
1182          extract_fields=[
1183              "duplicate_name.inputs.x",
1184              "duplicate_name.inputs.y",
1185              "duplicate_name.inputs.z",
1186          ],
1187          flush=True,
1188      )
1189      # Duplicate spans would all be null
1190      assert df["duplicate_name.inputs.x"].isnull().all()
1191      assert df["duplicate_name.inputs.y"].isnull().all()
1192      assert df["duplicate_name.inputs.z"].tolist() == [7]
1193  
1194  
1195  # Test a field that doesn't exist for extraction - we shouldn't throw, just return empty column
1196  @skip_when_testing_trace_sdk
1197  def test_search_traces_with_non_existent_field():
1198      model = DefaultTestModel()
1199      model.predict(2, 5)
1200  
1201      df = mlflow.search_traces(
1202          extract_fields=[
1203              "predict.inputs.k",
1204              "predict.inputs.x",
1205              "predict.outputs",
1206              "add_one_with_custom_name.inputs.z",
1207          ],
1208          flush=True,
1209      )
1210      assert df["predict.inputs.k"].isnull().all()
1211      assert df["predict.inputs.x"].tolist() == [2]
1212      assert df["predict.outputs"].tolist() == [64]
1213      assert df["add_one_with_custom_name.inputs.z"].tolist() == [7]
1214  
1215  
1216  @skip_when_testing_trace_sdk
1217  def test_search_traces_span_and_field_name_with_dot():
1218      with mlflow.start_span(name="span.name") as span:
1219          span.set_inputs({"a.b": 0})
1220          span.set_outputs({"x.y": 1})
1221  
1222      df = mlflow.search_traces(
1223          extract_fields=[
1224              "`span.name`.inputs",
1225              "`span.name`.inputs.`a.b`",
1226              "`span.name`.outputs",
1227              "`span.name`.outputs.`x.y`",
1228          ],
1229          flush=True,
1230      )
1231  
1232      assert df["span.name.inputs"].tolist() == [{"a.b": 0}]
1233      assert df["span.name.inputs.a.b"].tolist() == [0]
1234      assert df["span.name.outputs"].tolist() == [{"x.y": 1}]
1235      assert df["span.name.outputs.x.y"].tolist() == [1]
1236  
1237  
1238  @skip_when_testing_trace_sdk
1239  def test_search_traces_with_run_id():
1240      def _create_trace(name, tags=None):
1241          with mlflow.start_span(name=name) as span:
1242              for k, v in (tags or {}).items():
1243                  mlflow.set_trace_tag(trace_id=span.request_id, key=k, value=v)
1244          return span.request_id
1245  
1246      def _get_names(traces):
1247          tags = traces["tags"].tolist()
1248          return [tags[i].get(TraceTagKey.TRACE_NAME) for i in range(len(tags))]
1249  
1250      with mlflow.start_run() as run1:
1251          _create_trace(name="tr-1")
1252          _create_trace(name="tr-2", tags={"fruit": "apple"})
1253  
1254      with mlflow.start_run() as run2:
1255          _create_trace(name="tr-3")
1256          _create_trace(name="tr-4", tags={"fruit": "banana"})
1257          _create_trace(name="tr-5", tags={"fruit": "apple"})
1258  
1259      traces = mlflow.search_traces(flush=True)
1260      assert set(_get_names(traces)) == {"tr-5", "tr-4", "tr-3", "tr-2", "tr-1"}
1261  
1262      traces = mlflow.search_traces(run_id=run1.info.run_id, flush=True)
1263      assert set(_get_names(traces)) == {"tr-2", "tr-1"}
1264  
1265      traces = mlflow.search_traces(
1266          run_id=run2.info.run_id,
1267          filter_string="tag.fruit = 'apple'",
1268          flush=True,
1269      )
1270      assert _get_names(traces) == ["tr-5"]
1271  
1272      with pytest.raises(MlflowException, match="You cannot filter by run_id when it is already"):
1273          mlflow.search_traces(
1274              run_id=run2.info.run_id,
1275              filter_string="metadata.mlflow.sourceRun = '123'",
1276          )
1277  
1278      with pytest.raises(MlflowException, match=f"Run {run1.info.run_id} belongs to"):
1279          mlflow.search_traces(run_id=run1.info.run_id, locations=["1"])
1280  
1281  
1282  @pytest.mark.parametrize(
1283      "extract_fields",
1284      [
1285          ["span.llm.inputs"],
1286          ["span.llm.inputs.x"],
1287          ["span.llm.outputs"],
1288      ],
1289  )
1290  @skip_when_testing_trace_sdk
1291  def test_search_traces_invalid_extract_fields(extract_fields):
1292      with pytest.raises(MlflowException, match="Invalid field type"):
1293          mlflow.search_traces(extract_fields=extract_fields)
1294  
1295  
1296  def test_get_last_active_trace_id():
1297      assert mlflow.get_last_active_trace_id() is None
1298  
1299      @mlflow.trace()
1300      def predict(x, y):
1301          return x + y
1302  
1303      predict(1, 2)
1304      predict(2, 5)
1305      predict(3, 6)
1306  
1307      trace_id = mlflow.get_last_active_trace_id()
1308      trace = mlflow.get_trace(trace_id, flush=True)
1309      assert trace.info.trace_id is not None
1310      assert trace.data.request == '{"x": 3, "y": 6}'
1311  
1312      # Mutation of the copy should not affect the original trace logged in the backend
1313      trace.info.state = TraceState.ERROR
1314      original_trace = mlflow.get_trace(trace.info.trace_id)
1315      assert original_trace.info.state == TraceState.OK
1316  
1317  
1318  def test_get_last_active_trace_thread_local():
1319      assert mlflow.get_last_active_trace_id() is None
1320  
1321      def run(id):
1322          @mlflow.trace(name=f"predict_{id}")
1323          def predict(x, y):
1324              return x + y
1325  
1326          predict(1, 2)
1327  
1328          return mlflow.get_last_active_trace_id(thread_local=True)
1329  
1330      with ThreadPoolExecutor(
1331          max_workers=4, thread_name_prefix="test-tracing-fluent-last-active"
1332      ) as executor:
1333          futures = [executor.submit(run, i) for i in range(10)]
1334          trace_ids = [future.result() for future in futures]
1335  
1336      assert len(trace_ids) == 10
1337      for i, trace_id in enumerate(trace_ids):
1338          trace = mlflow.get_trace(trace_id, flush=True)
1339          assert trace.info.state == TraceState.OK
1340          assert trace.data.spans[0].name == f"predict_{i}"
1341  
1342  
1343  def test_trace_with_classmethod():
1344      class TestModel:
1345          @mlflow.trace
1346          @classmethod
1347          def predict(cls, x, y):
1348              return x + y
1349  
1350      # Call the classmethod
1351      result = TestModel.predict(1, 2)
1352      assert result == 3
1353  
1354      # Get the last trace and verify inputs and outputs
1355      trace_id = mlflow.get_last_active_trace_id()
1356      assert trace_id is not None
1357  
1358      trace = mlflow.get_trace(trace_id, flush=True)
1359      assert trace is not None
1360      assert len(trace.data.spans) > 0
1361  
1362      # The first span should be our traced function
1363      span = trace.data.spans[0]
1364      assert span.name == "predict"
1365      assert span.inputs == {"x": 1, "y": 2}
1366      assert span.outputs == 3
1367  
1368  
1369  def test_trace_with_classmethod_order_reversed():
1370      class TestModel:
1371          @classmethod
1372          @mlflow.trace
1373          def predict(cls, x, y):
1374              return x + y
1375  
1376      # Call the classmethod
1377      result = TestModel.predict(1, 2)
1378      assert result == 3
1379  
1380      # Get the last trace and verify inputs and outputs
1381      trace_id = mlflow.get_last_active_trace_id()
1382      assert trace_id is not None
1383  
1384      trace = mlflow.get_trace(trace_id, flush=True)
1385      assert trace is not None
1386      assert len(trace.data.spans) > 0
1387  
1388      # The first span should be our traced function
1389      span = trace.data.spans[0]
1390      assert span.name == "predict"
1391      assert span.inputs == {"x": 1, "y": 2}
1392      assert span.outputs == 3
1393  
1394  
1395  def test_trace_with_staticmethod():
1396      class TestModel:
1397          @mlflow.trace
1398          @staticmethod
1399          def predict(x, y):
1400              return x + y
1401  
1402      # Call the staticmethod
1403      result = TestModel.predict(1, 2)
1404      assert result == 3
1405  
1406      # Get the last trace and verify inputs and outputs
1407      trace_id = mlflow.get_last_active_trace_id()
1408      assert trace_id is not None
1409  
1410      trace = mlflow.get_trace(trace_id, flush=True)
1411      assert trace is not None
1412      assert len(trace.data.spans) > 0
1413  
1414      # The first span should be our traced function
1415      span = trace.data.spans[0]
1416      assert span.name == "predict"
1417      assert span.inputs == {"x": 1, "y": 2}
1418      assert span.outputs == 3
1419  
1420  
1421  def test_trace_with_staticmethod_order_reversed():
1422      class TestModel:
1423          @staticmethod
1424          @mlflow.trace
1425          def predict(x, y):
1426              return x + y
1427  
1428      # Call the staticmethod
1429      result = TestModel.predict(1, 2)
1430      assert result == 3
1431  
1432      # Get the last trace and verify inputs and outputs
1433      trace_id = mlflow.get_last_active_trace_id()
1434      assert trace_id is not None
1435  
1436      trace = mlflow.get_trace(trace_id, flush=True)
1437      assert trace is not None
1438      assert len(trace.data.spans) > 0
1439  
1440      # The first span should be our traced function
1441      span = trace.data.spans[0]
1442      assert span.name == "predict"
1443      assert span.inputs == {"x": 1, "y": 2}
1444      assert span.outputs == 3
1445  
1446  
1447  def test_update_current_trace():
1448      @mlflow.trace(name="root_function")
1449      def f(x):
1450          mlflow.update_current_trace(tags={"fruit": "apple", "animal": "dog"})
1451          return g(x) + 1
1452  
1453      @mlflow.trace(name="level_1_function")
1454      def g(y):
1455          with mlflow.start_span(name="level_2_span"):
1456              mlflow.update_current_trace(tags={"fruit": "orange", "vegetable": "carrot"})
1457              return h(y) * 2
1458  
1459      @mlflow.trace(name="level_3_function")
1460      def h(z):
1461          with mlflow.start_span(name="level_4_span"):
1462              with mlflow.start_span(name="level_5_span"):
1463                  mlflow.update_current_trace(tags={"depth": "deep", "level": "5"})
1464                  return z + 10
1465  
1466      f(1)
1467  
1468      expected_tags = {
1469          "animal": "dog",
1470          "fruit": "orange",
1471          "vegetable": "carrot",
1472          "depth": "deep",
1473          "level": "5",
1474      }
1475  
1476      # Validate in-memory trace
1477      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
1478      assert trace.info.state == TraceState.OK
1479      tags = {k: v for k, v in trace.info.tags.items() if not k.startswith("mlflow.")}
1480      assert tags == expected_tags
1481  
1482      # Validate backend trace
1483      traces = get_traces()
1484      assert len(traces) == 1
1485      assert traces[0].info.state == TraceState.OK
1486      tags = {k: v for k, v in traces[0].info.tags.items() if not k.startswith("mlflow.")}
1487      assert tags == expected_tags
1488  
1489      # Verify trace can be searched by span names (only when database backend is available)
1490      if not IS_TRACING_SDK_ONLY:
1491          trace_by_root_span = mlflow.search_traces(
1492              filter_string='span.name = "root_function"', return_type="list", flush=True
1493          )
1494          assert len(trace_by_root_span) == 1
1495  
1496          trace_by_level_2_span = mlflow.search_traces(
1497              filter_string='span.name = "level_2_span"', return_type="list", flush=True
1498          )
1499          assert len(trace_by_level_2_span) == 1
1500  
1501          trace_by_level_5_span = mlflow.search_traces(
1502              filter_string='span.name = "level_5_span"', return_type="list", flush=True
1503          )
1504          assert len(trace_by_level_5_span) == 1
1505  
1506          # All searches should return the same trace
1507          assert trace_by_root_span[0].info.request_id == trace.info.request_id
1508          assert trace_by_level_2_span[0].info.request_id == trace.info.request_id
1509          assert trace_by_level_5_span[0].info.request_id == trace.info.request_id
1510  
1511  
1512  def test_update_current_trace_with_client_request_id():
1513      from mlflow.tracing.trace_manager import InMemoryTraceManager
1514  
1515      # Test updating during span execution
1516      with mlflow.start_span("test_span") as span:
1517          # Update with both tags and client_request_id
1518          mlflow.update_current_trace(tags={"operation": "test"}, client_request_id="req-12345")
1519  
1520          # Check in-memory trace during execution
1521          trace_manager = InMemoryTraceManager.get_instance()
1522          with trace_manager.get_trace(span.trace_id) as trace:
1523              assert trace.info.client_request_id == "req-12345"
1524              tags = {k: v for k, v in trace.info.tags.items() if not k.startswith("mlflow.")}
1525              assert tags["operation"] == "test"
1526  
1527      # Test with tags only
1528      with mlflow.start_span("test_span_2") as span:
1529          mlflow.update_current_trace(tags={"operation": "tags_only"})
1530  
1531          trace_manager = InMemoryTraceManager.get_instance()
1532          with trace_manager.get_trace(span.trace_id) as trace:
1533              assert trace.info.client_request_id is None
1534              tags = {k: v for k, v in trace.info.tags.items() if not k.startswith("mlflow.")}
1535              assert tags["operation"] == "tags_only"
1536  
1537      # Test with client_request_id only
1538      with mlflow.start_span("test_span_3") as span:
1539          mlflow.update_current_trace(client_request_id="req-67890")
1540  
1541          trace_manager = InMemoryTraceManager.get_instance()
1542          with trace_manager.get_trace(span.trace_id) as trace:
1543              assert trace.info.client_request_id == "req-67890"
1544  
1545  
1546  def test_update_current_trace_client_request_id_overwrites():
1547      from mlflow.tracing.trace_manager import InMemoryTraceManager
1548  
1549      with mlflow.start_span("overwrite_test") as span:
1550          # First set
1551          mlflow.update_current_trace(client_request_id="req-initial")
1552  
1553          # Overwrite with new value
1554          mlflow.update_current_trace(client_request_id="req-updated")
1555  
1556          # Check during execution
1557          trace_manager = InMemoryTraceManager.get_instance()
1558          with trace_manager.get_trace(span.trace_id) as trace:
1559              # Should have the updated value, not the initial one
1560              assert trace.info.client_request_id == "req-updated"
1561  
1562  
1563  def test_update_current_trace_client_request_id_stringification():
1564      from mlflow.tracing.trace_manager import InMemoryTraceManager
1565  
1566      test_cases = [
1567          (123, "123"),
1568          (45.67, "45.67"),
1569          (True, "True"),
1570          (False, "False"),
1571          (None, None),  # None should remain None
1572          (["list", "value"], "['list', 'value']"),
1573          ({"dict": "value"}, "{'dict': 'value'}"),
1574      ]
1575  
1576      for input_value, expected_output in test_cases:
1577          with mlflow.start_span(f"stringification_test_{input_value}") as span:
1578              if input_value is None:
1579                  # None should not update the client_request_id
1580                  mlflow.update_current_trace(client_request_id=input_value)
1581                  trace_manager = InMemoryTraceManager.get_instance()
1582                  with trace_manager.get_trace(span.trace_id) as trace:
1583                      assert trace.info.client_request_id is None
1584              else:
1585                  mlflow.update_current_trace(client_request_id=input_value)
1586                  trace_manager = InMemoryTraceManager.get_instance()
1587                  with trace_manager.get_trace(span.trace_id) as trace:
1588                      assert trace.info.client_request_id == expected_output
1589                      assert isinstance(trace.info.client_request_id, str)
1590  
1591  
1592  def test_update_current_trace_with_metadata():
1593      @mlflow.trace
1594      def f():
1595          mlflow.update_current_trace(
1596              metadata={
1597                  "mlflow.source.name": "inference.py",
1598                  "mlflow.source.git.commit": "1234567890",
1599                  "mlflow.source.git.repoURL": "https://github.com/mlflow/mlflow",
1600                  "non-string-metadata": 123,
1601              },
1602          )
1603  
1604      f()
1605  
1606      expected_metadata = {
1607          "mlflow.source.name": "inference.py",
1608          "mlflow.source.git.commit": "1234567890",
1609          "mlflow.source.git.repoURL": "https://github.com/mlflow/mlflow",
1610          "non-string-metadata": "123",  # Should be stringified
1611      }
1612  
1613      # Validate in-memory trace
1614      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
1615      for k, v in expected_metadata.items():
1616          assert trace.info.trace_metadata[k] == v
1617  
1618      # Validate backend trace
1619      traces = get_traces()
1620      assert len(traces) == 1
1621      assert traces[0].info.status == "OK"
1622      for k, v in expected_metadata.items():
1623          assert traces[0].info.trace_metadata[k] == v
1624  
1625  
1626  @skip_when_testing_trace_sdk
1627  def test_update_current_trace_with_model_id():
1628      with mlflow.start_span("test_span"):
1629          mlflow.update_current_trace(model_id="model-123")
1630  
1631      trace = get_traces()[0]
1632      assert trace.info.trace_metadata[TraceMetadataKey.MODEL_ID] == "model-123"
1633  
1634  
1635  @skip_when_testing_trace_sdk
1636  def test_update_current_trace_should_not_raise_during_model_logging():
1637      """
1638      Tracing is disabled while model logging. When the model includes
1639      `update_current_trace` call, it should be no-op.
1640      """
1641  
1642      class MyModel(mlflow.pyfunc.PythonModel):
1643          @mlflow.trace
1644          def predict(self, model_inputs):
1645              mlflow.update_current_trace(tags={"fruit": "apple"})
1646              return [model_inputs[0] + 1]
1647  
1648      model = MyModel()
1649  
1650      model.predict([1])
1651      trace = get_traces()[0]
1652      assert trace.info.state == "OK"
1653      assert trace.info.tags["fruit"] == "apple"
1654      purge_traces()
1655  
1656      model_info = mlflow.pyfunc.log_model(
1657          python_model=model,
1658          name="model",
1659          input_example=[0],
1660      )
1661      # Trace should not be generated while logging the model
1662      assert get_traces() == []
1663  
1664      # Signature should be inferred properly without raising any exception
1665      assert model_info.signature is not None
1666      assert model_info.signature.inputs is not None
1667      assert model_info.signature.outputs is not None
1668  
1669      # Loading back the model
1670      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
1671      loaded_model.predict([1])
1672      trace = get_traces()[0]
1673      assert trace.info.status == "OK"
1674      assert trace.info.tags["fruit"] == "apple"
1675  
1676  
1677  def test_update_current_trace_with_state():
1678      from mlflow.tracing.trace_manager import InMemoryTraceManager
1679  
1680      # Test with TraceState enum
1681      with mlflow.start_span("test_span") as span:
1682          mlflow.update_current_trace(state=TraceState.ERROR)
1683  
1684          trace_manager = InMemoryTraceManager.get_instance()
1685          with trace_manager.get_trace(span.trace_id) as trace:
1686              assert trace.info.state == TraceState.ERROR
1687  
1688      # Test with string state
1689      with mlflow.start_span("test_span_2") as span:
1690          mlflow.update_current_trace(state="OK")
1691  
1692          trace_manager = InMemoryTraceManager.get_instance()
1693          with trace_manager.get_trace(span.trace_id) as trace:
1694              assert trace.info.state == TraceState.OK
1695  
1696      # Test with combined parameters
1697      with mlflow.start_span("test_span_3") as span:
1698          mlflow.update_current_trace(
1699              state="ERROR", tags={"error_type": "validation"}, client_request_id="req-123"
1700          )
1701  
1702          trace_manager = InMemoryTraceManager.get_instance()
1703          with trace_manager.get_trace(span.trace_id) as trace:
1704              assert trace.info.state == TraceState.ERROR
1705              assert trace.info.tags["error_type"] == "validation"
1706              assert trace.info.client_request_id == "req-123"
1707  
1708  
1709  def test_update_current_trace_state_none():
1710      from mlflow.tracing.trace_manager import InMemoryTraceManager
1711  
1712      with mlflow.start_span("test_span") as span:
1713          # First set state to OK
1714          mlflow.update_current_trace(state="OK")
1715  
1716          # Then call with state=None - should not change state
1717          mlflow.update_current_trace(state=None, tags={"test": "value"})
1718  
1719          trace_manager = InMemoryTraceManager.get_instance()
1720          with trace_manager.get_trace(span.trace_id) as trace:
1721              assert trace.info.state == TraceState.OK
1722              assert trace.info.tags["test"] == "value"
1723  
1724  
1725  def test_update_current_trace_state_validation():
1726      with mlflow.start_span("test_span"):
1727          # Valid states should work
1728          mlflow.update_current_trace(state="OK")
1729          mlflow.update_current_trace(state="ERROR")
1730          mlflow.update_current_trace(state=TraceState.OK)
1731          mlflow.update_current_trace(state=TraceState.ERROR)
1732  
1733          # Invalid string state should raise an exception
1734          with pytest.raises(
1735              MlflowException, match=r"State must be either 'OK' or 'ERROR', but got 'IN_PROGRESS'"
1736          ):
1737              mlflow.update_current_trace(state="IN_PROGRESS")
1738  
1739          # Invalid enum state should raise an exception
1740          with pytest.raises(
1741              MlflowException,
1742              match=r"State must be either 'OK' or 'ERROR', but got 'STATE_UNSPECIFIED'",
1743          ):
1744              mlflow.update_current_trace(state=TraceState.STATE_UNSPECIFIED)
1745  
1746          # Custom invalid string should raise an exception
1747          with pytest.raises(
1748              MlflowException, match=r"State must be either 'OK' or 'ERROR', but got 'CUSTOM_STATE'"
1749          ):
1750              mlflow.update_current_trace(state="CUSTOM_STATE")
1751  
1752          # Invalid types should raise an exception with a proper error message
1753          with pytest.raises(
1754              MlflowException, match=r"State must be either 'OK' or 'ERROR', but got '123'"
1755          ):
1756              mlflow.update_current_trace(state=123)
1757  
1758  
1759  def test_span_record_exception_with_string():
1760      with mlflow.start_span("test_span") as span:
1761          span.record_exception("Something went wrong")
1762  
1763      # Check persisted trace
1764      trace = get_traces()[0]
1765      spans = trace.data.spans
1766      test_span = spans[0]
1767  
1768      # Verify span status is ERROR
1769      assert test_span.status.status_code == SpanStatusCode.ERROR
1770  
1771      # Verify exception event was added
1772      exception_events = [event for event in test_span.events if "exception" in event.name.lower()]
1773      assert len(exception_events) == 1
1774  
1775      # Verify exception message is in the event
1776      exception_event = exception_events[0]
1777      assert "Something went wrong" in str(exception_event.attributes)
1778  
1779  
1780  def test_span_record_exception_with_exception():
1781      test_exception = ValueError("Custom error message")
1782  
1783      with mlflow.start_span("test_span") as span:
1784          span.record_exception(test_exception)
1785  
1786      # Check persisted trace
1787      trace = get_traces()[0]
1788      spans = trace.data.spans
1789      test_span = spans[0]
1790  
1791      # Verify span status is ERROR
1792      assert test_span.status.status_code == SpanStatusCode.ERROR
1793  
1794      # Verify exception event was added with proper exception details
1795      exception_events = [event for event in test_span.events if "exception" in event.name.lower()]
1796      assert len(exception_events) == 1
1797  
1798      exception_event = exception_events[0]
1799      event_attrs = str(exception_event.attributes)
1800      assert "ValueError" in event_attrs
1801      assert "Custom error message" in event_attrs
1802  
1803  
1804  def test_span_record_exception_invalid_type():
1805      with mlflow.start_span("test_span") as span:
1806          with pytest.raises(
1807              MlflowException,
1808              match="The `exception` parameter must be an Exception instance or a string",
1809          ):
1810              span.record_exception(123)
1811  
1812  
1813  def test_combined_state_and_record_exception():
1814      @mlflow.trace
1815      def test_function():
1816          # Get current span and record exception
1817          span = mlflow.get_current_active_span()
1818          span.record_exception("Processing failed")
1819  
1820          # Update trace state independently
1821          mlflow.update_current_trace(state="ERROR", tags={"error_source": "processing"})
1822          return "result"
1823  
1824      test_function()
1825  
1826      # Check the trace
1827      trace = get_traces()[0]
1828  
1829      # Verify trace state was set to ERROR
1830      assert trace.info.state == TraceState.ERROR
1831      assert trace.info.tags["error_source"] == "processing"
1832  
1833      # Verify span has exception event and ERROR state
1834      spans = trace.data.spans
1835      root_span = spans[0]
1836      assert root_span.status.status_code == SpanStatusCode.ERROR
1837  
1838      exception_events = [event for event in root_span.events if "exception" in event.name.lower()]
1839      assert len(exception_events) == 1
1840      assert "Processing failed" in str(exception_events[0].attributes)
1841  
1842  
1843  def test_span_record_exception_no_op_span():
1844      # This should not raise an exception
1845      from mlflow.entities.span import NoOpSpan
1846  
1847      no_op_span = NoOpSpan()
1848      no_op_span.record_exception("This should be ignored")
1849  
1850      # Should not create any traces
1851      assert get_traces() == []
1852  
1853  
1854  def test_update_current_trace_state_isolation():
1855      with mlflow.start_span("test_span") as span:
1856          # Set span status to OK explicitly
1857          span.set_status("OK")
1858  
1859          # Update trace state to ERROR
1860          mlflow.update_current_trace(state="ERROR")
1861  
1862          # Span status should still be OK
1863          assert span.status.status_code == SpanStatusCode.OK
1864  
1865      # Check the final persisted trace
1866      trace = get_traces()[0]
1867      assert trace.info.state == TraceState.ERROR
1868  
1869      # Verify span status remained OK despite trace state being ERROR
1870      spans = trace.data.spans
1871      test_span = spans[0]
1872      assert test_span.status.status_code == SpanStatusCode.OK
1873  
1874  
1875  @skip_when_testing_trace_sdk
1876  def test_non_ascii_characters_not_encoded_as_unicode():
1877      with mlflow.start_span() as span:
1878          span.set_inputs({"japanese": "あ", "emoji": "👍"})
1879  
1880      trace = mlflow.get_trace(span.trace_id, flush=True)
1881      span = trace.data.spans[0]
1882      assert span.inputs == {"japanese": "あ", "emoji": "👍"}
1883  
1884  
1885  _SAMPLE_REMOTE_TRACE = {
1886      "info": {
1887          "request_id": "2e72d64369624e6888324462b62dc120",
1888          "experiment_id": "0",
1889          "timestamp_ms": 1726145090860,
1890          "execution_time_ms": 162,
1891          "status": "OK",
1892          "request_metadata": {
1893              "mlflow.trace_schema.version": "2",
1894              "mlflow.traceInputs": '{"x": 1}',
1895              "mlflow.traceOutputs": '{"prediction": 1}',
1896          },
1897          "tags": {
1898              "fruit": "apple",
1899              "food": "pizza",
1900          },
1901      },
1902      "data": {
1903          "spans": [
1904              {
1905                  "name": "remote",
1906                  "context": {
1907                      "span_id": "0x337af925d6629c01",
1908                      "trace_id": "0x05e82d1fc4486f3986fae6dd7b5352b1",
1909                  },
1910                  "parent_id": None,
1911                  "start_time": 1726145091022155863,
1912                  "end_time": 1726145091022572053,
1913                  "status_code": "OK",
1914                  "status_message": "",
1915                  "attributes": {
1916                      "mlflow.traceRequestId": '"2e72d64369624e6888324462b62dc120"',
1917                      "mlflow.spanType": '"UNKNOWN"',
1918                      "mlflow.spanInputs": '{"x": 1}',
1919                      "mlflow.spanOutputs": '{"prediction": 1}',
1920                  },
1921                  "events": [
1922                      {"name": "event", "timestamp": 1726145091022287, "attributes": {"foo": "bar"}}
1923                  ],
1924              },
1925              {
1926                  "name": "remote-child",
1927                  "context": {
1928                      "span_id": "0xa3dde9f2ebac1936",
1929                      "trace_id": "0x05e82d1fc4486f3986fae6dd7b5352b1",
1930                  },
1931                  "parent_id": "0x337af925d6629c01",
1932                  "start_time": 1726145091022419340,
1933                  "end_time": 1726145091022497944,
1934                  "status_code": "OK",
1935                  "status_message": "",
1936                  "attributes": {
1937                      "mlflow.traceRequestId": '"2e72d64369624e6888324462b62dc120"',
1938                      "mlflow.spanType": '"UNKNOWN"',
1939                  },
1940                  "events": [],
1941              },
1942          ],
1943          "request": '{"x": 1}',
1944          "response": '{"prediction": 1}',
1945      },
1946  }
1947  
1948  
1949  def test_add_trace(mock_otel_trace_start_time):
1950      # Mimic a remote service call that returns a trace as a part of the response
1951      def dummy_remote_call():
1952          return {"prediction": 1, "trace": _SAMPLE_REMOTE_TRACE}
1953  
1954      @mlflow.trace
1955      def predict(add_trace: bool):
1956          resp = dummy_remote_call()
1957  
1958          if add_trace:
1959              mlflow.add_trace(resp["trace"])
1960          return resp["prediction"]
1961  
1962      # If we don't call add_trace, the trace from the remote service should be discarded
1963      predict(add_trace=False)
1964      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
1965      assert len(trace.data.spans) == 1
1966  
1967      # If we call add_trace, the trace from the remote service should be merged
1968      predict(add_trace=True)
1969      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
1970      trace_id = trace.info.trace_id
1971      assert trace_id is not None
1972      assert trace.data.request == '{"add_trace": true}'
1973      assert trace.data.response == "1"
1974      # Remote spans should be merged
1975      assert len(trace.data.spans) == 3
1976      assert all(span.trace_id == trace_id for span in trace.data.spans)
1977      parent_span, child_span, grandchild_span = trace.data.spans
1978      assert child_span.parent_id == parent_span.span_id
1979      assert child_span._trace_id == parent_span._trace_id
1980      assert grandchild_span.parent_id == child_span.span_id
1981      assert grandchild_span._trace_id == parent_span._trace_id
1982      # Check if span information is correctly copied
1983      rs = Trace.from_dict(_SAMPLE_REMOTE_TRACE).data.spans[0]
1984      assert child_span.name == rs.name
1985      assert child_span.start_time_ns == rs.start_time_ns
1986      assert child_span.end_time_ns == rs.end_time_ns
1987      assert child_span.status == rs.status
1988      assert child_span.span_type == rs.span_type
1989      assert child_span.events == rs.events
1990      # exclude request ID attribute from comparison
1991      for k in rs.attributes.keys() - {SpanAttributeKey.REQUEST_ID}:
1992          assert child_span.attributes[k] == rs.attributes[k]
1993  
1994  
1995  def test_add_trace_no_current_active_trace():
1996      # Use the remote trace without any active trace
1997      remote_trace = Trace.from_dict(_SAMPLE_REMOTE_TRACE)
1998  
1999      mlflow.add_trace(remote_trace)
2000  
2001      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
2002      assert len(trace.data.spans) == 3
2003      parent_span, child_span, grandchild_span = trace.data.spans
2004      assert parent_span.name == "Remote Trace <remote>"
2005      rs = remote_trace.data.spans[0]
2006      assert parent_span.start_time_ns == rs.start_time_ns - 1
2007      assert parent_span.end_time_ns == rs.end_time_ns
2008      assert child_span.name == rs.name
2009      assert child_span.parent_id is parent_span.span_id
2010      assert child_span.start_time_ns == rs.start_time_ns
2011      assert child_span.end_time_ns == rs.end_time_ns
2012      assert child_span.status == rs.status
2013      assert child_span.span_type == rs.span_type
2014      assert child_span.events == rs.events
2015      assert grandchild_span.parent_id == child_span.span_id
2016      # exclude request ID attribute from comparison
2017      for k in rs.attributes.keys() - {SpanAttributeKey.REQUEST_ID}:
2018          assert child_span.attributes[k] == rs.attributes[k]
2019  
2020  
2021  def test_add_trace_specific_target_span(mock_otel_trace_start_time):
2022      span = start_span_no_context(name="parent")
2023      mlflow.add_trace(_SAMPLE_REMOTE_TRACE, target=span)
2024      span.end()
2025  
2026      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
2027      assert len(trace.data.spans) == 3
2028      parent_span, child_span, grandchild_span = trace.data.spans
2029      assert parent_span.span_id == span.span_id
2030      rs = Trace.from_dict(_SAMPLE_REMOTE_TRACE).data.spans[0]
2031      assert child_span.name == rs.name
2032      assert child_span.parent_id is parent_span.span_id
2033      assert grandchild_span.parent_id == child_span.span_id
2034  
2035  
2036  def test_add_trace_merge_tags():
2037      client = TracingClient()
2038  
2039      # Start the parent trace and merge the above trace as a child
2040      with mlflow.start_span(name="parent") as span:
2041          client.set_trace_tag(span.trace_id, "vegetable", "carrot")
2042          client.set_trace_tag(span.trace_id, "food", "sushi")
2043  
2044          mlflow.add_trace(Trace.from_dict(_SAMPLE_REMOTE_TRACE))
2045  
2046      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
2047      custom_tags = {k: v for k, v in trace.info.tags.items() if not k.startswith("mlflow.")}
2048      assert custom_tags == {
2049          "fruit": "apple",
2050          "vegetable": "carrot",
2051          # Tag value from the parent trace should prevail
2052          "food": "sushi",
2053      }
2054  
2055  
2056  def test_add_trace_raise_for_invalid_trace():
2057      with pytest.raises(MlflowException, match="Invalid trace object"):
2058          mlflow.add_trace(None)
2059  
2060      with pytest.raises(MlflowException, match="Failed to load a trace object"):
2061          mlflow.add_trace({"info": {}, "data": {}})
2062  
2063      in_progress_trace = Trace(
2064          info=TraceInfo(
2065              trace_id="123",
2066              trace_location=TraceLocation.from_experiment_id("0"),
2067              request_time=0,
2068              execution_duration=0,
2069              state=TraceState.IN_PROGRESS,
2070          ),
2071          data=TraceData(),
2072      )
2073      with pytest.raises(MlflowException, match="The trace must be ended"):
2074          mlflow.add_trace(in_progress_trace)
2075  
2076      trace = Trace.from_dict(_SAMPLE_REMOTE_TRACE)
2077      spans = trace.data.spans
2078      unordered_trace = Trace(info=trace.info, data=TraceData(spans=[spans[1], spans[0]]))
2079      with pytest.raises(MlflowException, match="Span with ID "):
2080          mlflow.add_trace(unordered_trace)
2081  
2082  
2083  @skip_when_testing_trace_sdk
2084  def test_add_trace_in_databricks_model_serving(mock_databricks_serving_with_tracing_env):
2085      from mlflow.pyfunc.context import Context, set_prediction_context
2086  
2087      # Mimic a remote service call that returns a trace as a part of the response
2088      def dummy_remote_call():
2089          return {"prediction": 1, "trace": _SAMPLE_REMOTE_TRACE}
2090  
2091      # The parent function that invokes the dummy remote service
2092      @mlflow.trace
2093      def predict():
2094          resp = dummy_remote_call()
2095          remote_trace = Trace.from_dict(resp["trace"])
2096          mlflow.add_trace(remote_trace)
2097          return resp["prediction"]
2098  
2099      db_request_id = "databricks-request-id"
2100      with set_prediction_context(Context(request_id=db_request_id)):
2101          predict()
2102  
2103      # Pop the trace to be written to the inference table
2104      trace = Trace.from_dict(pop_trace(request_id=db_request_id))
2105  
2106      assert trace.info.trace_id.startswith("tr-")
2107      assert trace.info.client_request_id == db_request_id
2108      assert len(trace.data.spans) == 3
2109      assert all(span.trace_id == trace.info.trace_id for span in trace.data.spans)
2110      parent_span, child_span, grandchild_span = trace.data.spans
2111      assert child_span.parent_id == parent_span.span_id
2112      assert child_span._trace_id == parent_span._trace_id
2113      assert grandchild_span.parent_id == child_span.span_id
2114      assert grandchild_span._trace_id == parent_span._trace_id
2115      # Check if span information is correctly copied
2116      rs = Trace.from_dict(_SAMPLE_REMOTE_TRACE).data.spans[0]
2117      assert child_span.name == rs.name
2118      assert child_span.start_time_ns == rs.start_time_ns
2119      assert child_span.end_time_ns == rs.end_time_ns
2120  
2121  
2122  @skip_when_testing_trace_sdk
2123  def test_add_trace_logging_model_from_code():
2124      with mlflow.start_run():
2125          model_info = mlflow.pyfunc.log_model(
2126              name="model",
2127              python_model="tests/tracing/sample_code/model_with_add_trace.py",
2128              input_example=[1, 2],
2129          )
2130  
2131      loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
2132      # Trace should not be logged while logging / loading
2133      assert mlflow.get_trace(mlflow.get_last_active_trace_id()) is None
2134  
2135      loaded_model.predict(1)
2136      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
2137      assert trace is not None
2138      assert len(trace.data.spans) == 2
2139  
2140  
2141  @pytest.mark.parametrize(
2142      "inputs", [{"question": "Does mlflow support tracing?"}, "Does mlflow support tracing?", None]
2143  )
2144  @pytest.mark.parametrize("outputs", [{"answer": "Yes"}, "Yes", None])
2145  @pytest.mark.parametrize(
2146      "intermediate_outputs",
2147      [
2148          {
2149              "retrieved_documents": ["mlflow documentation"],
2150              "system_prompt": ["answer the question with yes or no"],
2151          },
2152          None,
2153      ],
2154  )
2155  def test_log_trace_success(inputs, outputs, intermediate_outputs):
2156      start_time_ms = 1736144700
2157      execution_time_ms = 5129
2158  
2159      mlflow.log_trace(
2160          name="test",
2161          request=inputs,
2162          response=outputs,
2163          intermediate_outputs=intermediate_outputs,
2164          start_time_ms=start_time_ms,
2165          execution_time_ms=execution_time_ms,
2166      )
2167  
2168      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
2169      if inputs is not None:
2170          assert trace.data.request == json.dumps(inputs)
2171      else:
2172          assert trace.data.request is None
2173      if outputs is not None:
2174          assert trace.data.response == json.dumps(outputs)
2175      else:
2176          assert trace.data.response is None
2177      if intermediate_outputs is not None:
2178          assert trace.data.intermediate_outputs == intermediate_outputs
2179      spans = trace.data.spans
2180      assert len(spans) == 1
2181      root_span = spans[0]
2182      assert root_span.name == "test"
2183      assert root_span.start_time_ns == start_time_ms * 1000000
2184      assert root_span.end_time_ns == (start_time_ms + execution_time_ms) * 1000000
2185  
2186  
2187  def test_set_delete_trace_tag():
2188      with mlflow.start_span("span1") as span:
2189          trace_id = span.trace_id
2190  
2191      mlflow.set_trace_tag(trace_id=trace_id, key="key1", value="value1")
2192      trace = mlflow.get_trace(trace_id=trace_id, flush=True)
2193      assert trace.info.tags["key1"] == "value1"
2194  
2195      mlflow.delete_trace_tag(trace_id=trace_id, key="key1")
2196      trace = mlflow.get_trace(trace_id=trace_id, flush=True)
2197      assert "key1" not in trace.info.tags
2198  
2199      # Test with request_id kwarg (backward compatibility)
2200      mlflow.set_trace_tag(request_id=trace_id, key="key3", value="value3")
2201      trace = mlflow.get_trace(request_id=trace_id, flush=True)
2202      assert trace.info.tags["key3"] == "value3"
2203  
2204      mlflow.delete_trace_tag(request_id=trace_id, key="key3")
2205      trace = mlflow.get_trace(request_id=trace_id, flush=True)
2206      assert "key3" not in trace.info.tags
2207  
2208  
2209  @pytest.mark.parametrize("is_databricks", [True, False])
2210  def test_search_traces_with_run_id_validates_store_filter_string(is_databricks):
2211      mock_store = mock.MagicMock()
2212      mock_store.search_traces.return_value = ([], None)
2213      mock_store.get_run.return_value = mock.MagicMock()
2214      mock_store.get_run.return_value.info.experiment_id = "test_exp_id"
2215  
2216      test_run_id = "test_run_123"
2217      with (
2218          mock.patch("mlflow.tracing.client._get_store", return_value=mock_store),
2219          mock.patch("mlflow.tracking.fluent._get_experiment_id", return_value="test_exp_id"),
2220      ):
2221          mlflow.search_traces(run_id=test_run_id)
2222  
2223          expected_filter_string = f"attribute.run_id = '{test_run_id}'"
2224          mock_store.search_traces.assert_called()
2225  
2226          call_args = mock_store.search_traces.call_args
2227          actual_filter_string = call_args[1]["filter_string"]
2228          assert actual_filter_string == expected_filter_string
2229  
2230  
2231  def test_search_traces_with_locations(mock_client):
2232      mock_client.search_traces.return_value = PagedList([], token=None)
2233  
2234      # Test with locations
2235      mlflow.search_traces(locations=["catalog1.schema1", "catalog2.schema2"])
2236  
2237      # Verify that search_traces was called with locations
2238      mock_client.search_traces.assert_called_once()
2239      call_kwargs = mock_client.search_traces.call_args.kwargs
2240      assert call_kwargs["locations"] == ["catalog1.schema1", "catalog2.schema2"]
2241      assert call_kwargs.get("experiment_ids") is None
2242  
2243  
2244  @pytest.mark.filterwarnings("ignore::FutureWarning")
2245  def test_search_traces_experiment_ids_deprecation_warning(mock_client):
2246      mock_client.search_traces.return_value = PagedList([], token=None)
2247  
2248      # Test that using experiment_ids shows a deprecation warning
2249      with pytest.warns(FutureWarning, match="experiment_ids.*deprecated.*use.*locations"):
2250          mlflow.search_traces(experiment_ids=["123"])
2251  
2252      # Verify that search_traces was called and experiment_ids was converted to locations
2253      mock_client.search_traces.assert_called_once()
2254      call_kwargs = mock_client.search_traces.call_args.kwargs
2255      assert call_kwargs["locations"] == ["123"]
2256      assert call_kwargs["experiment_ids"] is None
2257  
2258  
2259  def test_search_traces_with_sql_warehouse_id(mock_client):
2260      mock_client.search_traces.return_value = PagedList([], token=None)
2261  
2262      # Test with sql_warehouse_id
2263      mlflow.search_traces(locations=["123"], sql_warehouse_id="warehouse456")
2264  
2265      # Verify that search_traces was called with sql_warehouse_id
2266      mock_client.search_traces.assert_called_once()
2267      call_kwargs = mock_client.search_traces.call_args.kwargs
2268      assert call_kwargs["locations"] == ["123"]
2269      assert "sql_warehouse_id" not in call_kwargs
2270      assert os.environ["MLFLOW_TRACING_SQL_WAREHOUSE_ID"] == "warehouse456"
2271  
2272  
2273  @skip_when_testing_trace_sdk
2274  @pytest.mark.flaky(attempts=3, condition=sys.platform == "win32")
2275  @pytest.mark.parametrize("use_batch_processor", [False, True])
2276  def test_set_destination_in_threads(async_logging_enabled, use_batch_processor, monkeypatch):
2277      monkeypatch.setenv("MLFLOW_USE_BATCH_SPAN_PROCESSOR", str(use_batch_processor))
2278  
2279      # This test makes sure `set_destination` obeys thread-local behavior.
2280      class TestModel:
2281          def predict(self, x):
2282              with mlflow.start_span(name="root_span") as root_span:
2283  
2284                  def child_span_thread(z):
2285                      child_span = start_span_no_context(
2286                          name="child_span_1",
2287                          parent_span=root_span,
2288                      )
2289                      child_span.set_inputs(z)
2290                      time.sleep(0.5)
2291                      child_span.end()
2292  
2293                  thread = threading.Thread(
2294                      name="test-fluent-child-span", target=child_span_thread, args=(x + 1,)
2295                  )
2296                  thread.start()
2297                  thread.join()
2298              return x
2299  
2300      model = TestModel()
2301  
2302      def func(experiment_id: str | None, x: int):
2303          if experiment_id is not None:
2304              set_destination(MlflowExperiment(experiment_id), context_local=True)
2305  
2306          time.sleep(0.5)
2307          model.predict(x)
2308  
2309      # Main thread: global config
2310      experiment_id1 = mlflow.create_experiment(uuid.uuid4().hex)
2311      set_destination(MlflowExperiment(experiment_id1))
2312      func(None, 3)
2313  
2314      # Thread 1: context-local config
2315      experiment_id2 = mlflow.create_experiment(uuid.uuid4().hex)
2316      thread1 = threading.Thread(
2317          name="test-fluent-destination-thread1", target=func, args=(experiment_id2, 3)
2318      )
2319  
2320      # Thread 2: context-local config
2321      experiment_id3 = mlflow.create_experiment(uuid.uuid4().hex)
2322      thread2 = threading.Thread(
2323          name="test-fluent-destination-thread2", target=func, args=(experiment_id3, 40)
2324      )
2325  
2326      # Thread 3: no config -> fallback to global config
2327      thread3 = threading.Thread(name="test-fluent-destination-thread3", target=func, args=(None, 40))
2328  
2329      thread1.start()
2330      thread2.start()
2331      thread3.start()
2332  
2333      thread1.join()
2334      thread2.join()
2335      thread3.join()
2336  
2337      if async_logging_enabled:
2338          mlflow.flush_trace_async_logging(terminate=True)
2339  
2340      traces = get_traces(experiment_id1)
2341      assert len(traces) == 2  # main thread + thread 3
2342      assert traces[0].info.experiment_id == experiment_id1
2343      assert len(traces[0].data.spans) == 2
2344      assert traces[1].info.experiment_id == experiment_id1
2345      assert len(traces[1].data.spans) == 2
2346  
2347      for exp_id in [experiment_id2, experiment_id3]:
2348          traces = get_traces(exp_id)
2349          assert len(traces) == 1
2350          assert traces[0].info.experiment_id == exp_id
2351          assert len(traces[0].data.spans) == 2
2352  
2353  
2354  @pytest.mark.asyncio
2355  @skip_when_testing_trace_sdk
2356  async def test_set_destination_in_async_contexts(async_logging_enabled):
2357      class TestModel:
2358          async def predict(self, x):
2359              with mlflow.start_span(name="root_span") as root_span:
2360  
2361                  async def child_span_task(z):
2362                      child_span = start_span_no_context(
2363                          name="child_span_1",
2364                          parent_span=root_span,
2365                      )
2366                      child_span.set_inputs(z)
2367                      await asyncio.sleep(0.5)
2368                      child_span.end()
2369  
2370                  await child_span_task(x + 1)
2371              return x
2372  
2373      model = TestModel()
2374  
2375      async def async_func(experiment_id: str, x: int):
2376          set_destination(MlflowExperiment(experiment_id), context_local=True)
2377          await asyncio.sleep(0.5)
2378          await model.predict(x)
2379  
2380      experiment_id1 = mlflow.create_experiment(uuid.uuid4().hex)
2381      task1 = asyncio.create_task(async_func(experiment_id1, 3))
2382  
2383      experiment_id2 = mlflow.create_experiment(uuid.uuid4().hex)
2384      task2 = asyncio.create_task(async_func(experiment_id2, 40))
2385  
2386      await asyncio.gather(task1, task2)
2387  
2388      if async_logging_enabled:
2389          mlflow.flush_trace_async_logging(terminate=True)
2390  
2391      for exp_id in [experiment_id1, experiment_id2]:
2392          traces = get_traces(exp_id)
2393          assert len(traces) == 1
2394          assert traces[0].info.experiment_id == exp_id
2395          assert len(traces[0].data.spans) == 2
2396  
2397  
2398  def test_set_destination_from_env_var_databricks_uc(monkeypatch):
2399      monkeypatch.setenv("MLFLOW_TRACING_DESTINATION", "catalog.schema")
2400      destination = _MLFLOW_TRACE_USER_DESTINATION.get()
2401      assert isinstance(destination, UCSchemaLocation)
2402      assert destination.catalog_name == "catalog"
2403      assert destination.schema_name == "schema"
2404      assert mlflow.get_tracking_uri() == "databricks"
2405  
2406  
2407  @skip_when_testing_trace_sdk
2408  def test_traces_can_be_searched_by_span_properties(async_logging_enabled):
2409      @mlflow.trace(name="test_span")
2410      def test_function():
2411          return "result"
2412  
2413      test_function()
2414  
2415      if async_logging_enabled:
2416          mlflow.flush_trace_async_logging(terminate=True)
2417  
2418      traces = mlflow.search_traces(filter_string='span.name = "test_span"', return_type="list")
2419      assert len(traces) == 1, "Should find exactly one trace with span name 'test_span'"
2420      found_span_names = [span.name for span in traces[0].data.spans]
2421      assert "test_span" in found_span_names
2422  
2423  
2424  @pytest.mark.skipif(
2425      IS_TRACING_SDK_ONLY, reason="Skipping test because mlflow or mlflow-skinny is not installed."
2426  )
2427  def test_search_traces_with_full_text():
2428      with mlflow.start_span(name="test_span") as span:
2429          span.set_attribute("llm.inputs", "How's the result?")
2430          span.set_attribute("llm.outputs", "the number increased 90%")
2431          trace_id_1 = span.trace_id
2432  
2433      with mlflow.start_span(name="test_span") as span:
2434          span.set_outputs({"outputs": 1234567})
2435          span.set_attribute("test", "the number increased")
2436          trace_id_2 = span.trace_id
2437  
2438      with mlflow.start_span(name="test_span") as span:
2439          span.set_attribute("test", "result including 'single quotes'")
2440          trace_id_3 = span.trace_id
2441  
2442      traces = mlflow.search_traces(
2443          filter_string='trace.text LIKE "%How\'s the result?%"', return_type="list", flush=True
2444      )
2445      assert len(traces) == 1
2446      assert traces[0].info.trace_id == trace_id_1
2447  
2448      traces = mlflow.search_traces(
2449          filter_string='trace.text LIKE "%1234567%"', return_type="list", flush=True
2450      )
2451      assert len(traces) == 1
2452      assert traces[0].info.trace_id == trace_id_2
2453  
2454      traces = mlflow.search_traces(
2455          filter_string="trace.text LIKE \"%result including 'single quotes'%\"",
2456          return_type="list",
2457          flush=True,
2458      )
2459      assert len(traces) == 1
2460      assert traces[0].info.trace_id == trace_id_3
2461  
2462      traces = mlflow.search_traces(
2463          filter_string='trace.text LIKE "%increased 90%%"', return_type="list", flush=True
2464      )
2465      assert len(traces) == 1
2466      assert traces[0].info.trace_id == trace_id_1
2467  
2468  
2469  def _create_trace_with_session(session_id: str, name: str = "test_span") -> str:
2470      with mlflow.start_span(name=name) as span:
2471          mlflow.update_current_trace(metadata={TraceMetadataKey.TRACE_SESSION: session_id})
2472          span.set_inputs({"input": "test"})
2473          span.set_outputs({"output": "test"})
2474      mlflow.flush_trace_async_logging()
2475      return span.trace_id
2476  
2477  
2478  def _create_trace_without_session(name: str = "test_span") -> str:
2479      with mlflow.start_span(name=name) as span:
2480          span.set_inputs({"input": "test"})
2481          span.set_outputs({"output": "test"})
2482      mlflow.flush_trace_async_logging()
2483      return span.trace_id
2484  
2485  
2486  @pytest.mark.skipif(
2487      IS_TRACING_SDK_ONLY, reason="Skipping test because mlflow or mlflow-skinny is not installed."
2488  )
2489  def test_search_sessions_empty():
2490      # Create a trace without a session ID - should result in no sessions
2491      _create_trace_without_session()
2492      sessions = mlflow.search_sessions()
2493      assert sessions == []
2494  
2495  
2496  @pytest.mark.skipif(
2497      IS_TRACING_SDK_ONLY, reason="Skipping test because mlflow or mlflow-skinny is not installed."
2498  )
2499  def test_search_sessions_returns_grouped_traces():
2500      session_id_1 = f"session-1-{uuid.uuid4().hex[:8]}"
2501      session_id_2 = f"session-2-{uuid.uuid4().hex[:8]}"
2502  
2503      # Create traces for session 1
2504      trace_id_1 = _create_trace_with_session(session_id_1, "session1_trace1")
2505      trace_id_2 = _create_trace_with_session(session_id_1, "session1_trace2")
2506  
2507      # Create trace for session 2
2508      trace_id_3 = _create_trace_with_session(session_id_2, "session2_trace1")
2509  
2510      sessions = mlflow.search_sessions()
2511  
2512      assert len(sessions) == 2
2513  
2514      # Convert to dict keyed by session.id for easier assertions
2515      sessions_by_id = {s.id: s for s in sessions}
2516  
2517      assert len(sessions_by_id[session_id_1]) == 2
2518      assert len(sessions_by_id[session_id_2]) == 1
2519  
2520      # Verify trace IDs
2521      session_1_trace_ids = {t.info.trace_id for t in sessions_by_id[session_id_1]}
2522      assert trace_id_1 in session_1_trace_ids
2523      assert trace_id_2 in session_1_trace_ids
2524      assert sessions_by_id[session_id_2][0].info.trace_id == trace_id_3
2525  
2526  
2527  @pytest.mark.skipif(
2528      IS_TRACING_SDK_ONLY, reason="Skipping test because mlflow or mlflow-skinny is not installed."
2529  )
2530  def test_search_sessions_respects_max_results():
2531      session_ids = [f"session-{i}-{uuid.uuid4().hex[:8]}" for i in range(3)]
2532  
2533      # Create one trace per session
2534      for session_id in session_ids:
2535          _create_trace_with_session(session_id)
2536  
2537      sessions = mlflow.search_sessions(max_results=2)
2538  
2539      assert len(sessions) == 2
2540  
2541  
2542  @pytest.mark.skipif(
2543      IS_TRACING_SDK_ONLY, reason="Skipping test because mlflow or mlflow-skinny is not installed."
2544  )
2545  def test_search_sessions_skips_traces_without_session_id():
2546      session_id = f"session-{uuid.uuid4().hex[:8]}"
2547  
2548      # Create trace without session
2549      _create_trace_without_session("no_session_trace")
2550  
2551      # Create trace with session
2552      trace_id = _create_trace_with_session(session_id, "with_session_trace")
2553  
2554      sessions = mlflow.search_sessions()
2555  
2556      assert len(sessions) == 1
2557      assert len(sessions[0]) == 1
2558      assert sessions[0][0].info.trace_id == trace_id
2559  
2560  
2561  def test_search_sessions_validates_locations_type():
2562      with pytest.raises(MlflowException, match=r"locations must be a list"):
2563          mlflow.search_sessions(locations=4)
2564  
2565      with pytest.raises(MlflowException, match=r"locations must be a list"):
2566          mlflow.search_sessions(locations="4")
2567  
2568  
2569  @pytest.mark.skipif(
2570      IS_TRACING_SDK_ONLY, reason="Skipping test because mlflow or mlflow-skinny is not installed."
2571  )
2572  def test_search_sessions_with_default_experiment_id():
2573      session_id = f"session-{uuid.uuid4().hex[:8]}"
2574      _create_trace_with_session(session_id)
2575  
2576      # search_sessions should use the default experiment
2577      sessions = mlflow.search_sessions()
2578      assert len(sessions) == 1
2579  
2580  
2581  def test_search_sessions_raises_without_experiment():
2582      with mock.patch("mlflow.tracking.fluent._get_experiment_id", return_value=None):
2583          with pytest.raises(MlflowException, match=r"No active experiment found"):
2584              mlflow.search_sessions()
2585  
2586  
2587  @pytest.mark.skipif(
2588      IS_TRACING_SDK_ONLY, reason="Skipping test because mlflow or mlflow-skinny is not installed."
2589  )
2590  def test_search_sessions_include_spans_true():
2591      session_id = f"session-{uuid.uuid4().hex[:8]}"
2592      _create_trace_with_session(session_id)
2593  
2594      sessions = mlflow.search_sessions(include_spans=True)
2595  
2596      assert len(sessions) == 1
2597      assert len(sessions[0]) == 1
2598      # When include_spans=True, spans should be populated
2599      assert len(sessions[0][0].data.spans) > 0
2600  
2601  
2602  @pytest.mark.skipif(
2603      IS_TRACING_SDK_ONLY, reason="Skipping test because mlflow or mlflow-skinny is not installed."
2604  )
2605  def test_search_sessions_include_spans_false():
2606      session_id = f"session-{uuid.uuid4().hex[:8]}"
2607      _create_trace_with_session(session_id)
2608  
2609      sessions = mlflow.search_sessions(include_spans=False)
2610  
2611      assert len(sessions) == 1
2612      assert len(sessions[0]) == 1
2613      # When include_spans=False, spans should be empty
2614      assert len(sessions[0][0].data.spans) == 0
2615  
2616  
2617  @pytest.mark.parametrize("invalid_ratio", [-0.1, 1.1, -1, 2, 100])
2618  def test_trace_decorator_sampling_ratio_validation(invalid_ratio: float):
2619      with pytest.raises(
2620          MlflowException, match=r"sampling_ratio_override must be between 0\.0 and 1\.0"
2621      ):
2622          mlflow.trace(sampling_ratio_override=invalid_ratio)
2623  
2624  
2625  @pytest.mark.parametrize(
2626      ("sampling_ratio", "num_calls", "expected_min", "expected_max"),
2627      [
2628          (0.0, 10, 0, 0),
2629          (0.5, 100, 30, 70),
2630          (1.0, 10, 10, 10),
2631      ],
2632  )
2633  def test_trace_decorator_sampling_ratio(
2634      sampling_ratio: float, num_calls: int, expected_min: int, expected_max: int
2635  ):
2636      trace_ids: list[str] = []
2637  
2638      @mlflow.trace(sampling_ratio_override=sampling_ratio)
2639      def traced_func():
2640          if trace_id := mlflow.get_active_trace_id():
2641              trace_ids.append(trace_id)
2642          return "result"
2643  
2644      for _ in range(num_calls):
2645          assert traced_func() == "result"
2646  
2647      assert expected_min <= len(trace_ids) <= expected_max
2648  
2649  
2650  @pytest.mark.parametrize(
2651      ("outer_ratio", "inner_ratio", "expected_outer", "expected_inner"),
2652      [
2653          (1.0, 0.0, 5, 5),  # Parent sampled -> child also sampled (inner ratio ignored)
2654          (0.0, 1.0, 0, 0),  # Parent not sampled -> child also dropped (follows parent)
2655      ],
2656  )
2657  def test_trace_decorator_sampling_ratio_nested(
2658      outer_ratio: float, inner_ratio: float, expected_outer: int, expected_inner: int
2659  ):
2660      outer_trace_ids: list[str] = []
2661      inner_trace_ids: list[str] = []
2662  
2663      @mlflow.trace(sampling_ratio_override=outer_ratio)
2664      def outer():
2665          if trace_id := mlflow.get_active_trace_id():
2666              outer_trace_ids.append(trace_id)
2667          return inner()
2668  
2669      @mlflow.trace(sampling_ratio_override=inner_ratio)
2670      def inner():
2671          if trace_id := mlflow.get_active_trace_id():
2672              inner_trace_ids.append(trace_id)
2673          return "inner result"
2674  
2675      for _ in range(5):
2676          assert outer() == "inner result"
2677  
2678      assert len(outer_trace_ids) == expected_outer
2679      assert len(inner_trace_ids) == expected_inner
2680  
2681  
2682  def test_global_sampling_ratio_nested(monkeypatch):
2683      monkeypatch.setenv(MLFLOW_TRACE_SAMPLING_RATIO.name, "0.0")
2684      mlflow.tracing.reset()
2685  
2686      inner_trace_ids: list[str] = []
2687  
2688      @mlflow.trace
2689      def outer():
2690          return inner()
2691  
2692      # Inner uses sampling_ratio_override=1.0 so it would create a sampled
2693      # root trace if the dropped parent context were not propagated.
2694      @mlflow.trace(sampling_ratio_override=1.0)
2695      def inner():
2696          if trace_id := mlflow.get_active_trace_id():
2697              inner_trace_ids.append(trace_id)
2698          return "result"
2699  
2700      for _ in range(5):
2701          assert outer() == "result"
2702  
2703      assert len(inner_trace_ids) == 0
2704  
2705  
2706  def test_start_span_no_context_preserves_dropped_parent_context(monkeypatch):
2707      monkeypatch.setenv(MLFLOW_TRACE_SAMPLING_RATIO.name, "0.0")
2708      mlflow.tracing.reset()
2709  
2710      trace_ids: list[str] = []
2711  
2712      @mlflow.trace(sampling_ratio_override=1.0)
2713      def child():
2714          if trace_id := mlflow.get_active_trace_id():
2715              trace_ids.append(trace_id)
2716          return "result"
2717  
2718      root = start_span_no_context("root")
2719      nested_noop = start_span_no_context("nested_noop", parent_span=root)
2720  
2721      with safe_set_span_in_context(nested_noop):
2722          assert child() == "result"
2723  
2724      assert len(trace_ids) == 0
2725  
2726  
2727  @pytest.mark.parametrize(
2728      ("sampling_ratio", "expected_count"),
2729      [
2730          (0.0, 0),
2731          (1.0, 2),
2732      ],
2733  )
2734  def test_trace_decorator_sampling_ratio_generator(sampling_ratio: float, expected_count: int):
2735      trace_ids: list[str] = []
2736  
2737      @mlflow.trace(sampling_ratio_override=sampling_ratio)
2738      def gen():
2739          if trace_id := mlflow.get_active_trace_id():
2740              trace_ids.append(trace_id)
2741          for i in range(3):
2742              yield i
2743  
2744      assert list(gen()) == [0, 1, 2]
2745      assert list(gen()) == [0, 1, 2]
2746      assert len(trace_ids) == expected_count
2747  
2748  
2749  @pytest.mark.parametrize(
2750      ("sampling_ratio", "expected_child_count"),
2751      [
2752          (0.0, 0),
2753          (1.0, 6),
2754      ],
2755  )
2756  def test_trace_decorator_sampling_ratio_generator_with_child_spans(
2757      sampling_ratio: float, expected_child_count: int
2758  ):
2759      child_trace_ids: list[str] = []
2760  
2761      @mlflow.trace
2762      def child_func(value):
2763          if trace_id := mlflow.get_active_trace_id():
2764              child_trace_ids.append(trace_id)
2765          return value * 2
2766  
2767      @mlflow.trace(sampling_ratio_override=sampling_ratio)
2768      def gen():
2769          for i in range(3):
2770              yield child_func(i)
2771  
2772      assert list(gen()) == [0, 2, 4]
2773      assert list(gen()) == [0, 2, 4]
2774      assert len(child_trace_ids) == expected_child_count
2775  
2776  
2777  @pytest.mark.asyncio
2778  @pytest.mark.parametrize(
2779      ("sampling_ratio", "num_calls", "expected_min", "expected_max"),
2780      [
2781          (0.0, 10, 0, 0),
2782          (0.5, 100, 30, 70),
2783          (1.0, 10, 10, 10),
2784      ],
2785  )
2786  async def test_trace_decorator_sampling_ratio_async(
2787      sampling_ratio: float, num_calls: int, expected_min: int, expected_max: int
2788  ):
2789      trace_ids: list[str] = []
2790  
2791      @mlflow.trace(sampling_ratio_override=sampling_ratio)
2792      async def traced_func():
2793          if trace_id := mlflow.get_active_trace_id():
2794              trace_ids.append(trace_id)
2795          return "result"
2796  
2797      for _ in range(num_calls):
2798          assert await traced_func() == "result"
2799  
2800      assert expected_min <= len(trace_ids) <= expected_max
2801  
2802  
2803  @pytest.mark.asyncio
2804  @pytest.mark.parametrize(
2805      ("sampling_ratio", "expected_count"),
2806      [
2807          (0.0, 0),
2808          (1.0, 2),
2809      ],
2810  )
2811  async def test_trace_decorator_sampling_ratio_async_generator(
2812      sampling_ratio: float, expected_count: int
2813  ):
2814      trace_ids: list[str] = []
2815  
2816      @mlflow.trace(sampling_ratio_override=sampling_ratio)
2817      async def gen():
2818          if trace_id := mlflow.get_active_trace_id():
2819              trace_ids.append(trace_id)
2820          for i in range(3):
2821              yield i
2822  
2823      assert [item async for item in gen()] == [0, 1, 2]
2824      assert [item async for item in gen()] == [0, 1, 2]
2825      assert len(trace_ids) == expected_count
2826  
2827  
2828  @pytest.mark.asyncio
2829  @pytest.mark.parametrize(
2830      ("sampling_ratio", "expected_child_count"),
2831      [
2832          (0.0, 0),
2833          (1.0, 6),
2834      ],
2835  )
2836  async def test_trace_decorator_sampling_ratio_async_generator_with_child_spans(
2837      sampling_ratio: float, expected_child_count: int
2838  ):
2839      child_trace_ids: list[str] = []
2840  
2841      @mlflow.trace
2842      async def child_func(value):
2843          if trace_id := mlflow.get_active_trace_id():
2844              child_trace_ids.append(trace_id)
2845          return value * 2
2846  
2847      @mlflow.trace(sampling_ratio_override=sampling_ratio)
2848      async def gen():
2849          for i in range(3):
2850              yield await child_func(i)
2851  
2852      assert [i async for i in gen()] == [0, 2, 4]
2853      assert [i async for i in gen()] == [0, 2, 4]
2854      assert len(child_trace_ids) == expected_child_count
2855  
2856  
2857  @skip_when_testing_trace_sdk
2858  def test_trace_decorator_sampling_ratio_overrides_global():
2859      code = """
2860  import mlflow
2861  
2862  trace_ids: list[str] = []
2863  
2864  
2865  @mlflow.trace  # Should respect global 0.0
2866  def not_traced():
2867      if trace_id := mlflow.get_active_trace_id():
2868          trace_ids.append(trace_id)
2869      return "not traced"
2870  
2871  
2872  for _ in range(5):
2873      assert not_traced() == "not traced"
2874  
2875  assert len(trace_ids) == 0
2876  
2877  
2878  @mlflow.trace(sampling_ratio_override=1.0)  # Should override global 0.0
2879  def traced():
2880      if trace_id := mlflow.get_active_trace_id():
2881          trace_ids.append(trace_id)
2882      return "traced"
2883  
2884  
2885  for _ in range(5):
2886      assert traced() == "traced"
2887  
2888  assert len(trace_ids) == 5
2889  """
2890      subprocess.check_call(
2891          [sys.executable, "-c", code],
2892          env={
2893              **os.environ,
2894              "MLFLOW_TRACE_SAMPLING_RATIO": "0.0",
2895          },
2896      )
2897  
2898  
2899  @mlflow.trace
2900  def my_func():
2901      return "hello"
2902  
2903  
2904  def test_tracing_context_injects_metadata_and_tags():
2905      with mlflow.tracing.context(
2906          metadata={"custom_key": "custom_value"},
2907          tags={"my_tag": "tag_value"},
2908      ):
2909          my_func()
2910  
2911      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
2912      assert trace.info.request_metadata["custom_key"] == "custom_value"
2913      assert trace.info.tags["my_tag"] == "tag_value"
2914  
2915      # Trace created outside the block should NOT have the metadata
2916      my_func()
2917      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
2918      assert "session" not in trace.info.request_metadata
2919  
2920  
2921  def test_tracing_context_session_id_and_user():
2922      with mlflow.tracing.context(session_id="sess-123", user="user-456"):
2923          my_func()
2924  
2925      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
2926      assert trace.info.request_metadata["mlflow.trace.session"] == "sess-123"
2927      assert trace.info.request_metadata["mlflow.trace.user"] == "user-456"
2928  
2929      # session_id and user can coexist with explicit metadata
2930      with mlflow.tracing.context(
2931          session_id="sess-abc",
2932          user="user-xyz",
2933          metadata={"custom_key": "custom_value"},
2934      ):
2935          my_func()
2936  
2937      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
2938      assert trace.info.request_metadata["mlflow.trace.session"] == "sess-abc"
2939      assert trace.info.request_metadata["mlflow.trace.user"] == "user-xyz"
2940      assert trace.info.request_metadata["custom_key"] == "custom_value"
2941  
2942  
2943  def test_tracing_context_session_id_and_user_nesting():
2944      with mlflow.tracing.context(session_id="outer-sess", user="outer-user"):
2945          with mlflow.tracing.context(session_id="inner-sess"):
2946              my_func()
2947  
2948      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
2949      # Inner session_id overrides outer
2950      assert trace.info.request_metadata["mlflow.trace.session"] == "inner-sess"
2951      # Outer user is inherited
2952      assert trace.info.request_metadata["mlflow.trace.user"] == "outer-user"
2953  
2954  
2955  def test_tracing_context_nesting_merges():
2956      with mlflow.tracing.context(
2957          metadata={"outer_key": "outer_val", "shared": "outer"},
2958          tags={"outer_tag": "outer"},
2959      ):
2960          with mlflow.tracing.context(
2961              metadata={"inner_key": "inner_val", "shared": "inner"},
2962              tags={"inner_tag": "inner"},
2963          ):
2964              my_func()
2965  
2966      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
2967      # Both outer and inner metadata present
2968      assert trace.info.request_metadata["outer_key"] == "outer_val"
2969      assert trace.info.request_metadata["inner_key"] == "inner_val"
2970      # Inner wins on conflict
2971      assert trace.info.request_metadata["shared"] == "inner"
2972      # Both tags present
2973      assert trace.info.tags["outer_tag"] == "outer"
2974      assert trace.info.tags["inner_tag"] == "inner"
2975  
2976  
2977  def test_tracing_context_enabled_false_suppresses_traces():
2978      with mlflow.tracing.context(enabled=False):
2979          my_func()
2980  
2981          # Child context should inherit the enabled=False from the parent
2982          with mlflow.tracing.context(metadata={"k": "v"}):
2983              my_func()
2984  
2985          # Start trace with start_trace_no_context (used in autologging)
2986          span = mlflow.start_span_no_context("test")
2987          span.end()
2988  
2989      assert mlflow.get_last_active_trace_id() is None
2990  
2991      # After exiting, tracing should work normally
2992      my_func()
2993      assert mlflow.get_last_active_trace_id() is not None
2994  
2995  
2996  def test_tracing_context_enabled_is_thread_safe():
2997      def run_with_context(enabled):
2998          with mlflow.tracing.context(enabled=enabled):
2999              my_func()
3000              return mlflow.get_last_active_trace_id(thread_local=True)
3001  
3002      with ThreadPoolExecutor(
3003          max_workers=10, thread_name_prefix="test-fluent-tracing-context"
3004      ) as pool:
3005          futures = {
3006              pool.submit(run_with_context, enabled=(i % 2 == 0)): (i % 2 == 0) for i in range(10)
3007          }
3008          for future in as_completed(futures):
3009              enabled = futures[future]
3010              trace_id = future.result()
3011              assert (trace_id is not None) == enabled
3012  
3013  
3014  def test_flush_trace_async_logging_calls_flush_when_async_queue_exists():
3015      mock_exporter = mock.MagicMock()
3016      with mock.patch("mlflow.tracking.fluent._get_trace_exporter", return_value=mock_exporter):
3017          mlflow.flush_trace_async_logging(terminate=False)
3018      mock_exporter._async_queue.flush.assert_called_once_with(terminate=False)
3019  
3020  
3021  def test_flush_trace_async_logging_skips_when_async_queue_missing():
3022      # A bare SpanExporter (as used by StrandsSpanProcessor, mlflow/strands/autolog.py:40)
3023      # has no _async_queue attribute. flush_trace_async_logging(terminate=True) should return without
3024      # reaching the error handler.
3025      exporter = SpanExporter()
3026      assert not hasattr(exporter, "_async_queue")
3027      with (
3028          mock.patch("mlflow.tracking.fluent._get_trace_exporter", return_value=exporter),
3029          mock.patch(
3030              "mlflow.tracking.fluent._logger.error",
3031              side_effect=AssertionError("flush should not reach error handler"),
3032          ),
3033      ):
3034          mlflow.flush_trace_async_logging(terminate=False)
3035  
3036  
3037  def test_flush_trace_async_logging_no_spurious_error_when_tracing_disabled():
3038      mlflow.tracing.disable()
3039      with mock.patch("mlflow.tracking.fluent._logger") as mock_logger:
3040          mlflow.flush_trace_async_logging(terminate=True)
3041      mock_logger.error.assert_not_called()