/ tests / tracing / utils / test_timeout.py
test_timeout.py
  1  import time
  2  from concurrent.futures import ThreadPoolExecutor
  3  from unittest import mock
  4  
  5  import pytest
  6  
  7  import mlflow
  8  from mlflow.entities.span_event import SpanEvent
  9  from mlflow.entities.span_status import SpanStatusCode
 10  from mlflow.tracing.export.inference_table import _TRACE_BUFFER, pop_trace
 11  from mlflow.tracing.trace_manager import _Trace
 12  from mlflow.tracing.utils.timeout import MlflowTraceTimeoutCache
 13  
 14  from tests.tracing.helper import get_traces, skip_when_testing_trace_sdk
 15  
 16  
 17  def _mock_span(span_id, parent_id=None):
 18      span = mock.Mock()
 19      span.span_id = span_id
 20      span.parent_id = parent_id
 21      return span
 22  
 23  
 24  @pytest.fixture
 25  def cache():
 26      timeout_cache = MlflowTraceTimeoutCache(timeout=1, maxsize=10)
 27      yield timeout_cache
 28      timeout_cache.clear()
 29  
 30  
 31  def test_expire_traces(cache):
 32      span_1_1 = _mock_span("span_1")
 33      span_1_2 = _mock_span("span_2", parent_id="span_1")
 34      cache["tr_1"] = _Trace(None, span_dict={"span_1": span_1_1, "span_2": span_1_2})
 35      for _ in range(5):
 36          if "tr_1" not in cache:
 37              break
 38          time.sleep(1)
 39      else:
 40          pytest.fail("Trace should be expired within 5 seconds")
 41  
 42      span_1_1.end.assert_called_once()
 43      span_1_1.set_status.assert_called_once_with(SpanStatusCode.ERROR)
 44      span_1_1.add_event.assert_called_once()
 45      event = span_1_1.add_event.call_args[0][0]
 46      assert isinstance(event, SpanEvent)
 47      assert event.name == "exception"
 48      assert event.attributes["exception.message"].startswith("Trace tr_1 is timed out")
 49  
 50      # Non-root span should not be touched
 51      span_1_2.assert_not_called()
 52  
 53  
 54  class _SlowModel:
 55      @mlflow.trace
 56      def predict(self, x):
 57          for _ in range(x):
 58              self.slow_function()
 59          return
 60  
 61      @mlflow.trace
 62      def slow_function(self):
 63          time.sleep(1)
 64  
 65  
 66  @pytest.mark.skip(
 67      reason="batch_get_traces only return full traces for now, re-enable this test "
 68      "when batch_get_traces is updated to support partial traces"
 69  )
 70  def test_trace_halted_after_timeout(monkeypatch):
 71      # When MLFLOW_TRACE_TIMEOUT_SECONDS is set, MLflow should halt the trace after
 72      # the timeout and log it to the backend with an error status
 73      monkeypatch.setenv("MLFLOW_TRACE_TIMEOUT_SECONDS", "3")
 74  
 75      _SlowModel().predict(5)  # takes 5 seconds
 76  
 77      traces = get_traces()
 78      assert len(traces) == 1
 79      trace = traces[0]
 80      assert trace.info.execution_time_ms >= 2900  # Some margin for windows
 81      assert trace.info.status == SpanStatusCode.ERROR
 82      assert len(trace.data.spans) >= 3
 83  
 84      root_span = trace.data.spans[0]
 85      assert root_span.name == "predict"
 86      assert root_span.status.status_code == SpanStatusCode.ERROR
 87      assert root_span.events[0].name == "exception"
 88      assert (
 89          root_span
 90          .events[0]
 91          .attributes["exception.message"]
 92          .startswith(f"Trace {trace.info.request_id} is timed out")
 93      )
 94  
 95      first_span = trace.data.spans[1]
 96      assert first_span.name == "slow_function"
 97      assert first_span.status.status_code == SpanStatusCode.OK
 98  
 99      # The rest of the spans should not be logged to the backend.
100      in_progress_traces = mlflow.search_traces(
101          filter_string="status = 'IN_PROGRESS'",
102          return_type="list",
103      )
104      assert len(in_progress_traces) == 0
105  
106  
107  @skip_when_testing_trace_sdk
108  def test_trace_halted_after_timeout_in_model_serving(
109      monkeypatch, mock_databricks_serving_with_tracing_env
110  ):
111      from mlflow.pyfunc.context import Context, set_prediction_context
112  
113      monkeypatch.setenv("MLFLOW_TRACE_TIMEOUT_SECONDS", "3")
114  
115      # Simulate model serving env where multiple requests are processed concurrently
116      def _run_single(request_id, seconds):
117          with set_prediction_context(Context(request_id=request_id)):
118              _SlowModel().predict(seconds)
119  
120      with ThreadPoolExecutor(max_workers=2, thread_name_prefix="test-tracing-timeout") as executor:
121          executor.map(_run_single, ["request-id-1", "request-id-2", "request-id-3"], [5, 6, 1])
122  
123      # All traces should be logged
124      assert len(_TRACE_BUFFER) == 3
125  
126      # Long operation should be halted
127      assert pop_trace(request_id="request-id-1")["info"]["state"] == SpanStatusCode.ERROR
128      assert pop_trace(request_id="request-id-2")["info"]["state"] == SpanStatusCode.ERROR
129  
130      # Short operation should complete successfully
131      assert pop_trace(request_id="request-id-3")["info"]["state"] == SpanStatusCode.OK
132  
133  
134  @pytest.mark.skip(
135      reason="batch_get_traces only return full traces for now, re-enable this test "
136      "when batch_get_traces is updated to support partial traces"
137  )
138  def test_handle_timeout_update(monkeypatch):
139      # Create a first trace. At this moment, there is no timeout set
140      _SlowModel().predict(3)
141  
142      traces = get_traces()
143      assert len(traces) == 1
144      assert traces[0].info.status == SpanStatusCode.OK
145  
146      # Update timeout env var after cache creation
147      monkeypatch.setenv("MLFLOW_TRACE_TIMEOUT_SECONDS", "1")
148  
149      # Create a second trace. This should use the new timeout
150      _SlowModel().predict(3)
151  
152      traces = get_traces()
153      assert len(traces) == 2
154      assert traces[0].info.status == SpanStatusCode.ERROR
155  
156      # Update timeout to a larger value. Trace should complete successfully
157      monkeypatch.setenv("MLFLOW_TRACE_TIMEOUT_SECONDS", "100")
158      _SlowModel().predict(3)
159  
160      traces = get_traces()
161      assert len(traces) == 3
162      assert traces[0].info.status == SpanStatusCode.OK