flask_tracing_server.py
1 """Flask server for distributed tracing tests.""" 2 3 import sys 4 5 import requests 6 from flask import Flask, jsonify, request 7 8 import mlflow 9 from mlflow.tracing.distributed import ( 10 get_tracing_context_headers_for_http_request, 11 set_tracing_context_from_http_request_headers, 12 ) 13 14 REQUEST_TIMEOUT = 20 15 16 app = Flask(__name__) 17 18 19 @app.get("/health") 20 def health(): 21 return "ok", 200 22 23 24 @app.post("/handle") 25 def handle(): 26 headers = dict(request.headers) 27 with set_tracing_context_from_http_request_headers(headers): 28 with mlflow.start_span("server-handler") as span: 29 return jsonify({ 30 "trace_id": span.trace_id, 31 "span_id": span.span_id, 32 "parent_id": span.parent_id, 33 }) 34 35 36 @app.post("/handle1") 37 def handle1(): 38 headers = dict(request.headers) 39 with set_tracing_context_from_http_request_headers(headers): 40 with mlflow.start_span("server-handler1") as span: 41 # Get the URL for the second handler from environment or command line 42 # In nested tests, this will be passed via environment 43 second_server_url = request.args.get("second_server_url") 44 if not second_server_url: 45 return jsonify({"error": "second_server_url parameter required"}), 400 46 47 headers2 = get_tracing_context_headers_for_http_request() 48 resp2 = requests.post( 49 f"{second_server_url}/handle2", headers=headers2, timeout=REQUEST_TIMEOUT 50 ) 51 if not resp2.ok: 52 return jsonify({"error": f"Nested call failed: {resp2.status_code}"}), 502 53 54 payload2 = resp2.json() 55 return jsonify({ 56 "trace_id": span.trace_id, 57 "span_id": span.span_id, 58 "parent_id": span.parent_id, 59 "nested_call_resp": payload2, 60 }) 61 62 63 @app.post("/handle2") 64 def handle2(): 65 headers = dict(request.headers) 66 with set_tracing_context_from_http_request_headers(headers): 67 with mlflow.start_span("server-handler2") as span: 68 return jsonify({ 69 "trace_id": span.trace_id, 70 "span_id": span.span_id, 71 "parent_id": span.parent_id, 72 }) 73 74 75 if __name__ == "__main__": 76 if len(sys.argv) < 2: 77 raise SystemExit("Usage: flask_tracing_server.py <port>") 78 79 port = int(sys.argv[1]) 80 app.run(host="127.0.0.1", port=port, debug=False, use_reloader=False)