/ tests / tracing / test_distributed.py
test_distributed.py
  1  import os
  2  import re
  3  import subprocess
  4  import sys
  5  import time
  6  from contextlib import contextmanager
  7  from pathlib import Path
  8  from typing import Iterator
  9  
 10  import requests
 11  
 12  import mlflow
 13  from mlflow.tracing.distributed import (
 14      get_tracing_context_headers_for_http_request,
 15      set_tracing_context_from_http_request_headers,
 16  )
 17  
 18  from tests.helper_functions import get_safe_port
 19  from tests.tracing.helper import skip_when_testing_trace_sdk
 20  
 21  REQUEST_TIMEOUT = 30
 22  _TRACE_POLL_TIMEOUT = 10  # seconds to wait for subprocess spans to be exported
 23  
 24  
 25  def _wait_for_trace(trace_id: str, expected_span_count: int) -> "mlflow.entities.Trace":
 26      """Poll until the trace contains the expected number of spans or timeout expires."""
 27      deadline = time.time() + _TRACE_POLL_TIMEOUT
 28      trace = None
 29      while time.time() < deadline:
 30          mlflow.flush_trace_async_logging()
 31          trace = mlflow.get_trace(trace_id)
 32          if trace is not None and len(trace.data.spans) >= expected_span_count:
 33              return trace
 34          time.sleep(0.1)
 35      spans_found = len(trace.data.spans) if trace else 0
 36      raise AssertionError(
 37          f"Expected {expected_span_count} spans in trace {trace_id} within "
 38          f"{_TRACE_POLL_TIMEOUT}s, got {spans_found}"
 39      )
 40  
 41  
 42  @contextmanager
 43  def flask_server(
 44      server_script_path: Path,
 45      port: int,
 46      *,
 47      wait_timeout: int = 30,
 48      health_endpoint: str = "/health",
 49  ) -> Iterator[str]:
 50      """Context manager to run a Flask server in a subprocess."""
 51      server_env = {**os.environ, "MLFLOW_ENABLE_ASYNC_TRACE_LOGGING": "false"}
 52      with subprocess.Popen(
 53          [sys.executable, str(server_script_path), str(port)], env=server_env
 54      ) as proc:
 55          base_url = f"http://127.0.0.1:{port}"
 56  
 57          try:
 58              # Wait for server to be ready
 59              for _ in range(wait_timeout):
 60                  try:
 61                      response = requests.get(f"{base_url}{health_endpoint}", timeout=1.0)
 62                      if response.ok:
 63                          break
 64                  except requests.exceptions.RequestException:
 65                      time.sleep(0.2)
 66              else:
 67                  raise RuntimeError(f"Flask server failed to start within {wait_timeout} seconds")
 68  
 69              yield base_url
 70          finally:
 71              proc.terminate()
 72  
 73  
 74  def _parse_traceparent(header_value: str) -> tuple[int, int]:
 75      """
 76      Parse W3C traceparent header into (trace_id_int, span_id_int).
 77      Format: version-traceid-spanid-flags (all lowercase hex, no 0x prefix).
 78      """
 79      parts = header_value.split("-")
 80      assert len(parts) == 4, f"Invalid traceparent format: {header_value}"
 81      version, trace_id_hex, span_id_hex, flags = parts
 82      assert re.fullmatch(r"[0-9a-f]{2}", version), f"Invalid version: {version}"
 83      assert re.fullmatch(r"[0-9a-f]{32}", trace_id_hex), f"Invalid trace id: {trace_id_hex}"
 84      assert re.fullmatch(r"[0-9a-f]{16}", span_id_hex), f"Invalid span id: {span_id_hex}"
 85      assert re.fullmatch(r"[0-9a-f]{2}", flags), f"Invalid flags: {flags}"
 86      return int(trace_id_hex, 16), int(span_id_hex, 16)
 87  
 88  
 89  def test_get_tracing_context_headers_for_http_request_in_active_span():
 90      with mlflow.start_span("client-span"):
 91          current_span = mlflow.get_current_active_span()._span
 92          assert current_span.get_span_context().is_valid
 93          client_trace_id = current_span.get_span_context().trace_id
 94          client_span_id = current_span.get_span_context().span_id
 95  
 96          headers = get_tracing_context_headers_for_http_request()
 97          assert isinstance(headers, dict)
 98          assert "traceparent" in headers
 99  
