test_fluent.py
1 import asyncio 2 import json 3 import os 4 import subprocess 5 import sys 6 import threading 7 import time 8 import uuid 9 from concurrent.futures import ThreadPoolExecutor, as_completed 10 from dataclasses import asdict 11 from datetime import datetime 12 from unittest import mock 13 14 import pytest 15 from opentelemetry.sdk.trace.export import SpanExporter 16 17 import mlflow 18 from mlflow.entities import ( 19 SpanEvent, 20 SpanStatusCode, 21 SpanType, 22 Trace, 23 TraceData, 24 TraceInfo, 25 ) 26 from mlflow.entities.trace_location import TraceLocation, UCSchemaLocation 27 from mlflow.entities.trace_state import TraceState 28 from mlflow.environment_variables import MLFLOW_TRACE_SAMPLING_RATIO, MLFLOW_TRACKING_USERNAME 29 from mlflow.exceptions import MlflowException 30 from mlflow.store.entities.paged_list import PagedList 31 from mlflow.store.tracking import SEARCH_TRACES_DEFAULT_MAX_RESULTS 32 from mlflow.tracing.client import TracingClient 33 from mlflow.tracing.constant import ( 34 TRACE_SCHEMA_VERSION_KEY, 35 SpanAttributeKey, 36 TraceMetadataKey, 37 TraceTagKey, 38 ) 39 from mlflow.tracing.destination import MlflowExperiment 40 from mlflow.tracing.export.inference_table import pop_trace 41 from mlflow.tracing.fluent import start_span_no_context 42 from mlflow.tracing.provider import ( 43 _MLFLOW_TRACE_USER_DESTINATION, 44 _get_tracer, 45 safe_set_span_in_context, 46 set_destination, 47 ) 48 from mlflow.tracking.fluent import _get_experiment_id 49 from mlflow.version import IS_TRACING_SDK_ONLY 50 51 from tests.tracing.helper import ( 52 create_test_trace_info, 53 get_traces, 54 purge_traces, 55 skip_when_testing_trace_sdk, 56 ) 57 58 59 class DefaultTestModel: 60 @mlflow.trace() 61 def predict(self, x, y): 62 z = x + y 63 z = self.add_one(z) 64 z = mlflow.trace(self.square)(z) 65 return z # noqa: RET504 66 67 @mlflow.trace(span_type=SpanType.LLM, name="add_one_with_custom_name", attributes={"delta": 1}) 68 def add_one(self, z): 69 return z + 1 70 71 def square(self, t): 72 res = t**2 73 time.sleep(0.1) 74 return res 75 76 77 class DefaultAsyncTestModel: 78 @mlflow.trace() 79 async def predict(self, x, y): 80 z = x + y 81 z = await self.add_one(z) 82 z = await mlflow.trace(self.square)(z) 83 return z # noqa: RET504 84 85 @mlflow.trace(span_type=SpanType.LLM, name="add_one_with_custom_name", attributes={"delta": 1}) 86 async def add_one(self, z): 87 return z + 1 88 89 async def square(self, t): 90 res = t**2 91 time.sleep(0.1) 92 return res 93 94 95 class StreamTestModel: 96 @mlflow.trace(output_reducer=lambda x: sum(x)) 97 def predict_stream(self, x, y): 98 z = x + y 99 for i in range(z): 100 yield i 101 102 # Generator with a normal func 103 for i in range(z): 104 yield self.square(i) 105 106 # Nested generator 107 yield from self.generate_numbers(z) 108 109 @mlflow.trace 110 def square(self, t): 111 time.sleep(0.1) 112 return t**2 113 114 # No output_reducer -> record the list of outputs 115 @mlflow.trace 116 def generate_numbers(self, z): 117 for i in range(z): 118 yield i 119 120 121 class AsyncStreamTestModel: 122 @mlflow.trace(output_reducer=lambda x: sum(x)) 123 async def predict_stream(self, x, y): 124 z = x + y 125 for i in range(z): 126 yield i 127 128 # Generator with a normal func 129 for i in range(z): 130 yield await self.square(i) 131 132 # Nested generator 133 async for number in self.generate_numbers(z): 134 yield number 135 136 @mlflow.trace 137 async def square(self, t): 138 await asyncio.sleep(0.1) 139 return t**2 140 141 @mlflow.trace 142 async def generate_numbers(self, z): 143 for i in range(z): 144 yield i 145 146 147 class ErroringTestModel: 148 @mlflow.trace() 149 def predict(self, x, y): 150 return self.some_operation_raise_error(x, y) 151 152 @mlflow.trace() 153 def some_operation_raise_error(self, x, y): 154 raise ValueError("Some error") 155 156 157 class ErroringAsyncTestModel: 158 @mlflow.trace() 159 async def predict(self, x, y): 160 return await self.some_operation_raise_error(x, y) 161 162 @mlflow.trace() 163 async def some_operation_raise_error(self, x, y): 164 raise ValueError("Some error") 165 166 167 class ErroringStreamTestModel: 168 @mlflow.trace 169 def predict_stream(self, x): 170 for i in range(x): 171 if i > 0: 172 # Ensure distinct start_time_ns on Windows for deterministic span ordering 173 time.sleep(0.001) 174 yield self.some_operation_raise_error(i) 175 176 @mlflow.trace 177 def some_operation_raise_error(self, i): 178 if i >= 1: 179 raise ValueError("Some error") 180 return i 181 182 183 @pytest.fixture 184 def mock_client(): 185 client = mock.MagicMock() 186 with mock.patch("mlflow.tracing.fluent.TracingClient", return_value=client): 187 yield client 188 189 190 @pytest.fixture 191 def mock_otel_trace_start_time(): 192 # mock the start time of a trace, ensuring the root span has 193 # a smaller start time than child spans. 194 with mock.patch("opentelemetry.sdk.trace.time_ns", return_value=0): 195 yield 196 197 198 @pytest.mark.parametrize("with_active_run", [True, False]) 199 @pytest.mark.parametrize("wrap_sync_func", [True, False]) 200 def test_trace(wrap_sync_func, with_active_run, async_logging_enabled): 201 model = DefaultTestModel() if wrap_sync_func else DefaultAsyncTestModel() 202 203 if with_active_run: 204 if IS_TRACING_SDK_ONLY: 205 pytest.skip("Skipping test because mlflow or mlflow-skinny is not installed.") 206 207 with mlflow.start_run() as run: 208 model.predict(2, 5) if wrap_sync_func else asyncio.run(model.predict(2, 5)) 209 run_id = run.info.run_id 210 else: 211 model.predict(2, 5) if wrap_sync_func else asyncio.run(model.predict(2, 5)) 212 213 if async_logging_enabled: 214 mlflow.flush_trace_async_logging(terminate=True) 215 216 traces = get_traces() 217 assert len(traces) == 1 218 trace = traces[0] 219 assert trace.info.trace_id is not None 220 assert trace.info.experiment_id == _get_experiment_id() 221 assert trace.info.execution_time_ms >= 0.1 * 1e3 # at least 0.1 sec 222 assert trace.info.state == TraceState.OK 223 assert trace.info.request_metadata[TraceMetadataKey.INPUTS] == '{"x": 2, "y": 5}' 224 assert trace.info.request_metadata[TraceMetadataKey.OUTPUTS] == "64" 225 if with_active_run: 226 assert trace.info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run_id 227 228 assert trace.data.request == '{"x": 2, "y": 5}' 229 assert trace.data.response == "64" 230 assert len(trace.data.spans) == 3 231 232 span_name_to_span = {span.name: span for span in trace.data.spans} 233 root_span = span_name_to_span["predict"] 234 # TODO: Trace info timestamp is not accurate because it is not adjusted to exclude the latency 235 # assert root_span.start_time_ns // 1e6 == trace.info.timestamp_ms 236 assert root_span.parent_id is None 237 assert root_span.attributes == { 238 "mlflow.traceRequestId": trace.info.trace_id, 239 "mlflow.spanFunctionName": "predict", 240 "mlflow.spanType": "UNKNOWN", 241 "mlflow.spanInputs": {"x": 2, "y": 5}, 242 "mlflow.spanOutputs": 64, 243 } 244 245 child_span_1 = span_name_to_span["add_one_with_custom_name"] 246 assert child_span_1.parent_id == root_span.span_id 247 assert child_span_1.attributes == { 248 "delta": 1, 249 "mlflow.traceRequestId": trace.info.trace_id, 250 "mlflow.spanFunctionName": "add_one", 251 "mlflow.spanType": "LLM", 252 "mlflow.spanInputs": {"z": 7}, 253 "mlflow.spanOutputs": 8, 254 } 255 256 child_span_2 = span_name_to_span["square"] 257 assert child_span_2.parent_id == root_span.span_id 258 assert child_span_2.start_time_ns <= child_span_2.end_time_ns - 0.1 * 1e6 259 assert child_span_2.attributes == { 260 "mlflow.traceRequestId": trace.info.trace_id, 261 "mlflow.spanFunctionName": "square", 262 "mlflow.spanType": "UNKNOWN", 263 "mlflow.spanInputs": {"t": 8}, 264 "mlflow.spanOutputs": 64, 265 } 266 267 268 @pytest.mark.parametrize("wrap_sync_func", [True, False]) 269 def test_trace_stream(wrap_sync_func): 270 model = StreamTestModel() if wrap_sync_func else AsyncStreamTestModel() 271 272 stream = model.predict_stream(1, 2) 273 274 # Trace should not be logged until the generator is consumed 275 assert get_traces() == [] 276 # The span should not be set to active 277 # because the generator is not yet consumed 278 assert mlflow.get_current_active_span() is None 279 280 chunks = [] 281 if wrap_sync_func: 282 for chunk in stream: 283 chunks.append(chunk) 284 # The `predict` span should not be active here. 285 assert mlflow.get_current_active_span() is None 286 else: 287 288 async def consume_stream(): 289 async for chunk in stream: 290 chunks.append(chunk) 291 assert mlflow.get_current_active_span() is None 292 293 asyncio.run(consume_stream()) 294 295 traces = get_traces() 296 assert len(traces) == 1 297 trace = traces[0] 298 assert trace.info.trace_id is not None 299 assert trace.info.experiment_id == _get_experiment_id() 300 assert trace.info.execution_time_ms >= 0.1 * 1e3 # at least 0.1 sec 301 assert trace.info.status == SpanStatusCode.OK 302 metadata = trace.info.request_metadata 303 assert metadata[TraceMetadataKey.INPUTS] == '{"x": 1, "y": 2}' 304 assert metadata[TraceMetadataKey.OUTPUTS] == "11" # sum of the outputs 305 306 assert len(trace.data.spans) == 5 # 1 root span + 3 square + 1 generate_numbers 307 308 root_span = trace.data.spans[0] 309 assert root_span.name == "predict_stream" 310 assert root_span.inputs == {"x": 1, "y": 2} 311 assert root_span.outputs == 11 312 assert len(root_span.events) == 9 313 assert root_span.events[0].name == "mlflow.chunk.item.0" 314 assert root_span.events[0].attributes == {"mlflow.chunk.value": "0"} 315 assert root_span.events[8].name == "mlflow.chunk.item.8" 316 317 # Spans for the chid 'square' function 318 for i in range(3): 319 assert trace.data.spans[i + 1].name == "square" 320 assert trace.data.spans[i + 1].inputs == {"t": i} 321 assert trace.data.spans[i + 1].outputs == i**2 322 assert trace.data.spans[i + 1].parent_id == root_span.span_id 323 324 # Span for the 'generate_numbers' function 325 assert trace.data.spans[4].name == "generate_numbers" 326 assert trace.data.spans[4].inputs == {"z": 3} 327 assert trace.data.spans[4].outputs == [0, 1, 2] # list of outputs 328 assert len(trace.data.spans[4].events) == 3 329 330 331 def test_trace_with_databricks_tracking_uri(databricks_tracking_uri, monkeypatch): 332 monkeypatch.setenv("MLFLOW_EXPERIMENT_NAME", "test") 333 monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "bob") 334 monkeypatch.setattr(mlflow.tracking.context.default_context, "_get_source_name", lambda: "test") 335 336 model = DefaultTestModel() 337 338 mock_trace_info = mock.MagicMock() 339 mock_trace_info.trace_id = "123" 340 mock_trace_info.trace_location = mock.MagicMock() 341 mock_trace_info.trace_location.uc_schema = None 342 343 with ( 344 mock.patch( 345 "mlflow.tracing.client.TracingClient._upload_trace_data" 346 ) as mock_upload_trace_data, 347 mock.patch("mlflow.tracing.client._get_store") as mock_get_store, 348 ): 349 mock_get_store().start_trace.return_value = mock_trace_info 350 model.predict(2, 5) 351 mlflow.flush_trace_async_logging(terminate=True) 352 353 mock_get_store().start_trace.assert_called_once() 354 mock_upload_trace_data.assert_called_once() 355 356 357 # NB: async logging should be no-op for model serving, 358 # but we test it here to make sure it doesn't break 359 @skip_when_testing_trace_sdk 360 def test_trace_in_databricks_model_serving( 361 mock_databricks_serving_with_tracing_env, async_logging_enabled 362 ): 363 # Dummy flask app for prediction 364 import flask 365 366 from mlflow.pyfunc.context import Context, set_prediction_context 367 368 app = flask.Flask(__name__) 369 370 @app.route("/invocations", methods=["POST"]) 371 def predict(): 372 data = json.loads(flask.request.data.decode("utf-8")) 373 request_id = flask.request.headers.get("X-Request-ID") 374 375 with set_prediction_context(Context(request_id=request_id)): 376 prediction = TestModel().predict(**data) 377 378 trace = pop_trace(request_id=request_id) 379 380 result = json.dumps( 381 { 382 "prediction": prediction, 383 "trace": trace, 384 }, 385 default=str, 386 ) 387 return flask.Response(response=result, status=200, mimetype="application/json") 388 389 class TestModel: 390 @mlflow.trace() 391 def predict(self, x, y): 392 z = x + y 393 z = self.add_one(z) 394 with mlflow.start_span(name="square") as span: 395 z = self.square(z) 396 span.add_event(SpanEvent("event", 0, attributes={"foo": "bar"})) 397 return z 398 399 @mlflow.trace(span_type=SpanType.LLM, name="custom", attributes={"delta": 1}) 400 def add_one(self, z): 401 return z + 1 402 403 def square(self, t): 404 return t**2 405 406 # Mimic scoring request 407 databricks_request_id = "request-12345" 408 response = app.test_client().post( 409 "/invocations", 410 headers={"X-Request-ID": databricks_request_id}, 411 data=json.dumps({"x": 2, "y": 5}), 412 ) 413 414 assert response.status_code == 200 415 assert response.json["prediction"] == 64 416 417 trace_dict = response.json["trace"] 418 trace = Trace.from_dict(trace_dict) 419 assert trace.info.trace_id.startswith("tr-") 420 assert trace.info.client_request_id == databricks_request_id 421 assert trace.info.request_metadata[TRACE_SCHEMA_VERSION_KEY] == "3" 422 assert len(trace.data.spans) == 3 423 424 span_name_to_span = {span.name: span for span in trace.data.spans} 425 root_span = span_name_to_span["predict"] 426 assert isinstance(root_span._trace_id, str) 427 assert isinstance(root_span.span_id, str) 428 assert isinstance(root_span.start_time_ns, int) 429 assert isinstance(root_span.end_time_ns, int) 430 assert root_span.status.status_code.value == "OK" 431 assert root_span.status.description == "" 432 assert root_span.attributes == { 433 "mlflow.traceRequestId": trace.info.trace_id, 434 "mlflow.spanType": SpanType.UNKNOWN, 435 "mlflow.spanFunctionName": "predict", 436 "mlflow.spanInputs": {"x": 2, "y": 5}, 437 "mlflow.spanOutputs": 64, 438 } 439 assert root_span.events == [] 440 441 child_span_1 = span_name_to_span["custom"] 442 assert child_span_1.parent_id == root_span.span_id 443 assert child_span_1.attributes == { 444 "delta": 1, 445 "mlflow.traceRequestId": trace.info.trace_id, 446 "mlflow.spanType": SpanType.LLM, 447 "mlflow.spanFunctionName": "add_one", 448 "mlflow.spanInputs": {"z": 7}, 449 "mlflow.spanOutputs": 8, 450 } 451 assert child_span_1.events == [] 452 453 child_span_2 = span_name_to_span["square"] 454 assert child_span_2.parent_id == root_span.span_id 455 assert child_span_2.attributes == { 456 "mlflow.traceRequestId": trace.info.trace_id, 457 "mlflow.spanType": SpanType.UNKNOWN, 458 } 459 assert asdict(child_span_2.events[0]) == { 460 "name": "event", 461 "timestamp": 0, 462 "attributes": {"foo": "bar"}, 463 } 464 465 # The trace should be removed from the buffer after being retrieved 466 assert pop_trace(request_id=databricks_request_id) is None 467 468 # In model serving, the traces should not be stored in the fluent API buffer 469 traces = get_traces() 470 assert len(traces) == 0 471 472 473 @skip_when_testing_trace_sdk 474 def test_trace_in_model_evaluation(monkeypatch, async_logging_enabled): 475 from mlflow.pyfunc.context import Context, set_prediction_context 476 477 monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "bob") 478 monkeypatch.setattr(mlflow.tracking.context.default_context, "_get_source_name", lambda: "test") 479 480 class TestModel: 481 @mlflow.trace() 482 def predict(self, x, y): 483 return x + y 484 485 model = TestModel() 486 487 # mock _upload_trace_data to avoid generating trace data file 488 with mlflow.start_run() as run: 489 run_id = run.info.run_id 490 request_id_1 = "tr-eval-123" 491 with set_prediction_context(Context(request_id=request_id_1, is_evaluate=True)): 492 model.predict(1, 2) 493 494 request_id_2 = "tr-eval-456" 495 with set_prediction_context(Context(request_id=request_id_2, is_evaluate=True)): 496 model.predict(3, 4) 497 498 if async_logging_enabled: 499 mlflow.flush_trace_async_logging(terminate=True) 500 501 trace = mlflow.get_trace(request_id_1) 502 assert trace.info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run_id 503 assert trace.info.tags[TraceTagKey.EVAL_REQUEST_ID] == request_id_1 504 505 trace = mlflow.get_trace(request_id_2) 506 assert trace.info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run_id 507 assert trace.info.tags[TraceTagKey.EVAL_REQUEST_ID] == request_id_2 508 509 510 @pytest.mark.parametrize("sync", [True, False]) 511 def test_trace_handle_exception_during_prediction(sync): 512 # This test is to make sure that the exception raised by the main prediction 513 # logic is raised properly and the trace is still logged. 514 model = ErroringTestModel() if sync else ErroringAsyncTestModel() 515 516 with pytest.raises(ValueError, match=r"Some error"): 517 model.predict(2, 5) if sync else asyncio.run(model.predict(2, 5)) 518 519 # Trace should be logged even if the function fails, with status code ERROR 520 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 521 assert trace.info.trace_id is not None 522 assert trace.info.state == TraceState.ERROR 523 assert trace.info.request_metadata[TraceMetadataKey.INPUTS] == '{"x": 2, "y": 5}' 524 assert trace.info.request_metadata[TraceMetadataKey.OUTPUTS] == "" 525 526 assert trace.data.request == '{"x": 2, "y": 5}' 527 assert trace.data.response is None 528 assert len(trace.data.spans) == 2 529 530 531 def test_trace_handle_exception_during_streaming(): 532 model = ErroringStreamTestModel() 533 534 stream = model.predict_stream(2) 535 536 chunks = [] 537 with pytest.raises(ValueError, match=r"Some error"): # noqa: PT012 538 for chunk in stream: 539 chunks.append(chunk) 540 541 # The test model raises an error after the first chunk 542 assert len(chunks) == 1 543 544 traces = get_traces() 545 assert len(traces) == 1 546 trace = traces[0] 547 assert trace.info.state == TraceState.ERROR 548 assert trace.info.request_metadata[TraceMetadataKey.INPUTS] == '{"x": 2}' 549 550 # The test model is expected to produce three spans 551 # 1. Root span (error - inherited from the child) 552 # 2. First chunk span (OK) 553 # 3. Second chunk span (error) 554 spans = trace.data.spans 555 assert len(spans) == 3 556 assert spans[0].name == "predict_stream" 557 assert spans[0].status.status_code == SpanStatusCode.ERROR 558 assert spans[1].name == "some_operation_raise_error" 559 assert spans[1].status.status_code == SpanStatusCode.OK 560 assert spans[2].name == "some_operation_raise_error" 561 assert spans[2].status.status_code == SpanStatusCode.ERROR 562 563 # One chunk event + one exception event 564 assert len(spans[0].events) == 2 565 assert spans[0].events[0].name == "mlflow.chunk.item.0" 566 assert spans[0].events[1].name == "exception" 567 568 569 @pytest.mark.parametrize( 570 "model", 571 [ 572 DefaultTestModel(), 573 DefaultAsyncTestModel(), 574 StreamTestModel(), 575 AsyncStreamTestModel(), 576 ], 577 ) 578 def test_trace_ignore_exception(monkeypatch, model): 579 # This test is to make sure that the main prediction logic is not affected 580 # by the exception raised by the tracing logic. 581 def _call_model_and_assert_output(model): 582 if isinstance(model, DefaultTestModel): 583 output = model.predict(2, 5) 584 assert output == 64 585 elif isinstance(model, DefaultAsyncTestModel): 586 output = asyncio.run(model.predict(2, 5)) 587 assert output == 64 588 elif isinstance(model, StreamTestModel): 589 stream = model.predict_stream(2, 5) 590 assert len(list(stream)) == 21 591 elif isinstance(model, AsyncStreamTestModel): 592 astream = model.predict_stream(2, 5) 593 594 async def _consume_stream(): 595 return [chunk async for chunk in astream] 596 597 stream = asyncio.run(_consume_stream()) 598 assert len(list(stream)) == 21 599 else: 600 raise ValueError("Unknown model type") 601 602 # Exception during starting span: trace should not be logged. 603 with mock.patch("mlflow.tracing.provider._get_tracer", side_effect=ValueError("Some error")): 604 _call_model_and_assert_output(model) 605 606 assert get_traces() == [] 607 608 # Exception during ending span: trace should not be logged. 609 tracer = _get_tracer(__name__) 610 611 def _always_fail(*args, **kwargs): 612 raise ValueError("Some error") 613 614 monkeypatch.setattr(tracer.span_processor, "on_end", _always_fail) 615 _call_model_and_assert_output(model) 616 assert len(get_traces()) == 0 617 618 619 def test_trace_skip_resolving_unrelated_tags_to_traces(): 620 with mock.patch("mlflow.tracking.context.registry.DatabricksRepoRunContext") as mock_context: 621 mock_context.in_context.return_value = ["unrelated tags"] 622 623 model = DefaultTestModel() 624 model.predict(2, 5) 625 626 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 627 assert "unrelated tags" not in trace.info.tags 628 629 630 # Tracing SDK doesn't have `create_experiment` support 631 @skip_when_testing_trace_sdk 632 def test_trace_with_experiment_id(): 633 exp_1 = mlflow.create_experiment("exp_1") 634 exp_2 = mlflow.set_experiment("exp_2").experiment_id # active experiment 635 636 @mlflow.trace(trace_destination=MlflowExperiment(exp_1)) 637 def predict_1(): 638 with mlflow.start_span(name="child_span"): 639 return 640 641 @mlflow.trace() 642 def predict_2(): 643 pass 644 645 predict_1() 646 traces = get_traces(experiment_id=exp_1) 647 assert len(traces) == 1 648 assert traces[0].info.experiment_id == exp_1 649 assert len(traces[0].data.spans) == 2 650 assert get_traces(experiment_id=exp_2) == [] 651 652 predict_2() 653 traces = get_traces(experiment_id=exp_2) 654 assert len(traces) == 1 655 assert traces[0].info.experiment_id == exp_2 656 657 658 # Tracing SDK doesn't have `create_experiment` support 659 @skip_when_testing_trace_sdk 660 def test_trace_with_experiment_id_issue_warning_when_not_root_span(): 661 exp_1 = mlflow.create_experiment("exp_1") 662 663 @mlflow.trace(trace_destination=MlflowExperiment(exp_1)) 664 def predict_1(): 665 return predict_2() 666 667 @mlflow.trace(trace_destination=MlflowExperiment(exp_1)) 668 def predict_2(): 669 return 670 671 with mock.patch("mlflow.tracing.provider._logger") as mock_logger: 672 predict_1() 673 674 assert mock_logger.warning.call_count == 1 675 assert mock_logger.warning.call_args[0][0] == ( 676 "The `experiment_id` parameter can only be used for root spans, but the span " 677 "`predict_2` is not a root span. The specified value `1` will be ignored." 678 ) 679 680 681 def test_start_span_context_manager(async_logging_enabled): 682 datetime_now = datetime.now() 683 684 class TestModel: 685 def predict(self, x, y): 686 with mlflow.start_span(name="root_span") as root_span: 687 root_span.set_inputs({"x": x, "y": y}) 688 z = x + y 689 690 with mlflow.start_span(name="child_span", span_type=SpanType.LLM) as child_span: 691 child_span.set_inputs(z) 692 z = z + 2 693 child_span.set_outputs(z) 694 child_span.set_attributes({"delta": 2, "time": datetime_now}) 695 696 # Ensure deterministic span order on Windows by forcing different start_time_ns 697 time.sleep(0.001) 698 res = self.square(z) 699 root_span.set_outputs(res) 700 return res 701 702 def square(self, t): 703 with mlflow.start_span(name="child_span") as span: 704 span.set_inputs({"t": t}) 705 res = t**2 706 time.sleep(0.1) 707 span.set_outputs(res) 708 return res 709 710 model = TestModel() 711 model.predict(1, 2) 712 713 if async_logging_enabled: 714 mlflow.flush_trace_async_logging(terminate=True) 715 716 traces = get_traces() 717 assert len(traces) == 1 718 trace = traces[0] 719 assert trace.info.trace_id is not None 720 assert trace.info.experiment_id == _get_experiment_id() 721 assert trace.info.execution_time_ms >= 0.1 * 1e3 # at least 0.1 sec 722 assert trace.info.state == TraceState.OK 723 assert trace.info.request_metadata[TraceMetadataKey.INPUTS] == '{"x": 1, "y": 2}' 724 assert trace.info.request_metadata[TraceMetadataKey.OUTPUTS] == "25" 725 726 assert trace.data.request == '{"x": 1, "y": 2}' 727 assert trace.data.response == "25" 728 assert len(trace.data.spans) == 3 729 730 root_span = trace.data.spans[0] 731 assert root_span.name == "root_span" 732 assert root_span.parent_id is None 733 assert root_span.attributes == { 734 "mlflow.traceRequestId": trace.info.trace_id, 735 "mlflow.spanType": "UNKNOWN", 736 "mlflow.spanInputs": {"x": 1, "y": 2}, 737 "mlflow.spanOutputs": 25, 738 } 739 740 child_span_1 = trace.data.spans[1] 741 assert child_span_1.name == "child_span" 742 assert child_span_1.parent_id == root_span.span_id 743 assert child_span_1.attributes == { 744 "delta": 2, 745 "time": str(datetime_now), 746 "mlflow.traceRequestId": trace.info.trace_id, 747 "mlflow.spanType": "LLM", 748 "mlflow.spanInputs": 3, 749 "mlflow.spanOutputs": 5, 750 } 751 752 child_span_2 = trace.data.spans[2] 753 assert child_span_2.name == "child_span" 754 assert child_span_2.parent_id == root_span.span_id 755 assert child_span_2.attributes == { 756 "mlflow.traceRequestId": trace.info.trace_id, 757 "mlflow.spanType": "UNKNOWN", 758 "mlflow.spanInputs": {"t": 5}, 759 "mlflow.spanOutputs": 25, 760 } 761 assert child_span_2.start_time_ns <= child_span_2.end_time_ns - 0.1 * 1e6 762 763 764 def test_start_span_context_manager_with_imperative_apis(async_logging_enabled): 765 # This test is to make sure that the spans created with fluent APIs and imperative APIs 766 # (via MLflow client) are correctly linked together. This usage is not recommended but 767 # should be supported for the advanced use cases like using LangChain callbacks as a 768 # part of broader tracing. 769 class TestModel: 770 def predict(self, x, y): 771 with mlflow.start_span(name="root_span") as root_span: 772 root_span.set_inputs({"x": x, "y": y}) 773 z = x + y 774 775 child_span = start_span_no_context( 776 name="child_span_1", 777 span_type=SpanType.LLM, 778 parent_span=root_span, 779 ) 780 child_span.set_inputs(z) 781 782 z = z + 2 783 time.sleep(0.1) 784 785 child_span.set_outputs(z) 786 child_span.set_attributes({"delta": 2}) 787 child_span.end() 788 789 root_span.set_outputs(z) 790 return z 791 792 model = TestModel() 793 model.predict(1, 2) 794 795 if async_logging_enabled: 796 mlflow.flush_trace_async_logging(terminate=True) 797 798 traces = get_traces() 799 assert len(traces) == 1 800 trace = traces[0] 801 assert trace.info.trace_id is not None 802 assert trace.info.experiment_id == _get_experiment_id() 803 assert trace.info.execution_time_ms >= 0.1 * 1e3 # at least 0.1 sec 804 assert trace.info.state == TraceState.OK 805 assert trace.info.request_metadata[TraceMetadataKey.INPUTS] == '{"x": 1, "y": 2}' 806 assert trace.info.request_metadata[TraceMetadataKey.OUTPUTS] == "5" 807 808 assert trace.data.request == '{"x": 1, "y": 2}' 809 assert trace.data.response == "5" 810 assert len(trace.data.spans) == 2 811 812 span_name_to_span = {span.name: span for span in trace.data.spans} 813 root_span = span_name_to_span["root_span"] 814 assert root_span.parent_id is None 815 assert root_span.attributes == { 816 "mlflow.traceRequestId": trace.info.trace_id, 817 "mlflow.spanType": "UNKNOWN", 818 "mlflow.spanInputs": {"x": 1, "y": 2}, 819 "mlflow.spanOutputs": 5, 820 } 821 822 child_span_1 = span_name_to_span["child_span_1"] 823 assert child_span_1.parent_id == root_span.span_id 824 assert child_span_1.attributes == { 825 "delta": 2, 826 "mlflow.traceRequestId": trace.info.trace_id, 827 "mlflow.spanType": "LLM", 828 "mlflow.spanInputs": 3, 829 "mlflow.spanOutputs": 5, 830 } 831 832 833 def test_mlflow_trace_isolated_from_other_otel_processors(): 834 # Set up non-MLFlow tracer 835 import opentelemetry.sdk.trace as trace_sdk 836 from opentelemetry import trace 837 838 class MockOtelExporter(trace_sdk.export.SpanExporter): 839 def __init__(self): 840 self.exported_spans = [] 841 842 def export(self, spans): 843 self.exported_spans.extend(spans) 844 845 other_exporter = MockOtelExporter() 846 provider = trace_sdk.TracerProvider() 847 processor = trace_sdk.export.SimpleSpanProcessor(other_exporter) 848 provider.add_span_processor(processor) 849 trace.set_tracer_provider(provider) 850 851 # Create MLflow trace 852 with mlflow.start_span(name="mlflow_span"): 853 pass 854 855 # Create non-MLflow trace 856 tracer = trace.get_tracer(__name__) 857 with tracer.start_as_current_span("non_mlflow_span"): 858 pass 859 860 # MLflow only processes spans created with MLflow APIs 861 assert len(get_traces()) == 1 862 assert ( 863 mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True).data.spans[0].name 864 == "mlflow_span" 865 ) 866 867 # Other spans are processed by the other processor 868 assert len(other_exporter.exported_spans) == 1 869 assert other_exporter.exported_spans[0].name == "non_mlflow_span" 870 871 872 def test_get_trace(): 873 with mock.patch("mlflow.tracing.display.get_display_handler") as mock_get_display_handler: 874 model = DefaultTestModel() 875 model.predict(2, 5) 876 877 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 878 trace_id = trace.info.trace_id 879 mock_get_display_handler.reset_mock() 880 881 # Fetch trace from in-memory buffer 882 trace_in_memory = mlflow.get_trace(trace_id) 883 assert trace.info.trace_id == trace_in_memory.info.trace_id 884 mock_get_display_handler.assert_not_called() 885 886 # Fetch trace from backend 887 trace_from_backend = mlflow.get_trace(trace.info.trace_id) 888 assert trace.info.trace_id == trace_from_backend.info.trace_id 889 mock_get_display_handler.assert_not_called() 890 891 # If not found, return None with warning 892 with mock.patch("mlflow.tracing.fluent._logger") as mock_logger: 893 assert mlflow.get_trace("not_found") is None 894 mock_logger.warning.assert_called_once() 895 896 897 def test_test_search_traces_empty(mock_client): 898 mock_client.search_traces.return_value = PagedList([], token=None) 899 900 traces = mlflow.search_traces() 901 assert len(traces) == 0 902 903 if not IS_TRACING_SDK_ONLY: 904 default_columns = Trace.pandas_dataframe_columns() 905 assert traces.columns.tolist() == default_columns 906 907 traces = mlflow.search_traces(extract_fields=["foo.inputs.bar"]) 908 assert traces.columns.tolist() == [*default_columns, "foo.inputs.bar"] 909 910 mock_client.search_traces.assert_called() 911 912 913 @pytest.mark.parametrize("return_type", ["pandas", "list"]) 914 def test_search_traces(return_type, mock_client): 915 if return_type == "pandas" and IS_TRACING_SDK_ONLY: 916 pytest.skip("Skipping test because mlflow or mlflow-skinny is not installed.") 917 918 mock_client.search_traces.return_value = PagedList( 919 [ 920 Trace( 921 info=create_test_trace_info(f"tr-{i}"), 922 data=TraceData([]), 923 ) 924 for i in range(10) 925 ], 926 token=None, 927 ) 928 929 traces = mlflow.search_traces( 930 locations=["1"], 931 filter_string="name = 'foo'", 932 max_results=10, 933 order_by=["timestamp DESC"], 934 return_type=return_type, 935 ) 936 937 if return_type == "pandas": 938 import pandas as pd 939 940 assert isinstance(traces, pd.DataFrame) 941 else: 942 assert isinstance(traces, list) 943 assert all(isinstance(trace, Trace) for trace in traces) 944 945 assert len(traces) == 10 946 mock_client.search_traces.assert_called_once_with( 947 experiment_ids=None, 948 run_id=None, 949 filter_string="name = 'foo'", 950 max_results=10, 951 order_by=["timestamp DESC"], 952 page_token=None, 953 model_id=None, 954 include_spans=True, 955 locations=["1"], 956 ) 957 958 959 def test_search_traces_invalid_return_types(mock_client): 960 with pytest.raises(MlflowException, match=r"Invalid return type"): 961 mlflow.search_traces(return_type="invalid") 962 963 with pytest.raises(MlflowException, match=r"The `extract_fields`"): 964 mlflow.search_traces(extract_fields=["foo.inputs.bar"], return_type="list") 965 966 967 def test_search_traces_validates_experiment_ids_type(): 968 with pytest.raises(MlflowException, match=r"locations must be a list"): 969 mlflow.search_traces(locations=4) 970 971 with pytest.raises(MlflowException, match=r"locations must be a list"): 972 mlflow.search_traces(locations="4") 973 974 975 def test_search_traces_with_pagination(mock_client): 976 traces = [ 977 Trace( 978 info=create_test_trace_info(f"tr-{i}"), 979 data=TraceData([]), 980 ) 981 for i in range(30) 982 ] 983 984 mock_client.search_traces.side_effect = [ 985 PagedList(traces[:10], token="token-1"), 986 PagedList(traces[10:20], token="token-2"), 987 PagedList(traces[20:], token=None), 988 ] 989 990 traces = mlflow.search_traces(locations=["1"]) 991 992 assert len(traces) == 30 993 common_args = { 994 "experiment_ids": None, 995 "run_id": None, 996 "max_results": SEARCH_TRACES_DEFAULT_MAX_RESULTS, 997 "filter_string": None, 998 "order_by": None, 999 "include_spans": True, 1000 "model_id": None, 1001 "locations": ["1"], 1002 } 1003 mock_client.search_traces.assert_has_calls([ 1004 mock.call(**common_args, page_token=None), 1005 mock.call(**common_args, page_token="token-1"), 1006 mock.call(**common_args, page_token="token-2"), 1007 ]) 1008 1009 1010 def test_search_traces_with_default_experiment_id(mock_client): 1011 mock_client.search_traces.return_value = PagedList([], token=None) 1012 with mock.patch("mlflow.tracking.fluent._get_experiment_id", return_value="123"): 1013 mlflow.search_traces() 1014 1015 mock_client.search_traces.assert_called_once_with( 1016 experiment_ids=None, 1017 run_id=None, 1018 filter_string=None, 1019 max_results=SEARCH_TRACES_DEFAULT_MAX_RESULTS, 1020 order_by=None, 1021 page_token=None, 1022 model_id=None, 1023 include_spans=True, 1024 locations=["123"], 1025 ) 1026 1027 1028 @skip_when_testing_trace_sdk 1029 def test_search_traces_yields_expected_dataframe_contents(monkeypatch): 1030 model = DefaultTestModel() 1031 expected_traces = [] 1032 for _ in range(10): 1033 model.predict(2, 5) 1034 time.sleep(0.1) 1035 1036 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 1037 expected_traces.append(trace) 1038 1039 df = mlflow.search_traces(max_results=10, order_by=["timestamp ASC"], flush=True) 1040 assert df.columns.tolist() == [ 1041 "trace_id", 1042 "trace", 1043 "client_request_id", 1044 "state", 1045 "request_time", 1046 "execution_duration", 1047 "request", 1048 "response", 1049 "trace_metadata", 1050 "tags", 1051 "spans", 1052 "assessments", 1053 ] 1054 for idx, trace in enumerate(expected_traces): 1055 assert df.iloc[idx].trace_id == trace.info.trace_id 1056 assert Trace.from_json(df.iloc[idx].trace).info.trace_id == trace.info.trace_id 1057 assert df.iloc[idx].client_request_id == trace.info.client_request_id 1058 assert df.iloc[idx].state == trace.info.state 1059 assert df.iloc[idx].request_time == trace.info.request_time 1060 assert df.iloc[idx].execution_duration == pytest.approx( 1061 trace.info.execution_duration, abs=1 1062 ) 1063 assert df.iloc[idx].request == json.loads(trace.data.request) 1064 assert df.iloc[idx].response == json.loads(trace.data.response) 1065 assert df.iloc[idx].trace_metadata == trace.info.trace_metadata 1066 assert df.iloc[idx].spans == [s.to_dict() for s in trace.data.spans] 1067 assert df.iloc[idx].tags == trace.info.tags 1068 assert df.iloc[idx].assessments == trace.info.assessments 1069 1070 1071 @skip_when_testing_trace_sdk 1072 def test_search_traces_handles_missing_response_tags_and_metadata(mock_client): 1073 mock_client.search_traces.return_value = PagedList( 1074 [ 1075 Trace( 1076 info=TraceInfo( 1077 trace_id="5", 1078 trace_location=TraceLocation.from_experiment_id("test"), 1079 request_time=1, 1080 execution_duration=2, 1081 state=TraceState.OK, 1082 ), 1083 data=TraceData(spans=[]), 1084 ) 1085 ], 1086 token=None, 1087 ) 1088 1089 df = mlflow.search_traces() 1090 assert df["response"].isnull().all() 1091 assert df["tags"].tolist() == [{}] 1092 assert df["trace_metadata"].tolist() == [{}] 1093 1094 1095 @skip_when_testing_trace_sdk 1096 def test_search_traces_extracts_fields_as_expected(): 1097 model = DefaultTestModel() 1098 model.predict(2, 5) 1099 1100 df = mlflow.search_traces( 1101 extract_fields=["predict.inputs.x", "predict.outputs", "add_one_with_custom_name.inputs.z"], 1102 flush=True, 1103 ) 1104 assert df["predict.inputs.x"].tolist() == [2] 1105 assert df["predict.outputs"].tolist() == [64] 1106 assert df["add_one_with_custom_name.inputs.z"].tolist() == [7] 1107 1108 1109 # no spans have the input or output with name, 1110 # some span has an input but we're looking for output, 1111 @skip_when_testing_trace_sdk 1112 def test_search_traces_with_input_and_no_output(): 1113 with mlflow.start_span(name="with_input_and_no_output") as span: 1114 span.set_inputs({"a": 1}) 1115 1116 df = mlflow.search_traces( 1117 extract_fields=["with_input_and_no_output.inputs.a", "with_input_and_no_output.outputs"], 1118 flush=True, 1119 ) 1120 assert df["with_input_and_no_output.inputs.a"].tolist() == [1] 1121 assert df["with_input_and_no_output.outputs"].isnull().all() 1122 1123 1124 @skip_when_testing_trace_sdk 1125 def test_search_traces_with_non_dict_span_inputs_outputs(): 1126 with mlflow.start_span(name="non_dict_span") as span: 1127 span.set_inputs(["a", "b"]) 1128 span.set_outputs([1, 2, 3]) 1129 1130 df = mlflow.search_traces( 1131 extract_fields=["non_dict_span.inputs", "non_dict_span.outputs", "non_dict_span.inputs.x"], 1132 flush=True, 1133 ) 1134 assert df["non_dict_span.inputs"].tolist() == [["a", "b"]] 1135 assert df["non_dict_span.outputs"].tolist() == [[1, 2, 3]] 1136 assert df["non_dict_span.inputs.x"].isnull().all() 1137 1138 1139 @skip_when_testing_trace_sdk 1140 def test_search_traces_extract_fields_preserves_standard_columns(): 1141 with mlflow.start_span(name="test_span") as span: 1142 span.set_inputs({"x": 1}) 1143 span.set_outputs({"y": 2}) 1144 1145 df = mlflow.search_traces(extract_fields=["test_span.inputs.x"], flush=True) 1146 1147 # Verify standard columns still exist 1148 assert "trace_id" in df.columns 1149 assert "spans" in df.columns 1150 assert "tags" in df.columns 1151 assert "request" in df.columns 1152 assert "response" in df.columns 1153 1154 # Verify extract field was added 1155 assert "test_span.inputs.x" in df.columns 1156 assert df["test_span.inputs.x"].tolist() == [1] 1157 1158 1159 @skip_when_testing_trace_sdk 1160 def test_search_traces_with_multiple_spans_with_same_name(): 1161 class TestModel: 1162 @mlflow.trace(name="duplicate_name") 1163 def predict(self, x, y): 1164 z = x + y 1165 z = self.add_one(z) 1166 z = mlflow.trace(self.square)(z) 1167 return z # noqa: RET504 1168 1169 @mlflow.trace(span_type=SpanType.LLM, name="duplicate_name", attributes={"delta": 1}) 1170 def add_one(self, z): 1171 return z + 1 1172 1173 def square(self, t): 1174 res = t**2 1175 time.sleep(0.1) 1176 return res 1177 1178 model = TestModel() 1179 model.predict(2, 5) 1180 1181 df = mlflow.search_traces( 1182 extract_fields=[ 1183 "duplicate_name.inputs.x", 1184 "duplicate_name.inputs.y", 1185 "duplicate_name.inputs.z", 1186 ], 1187 flush=True, 1188 ) 1189 # Duplicate spans would all be null 1190 assert df["duplicate_name.inputs.x"].isnull().all() 1191 assert df["duplicate_name.inputs.y"].isnull().all() 1192 assert df["duplicate_name.inputs.z"].tolist() == [7] 1193 1194 1195 # Test a field that doesn't exist for extraction - we shouldn't throw, just return empty column 1196 @skip_when_testing_trace_sdk 1197 def test_search_traces_with_non_existent_field(): 1198 model = DefaultTestModel() 1199 model.predict(2, 5) 1200 1201 df = mlflow.search_traces( 1202 extract_fields=[ 1203 "predict.inputs.k", 1204 "predict.inputs.x", 1205 "predict.outputs", 1206 "add_one_with_custom_name.inputs.z", 1207 ], 1208 flush=True, 1209 ) 1210 assert df["predict.inputs.k"].isnull().all() 1211 assert df["predict.inputs.x"].tolist() == [2] 1212 assert df["predict.outputs"].tolist() == [64] 1213 assert df["add_one_with_custom_name.inputs.z"].tolist() == [7] 1214 1215 1216 @skip_when_testing_trace_sdk 1217 def test_search_traces_span_and_field_name_with_dot(): 1218 with mlflow.start_span(name="span.name") as span: 1219 span.set_inputs({"a.b": 0}) 1220 span.set_outputs({"x.y": 1}) 1221 1222 df = mlflow.search_traces( 1223 extract_fields=[ 1224 "`span.name`.inputs", 1225 "`span.name`.inputs.`a.b`", 1226 "`span.name`.outputs", 1227 "`span.name`.outputs.`x.y`", 1228 ], 1229 flush=True, 1230 ) 1231 1232 assert df["span.name.inputs"].tolist() == [{"a.b": 0}] 1233 assert df["span.name.inputs.a.b"].tolist() == [0] 1234 assert df["span.name.outputs"].tolist() == [{"x.y": 1}] 1235 assert df["span.name.outputs.x.y"].tolist() == [1] 1236 1237 1238 @skip_when_testing_trace_sdk 1239 def test_search_traces_with_run_id(): 1240 def _create_trace(name, tags=None): 1241 with mlflow.start_span(name=name) as span: 1242 for k, v in (tags or {}).items(): 1243 mlflow.set_trace_tag(trace_id=span.request_id, key=k, value=v) 1244 return span.request_id 1245 1246 def _get_names(traces): 1247 tags = traces["tags"].tolist() 1248 return [tags[i].get(TraceTagKey.TRACE_NAME) for i in range(len(tags))] 1249 1250 with mlflow.start_run() as run1: 1251 _create_trace(name="tr-1") 1252 _create_trace(name="tr-2", tags={"fruit": "apple"}) 1253 1254 with mlflow.start_run() as run2: 1255 _create_trace(name="tr-3") 1256 _create_trace(name="tr-4", tags={"fruit": "banana"}) 1257 _create_trace(name="tr-5", tags={"fruit": "apple"}) 1258 1259 traces = mlflow.search_traces(flush=True) 1260 assert set(_get_names(traces)) == {"tr-5", "tr-4", "tr-3", "tr-2", "tr-1"} 1261 1262 traces = mlflow.search_traces(run_id=run1.info.run_id, flush=True) 1263 assert set(_get_names(traces)) == {"tr-2", "tr-1"} 1264 1265 traces = mlflow.search_traces( 1266 run_id=run2.info.run_id, 1267 filter_string="tag.fruit = 'apple'", 1268 flush=True, 1269 ) 1270 assert _get_names(traces) == ["tr-5"] 1271 1272 with pytest.raises(MlflowException, match="You cannot filter by run_id when it is already"): 1273 mlflow.search_traces( 1274 run_id=run2.info.run_id, 1275 filter_string="metadata.mlflow.sourceRun = '123'", 1276 ) 1277 1278 with pytest.raises(MlflowException, match=f"Run {run1.info.run_id} belongs to"): 1279 mlflow.search_traces(run_id=run1.info.run_id, locations=["1"]) 1280 1281 1282 @pytest.mark.parametrize( 1283 "extract_fields", 1284 [ 1285 ["span.llm.inputs"], 1286 ["span.llm.inputs.x"], 1287 ["span.llm.outputs"], 1288 ], 1289 ) 1290 @skip_when_testing_trace_sdk 1291 def test_search_traces_invalid_extract_fields(extract_fields): 1292 with pytest.raises(MlflowException, match="Invalid field type"): 1293 mlflow.search_traces(extract_fields=extract_fields) 1294 1295 1296 def test_get_last_active_trace_id(): 1297 assert mlflow.get_last_active_trace_id() is None 1298 1299 @mlflow.trace() 1300 def predict(x, y): 1301 return x + y 1302 1303 predict(1, 2) 1304 predict(2, 5) 1305 predict(3, 6) 1306 1307 trace_id = mlflow.get_last_active_trace_id() 1308 trace = mlflow.get_trace(trace_id, flush=True) 1309 assert trace.info.trace_id is not None 1310 assert trace.data.request == '{"x": 3, "y": 6}' 1311 1312 # Mutation of the copy should not affect the original trace logged in the backend 1313 trace.info.state = TraceState.ERROR 1314 original_trace = mlflow.get_trace(trace.info.trace_id) 1315 assert original_trace.info.state == TraceState.OK 1316 1317 1318 def test_get_last_active_trace_thread_local(): 1319 assert mlflow.get_last_active_trace_id() is None 1320 1321 def run(id): 1322 @mlflow.trace(name=f"predict_{id}") 1323 def predict(x, y): 1324 return x + y 1325 1326 predict(1, 2) 1327 1328 return mlflow.get_last_active_trace_id(thread_local=True) 1329 1330 with ThreadPoolExecutor( 1331 max_workers=4, thread_name_prefix="test-tracing-fluent-last-active" 1332 ) as executor: 1333 futures = [executor.submit(run, i) for i in range(10)] 1334 trace_ids = [future.result() for future in futures] 1335 1336 assert len(trace_ids) == 10 1337 for i, trace_id in enumerate(trace_ids): 1338 trace = mlflow.get_trace(trace_id, flush=True) 1339 assert trace.info.state == TraceState.OK 1340 assert trace.data.spans[0].name == f"predict_{i}" 1341 1342 1343 def test_trace_with_classmethod(): 1344 class TestModel: 1345 @mlflow.trace 1346 @classmethod 1347 def predict(cls, x, y): 1348 return x + y 1349 1350 # Call the classmethod 1351 result = TestModel.predict(1, 2) 1352 assert result == 3 1353 1354 # Get the last trace and verify inputs and outputs 1355 trace_id = mlflow.get_last_active_trace_id() 1356 assert trace_id is not None 1357 1358 trace = mlflow.get_trace(trace_id, flush=True) 1359 assert trace is not None 1360 assert len(trace.data.spans) > 0 1361 1362 # The first span should be our traced function 1363 span = trace.data.spans[0] 1364 assert span.name == "predict" 1365 assert span.inputs == {"x": 1, "y": 2} 1366 assert span.outputs == 3 1367 1368 1369 def test_trace_with_classmethod_order_reversed(): 1370 class TestModel: 1371 @classmethod 1372 @mlflow.trace 1373 def predict(cls, x, y): 1374 return x + y 1375 1376 # Call the classmethod 1377 result = TestModel.predict(1, 2) 1378 assert result == 3 1379 1380 # Get the last trace and verify inputs and outputs 1381 trace_id = mlflow.get_last_active_trace_id() 1382 assert trace_id is not None 1383 1384 trace = mlflow.get_trace(trace_id, flush=True) 1385 assert trace is not None 1386 assert len(trace.data.spans) > 0 1387 1388 # The first span should be our traced function 1389 span = trace.data.spans[0] 1390 assert span.name == "predict" 1391 assert span.inputs == {"x": 1, "y": 2} 1392 assert span.outputs == 3 1393 1394 1395 def test_trace_with_staticmethod(): 1396 class TestModel: 1397 @mlflow.trace 1398 @staticmethod 1399 def predict(x, y): 1400 return x + y 1401 1402 # Call the staticmethod 1403 result = TestModel.predict(1, 2) 1404 assert result == 3 1405 1406 # Get the last trace and verify inputs and outputs 1407 trace_id = mlflow.get_last_active_trace_id() 1408 assert trace_id is not None 1409 1410 trace = mlflow.get_trace(trace_id, flush=True) 1411 assert trace is not None 1412 assert len(trace.data.spans) > 0 1413 1414 # The first span should be our traced function 1415 span = trace.data.spans[0] 1416 assert span.name == "predict" 1417 assert span.inputs == {"x": 1, "y": 2} 1418 assert span.outputs == 3 1419 1420 1421 def test_trace_with_staticmethod_order_reversed(): 1422 class TestModel: 1423 @staticmethod 1424 @mlflow.trace 1425 def predict(x, y): 1426 return x + y 1427 1428 # Call the staticmethod 1429 result = TestModel.predict(1, 2) 1430 assert result == 3 1431 1432 # Get the last trace and verify inputs and outputs 1433 trace_id = mlflow.get_last_active_trace_id() 1434 assert trace_id is not None 1435 1436 trace = mlflow.get_trace(trace_id, flush=True) 1437 assert trace is not None 1438 assert len(trace.data.spans) > 0 1439 1440 # The first span should be our traced function 1441 span = trace.data.spans[0] 1442 assert span.name == "predict" 1443 assert span.inputs == {"x": 1, "y": 2} 1444 assert span.outputs == 3 1445 1446 1447 def test_update_current_trace(): 1448 @mlflow.trace(name="root_function") 1449 def f(x): 1450 mlflow.update_current_trace(tags={"fruit": "apple", "animal": "dog"}) 1451 return g(x) + 1 1452 1453 @mlflow.trace(name="level_1_function") 1454 def g(y): 1455 with mlflow.start_span(name="level_2_span"): 1456 mlflow.update_current_trace(tags={"fruit": "orange", "vegetable": "carrot"}) 1457 return h(y) * 2 1458 1459 @mlflow.trace(name="level_3_function") 1460 def h(z): 1461 with mlflow.start_span(name="level_4_span"): 1462 with mlflow.start_span(name="level_5_span"): 1463 mlflow.update_current_trace(tags={"depth": "deep", "level": "5"}) 1464 return z + 10 1465 1466 f(1) 1467 1468 expected_tags = { 1469 "animal": "dog", 1470 "fruit": "orange", 1471 "vegetable": "carrot", 1472 "depth": "deep", 1473 "level": "5", 1474 } 1475 1476 # Validate in-memory trace 1477 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 1478 assert trace.info.state == TraceState.OK 1479 tags = {k: v for k, v in trace.info.tags.items() if not k.startswith("mlflow.")} 1480 assert tags == expected_tags 1481 1482 # Validate backend trace 1483 traces = get_traces() 1484 assert len(traces) == 1 1485 assert traces[0].info.state == TraceState.OK 1486 tags = {k: v for k, v in traces[0].info.tags.items() if not k.startswith("mlflow.")} 1487 assert tags == expected_tags 1488 1489 # Verify trace can be searched by span names (only when database backend is available) 1490 if not IS_TRACING_SDK_ONLY: 1491 trace_by_root_span = mlflow.search_traces( 1492 filter_string='span.name = "root_function"', return_type="list", flush=True 1493 ) 1494 assert len(trace_by_root_span) == 1 1495 1496 trace_by_level_2_span = mlflow.search_traces( 1497 filter_string='span.name = "level_2_span"', return_type="list", flush=True 1498 ) 1499 assert len(trace_by_level_2_span) == 1 1500 1501 trace_by_level_5_span = mlflow.search_traces( 1502 filter_string='span.name = "level_5_span"', return_type="list", flush=True 1503 ) 1504 assert len(trace_by_level_5_span) == 1 1505 1506 # All searches should return the same trace 1507 assert trace_by_root_span[0].info.request_id == trace.info.request_id 1508 assert trace_by_level_2_span[0].info.request_id == trace.info.request_id 1509 assert trace_by_level_5_span[0].info.request_id == trace.info.request_id 1510 1511 1512 def test_update_current_trace_with_client_request_id(): 1513 from mlflow.tracing.trace_manager import InMemoryTraceManager 1514 1515 # Test updating during span execution 1516 with mlflow.start_span("test_span") as span: 1517 # Update with both tags and client_request_id 1518 mlflow.update_current_trace(tags={"operation": "test"}, client_request_id="req-12345") 1519 1520 # Check in-memory trace during execution 1521 trace_manager = InMemoryTraceManager.get_instance() 1522 with trace_manager.get_trace(span.trace_id) as trace: 1523 assert trace.info.client_request_id == "req-12345" 1524 tags = {k: v for k, v in trace.info.tags.items() if not k.startswith("mlflow.")} 1525 assert tags["operation"] == "test" 1526 1527 # Test with tags only 1528 with mlflow.start_span("test_span_2") as span: 1529 mlflow.update_current_trace(tags={"operation": "tags_only"}) 1530 1531 trace_manager = InMemoryTraceManager.get_instance() 1532 with trace_manager.get_trace(span.trace_id) as trace: 1533 assert trace.info.client_request_id is None 1534 tags = {k: v for k, v in trace.info.tags.items() if not k.startswith("mlflow.")} 1535 assert tags["operation"] == "tags_only" 1536 1537 # Test with client_request_id only 1538 with mlflow.start_span("test_span_3") as span: 1539 mlflow.update_current_trace(client_request_id="req-67890") 1540 1541 trace_manager = InMemoryTraceManager.get_instance() 1542 with trace_manager.get_trace(span.trace_id) as trace: 1543 assert trace.info.client_request_id == "req-67890" 1544 1545 1546 def test_update_current_trace_client_request_id_overwrites(): 1547 from mlflow.tracing.trace_manager import InMemoryTraceManager 1548 1549 with mlflow.start_span("overwrite_test") as span: 1550 # First set 1551 mlflow.update_current_trace(client_request_id="req-initial") 1552 1553 # Overwrite with new value 1554 mlflow.update_current_trace(client_request_id="req-updated") 1555 1556 # Check during execution 1557 trace_manager = InMemoryTraceManager.get_instance() 1558 with trace_manager.get_trace(span.trace_id) as trace: 1559 # Should have the updated value, not the initial one 1560 assert trace.info.client_request_id == "req-updated" 1561 1562 1563 def test_update_current_trace_client_request_id_stringification(): 1564 from mlflow.tracing.trace_manager import InMemoryTraceManager 1565 1566 test_cases = [ 1567 (123, "123"), 1568 (45.67, "45.67"), 1569 (True, "True"), 1570 (False, "False"), 1571 (None, None), # None should remain None 1572 (["list", "value"], "['list', 'value']"), 1573 ({"dict": "value"}, "{'dict': 'value'}"), 1574 ] 1575 1576 for input_value, expected_output in test_cases: 1577 with mlflow.start_span(f"stringification_test_{input_value}") as span: 1578 if input_value is None: 1579 # None should not update the client_request_id 1580 mlflow.update_current_trace(client_request_id=input_value) 1581 trace_manager = InMemoryTraceManager.get_instance() 1582 with trace_manager.get_trace(span.trace_id) as trace: 1583 assert trace.info.client_request_id is None 1584 else: 1585 mlflow.update_current_trace(client_request_id=input_value) 1586 trace_manager = InMemoryTraceManager.get_instance() 1587 with trace_manager.get_trace(span.trace_id) as trace: 1588 assert trace.info.client_request_id == expected_output 1589 assert isinstance(trace.info.client_request_id, str) 1590 1591 1592 def test_update_current_trace_with_metadata(): 1593 @mlflow.trace 1594 def f(): 1595 mlflow.update_current_trace( 1596 metadata={ 1597 "mlflow.source.name": "inference.py", 1598 "mlflow.source.git.commit": "1234567890", 1599 "mlflow.source.git.repoURL": "https://github.com/mlflow/mlflow", 1600 "non-string-metadata": 123, 1601 }, 1602 ) 1603 1604 f() 1605 1606 expected_metadata = { 1607 "mlflow.source.name": "inference.py", 1608 "mlflow.source.git.commit": "1234567890", 1609 "mlflow.source.git.repoURL": "https://github.com/mlflow/mlflow", 1610 "non-string-metadata": "123", # Should be stringified 1611 } 1612 1613 # Validate in-memory trace 1614 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 1615 for k, v in expected_metadata.items(): 1616 assert trace.info.trace_metadata[k] == v 1617 1618 # Validate backend trace 1619 traces = get_traces() 1620 assert len(traces) == 1 1621 assert traces[0].info.status == "OK" 1622 for k, v in expected_metadata.items(): 1623 assert traces[0].info.trace_metadata[k] == v 1624 1625 1626 @skip_when_testing_trace_sdk 1627 def test_update_current_trace_with_model_id(): 1628 with mlflow.start_span("test_span"): 1629 mlflow.update_current_trace(model_id="model-123") 1630 1631 trace = get_traces()[0] 1632 assert trace.info.trace_metadata[TraceMetadataKey.MODEL_ID] == "model-123" 1633 1634 1635 @skip_when_testing_trace_sdk 1636 def test_update_current_trace_should_not_raise_during_model_logging(): 1637 """ 1638 Tracing is disabled while model logging. When the model includes 1639 `update_current_trace` call, it should be no-op. 1640 """ 1641 1642 class MyModel(mlflow.pyfunc.PythonModel): 1643 @mlflow.trace 1644 def predict(self, model_inputs): 1645 mlflow.update_current_trace(tags={"fruit": "apple"}) 1646 return [model_inputs[0] + 1] 1647 1648 model = MyModel() 1649 1650 model.predict([1]) 1651 trace = get_traces()[0] 1652 assert trace.info.state == "OK" 1653 assert trace.info.tags["fruit"] == "apple" 1654 purge_traces() 1655 1656 model_info = mlflow.pyfunc.log_model( 1657 python_model=model, 1658 name="model", 1659 input_example=[0], 1660 ) 1661 # Trace should not be generated while logging the model 1662 assert get_traces() == [] 1663 1664 # Signature should be inferred properly without raising any exception 1665 assert model_info.signature is not None 1666 assert model_info.signature.inputs is not None 1667 assert model_info.signature.outputs is not None 1668 1669 # Loading back the model 1670 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 1671 loaded_model.predict([1]) 1672 trace = get_traces()[0] 1673 assert trace.info.status == "OK" 1674 assert trace.info.tags["fruit"] == "apple" 1675 1676 1677 def test_update_current_trace_with_state(): 1678 from mlflow.tracing.trace_manager import InMemoryTraceManager 1679 1680 # Test with TraceState enum 1681 with mlflow.start_span("test_span") as span: 1682 mlflow.update_current_trace(state=TraceState.ERROR) 1683 1684 trace_manager = InMemoryTraceManager.get_instance() 1685 with trace_manager.get_trace(span.trace_id) as trace: 1686 assert trace.info.state == TraceState.ERROR 1687 1688 # Test with string state 1689 with mlflow.start_span("test_span_2") as span: 1690 mlflow.update_current_trace(state="OK") 1691 1692 trace_manager = InMemoryTraceManager.get_instance() 1693 with trace_manager.get_trace(span.trace_id) as trace: 1694 assert trace.info.state == TraceState.OK 1695 1696 # Test with combined parameters 1697 with mlflow.start_span("test_span_3") as span: 1698 mlflow.update_current_trace( 1699 state="ERROR", tags={"error_type": "validation"}, client_request_id="req-123" 1700 ) 1701 1702 trace_manager = InMemoryTraceManager.get_instance() 1703 with trace_manager.get_trace(span.trace_id) as trace: 1704 assert trace.info.state == TraceState.ERROR 1705 assert trace.info.tags["error_type"] == "validation" 1706 assert trace.info.client_request_id == "req-123" 1707 1708 1709 def test_update_current_trace_state_none(): 1710 from mlflow.tracing.trace_manager import InMemoryTraceManager 1711 1712 with mlflow.start_span("test_span") as span: 1713 # First set state to OK 1714 mlflow.update_current_trace(state="OK") 1715 1716 # Then call with state=None - should not change state 1717 mlflow.update_current_trace(state=None, tags={"test": "value"}) 1718 1719 trace_manager = InMemoryTraceManager.get_instance() 1720 with trace_manager.get_trace(span.trace_id) as trace: 1721 assert trace.info.state == TraceState.OK 1722 assert trace.info.tags["test"] == "value" 1723 1724 1725 def test_update_current_trace_state_validation(): 1726 with mlflow.start_span("test_span"): 1727 # Valid states should work 1728 mlflow.update_current_trace(state="OK") 1729 mlflow.update_current_trace(state="ERROR") 1730 mlflow.update_current_trace(state=TraceState.OK) 1731 mlflow.update_current_trace(state=TraceState.ERROR) 1732 1733 # Invalid string state should raise an exception 1734 with pytest.raises( 1735 MlflowException, match=r"State must be either 'OK' or 'ERROR', but got 'IN_PROGRESS'" 1736 ): 1737 mlflow.update_current_trace(state="IN_PROGRESS") 1738 1739 # Invalid enum state should raise an exception 1740 with pytest.raises( 1741 MlflowException, 1742 match=r"State must be either 'OK' or 'ERROR', but got 'STATE_UNSPECIFIED'", 1743 ): 1744 mlflow.update_current_trace(state=TraceState.STATE_UNSPECIFIED) 1745 1746 # Custom invalid string should raise an exception 1747 with pytest.raises( 1748 MlflowException, match=r"State must be either 'OK' or 'ERROR', but got 'CUSTOM_STATE'" 1749 ): 1750 mlflow.update_current_trace(state="CUSTOM_STATE") 1751 1752 # Invalid types should raise an exception with a proper error message 1753 with pytest.raises( 1754 MlflowException, match=r"State must be either 'OK' or 'ERROR', but got '123'" 1755 ): 1756 mlflow.update_current_trace(state=123) 1757 1758 1759 def test_span_record_exception_with_string(): 1760 with mlflow.start_span("test_span") as span: 1761 span.record_exception("Something went wrong") 1762 1763 # Check persisted trace 1764 trace = get_traces()[0] 1765 spans = trace.data.spans 1766 test_span = spans[0] 1767 1768 # Verify span status is ERROR 1769 assert test_span.status.status_code == SpanStatusCode.ERROR 1770 1771 # Verify exception event was added 1772 exception_events = [event for event in test_span.events if "exception" in event.name.lower()] 1773 assert len(exception_events) == 1 1774 1775 # Verify exception message is in the event 1776 exception_event = exception_events[0] 1777 assert "Something went wrong" in str(exception_event.attributes) 1778 1779 1780 def test_span_record_exception_with_exception(): 1781 test_exception = ValueError("Custom error message") 1782 1783 with mlflow.start_span("test_span") as span: 1784 span.record_exception(test_exception) 1785 1786 # Check persisted trace 1787 trace = get_traces()[0] 1788 spans = trace.data.spans 1789 test_span = spans[0] 1790 1791 # Verify span status is ERROR 1792 assert test_span.status.status_code == SpanStatusCode.ERROR 1793 1794 # Verify exception event was added with proper exception details 1795 exception_events = [event for event in test_span.events if "exception" in event.name.lower()] 1796 assert len(exception_events) == 1 1797 1798 exception_event = exception_events[0] 1799 event_attrs = str(exception_event.attributes) 1800 assert "ValueError" in event_attrs 1801 assert "Custom error message" in event_attrs 1802 1803 1804 def test_span_record_exception_invalid_type(): 1805 with mlflow.start_span("test_span") as span: 1806 with pytest.raises( 1807 MlflowException, 1808 match="The `exception` parameter must be an Exception instance or a string", 1809 ): 1810 span.record_exception(123) 1811 1812 1813 def test_combined_state_and_record_exception(): 1814 @mlflow.trace 1815 def test_function(): 1816 # Get current span and record exception 1817 span = mlflow.get_current_active_span() 1818 span.record_exception("Processing failed") 1819 1820 # Update trace state independently 1821 mlflow.update_current_trace(state="ERROR", tags={"error_source": "processing"}) 1822 return "result" 1823 1824 test_function() 1825 1826 # Check the trace 1827 trace = get_traces()[0] 1828 1829 # Verify trace state was set to ERROR 1830 assert trace.info.state == TraceState.ERROR 1831 assert trace.info.tags["error_source"] == "processing" 1832 1833 # Verify span has exception event and ERROR state 1834 spans = trace.data.spans 1835 root_span = spans[0] 1836 assert root_span.status.status_code == SpanStatusCode.ERROR 1837 1838 exception_events = [event for event in root_span.events if "exception" in event.name.lower()] 1839 assert len(exception_events) == 1 1840 assert "Processing failed" in str(exception_events[0].attributes) 1841 1842 1843 def test_span_record_exception_no_op_span(): 1844 # This should not raise an exception 1845 from mlflow.entities.span import NoOpSpan 1846 1847 no_op_span = NoOpSpan() 1848 no_op_span.record_exception("This should be ignored") 1849 1850 # Should not create any traces 1851 assert get_traces() == [] 1852 1853 1854 def test_update_current_trace_state_isolation(): 1855 with mlflow.start_span("test_span") as span: 1856 # Set span status to OK explicitly 1857 span.set_status("OK") 1858 1859 # Update trace state to ERROR 1860 mlflow.update_current_trace(state="ERROR") 1861 1862 # Span status should still be OK 1863 assert span.status.status_code == SpanStatusCode.OK 1864 1865 # Check the final persisted trace 1866 trace = get_traces()[0] 1867 assert trace.info.state == TraceState.ERROR 1868 1869 # Verify span status remained OK despite trace state being ERROR 1870 spans = trace.data.spans 1871 test_span = spans[0] 1872 assert test_span.status.status_code == SpanStatusCode.OK 1873 1874 1875 @skip_when_testing_trace_sdk 1876 def test_non_ascii_characters_not_encoded_as_unicode(): 1877 with mlflow.start_span() as span: 1878 span.set_inputs({"japanese": "あ", "emoji": "👍"}) 1879 1880 trace = mlflow.get_trace(span.trace_id, flush=True) 1881 span = trace.data.spans[0] 1882 assert span.inputs == {"japanese": "あ", "emoji": "👍"} 1883 1884 1885 _SAMPLE_REMOTE_TRACE = { 1886 "info": { 1887 "request_id": "2e72d64369624e6888324462b62dc120", 1888 "experiment_id": "0", 1889 "timestamp_ms": 1726145090860, 1890 "execution_time_ms": 162, 1891 "status": "OK", 1892 "request_metadata": { 1893 "mlflow.trace_schema.version": "2", 1894 "mlflow.traceInputs": '{"x": 1}', 1895 "mlflow.traceOutputs": '{"prediction": 1}', 1896 }, 1897 "tags": { 1898 "fruit": "apple", 1899 "food": "pizza", 1900 }, 1901 }, 1902 "data": { 1903 "spans": [ 1904 { 1905 "name": "remote", 1906 "context": { 1907 "span_id": "0x337af925d6629c01", 1908 "trace_id": "0x05e82d1fc4486f3986fae6dd7b5352b1", 1909 }, 1910 "parent_id": None, 1911 "start_time": 1726145091022155863, 1912 "end_time": 1726145091022572053, 1913 "status_code": "OK", 1914 "status_message": "", 1915 "attributes": { 1916 "mlflow.traceRequestId": '"2e72d64369624e6888324462b62dc120"', 1917 "mlflow.spanType": '"UNKNOWN"', 1918 "mlflow.spanInputs": '{"x": 1}', 1919 "mlflow.spanOutputs": '{"prediction": 1}', 1920 }, 1921 "events": [ 1922 {"name": "event", "timestamp": 1726145091022287, "attributes": {"foo": "bar"}} 1923 ], 1924 }, 1925 { 1926 "name": "remote-child", 1927 "context": { 1928 "span_id": "0xa3dde9f2ebac1936", 1929 "trace_id": "0x05e82d1fc4486f3986fae6dd7b5352b1", 1930 }, 1931 "parent_id": "0x337af925d6629c01", 1932 "start_time": 1726145091022419340, 1933 "end_time": 1726145091022497944, 1934 "status_code": "OK", 1935 "status_message": "", 1936 "attributes": { 1937 "mlflow.traceRequestId": '"2e72d64369624e6888324462b62dc120"', 1938 "mlflow.spanType": '"UNKNOWN"', 1939 }, 1940 "events": [], 1941 }, 1942 ], 1943 "request": '{"x": 1}', 1944 "response": '{"prediction": 1}', 1945 }, 1946 } 1947 1948 1949 def test_add_trace(mock_otel_trace_start_time): 1950 # Mimic a remote service call that returns a trace as a part of the response 1951 def dummy_remote_call(): 1952 return {"prediction": 1, "trace": _SAMPLE_REMOTE_TRACE} 1953 1954 @mlflow.trace 1955 def predict(add_trace: bool): 1956 resp = dummy_remote_call() 1957 1958 if add_trace: 1959 mlflow.add_trace(resp["trace"]) 1960 return resp["prediction"] 1961 1962 # If we don't call add_trace, the trace from the remote service should be discarded 1963 predict(add_trace=False) 1964 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 1965 assert len(trace.data.spans) == 1 1966 1967 # If we call add_trace, the trace from the remote service should be merged 1968 predict(add_trace=True) 1969 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 1970 trace_id = trace.info.trace_id 1971 assert trace_id is not None 1972 assert trace.data.request == '{"add_trace": true}' 1973 assert trace.data.response == "1" 1974 # Remote spans should be merged 1975 assert len(trace.data.spans) == 3 1976 assert all(span.trace_id == trace_id for span in trace.data.spans) 1977 parent_span, child_span, grandchild_span = trace.data.spans 1978 assert child_span.parent_id == parent_span.span_id 1979 assert child_span._trace_id == parent_span._trace_id 1980 assert grandchild_span.parent_id == child_span.span_id 1981 assert grandchild_span._trace_id == parent_span._trace_id 1982 # Check if span information is correctly copied 1983 rs = Trace.from_dict(_SAMPLE_REMOTE_TRACE).data.spans[0] 1984 assert child_span.name == rs.name 1985 assert child_span.start_time_ns == rs.start_time_ns 1986 assert child_span.end_time_ns == rs.end_time_ns 1987 assert child_span.status == rs.status 1988 assert child_span.span_type == rs.span_type 1989 assert child_span.events == rs.events 1990 # exclude request ID attribute from comparison 1991 for k in rs.attributes.keys() - {SpanAttributeKey.REQUEST_ID}: 1992 assert child_span.attributes[k] == rs.attributes[k] 1993 1994 1995 def test_add_trace_no_current_active_trace(): 1996 # Use the remote trace without any active trace 1997 remote_trace = Trace.from_dict(_SAMPLE_REMOTE_TRACE) 1998 1999 mlflow.add_trace(remote_trace) 2000 2001 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 2002 assert len(trace.data.spans) == 3 2003 parent_span, child_span, grandchild_span = trace.data.spans 2004 assert parent_span.name == "Remote Trace <remote>" 2005 rs = remote_trace.data.spans[0] 2006 assert parent_span.start_time_ns == rs.start_time_ns - 1 2007 assert parent_span.end_time_ns == rs.end_time_ns 2008 assert child_span.name == rs.name 2009 assert child_span.parent_id is parent_span.span_id 2010 assert child_span.start_time_ns == rs.start_time_ns 2011 assert child_span.end_time_ns == rs.end_time_ns 2012 assert child_span.status == rs.status 2013 assert child_span.span_type == rs.span_type 2014 assert child_span.events == rs.events 2015 assert grandchild_span.parent_id == child_span.span_id 2016 # exclude request ID attribute from comparison 2017 for k in rs.attributes.keys() - {SpanAttributeKey.REQUEST_ID}: 2018 assert child_span.attributes[k] == rs.attributes[k] 2019 2020 2021 def test_add_trace_specific_target_span(mock_otel_trace_start_time): 2022 span = start_span_no_context(name="parent") 2023 mlflow.add_trace(_SAMPLE_REMOTE_TRACE, target=span) 2024 span.end() 2025 2026 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 2027 assert len(trace.data.spans) == 3 2028 parent_span, child_span, grandchild_span = trace.data.spans 2029 assert parent_span.span_id == span.span_id 2030 rs = Trace.from_dict(_SAMPLE_REMOTE_TRACE).data.spans[0] 2031 assert child_span.name == rs.name 2032 assert child_span.parent_id is parent_span.span_id 2033 assert grandchild_span.parent_id == child_span.span_id 2034 2035 2036 def test_add_trace_merge_tags(): 2037 client = TracingClient() 2038 2039 # Start the parent trace and merge the above trace as a child 2040 with mlflow.start_span(name="parent") as span: 2041 client.set_trace_tag(span.trace_id, "vegetable", "carrot") 2042 client.set_trace_tag(span.trace_id, "food", "sushi") 2043 2044 mlflow.add_trace(Trace.from_dict(_SAMPLE_REMOTE_TRACE)) 2045 2046 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 2047 custom_tags = {k: v for k, v in trace.info.tags.items() if not k.startswith("mlflow.")} 2048 assert custom_tags == { 2049 "fruit": "apple", 2050 "vegetable": "carrot", 2051 # Tag value from the parent trace should prevail 2052 "food": "sushi", 2053 } 2054 2055 2056 def test_add_trace_raise_for_invalid_trace(): 2057 with pytest.raises(MlflowException, match="Invalid trace object"): 2058 mlflow.add_trace(None) 2059 2060 with pytest.raises(MlflowException, match="Failed to load a trace object"): 2061 mlflow.add_trace({"info": {}, "data": {}}) 2062 2063 in_progress_trace = Trace( 2064 info=TraceInfo( 2065 trace_id="123", 2066 trace_location=TraceLocation.from_experiment_id("0"), 2067 request_time=0, 2068 execution_duration=0, 2069 state=TraceState.IN_PROGRESS, 2070 ), 2071 data=TraceData(), 2072 ) 2073 with pytest.raises(MlflowException, match="The trace must be ended"): 2074 mlflow.add_trace(in_progress_trace) 2075 2076 trace = Trace.from_dict(_SAMPLE_REMOTE_TRACE) 2077 spans = trace.data.spans 2078 unordered_trace = Trace(info=trace.info, data=TraceData(spans=[spans[1], spans[0]])) 2079 with pytest.raises(MlflowException, match="Span with ID "): 2080 mlflow.add_trace(unordered_trace) 2081 2082 2083 @skip_when_testing_trace_sdk 2084 def test_add_trace_in_databricks_model_serving(mock_databricks_serving_with_tracing_env): 2085 from mlflow.pyfunc.context import Context, set_prediction_context 2086 2087 # Mimic a remote service call that returns a trace as a part of the response 2088 def dummy_remote_call(): 2089 return {"prediction": 1, "trace": _SAMPLE_REMOTE_TRACE} 2090 2091 # The parent function that invokes the dummy remote service 2092 @mlflow.trace 2093 def predict(): 2094 resp = dummy_remote_call() 2095 remote_trace = Trace.from_dict(resp["trace"]) 2096 mlflow.add_trace(remote_trace) 2097 return resp["prediction"] 2098 2099 db_request_id = "databricks-request-id" 2100 with set_prediction_context(Context(request_id=db_request_id)): 2101 predict() 2102 2103 # Pop the trace to be written to the inference table 2104 trace = Trace.from_dict(pop_trace(request_id=db_request_id)) 2105 2106 assert trace.info.trace_id.startswith("tr-") 2107 assert trace.info.client_request_id == db_request_id 2108 assert len(trace.data.spans) == 3 2109 assert all(span.trace_id == trace.info.trace_id for span in trace.data.spans) 2110 parent_span, child_span, grandchild_span = trace.data.spans 2111 assert child_span.parent_id == parent_span.span_id 2112 assert child_span._trace_id == parent_span._trace_id 2113 assert grandchild_span.parent_id == child_span.span_id 2114 assert grandchild_span._trace_id == parent_span._trace_id 2115 # Check if span information is correctly copied 2116 rs = Trace.from_dict(_SAMPLE_REMOTE_TRACE).data.spans[0] 2117 assert child_span.name == rs.name 2118 assert child_span.start_time_ns == rs.start_time_ns 2119 assert child_span.end_time_ns == rs.end_time_ns 2120 2121 2122 @skip_when_testing_trace_sdk 2123 def test_add_trace_logging_model_from_code(): 2124 with mlflow.start_run(): 2125 model_info = mlflow.pyfunc.log_model( 2126 name="model", 2127 python_model="tests/tracing/sample_code/model_with_add_trace.py", 2128 input_example=[1, 2], 2129 ) 2130 2131 loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) 2132 # Trace should not be logged while logging / loading 2133 assert mlflow.get_trace(mlflow.get_last_active_trace_id()) is None 2134 2135 loaded_model.predict(1) 2136 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 2137 assert trace is not None 2138 assert len(trace.data.spans) == 2 2139 2140 2141 @pytest.mark.parametrize( 2142 "inputs", [{"question": "Does mlflow support tracing?"}, "Does mlflow support tracing?", None] 2143 ) 2144 @pytest.mark.parametrize("outputs", [{"answer": "Yes"}, "Yes", None]) 2145 @pytest.mark.parametrize( 2146 "intermediate_outputs", 2147 [ 2148 { 2149 "retrieved_documents": ["mlflow documentation"], 2150 "system_prompt": ["answer the question with yes or no"], 2151 }, 2152 None, 2153 ], 2154 ) 2155 def test_log_trace_success(inputs, outputs, intermediate_outputs): 2156 start_time_ms = 1736144700 2157 execution_time_ms = 5129 2158 2159 mlflow.log_trace( 2160 name="test", 2161 request=inputs, 2162 response=outputs, 2163 intermediate_outputs=intermediate_outputs, 2164 start_time_ms=start_time_ms, 2165 execution_time_ms=execution_time_ms, 2166 ) 2167 2168 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 2169 if inputs is not None: 2170 assert trace.data.request == json.dumps(inputs) 2171 else: 2172 assert trace.data.request is None 2173 if outputs is not None: 2174 assert trace.data.response == json.dumps(outputs) 2175 else: 2176 assert trace.data.response is None 2177 if intermediate_outputs is not None: 2178 assert trace.data.intermediate_outputs == intermediate_outputs 2179 spans = trace.data.spans 2180 assert len(spans) == 1 2181 root_span = spans[0] 2182 assert root_span.name == "test" 2183 assert root_span.start_time_ns == start_time_ms * 1000000 2184 assert root_span.end_time_ns == (start_time_ms + execution_time_ms) * 1000000 2185 2186 2187 def test_set_delete_trace_tag(): 2188 with mlflow.start_span("span1") as span: 2189 trace_id = span.trace_id 2190 2191 mlflow.set_trace_tag(trace_id=trace_id, key="key1", value="value1") 2192 trace = mlflow.get_trace(trace_id=trace_id, flush=True) 2193 assert trace.info.tags["key1"] == "value1" 2194 2195 mlflow.delete_trace_tag(trace_id=trace_id, key="key1") 2196 trace = mlflow.get_trace(trace_id=trace_id, flush=True) 2197 assert "key1" not in trace.info.tags 2198 2199 # Test with request_id kwarg (backward compatibility) 2200 mlflow.set_trace_tag(request_id=trace_id, key="key3", value="value3") 2201 trace = mlflow.get_trace(request_id=trace_id, flush=True) 2202 assert trace.info.tags["key3"] == "value3" 2203 2204 mlflow.delete_trace_tag(request_id=trace_id, key="key3") 2205 trace = mlflow.get_trace(request_id=trace_id, flush=True) 2206 assert "key3" not in trace.info.tags 2207 2208 2209 @pytest.mark.parametrize("is_databricks", [True, False]) 2210 def test_search_traces_with_run_id_validates_store_filter_string(is_databricks): 2211 mock_store = mock.MagicMock() 2212 mock_store.search_traces.return_value = ([], None) 2213 mock_store.get_run.return_value = mock.MagicMock() 2214 mock_store.get_run.return_value.info.experiment_id = "test_exp_id" 2215 2216 test_run_id = "test_run_123" 2217 with ( 2218 mock.patch("mlflow.tracing.client._get_store", return_value=mock_store), 2219 mock.patch("mlflow.tracking.fluent._get_experiment_id", return_value="test_exp_id"), 2220 ): 2221 mlflow.search_traces(run_id=test_run_id) 2222 2223 expected_filter_string = f"attribute.run_id = '{test_run_id}'" 2224 mock_store.search_traces.assert_called() 2225 2226 call_args = mock_store.search_traces.call_args 2227 actual_filter_string = call_args[1]["filter_string"] 2228 assert actual_filter_string == expected_filter_string 2229 2230 2231 def test_search_traces_with_locations(mock_client): 2232 mock_client.search_traces.return_value = PagedList([], token=None) 2233 2234 # Test with locations 2235 mlflow.search_traces(locations=["catalog1.schema1", "catalog2.schema2"]) 2236 2237 # Verify that search_traces was called with locations 2238 mock_client.search_traces.assert_called_once() 2239 call_kwargs = mock_client.search_traces.call_args.kwargs 2240 assert call_kwargs["locations"] == ["catalog1.schema1", "catalog2.schema2"] 2241 assert call_kwargs.get("experiment_ids") is None 2242 2243 2244 @pytest.mark.filterwarnings("ignore::FutureWarning") 2245 def test_search_traces_experiment_ids_deprecation_warning(mock_client): 2246 mock_client.search_traces.return_value = PagedList([], token=None) 2247 2248 # Test that using experiment_ids shows a deprecation warning 2249 with pytest.warns(FutureWarning, match="experiment_ids.*deprecated.*use.*locations"): 2250 mlflow.search_traces(experiment_ids=["123"]) 2251 2252 # Verify that search_traces was called and experiment_ids was converted to locations 2253 mock_client.search_traces.assert_called_once() 2254 call_kwargs = mock_client.search_traces.call_args.kwargs 2255 assert call_kwargs["locations"] == ["123"] 2256 assert call_kwargs["experiment_ids"] is None 2257 2258 2259 def test_search_traces_with_sql_warehouse_id(mock_client): 2260 mock_client.search_traces.return_value = PagedList([], token=None) 2261 2262 # Test with sql_warehouse_id 2263 mlflow.search_traces(locations=["123"], sql_warehouse_id="warehouse456") 2264 2265 # Verify that search_traces was called with sql_warehouse_id 2266 mock_client.search_traces.assert_called_once() 2267 call_kwargs = mock_client.search_traces.call_args.kwargs 2268 assert call_kwargs["locations"] == ["123"] 2269 assert "sql_warehouse_id" not in call_kwargs 2270 assert os.environ["MLFLOW_TRACING_SQL_WAREHOUSE_ID"] == "warehouse456" 2271 2272 2273 @skip_when_testing_trace_sdk 2274 @pytest.mark.flaky(attempts=3, condition=sys.platform == "win32") 2275 @pytest.mark.parametrize("use_batch_processor", [False, True]) 2276 def test_set_destination_in_threads(async_logging_enabled, use_batch_processor, monkeypatch): 2277 monkeypatch.setenv("MLFLOW_USE_BATCH_SPAN_PROCESSOR", str(use_batch_processor)) 2278 2279 # This test makes sure `set_destination` obeys thread-local behavior. 2280 class TestModel: 2281 def predict(self, x): 2282 with mlflow.start_span(name="root_span") as root_span: 2283 2284 def child_span_thread(z): 2285 child_span = start_span_no_context( 2286 name="child_span_1", 2287 parent_span=root_span, 2288 ) 2289 child_span.set_inputs(z) 2290 time.sleep(0.5) 2291 child_span.end() 2292 2293 thread = threading.Thread( 2294 name="test-fluent-child-span", target=child_span_thread, args=(x + 1,) 2295 ) 2296 thread.start() 2297 thread.join() 2298 return x 2299 2300 model = TestModel() 2301 2302 def func(experiment_id: str | None, x: int): 2303 if experiment_id is not None: 2304 set_destination(MlflowExperiment(experiment_id), context_local=True) 2305 2306 time.sleep(0.5) 2307 model.predict(x) 2308 2309 # Main thread: global config 2310 experiment_id1 = mlflow.create_experiment(uuid.uuid4().hex) 2311 set_destination(MlflowExperiment(experiment_id1)) 2312 func(None, 3) 2313 2314 # Thread 1: context-local config 2315 experiment_id2 = mlflow.create_experiment(uuid.uuid4().hex) 2316 thread1 = threading.Thread( 2317 name="test-fluent-destination-thread1", target=func, args=(experiment_id2, 3) 2318 ) 2319 2320 # Thread 2: context-local config 2321 experiment_id3 = mlflow.create_experiment(uuid.uuid4().hex) 2322 thread2 = threading.Thread( 2323 name="test-fluent-destination-thread2", target=func, args=(experiment_id3, 40) 2324 ) 2325 2326 # Thread 3: no config -> fallback to global config 2327 thread3 = threading.Thread(name="test-fluent-destination-thread3", target=func, args=(None, 40)) 2328 2329 thread1.start() 2330 thread2.start() 2331 thread3.start() 2332 2333 thread1.join() 2334 thread2.join() 2335 thread3.join() 2336 2337 if async_logging_enabled: 2338 mlflow.flush_trace_async_logging(terminate=True) 2339 2340 traces = get_traces(experiment_id1) 2341 assert len(traces) == 2 # main thread + thread 3 2342 assert traces[0].info.experiment_id == experiment_id1 2343 assert len(traces[0].data.spans) == 2 2344 assert traces[1].info.experiment_id == experiment_id1 2345 assert len(traces[1].data.spans) == 2 2346 2347 for exp_id in [experiment_id2, experiment_id3]: 2348 traces = get_traces(exp_id) 2349 assert len(traces) == 1 2350 assert traces[0].info.experiment_id == exp_id 2351 assert len(traces[0].data.spans) == 2 2352 2353 2354 @pytest.mark.asyncio 2355 @skip_when_testing_trace_sdk 2356 async def test_set_destination_in_async_contexts(async_logging_enabled): 2357 class TestModel: 2358 async def predict(self, x): 2359 with mlflow.start_span(name="root_span") as root_span: 2360 2361 async def child_span_task(z): 2362 child_span = start_span_no_context( 2363 name="child_span_1", 2364 parent_span=root_span, 2365 ) 2366 child_span.set_inputs(z) 2367 await asyncio.sleep(0.5) 2368 child_span.end() 2369 2370 await child_span_task(x + 1) 2371 return x 2372 2373 model = TestModel() 2374 2375 async def async_func(experiment_id: str, x: int): 2376 set_destination(MlflowExperiment(experiment_id), context_local=True) 2377 await asyncio.sleep(0.5) 2378 await model.predict(x) 2379 2380 experiment_id1 = mlflow.create_experiment(uuid.uuid4().hex) 2381 task1 = asyncio.create_task(async_func(experiment_id1, 3)) 2382 2383 experiment_id2 = mlflow.create_experiment(uuid.uuid4().hex) 2384 task2 = asyncio.create_task(async_func(experiment_id2, 40)) 2385 2386 await asyncio.gather(task1, task2) 2387 2388 if async_logging_enabled: 2389 mlflow.flush_trace_async_logging(terminate=True) 2390 2391 for exp_id in [experiment_id1, experiment_id2]: 2392 traces = get_traces(exp_id) 2393 assert len(traces) == 1 2394 assert traces[0].info.experiment_id == exp_id 2395 assert len(traces[0].data.spans) == 2 2396 2397 2398 def test_set_destination_from_env_var_databricks_uc(monkeypatch): 2399 monkeypatch.setenv("MLFLOW_TRACING_DESTINATION", "catalog.schema") 2400 destination = _MLFLOW_TRACE_USER_DESTINATION.get() 2401 assert isinstance(destination, UCSchemaLocation) 2402 assert destination.catalog_name == "catalog" 2403 assert destination.schema_name == "schema" 2404 assert mlflow.get_tracking_uri() == "databricks" 2405 2406 2407 @skip_when_testing_trace_sdk 2408 def test_traces_can_be_searched_by_span_properties(async_logging_enabled): 2409 @mlflow.trace(name="test_span") 2410 def test_function(): 2411 return "result" 2412 2413 test_function() 2414 2415 if async_logging_enabled: 2416 mlflow.flush_trace_async_logging(terminate=True) 2417 2418 traces = mlflow.search_traces(filter_string='span.name = "test_span"', return_type="list") 2419 assert len(traces) == 1, "Should find exactly one trace with span name 'test_span'" 2420 found_span_names = [span.name for span in traces[0].data.spans] 2421 assert "test_span" in found_span_names 2422 2423 2424 @pytest.mark.skipif( 2425 IS_TRACING_SDK_ONLY, reason="Skipping test because mlflow or mlflow-skinny is not installed." 2426 ) 2427 def test_search_traces_with_full_text(): 2428 with mlflow.start_span(name="test_span") as span: 2429 span.set_attribute("llm.inputs", "How's the result?") 2430 span.set_attribute("llm.outputs", "the number increased 90%") 2431 trace_id_1 = span.trace_id 2432 2433 with mlflow.start_span(name="test_span") as span: 2434 span.set_outputs({"outputs": 1234567}) 2435 span.set_attribute("test", "the number increased") 2436 trace_id_2 = span.trace_id 2437 2438 with mlflow.start_span(name="test_span") as span: 2439 span.set_attribute("test", "result including 'single quotes'") 2440 trace_id_3 = span.trace_id 2441 2442 traces = mlflow.search_traces( 2443 filter_string='trace.text LIKE "%How\'s the result?%"', return_type="list", flush=True 2444 ) 2445 assert len(traces) == 1 2446 assert traces[0].info.trace_id == trace_id_1 2447 2448 traces = mlflow.search_traces( 2449 filter_string='trace.text LIKE "%1234567%"', return_type="list", flush=True 2450 ) 2451 assert len(traces) == 1 2452 assert traces[0].info.trace_id == trace_id_2 2453 2454 traces = mlflow.search_traces( 2455 filter_string="trace.text LIKE \"%result including 'single quotes'%\"", 2456 return_type="list", 2457 flush=True, 2458 ) 2459 assert len(traces) == 1 2460 assert traces[0].info.trace_id == trace_id_3 2461 2462 traces = mlflow.search_traces( 2463 filter_string='trace.text LIKE "%increased 90%%"', return_type="list", flush=True 2464 ) 2465 assert len(traces) == 1 2466 assert traces[0].info.trace_id == trace_id_1 2467 2468 2469 def _create_trace_with_session(session_id: str, name: str = "test_span") -> str: 2470 with mlflow.start_span(name=name) as span: 2471 mlflow.update_current_trace(metadata={TraceMetadataKey.TRACE_SESSION: session_id}) 2472 span.set_inputs({"input": "test"}) 2473 span.set_outputs({"output": "test"}) 2474 mlflow.flush_trace_async_logging() 2475 return span.trace_id 2476 2477 2478 def _create_trace_without_session(name: str = "test_span") -> str: 2479 with mlflow.start_span(name=name) as span: 2480 span.set_inputs({"input": "test"}) 2481 span.set_outputs({"output": "test"}) 2482 mlflow.flush_trace_async_logging() 2483 return span.trace_id 2484 2485 2486 @pytest.mark.skipif( 2487 IS_TRACING_SDK_ONLY, reason="Skipping test because mlflow or mlflow-skinny is not installed." 2488 ) 2489 def test_search_sessions_empty(): 2490 # Create a trace without a session ID - should result in no sessions 2491 _create_trace_without_session() 2492 sessions = mlflow.search_sessions() 2493 assert sessions == [] 2494 2495 2496 @pytest.mark.skipif( 2497 IS_TRACING_SDK_ONLY, reason="Skipping test because mlflow or mlflow-skinny is not installed." 2498 ) 2499 def test_search_sessions_returns_grouped_traces(): 2500 session_id_1 = f"session-1-{uuid.uuid4().hex[:8]}" 2501 session_id_2 = f"session-2-{uuid.uuid4().hex[:8]}" 2502 2503 # Create traces for session 1 2504 trace_id_1 = _create_trace_with_session(session_id_1, "session1_trace1") 2505 trace_id_2 = _create_trace_with_session(session_id_1, "session1_trace2") 2506 2507 # Create trace for session 2 2508 trace_id_3 = _create_trace_with_session(session_id_2, "session2_trace1") 2509 2510 sessions = mlflow.search_sessions() 2511 2512 assert len(sessions) == 2 2513 2514 # Convert to dict keyed by session.id for easier assertions 2515 sessions_by_id = {s.id: s for s in sessions} 2516 2517 assert len(sessions_by_id[session_id_1]) == 2 2518 assert len(sessions_by_id[session_id_2]) == 1 2519 2520 # Verify trace IDs 2521 session_1_trace_ids = {t.info.trace_id for t in sessions_by_id[session_id_1]} 2522 assert trace_id_1 in session_1_trace_ids 2523 assert trace_id_2 in session_1_trace_ids 2524 assert sessions_by_id[session_id_2][0].info.trace_id == trace_id_3 2525 2526 2527 @pytest.mark.skipif( 2528 IS_TRACING_SDK_ONLY, reason="Skipping test because mlflow or mlflow-skinny is not installed." 2529 ) 2530 def test_search_sessions_respects_max_results(): 2531 session_ids = [f"session-{i}-{uuid.uuid4().hex[:8]}" for i in range(3)] 2532 2533 # Create one trace per session 2534 for session_id in session_ids: 2535 _create_trace_with_session(session_id) 2536 2537 sessions = mlflow.search_sessions(max_results=2) 2538 2539 assert len(sessions) == 2 2540 2541 2542 @pytest.mark.skipif( 2543 IS_TRACING_SDK_ONLY, reason="Skipping test because mlflow or mlflow-skinny is not installed." 2544 ) 2545 def test_search_sessions_skips_traces_without_session_id(): 2546 session_id = f"session-{uuid.uuid4().hex[:8]}" 2547 2548 # Create trace without session 2549 _create_trace_without_session("no_session_trace") 2550 2551 # Create trace with session 2552 trace_id = _create_trace_with_session(session_id, "with_session_trace") 2553 2554 sessions = mlflow.search_sessions() 2555 2556 assert len(sessions) == 1 2557 assert len(sessions[0]) == 1 2558 assert sessions[0][0].info.trace_id == trace_id 2559 2560 2561 def test_search_sessions_validates_locations_type(): 2562 with pytest.raises(MlflowException, match=r"locations must be a list"): 2563 mlflow.search_sessions(locations=4) 2564 2565 with pytest.raises(MlflowException, match=r"locations must be a list"): 2566 mlflow.search_sessions(locations="4") 2567 2568 2569 @pytest.mark.skipif( 2570 IS_TRACING_SDK_ONLY, reason="Skipping test because mlflow or mlflow-skinny is not installed." 2571 ) 2572 def test_search_sessions_with_default_experiment_id(): 2573 session_id = f"session-{uuid.uuid4().hex[:8]}" 2574 _create_trace_with_session(session_id) 2575 2576 # search_sessions should use the default experiment 2577 sessions = mlflow.search_sessions() 2578 assert len(sessions) == 1 2579 2580 2581 def test_search_sessions_raises_without_experiment(): 2582 with mock.patch("mlflow.tracking.fluent._get_experiment_id", return_value=None): 2583 with pytest.raises(MlflowException, match=r"No active experiment found"): 2584 mlflow.search_sessions() 2585 2586 2587 @pytest.mark.skipif( 2588 IS_TRACING_SDK_ONLY, reason="Skipping test because mlflow or mlflow-skinny is not installed." 2589 ) 2590 def test_search_sessions_include_spans_true(): 2591 session_id = f"session-{uuid.uuid4().hex[:8]}" 2592 _create_trace_with_session(session_id) 2593 2594 sessions = mlflow.search_sessions(include_spans=True) 2595 2596 assert len(sessions) == 1 2597 assert len(sessions[0]) == 1 2598 # When include_spans=True, spans should be populated 2599 assert len(sessions[0][0].data.spans) > 0 2600 2601 2602 @pytest.mark.skipif( 2603 IS_TRACING_SDK_ONLY, reason="Skipping test because mlflow or mlflow-skinny is not installed." 2604 ) 2605 def test_search_sessions_include_spans_false(): 2606 session_id = f"session-{uuid.uuid4().hex[:8]}" 2607 _create_trace_with_session(session_id) 2608 2609 sessions = mlflow.search_sessions(include_spans=False) 2610 2611 assert len(sessions) == 1 2612 assert len(sessions[0]) == 1 2613 # When include_spans=False, spans should be empty 2614 assert len(sessions[0][0].data.spans) == 0 2615 2616 2617 @pytest.mark.parametrize("invalid_ratio", [-0.1, 1.1, -1, 2, 100]) 2618 def test_trace_decorator_sampling_ratio_validation(invalid_ratio: float): 2619 with pytest.raises( 2620 MlflowException, match=r"sampling_ratio_override must be between 0\.0 and 1\.0" 2621 ): 2622 mlflow.trace(sampling_ratio_override=invalid_ratio) 2623 2624 2625 @pytest.mark.parametrize( 2626 ("sampling_ratio", "num_calls", "expected_min", "expected_max"), 2627 [ 2628 (0.0, 10, 0, 0), 2629 (0.5, 100, 30, 70), 2630 (1.0, 10, 10, 10), 2631 ], 2632 ) 2633 def test_trace_decorator_sampling_ratio( 2634 sampling_ratio: float, num_calls: int, expected_min: int, expected_max: int 2635 ): 2636 trace_ids: list[str] = [] 2637 2638 @mlflow.trace(sampling_ratio_override=sampling_ratio) 2639 def traced_func(): 2640 if trace_id := mlflow.get_active_trace_id(): 2641 trace_ids.append(trace_id) 2642 return "result" 2643 2644 for _ in range(num_calls): 2645 assert traced_func() == "result" 2646 2647 assert expected_min <= len(trace_ids) <= expected_max 2648 2649 2650 @pytest.mark.parametrize( 2651 ("outer_ratio", "inner_ratio", "expected_outer", "expected_inner"), 2652 [ 2653 (1.0, 0.0, 5, 5), # Parent sampled -> child also sampled (inner ratio ignored) 2654 (0.0, 1.0, 0, 0), # Parent not sampled -> child also dropped (follows parent) 2655 ], 2656 ) 2657 def test_trace_decorator_sampling_ratio_nested( 2658 outer_ratio: float, inner_ratio: float, expected_outer: int, expected_inner: int 2659 ): 2660 outer_trace_ids: list[str] = [] 2661 inner_trace_ids: list[str] = [] 2662 2663 @mlflow.trace(sampling_ratio_override=outer_ratio) 2664 def outer(): 2665 if trace_id := mlflow.get_active_trace_id(): 2666 outer_trace_ids.append(trace_id) 2667 return inner() 2668 2669 @mlflow.trace(sampling_ratio_override=inner_ratio) 2670 def inner(): 2671 if trace_id := mlflow.get_active_trace_id(): 2672 inner_trace_ids.append(trace_id) 2673 return "inner result" 2674 2675 for _ in range(5): 2676 assert outer() == "inner result" 2677 2678 assert len(outer_trace_ids) == expected_outer 2679 assert len(inner_trace_ids) == expected_inner 2680 2681 2682 def test_global_sampling_ratio_nested(monkeypatch): 2683 monkeypatch.setenv(MLFLOW_TRACE_SAMPLING_RATIO.name, "0.0") 2684 mlflow.tracing.reset() 2685 2686 inner_trace_ids: list[str] = [] 2687 2688 @mlflow.trace 2689 def outer(): 2690 return inner() 2691 2692 # Inner uses sampling_ratio_override=1.0 so it would create a sampled 2693 # root trace if the dropped parent context were not propagated. 2694 @mlflow.trace(sampling_ratio_override=1.0) 2695 def inner(): 2696 if trace_id := mlflow.get_active_trace_id(): 2697 inner_trace_ids.append(trace_id) 2698 return "result" 2699 2700 for _ in range(5): 2701 assert outer() == "result" 2702 2703 assert len(inner_trace_ids) == 0 2704 2705 2706 def test_start_span_no_context_preserves_dropped_parent_context(monkeypatch): 2707 monkeypatch.setenv(MLFLOW_TRACE_SAMPLING_RATIO.name, "0.0") 2708 mlflow.tracing.reset() 2709 2710 trace_ids: list[str] = [] 2711 2712 @mlflow.trace(sampling_ratio_override=1.0) 2713 def child(): 2714 if trace_id := mlflow.get_active_trace_id(): 2715 trace_ids.append(trace_id) 2716 return "result" 2717 2718 root = start_span_no_context("root") 2719 nested_noop = start_span_no_context("nested_noop", parent_span=root) 2720 2721 with safe_set_span_in_context(nested_noop): 2722 assert child() == "result" 2723 2724 assert len(trace_ids) == 0 2725 2726 2727 @pytest.mark.parametrize( 2728 ("sampling_ratio", "expected_count"), 2729 [ 2730 (0.0, 0), 2731 (1.0, 2), 2732 ], 2733 ) 2734 def test_trace_decorator_sampling_ratio_generator(sampling_ratio: float, expected_count: int): 2735 trace_ids: list[str] = [] 2736 2737 @mlflow.trace(sampling_ratio_override=sampling_ratio) 2738 def gen(): 2739 if trace_id := mlflow.get_active_trace_id(): 2740 trace_ids.append(trace_id) 2741 for i in range(3): 2742 yield i 2743 2744 assert list(gen()) == [0, 1, 2] 2745 assert list(gen()) == [0, 1, 2] 2746 assert len(trace_ids) == expected_count 2747 2748 2749 @pytest.mark.parametrize( 2750 ("sampling_ratio", "expected_child_count"), 2751 [ 2752 (0.0, 0), 2753 (1.0, 6), 2754 ], 2755 ) 2756 def test_trace_decorator_sampling_ratio_generator_with_child_spans( 2757 sampling_ratio: float, expected_child_count: int 2758 ): 2759 child_trace_ids: list[str] = [] 2760 2761 @mlflow.trace 2762 def child_func(value): 2763 if trace_id := mlflow.get_active_trace_id(): 2764 child_trace_ids.append(trace_id) 2765 return value * 2 2766 2767 @mlflow.trace(sampling_ratio_override=sampling_ratio) 2768 def gen(): 2769 for i in range(3): 2770 yield child_func(i) 2771 2772 assert list(gen()) == [0, 2, 4] 2773 assert list(gen()) == [0, 2, 4] 2774 assert len(child_trace_ids) == expected_child_count 2775 2776 2777 @pytest.mark.asyncio 2778 @pytest.mark.parametrize( 2779 ("sampling_ratio", "num_calls", "expected_min", "expected_max"), 2780 [ 2781 (0.0, 10, 0, 0), 2782 (0.5, 100, 30, 70), 2783 (1.0, 10, 10, 10), 2784 ], 2785 ) 2786 async def test_trace_decorator_sampling_ratio_async( 2787 sampling_ratio: float, num_calls: int, expected_min: int, expected_max: int 2788 ): 2789 trace_ids: list[str] = [] 2790 2791 @mlflow.trace(sampling_ratio_override=sampling_ratio) 2792 async def traced_func(): 2793 if trace_id := mlflow.get_active_trace_id(): 2794 trace_ids.append(trace_id) 2795 return "result" 2796 2797 for _ in range(num_calls): 2798 assert await traced_func() == "result" 2799 2800 assert expected_min <= len(trace_ids) <= expected_max 2801 2802 2803 @pytest.mark.asyncio 2804 @pytest.mark.parametrize( 2805 ("sampling_ratio", "expected_count"), 2806 [ 2807 (0.0, 0), 2808 (1.0, 2), 2809 ], 2810 ) 2811 async def test_trace_decorator_sampling_ratio_async_generator( 2812 sampling_ratio: float, expected_count: int 2813 ): 2814 trace_ids: list[str] = [] 2815 2816 @mlflow.trace(sampling_ratio_override=sampling_ratio) 2817 async def gen(): 2818 if trace_id := mlflow.get_active_trace_id(): 2819 trace_ids.append(trace_id) 2820 for i in range(3): 2821 yield i 2822 2823 assert [item async for item in gen()] == [0, 1, 2] 2824 assert [item async for item in gen()] == [0, 1, 2] 2825 assert len(trace_ids) == expected_count 2826 2827 2828 @pytest.mark.asyncio 2829 @pytest.mark.parametrize( 2830 ("sampling_ratio", "expected_child_count"), 2831 [ 2832 (0.0, 0), 2833 (1.0, 6), 2834 ], 2835 ) 2836 async def test_trace_decorator_sampling_ratio_async_generator_with_child_spans( 2837 sampling_ratio: float, expected_child_count: int 2838 ): 2839 child_trace_ids: list[str] = [] 2840 2841 @mlflow.trace 2842 async def child_func(value): 2843 if trace_id := mlflow.get_active_trace_id(): 2844 child_trace_ids.append(trace_id) 2845 return value * 2 2846 2847 @mlflow.trace(sampling_ratio_override=sampling_ratio) 2848 async def gen(): 2849 for i in range(3): 2850 yield await child_func(i) 2851 2852 assert [i async for i in gen()] == [0, 2, 4] 2853 assert [i async for i in gen()] == [0, 2, 4] 2854 assert len(child_trace_ids) == expected_child_count 2855 2856 2857 @skip_when_testing_trace_sdk 2858 def test_trace_decorator_sampling_ratio_overrides_global(): 2859 code = """ 2860 import mlflow 2861 2862 trace_ids: list[str] = [] 2863 2864 2865 @mlflow.trace # Should respect global 0.0 2866 def not_traced(): 2867 if trace_id := mlflow.get_active_trace_id(): 2868 trace_ids.append(trace_id) 2869 return "not traced" 2870 2871 2872 for _ in range(5): 2873 assert not_traced() == "not traced" 2874 2875 assert len(trace_ids) == 0 2876 2877 2878 @mlflow.trace(sampling_ratio_override=1.0) # Should override global 0.0 2879 def traced(): 2880 if trace_id := mlflow.get_active_trace_id(): 2881 trace_ids.append(trace_id) 2882 return "traced" 2883 2884 2885 for _ in range(5): 2886 assert traced() == "traced" 2887 2888 assert len(trace_ids) == 5 2889 """ 2890 subprocess.check_call( 2891 [sys.executable, "-c", code], 2892 env={ 2893 **os.environ, 2894 "MLFLOW_TRACE_SAMPLING_RATIO": "0.0", 2895 }, 2896 ) 2897 2898 2899 @mlflow.trace 2900 def my_func(): 2901 return "hello" 2902 2903 2904 def test_tracing_context_injects_metadata_and_tags(): 2905 with mlflow.tracing.context( 2906 metadata={"custom_key": "custom_value"}, 2907 tags={"my_tag": "tag_value"}, 2908 ): 2909 my_func() 2910 2911 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 2912 assert trace.info.request_metadata["custom_key"] == "custom_value" 2913 assert trace.info.tags["my_tag"] == "tag_value" 2914 2915 # Trace created outside the block should NOT have the metadata 2916 my_func() 2917 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 2918 assert "session" not in trace.info.request_metadata 2919 2920 2921 def test_tracing_context_session_id_and_user(): 2922 with mlflow.tracing.context(session_id="sess-123", user="user-456"): 2923 my_func() 2924 2925 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 2926 assert trace.info.request_metadata["mlflow.trace.session"] == "sess-123" 2927 assert trace.info.request_metadata["mlflow.trace.user"] == "user-456" 2928 2929 # session_id and user can coexist with explicit metadata 2930 with mlflow.tracing.context( 2931 session_id="sess-abc", 2932 user="user-xyz", 2933 metadata={"custom_key": "custom_value"}, 2934 ): 2935 my_func() 2936 2937 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 2938 assert trace.info.request_metadata["mlflow.trace.session"] == "sess-abc" 2939 assert trace.info.request_metadata["mlflow.trace.user"] == "user-xyz" 2940 assert trace.info.request_metadata["custom_key"] == "custom_value" 2941 2942 2943 def test_tracing_context_session_id_and_user_nesting(): 2944 with mlflow.tracing.context(session_id="outer-sess", user="outer-user"): 2945 with mlflow.tracing.context(session_id="inner-sess"): 2946 my_func() 2947 2948 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 2949 # Inner session_id overrides outer 2950 assert trace.info.request_metadata["mlflow.trace.session"] == "inner-sess" 2951 # Outer user is inherited 2952 assert trace.info.request_metadata["mlflow.trace.user"] == "outer-user" 2953 2954 2955 def test_tracing_context_nesting_merges(): 2956 with mlflow.tracing.context( 2957 metadata={"outer_key": "outer_val", "shared": "outer"}, 2958 tags={"outer_tag": "outer"}, 2959 ): 2960 with mlflow.tracing.context( 2961 metadata={"inner_key": "inner_val", "shared": "inner"}, 2962 tags={"inner_tag": "inner"}, 2963 ): 2964 my_func() 2965 2966 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 2967 # Both outer and inner metadata present 2968 assert trace.info.request_metadata["outer_key"] == "outer_val" 2969 assert trace.info.request_metadata["inner_key"] == "inner_val" 2970 # Inner wins on conflict 2971 assert trace.info.request_metadata["shared"] == "inner" 2972 # Both tags present 2973 assert trace.info.tags["outer_tag"] == "outer" 2974 assert trace.info.tags["inner_tag"] == "inner" 2975 2976 2977 def test_tracing_context_enabled_false_suppresses_traces(): 2978 with mlflow.tracing.context(enabled=False): 2979 my_func() 2980 2981 # Child context should inherit the enabled=False from the parent 2982 with mlflow.tracing.context(metadata={"k": "v"}): 2983 my_func() 2984 2985 # Start trace with start_trace_no_context (used in autologging) 2986 span = mlflow.start_span_no_context("test") 2987 span.end() 2988 2989 assert mlflow.get_last_active_trace_id() is None 2990 2991 # After exiting, tracing should work normally 2992 my_func() 2993 assert mlflow.get_last_active_trace_id() is not None 2994 2995 2996 def test_tracing_context_enabled_is_thread_safe(): 2997 def run_with_context(enabled): 2998 with mlflow.tracing.context(enabled=enabled): 2999 my_func() 3000 return mlflow.get_last_active_trace_id(thread_local=True) 3001 3002 with ThreadPoolExecutor( 3003 max_workers=10, thread_name_prefix="test-fluent-tracing-context" 3004 ) as pool: 3005 futures = { 3006 pool.submit(run_with_context, enabled=(i % 2 == 0)): (i % 2 == 0) for i in range(10) 3007 } 3008 for future in as_completed(futures): 3009 enabled = futures[future] 3010 trace_id = future.result() 3011 assert (trace_id is not None) == enabled 3012 3013 3014 def test_flush_trace_async_logging_calls_flush_when_async_queue_exists(): 3015 mock_exporter = mock.MagicMock() 3016 with mock.patch("mlflow.tracking.fluent._get_trace_exporter", return_value=mock_exporter): 3017 mlflow.flush_trace_async_logging(terminate=False) 3018 mock_exporter._async_queue.flush.assert_called_once_with(terminate=False) 3019 3020 3021 def test_flush_trace_async_logging_skips_when_async_queue_missing(): 3022 # A bare SpanExporter (as used by StrandsSpanProcessor, mlflow/strands/autolog.py:40) 3023 # has no _async_queue attribute. flush_trace_async_logging(terminate=True) should return without 3024 # reaching the error handler. 3025 exporter = SpanExporter() 3026 assert not hasattr(exporter, "_async_queue") 3027 with ( 3028 mock.patch("mlflow.tracking.fluent._get_trace_exporter", return_value=exporter), 3029 mock.patch( 3030 "mlflow.tracking.fluent._logger.error", 3031 side_effect=AssertionError("flush should not reach error handler"), 3032 ), 3033 ): 3034 mlflow.flush_trace_async_logging(terminate=False) 3035 3036 3037 def test_flush_trace_async_logging_no_spurious_error_when_tracing_disabled(): 3038 mlflow.tracing.disable() 3039 with mock.patch("mlflow.tracking.fluent._logger") as mock_logger: 3040 mlflow.flush_trace_async_logging(terminate=True) 3041 mock_logger.error.assert_not_called()