test_timeout.py
1 import time 2 from concurrent.futures import ThreadPoolExecutor 3 from unittest import mock 4 5 import pytest 6 7 import mlflow 8 from mlflow.entities.span_event import SpanEvent 9 from mlflow.entities.span_status import SpanStatusCode 10 from mlflow.tracing.export.inference_table import _TRACE_BUFFER, pop_trace 11 from mlflow.tracing.trace_manager import _Trace 12 from mlflow.tracing.utils.timeout import MlflowTraceTimeoutCache 13 14 from tests.tracing.helper import get_traces, skip_when_testing_trace_sdk 15 16 17 def _mock_span(span_id, parent_id=None): 18 span = mock.Mock() 19 span.span_id = span_id 20 span.parent_id = parent_id 21 return span 22 23 24 @pytest.fixture 25 def cache(): 26 timeout_cache = MlflowTraceTimeoutCache(timeout=1, maxsize=10) 27 yield timeout_cache 28 timeout_cache.clear() 29 30 31 def test_expire_traces(cache): 32 span_1_1 = _mock_span("span_1") 33 span_1_2 = _mock_span("span_2", parent_id="span_1") 34 cache["tr_1"] = _Trace(None, span_dict={"span_1": span_1_1, "span_2": span_1_2}) 35 for _ in range(5): 36 if "tr_1" not in cache: 37 break 38 time.sleep(1) 39 else: 40 pytest.fail("Trace should be expired within 5 seconds") 41 42 span_1_1.end.assert_called_once() 43 span_1_1.set_status.assert_called_once_with(SpanStatusCode.ERROR) 44 span_1_1.add_event.assert_called_once() 45 event = span_1_1.add_event.call_args[0][0] 46 assert isinstance(event, SpanEvent) 47 assert event.name == "exception" 48 assert event.attributes["exception.message"].startswith("Trace tr_1 is timed out") 49 50 # Non-root span should not be touched 51 span_1_2.assert_not_called() 52 53 54 class _SlowModel: 55 @mlflow.trace 56 def predict(self, x): 57 for _ in range(x): 58 self.slow_function() 59 return 60 61 @mlflow.trace 62 def slow_function(self): 63 time.sleep(1) 64 65 66 @pytest.mark.skip( 67 reason="batch_get_traces only return full traces for now, re-enable this test " 68 "when batch_get_traces is updated to support partial traces" 69 ) 70 def test_trace_halted_after_timeout(monkeypatch): 71 # When MLFLOW_TRACE_TIMEOUT_SECONDS is set, MLflow should halt the trace after 72 # the timeout and log it to the backend with an error status 73 monkeypatch.setenv("MLFLOW_TRACE_TIMEOUT_SECONDS", "3") 74 75 _SlowModel().predict(5) # takes 5 seconds 76 77 traces = get_traces() 78 assert len(traces) == 1 79 trace = traces[0] 80 assert trace.info.execution_time_ms >= 2900 # Some margin for windows 81 assert trace.info.status == SpanStatusCode.ERROR 82 assert len(trace.data.spans) >= 3 83 84 root_span = trace.data.spans[0] 85 assert root_span.name == "predict" 86 assert root_span.status.status_code == SpanStatusCode.ERROR 87 assert root_span.events[0].name == "exception" 88 assert ( 89 root_span 90 .events[0] 91 .attributes["exception.message"] 92 .startswith(f"Trace {trace.info.request_id} is timed out") 93 ) 94 95 first_span = trace.data.spans[1] 96 assert first_span.name == "slow_function" 97 assert first_span.status.status_code == SpanStatusCode.OK 98 99 # The rest of the spans should not be logged to the backend. 100 in_progress_traces = mlflow.search_traces( 101 filter_string="status = 'IN_PROGRESS'", 102 return_type="list", 103 ) 104 assert len(in_progress_traces) == 0 105 106 107 @skip_when_testing_trace_sdk 108 def test_trace_halted_after_timeout_in_model_serving( 109 monkeypatch, mock_databricks_serving_with_tracing_env 110 ): 111 from mlflow.pyfunc.context import Context, set_prediction_context 112 113 monkeypatch.setenv("MLFLOW_TRACE_TIMEOUT_SECONDS", "3") 114 115 # Simulate model serving env where multiple requests are processed concurrently 116 def _run_single(request_id, seconds): 117 with set_prediction_context(Context(request_id=request_id)): 118 _SlowModel().predict(seconds) 119 120 with ThreadPoolExecutor(max_workers=2, thread_name_prefix="test-tracing-timeout") as executor: 121 executor.map(_run_single, ["request-id-1", "request-id-2", "request-id-3"], [5, 6, 1]) 122 123 # All traces should be logged 124 assert len(_TRACE_BUFFER) == 3 125 126 # Long operation should be halted 127 assert pop_trace(request_id="request-id-1")["info"]["state"] == SpanStatusCode.ERROR 128 assert pop_trace(request_id="request-id-2")["info"]["state"] == SpanStatusCode.ERROR 129 130 # Short operation should complete successfully 131 assert pop_trace(request_id="request-id-3")["info"]["state"] == SpanStatusCode.OK 132 133 134 @pytest.mark.skip( 135 reason="batch_get_traces only return full traces for now, re-enable this test " 136 "when batch_get_traces is updated to support partial traces" 137 ) 138 def test_handle_timeout_update(monkeypatch): 139 # Create a first trace. At this moment, there is no timeout set 140 _SlowModel().predict(3) 141 142 traces = get_traces() 143 assert len(traces) == 1 144 assert traces[0].info.status == SpanStatusCode.OK 145 146 # Update timeout env var after cache creation 147 monkeypatch.setenv("MLFLOW_TRACE_TIMEOUT_SECONDS", "1") 148 149 # Create a second trace. This should use the new timeout 150 _SlowModel().predict(3) 151 152 traces = get_traces() 153 assert len(traces) == 2 154 assert traces[0].info.status == SpanStatusCode.ERROR 155 156 # Update timeout to a larger value. Trace should complete successfully 157 monkeypatch.setenv("MLFLOW_TRACE_TIMEOUT_SECONDS", "100") 158 _SlowModel().predict(3) 159 160 traces = get_traces() 161 assert len(traces) == 3 162 assert traces[0].info.status == SpanStatusCode.OK