test_distributed.py
1 import os 2 import re 3 import subprocess 4 import sys 5 import time 6 from contextlib import contextmanager 7 from pathlib import Path 8 from typing import Iterator 9 10 import requests 11 12 import mlflow 13 from mlflow.tracing.distributed import ( 14 get_tracing_context_headers_for_http_request, 15 set_tracing_context_from_http_request_headers, 16 ) 17 18 from tests.helper_functions import get_safe_port 19 from tests.tracing.helper import skip_when_testing_trace_sdk 20 21 REQUEST_TIMEOUT = 30 22 _TRACE_POLL_TIMEOUT = 10 # seconds to wait for subprocess spans to be exported 23 24 25 def _wait_for_trace(trace_id: str, expected_span_count: int) -> "mlflow.entities.Trace": 26 """Poll until the trace contains the expected number of spans or timeout expires.""" 27 deadline = time.time() + _TRACE_POLL_TIMEOUT 28 trace = None 29 while time.time() < deadline: 30 mlflow.flush_trace_async_logging() 31 trace = mlflow.get_trace(trace_id) 32 if trace is not None and len(trace.data.spans) >= expected_span_count: 33 return trace 34 time.sleep(0.1) 35 spans_found = len(trace.data.spans) if trace else 0 36 raise AssertionError( 37 f"Expected {expected_span_count} spans in trace {trace_id} within " 38 f"{_TRACE_POLL_TIMEOUT}s, got {spans_found}" 39 ) 40 41 42 @contextmanager 43 def flask_server( 44 server_script_path: Path, 45 port: int, 46 *, 47 wait_timeout: int = 30, 48 health_endpoint: str = "/health", 49 ) -> Iterator[str]: 50 """Context manager to run a Flask server in a subprocess.""" 51 server_env = {**os.environ, "MLFLOW_ENABLE_ASYNC_TRACE_LOGGING": "false"} 52 with subprocess.Popen( 53 [sys.executable, str(server_script_path), str(port)], env=server_env 54 ) as proc: 55 base_url = f"http://127.0.0.1:{port}" 56 57 try: 58 # Wait for server to be ready 59 for _ in range(wait_timeout): 60 try: 61 response = requests.get(f"{base_url}{health_endpoint}", timeout=1.0) 62 if response.ok: 63 break 64 except requests.exceptions.RequestException: 65 time.sleep(0.2) 66 else: 67 raise RuntimeError(f"Flask server failed to start within {wait_timeout} seconds") 68 69 yield base_url 70 finally: 71 proc.terminate() 72 73 74 def _parse_traceparent(header_value: str) -> tuple[int, int]: 75 """ 76 Parse W3C traceparent header into (trace_id_int, span_id_int). 77 Format: version-traceid-spanid-flags (all lowercase hex, no 0x prefix). 78 """ 79 parts = header_value.split("-") 80 assert len(parts) == 4, f"Invalid traceparent format: {header_value}" 81 version, trace_id_hex, span_id_hex, flags = parts 82 assert re.fullmatch(r"[0-9a-f]{2}", version), f"Invalid version: {version}" 83 assert re.fullmatch(r"[0-9a-f]{32}", trace_id_hex), f"Invalid trace id: {trace_id_hex}" 84 assert re.fullmatch(r"[0-9a-f]{16}", span_id_hex), f"Invalid span id: {span_id_hex}" 85 assert re.fullmatch(r"[0-9a-f]{2}", flags), f"Invalid flags: {flags}" 86 return int(trace_id_hex, 16), int(span_id_hex, 16) 87 88 89 def test_get_tracing_context_headers_for_http_request_in_active_span(): 90 with mlflow.start_span("client-span"): 91 current_span = mlflow.get_current_active_span()._span 92 assert current_span.get_span_context().is_valid 93 client_trace_id = current_span.get_span_context().trace_id 94 client_span_id = current_span.get_span_context().span_id 95 96 headers = get_tracing_context_headers_for_http_request() 97 assert isinstance(headers, dict) 98 assert "traceparent" in headers 99 100 # Validate that the header encodes the same trace and span IDs 101 header_trace_id, header_span_id = _parse_traceparent(headers["traceparent"]) 102 assert header_trace_id == client_trace_id 103 assert header_span_id == client_span_id 104 105 106 def test_get_tracing_context_headers_for_http_request_without_active_span(): 107 headers = get_tracing_context_headers_for_http_request() 108 assert headers == {} 109 110 111 def test_set_tracing_context_from_http_request_headers(): 112 # Create headers from a client context first 113 with mlflow.start_span("client-to-generate-headers") as client_span: 114 client_headers = get_tracing_context_headers_for_http_request() 115 client_trace_id = client_span.trace_id 116 client_span_id = client_span.span_id 117 118 assert mlflow.get_current_active_span() is None 119 120 # Attach the context from headers and verify it becomes current inside the block 121 with set_tracing_context_from_http_request_headers(client_headers): 122 # get_current_active_span returns None because it is a `NonRecordingSpan` 123 assert mlflow.get_current_active_span() is None 124 125 with mlflow.start_span("child-span") as child_span: 126 assert child_span.parent_id == client_span_id 127 assert child_span.trace_id == client_trace_id 128 129 130 @skip_when_testing_trace_sdk 131 def test_distributed_tracing_e2e(tmp_path): 132 # Path to the Flask server script 133 server_path = Path(__file__).parent / "fixtures" / "flask_tracing_server.py" 134 port = get_safe_port() 135 136 # Start Flask server using the context manager 137 with flask_server(server_path, port) as base_url: 138 # Client side: create a span and send headers to server 139 with mlflow.start_span("client-root") as client_span: 140 headers = get_tracing_context_headers_for_http_request() 141 resp = requests.post(f"{base_url}/handle", headers=headers, timeout=REQUEST_TIMEOUT) 142 assert resp.ok, f"Server returned {resp.status_code}: {resp.text}" 143 payload = resp.json() 144 145 # Validate server span is a child in the same trace 146 assert payload["trace_id"] == client_span.trace_id 147 assert payload["parent_id"] == client_span.span_id 148 149 # Poll until the server subprocess's BatchSpanProcessor has exported its span. 150 trace = _wait_for_trace(client_span.trace_id, expected_span_count=2) 151 spans = trace.data.spans 152 assert len(spans) == 2 153 154 # Identify root and child 155 root_span = next(s for s in spans if s.parent_id is None) 156 child_span = next(s for s in spans if s.parent_id is not None) 157 158 assert root_span.name == "client-root" 159 assert child_span.name == "server-handler" 160 assert child_span.parent_id == root_span.span_id 161 162 163 @skip_when_testing_trace_sdk 164 def test_distributed_tracing_e2e_nested_call(tmp_path): 165 # Path to the Flask server script 166 server_path = Path(__file__).parent / "fixtures" / "flask_tracing_server.py" 167 168 # Start first Flask server, then get port for second server to avoid port conflicts 169 port = get_safe_port() 170 with flask_server(server_path, port) as base_url: 171 port2 = get_safe_port() 172 with flask_server(server_path, port2) as base_url2: 173 # Client side: create a span and send headers to server 174 with mlflow.start_span("client-root") as client_span: 175 headers = get_tracing_context_headers_for_http_request() 176 # Pass the second server URL as a query parameter 177 resp = requests.post( 178 f"{base_url}/handle1", 179 headers=headers, 180 params={"second_server_url": base_url2}, 181 timeout=REQUEST_TIMEOUT, 182 ) 183 assert resp.ok, f"Server returned {resp.status_code}: {resp.text}" 184 payload = resp.json() 185 186 # Validate server span is a child in the same trace 187 assert payload["trace_id"] == client_span.trace_id 188 assert payload["parent_id"] == client_span.span_id 189 child_span1_id = payload["span_id"] 190 assert payload["nested_call_resp"]["trace_id"] == client_span.trace_id 191 assert payload["nested_call_resp"]["parent_id"] == child_span1_id 192 child_span2_id = payload["nested_call_resp"]["span_id"] 193 194 trace = _wait_for_trace(client_span.trace_id, expected_span_count=3) 195 spans = trace.data.spans 196 assert len(spans) == 3 197 198 # Identify root and child 199 root_span = next(s for s in spans if s.parent_id is None) 200 child_span1 = next(s for s in spans if s.parent_id == root_span.span_id) 201 child_span2 = next(s for s in spans if s.parent_id == child_span1.span_id) 202 203 assert root_span.name == "client-root" 204 assert child_span1.name == "server-handler1" 205 assert child_span2.name == "server-handler2" 206 assert child_span1.span_id == child_span1_id 207 assert child_span2.span_id == child_span2_id