/ tests / tracing / fixtures / flask_tracing_server.py
flask_tracing_server.py
 1  """Flask server for distributed tracing tests."""
 2  
 3  import sys
 4  
 5  import requests
 6  from flask import Flask, jsonify, request
 7  
 8  import mlflow
 9  from mlflow.tracing.distributed import (
10      get_tracing_context_headers_for_http_request,
11      set_tracing_context_from_http_request_headers,
12  )
13  
14  REQUEST_TIMEOUT = 20
15  
16  app = Flask(__name__)
17  
18  
19  @app.get("/health")
20  def health():
21      return "ok", 200
22  
23  
24  @app.post("/handle")
25  def handle():
26      headers = dict(request.headers)
27      with set_tracing_context_from_http_request_headers(headers):
28          with mlflow.start_span("server-handler") as span:
29              return jsonify({
30                  "trace_id": span.trace_id,
31                  "span_id": span.span_id,
32                  "parent_id": span.parent_id,
33              })
34  
35  
36  @app.post("/handle1")
37  def handle1():
38      headers = dict(request.headers)
39      with set_tracing_context_from_http_request_headers(headers):
40          with mlflow.start_span("server-handler1") as span:
41              # Get the URL for the second handler from environment or command line
42              # In nested tests, this will be passed via environment
43              second_server_url = request.args.get("second_server_url")
44              if not second_server_url:
45                  return jsonify({"error": "second_server_url parameter required"}), 400
46  
47              headers2 = get_tracing_context_headers_for_http_request()
48              resp2 = requests.post(
49                  f"{second_server_url}/handle2", headers=headers2, timeout=REQUEST_TIMEOUT
50              )
51              if not resp2.ok:
52                  return jsonify({"error": f"Nested call failed: {resp2.status_code}"}), 502
53  
54              payload2 = resp2.json()
55              return jsonify({
56                  "trace_id": span.trace_id,
57                  "span_id": span.span_id,
58                  "parent_id": span.parent_id,
59                  "nested_call_resp": payload2,
60              })
61  
62  
63  @app.post("/handle2")
64  def handle2():
65      headers = dict(request.headers)
66      with set_tracing_context_from_http_request_headers(headers):
67          with mlflow.start_span("server-handler2") as span:
68              return jsonify({
69                  "trace_id": span.trace_id,
70                  "span_id": span.span_id,
71                  "parent_id": span.parent_id,
72              })
73  
74  
75  if __name__ == "__main__":
76      if len(sys.argv) < 2:
77          raise SystemExit("Usage: flask_tracing_server.py <port>")
78  
79      port = int(sys.argv[1])
80      app.run(host="127.0.0.1", port=port, debug=False, use_reloader=False)