/ tests / tracing / conftest.py
conftest.py
  1  import random
  2  import subprocess
  3  import tempfile
  4  import time
  5  from unittest import mock
  6  
  7  import pytest
  8  
  9  import mlflow
 10  from mlflow.environment_variables import (
 11      MLFLOW_ENABLE_ASYNC_LOGGING,
 12      MLFLOW_ENABLE_ASYNC_TRACE_LOGGING,
 13  )
 14  from mlflow.tracing.fluent import _flush_pending_async_trace_writes
 15  
 16  
 17  @pytest.fixture(autouse=True)
 18  def enable_async_trace_logging(monkeypatch):
 19      """Enable async trace logging for all tests in tests/tracing/ to exercise the async path.
 20  
 21      Overrides the global disable_async_trace_logging fixture from tests/conftest.py.
 22      Tests that need both sync and async coverage use the async_logging_enabled fixture.
 23      Terminates async queues on teardown to prevent thread leaks between tests.
 24      """
 25      monkeypatch.setenv(MLFLOW_ENABLE_ASYNC_TRACE_LOGGING.name, "true")
 26      monkeypatch.setenv(MLFLOW_ENABLE_ASYNC_LOGGING.name, "true")
 27  
 28      yield
 29  
 30      _flush_pending_async_trace_writes(terminate=True)
 31  
 32  
 33  @pytest.fixture(autouse=True)
 34  def reset_active_experiment():
 35      yield
 36      mlflow.tracking.fluent._active_experiment_id = None
 37  
 38  
 39  @pytest.fixture(autouse=True)
 40  def reset_tracking_uri():
 41      # Some API like set_destination("databricks") updates the tracking URI,
 42      # we should reset it between tests
 43      original_tracking_uri = mlflow.get_tracking_uri()
 44  
 45      yield
 46  
 47      mlflow.set_tracking_uri(original_tracking_uri)
 48  
 49  
 50  @pytest.fixture
 51  def databricks_tracking_uri():
 52      with mock.patch("mlflow.get_tracking_uri", return_value="databricks"):
 53          yield
 54  
 55  
 56  # Fixture to run the test case with and without async logging enabled.
 57  # When async logging is enabled, the batch span processor is also active (the default),
 58  # so tests exercise the full production pipeline.
 59  @pytest.fixture(params=[True, False])
 60  def async_logging_enabled(request, monkeypatch):
 61      monkeypatch.setenv(MLFLOW_ENABLE_ASYNC_TRACE_LOGGING.name, str(request.param))
 62      # TODO: V2 Trace depends on this env var rather than MLFLOW_ENABLE_ASYNC_TRACE_LOGGING
 63      # We should remove this once the backend is fully migrated to V3
 64      monkeypatch.setenv(MLFLOW_ENABLE_ASYNC_LOGGING.name, str(request.param))
 65      return request.param
 66  
 67  
 68  @pytest.fixture
 69  def otel_collector():
 70      """Start an OpenTelemetry collector in a Docker container."""
 71      subprocess.check_call(["docker", "pull", "otel/opentelemetry-collector"])
 72  
 73      # Use a random port to avoid conflicts
 74      port = random.randint(20000, 30000)
 75  
 76      docker_collector_config = """receivers:
 77    otlp:
 78      protocols:
 79        grpc:
 80          endpoint: 0.0.0.0:4317
 81  
 82  exporters:
 83    debug:
 84      verbosity: detailed
 85      sampling_initial: 5
 86      sampling_thereafter: 1
 87  
 88  service:
 89    pipelines:
 90      traces:
 91        receivers: [otlp]
 92        exporters: [debug]"""
 93  
 94      with tempfile.NamedTemporaryFile() as output_file:
 95          # Use echo to pipe config to Docker stdin
 96          docker_cmd = [
 97              "bash",
 98              "-c",
 99              f'echo "{docker_collector_config}" | '
100              f"docker run --rm -p 127.0.0.1:{port}:4317 -i "
101              f"otel/opentelemetry-collector --config=/dev/stdin",
102          ]
103  
104          process = subprocess.Popen(
105              docker_cmd,
106              stdout=output_file,
107              stderr=subprocess.STDOUT,
108              text=True,
109          )
110  
111          # Wait for the collector to start
112          time.sleep(5)
113  
114          yield process, output_file.name, port
115  
116          # Stop the collector
117          process.terminate()
118          try:
119              process.wait(timeout=5)
120          except subprocess.TimeoutExpired:
121              process.kill()
122              process.wait()