100          # Validate that the header encodes the same trace and span IDs
101          header_trace_id, header_span_id = _parse_traceparent(headers["traceparent"])
102          assert header_trace_id == client_trace_id
103          assert header_span_id == client_span_id
104  
105  
106  def test_get_tracing_context_headers_for_http_request_without_active_span():
107      headers = get_tracing_context_headers_for_http_request()
108      assert headers == {}
109  
110  
111  def test_set_tracing_context_from_http_request_headers():
112      # Create headers from a client context first
113      with mlflow.start_span("client-to-generate-headers") as client_span:
114          client_headers = get_tracing_context_headers_for_http_request()
115          client_trace_id = client_span.trace_id
116          client_span_id = client_span.span_id
117  
118      assert mlflow.get_current_active_span() is None
119  
120      # Attach the context from headers and verify it becomes current inside the block
121      with set_tracing_context_from_http_request_headers(client_headers):
122          # get_current_active_span returns None because it is a `NonRecordingSpan`
123          assert mlflow.get_current_active_span() is None
124  
125          with mlflow.start_span("child-span") as child_span:
126              assert child_span.parent_id == client_span_id
127              assert child_span.trace_id == client_trace_id
128  
129  
130  @skip_when_testing_trace_sdk
131  def test_distributed_tracing_e2e(tmp_path):
132      # Path to the Flask server script
133      server_path = Path(__file__).parent / "fixtures" / "flask_tracing_server.py"
134      port = get_safe_port()
135  
136      # Start Flask server using the context manager
137      with flask_server(server_path, port) as base_url:
138          # Client side: create a span and send headers to server
139          with mlflow.start_span("client-root") as client_span:
140              headers = get_tracing_context_headers_for_http_request()
141              resp = requests.post(f"{base_url}/handle", headers=headers, timeout=REQUEST_TIMEOUT)
142              assert resp.ok, f"Server returned {resp.status_code}: {resp.text}"
143              payload = resp.json()
144  
145              # Validate server span is a child in the same trace
146              assert payload["trace_id"] == client_span.trace_id
147              assert payload["parent_id"] == client_span.span_id
148  
149      # Poll until the server subprocess's BatchSpanProcessor has exported its span.
150      trace = _wait_for_trace(client_span.trace_id, expected_span_count=2)
151      spans = trace.data.spans
152      assert len(spans) == 2
153  
154      # Identify root and child
155      root_span = next(s for s in spans if s.parent_id is None)
156      child_span = next(s for s in spans if s.parent_id is not None)
157  
158      assert root_span.name == "client-root"
159      assert child_span.name == "server-handler"
160      assert child_span.parent_id == root_span.span_id
161  
162  
163  @skip_when_testing_trace_sdk
164  def test_distributed_tracing_e2e_nested_call(tmp_path):
165      # Path to the Flask server script
166      server_path = Path(__file__).parent / "fixtures" / "flask_tracing_server.py"
167  
168      # Start first Flask server, then get port for second server to avoid port conflicts
169      port = get_safe_port()
170      with flask_server(server_path, port) as base_url:
171          port2 = get_safe_port()
172          with flask_server(server_path, port2) as base_url2:
173              # Client side: create a span and send headers to server
174              with mlflow.start_span("client-root") as client_span:
175                  headers = get_tracing_context_headers_for_http_request()
176                  # Pass the second server URL as a query parameter
177                  resp = requests.post(
178                      f"{base_url}/handle1",
179                      headers=headers,
180                      params={"second_server_url": base_url2},
181                      timeout=REQUEST_TIMEOUT,
182                  )
183                  assert resp.ok, f"Server returned {resp.status_code}: {resp.text}"
184                  payload = resp.json()
185  
186                  # Validate server span is a child in the same trace
187                  assert payload["trace_id"] == client_span.trace_id
188                  assert payload["parent_id"] == client_span.span_id
189                  child_span1_id = payload["span_id"]
190                  assert payload["nested_call_resp"]["trace_id"] == client_span.trace_id
191                  assert payload["nested_call_resp"]["parent_id"] == child_span1_id
192                  child_span2_id = payload["nested_call_resp"]["span_id"]
193  
194      trace = _wait_for_trace(client_span.trace_id, expected_span_count=3)
195      spans = trace.data.spans
196      assert len(spans) == 3
197  
198      # Identify root and child
199      root_span = next(s for s in spans if s.parent_id is None)
200      child_span1 = next(s for s in spans if s.parent_id == root_span.span_id)
201      child_span2 = next(s for s in spans if s.parent_id == child_span1.span_id)
202  
203      assert root_span.name == "client-root"
204      assert child_span1.name == "server-handler1"
205      assert child_span2.name == "server-handler2"
206      assert child_span1.span_id == child_span1_id
207      assert child_span2.span_id == child_span2_id