app.py
1 import base64 2 import hashlib 3 import hmac 4 import itertools 5 import json 6 import sys 7 from pathlib import Path 8 9 import fastapi 10 import uvicorn 11 from fastapi import HTTPException, Request 12 13 from mlflow.webhooks.constants import ( 14 WEBHOOK_DELIVERY_ID_HEADER, 15 WEBHOOK_SIGNATURE_HEADER, 16 WEBHOOK_SIGNATURE_VERSION, 17 WEBHOOK_TIMESTAMP_HEADER, 18 ) 19 20 LOG_FILE = Path("logs.jsonl") 21 22 app = fastapi.FastAPI() 23 24 25 @app.get("/health") 26 async def health_check(): 27 return {"status": "ok"} 28 29 30 @app.post("/insecure-webhook") 31 async def insecure_webhook(request: Request): 32 payload = await request.json() 33 # Extract the data field from webhook payload 34 actual_payload = payload.get("data", payload) 35 webhook_data = { 36 "endpoint": "/insecure-webhook", 37 "payload": actual_payload, 38 "headers": dict(request.headers), 39 "status_code": 200, 40 "error": None, 41 } 42 with LOG_FILE.open("a") as f: 43 f.write(json.dumps(webhook_data) + "\n") 44 45 return {"status": "received"} 46 47 48 @app.post("/reset") 49 async def reset(): 50 """Reset both logs and counters for testing""" 51 global flaky_counter, rate_limited_counter 52 53 # Clear logs 54 if LOG_FILE.exists(): 55 LOG_FILE.open("w").close() 56 57 # Reset all counters 58 flaky_counter = itertools.count(1) 59 rate_limited_counter = itertools.count(1) 60 61 return {"status": "reset complete", "logs": "cleared", "counters": "reset"} 62 63 64 @app.get("/logs") 65 async def get_logs(): 66 if not LOG_FILE.exists(): 67 return {"logs": []} 68 69 with LOG_FILE.open("r") as f: 70 logs = [json.loads(s) for line in f if (s := line.strip())] 71 return {"logs": logs} 72 73 74 # Secret key for HMAC verification (in real world, this would be stored securely) 75 WEBHOOK_SECRET = "test-secret-key" 76 77 78 def verify_webhook_signature( 79 payload: str, signature: str, delivery_id: str, timestamp: str 80 ) -> bool: 81 if not signature or not signature.startswith(f"{WEBHOOK_SIGNATURE_VERSION},"): 82 return False 83 84 # Signature format: delivery_id.timestamp.payload 85 signed_content = f"{delivery_id}.{timestamp}.{payload}" 86 expected_signature = hmac.new( 87 WEBHOOK_SECRET.encode("utf-8"), signed_content.encode("utf-8"), hashlib.sha256 88 ).digest() 89 expected_signature_b64 = base64.b64encode(expected_signature).decode("utf-8") 90 91 provided_signature = signature.removeprefix(f"{WEBHOOK_SIGNATURE_VERSION},") 92 return hmac.compare_digest(expected_signature_b64, provided_signature) 93 94 95 @app.post("/secure-webhook") 96 async def secure_webhook(request: Request): 97 body = await request.body() 98 signature = request.headers.get(WEBHOOK_SIGNATURE_HEADER) 99 timestamp = request.headers.get(WEBHOOK_TIMESTAMP_HEADER) 100 delivery_id = request.headers.get(WEBHOOK_DELIVERY_ID_HEADER) 101 102 if not signature: 103 error_data = { 104 "endpoint": "/secure-webhook", 105 "headers": dict(request.headers), 106 "status_code": 400, 107 "payload": None, 108 "error": "Missing signature header", 109 } 110 with LOG_FILE.open("a") as f: 111 f.write(json.dumps(error_data) + "\n") 112 raise HTTPException(status_code=400, detail="Missing signature header") 113 114 if not timestamp: 115 error_data = { 116 "endpoint": "/secure-webhook", 117 "error": "Missing timestamp header", 118 "status_code": 400, 119 "headers": dict(request.headers), 120 } 121 with LOG_FILE.open("a") as f: 122 f.write(json.dumps(error_data) + "\n") 123 raise HTTPException(status_code=400, detail="Missing timestamp header") 124 125 if not delivery_id: 126 error_data = { 127 "endpoint": "/secure-webhook", 128 "error": "Missing delivery ID header", 129 "status_code": 400, 130 "headers": dict(request.headers), 131 } 132 with LOG_FILE.open("a") as f: 133 f.write(json.dumps(error_data) + "\n") 134 raise HTTPException(status_code=400, detail="Missing delivery ID header") 135 136 if not verify_webhook_signature(body.decode("utf-8"), signature, delivery_id, timestamp): 137 error_data = { 138 "endpoint": "/secure-webhook", 139 "headers": dict(request.headers), 140 "status_code": 401, 141 "payload": None, 142 "error": "Invalid signature", 143 } 144 with LOG_FILE.open("a") as f: 145 f.write(json.dumps(error_data) + "\n") 146 raise HTTPException(status_code=401, detail="Invalid signature") 147 148 payload = json.loads(body) 149 # Extract the data field from webhook payload 150 actual_payload = payload.get("data", payload) 151 webhook_data = { 152 "endpoint": "/secure-webhook", 153 "payload": actual_payload, 154 "headers": dict(request.headers), 155 "status_code": 200, 156 "error": None, 157 } 158 159 with LOG_FILE.open("a") as f: 160 f.write(json.dumps(webhook_data) + "\n") 161 162 return {"status": "received", "signature": "verified"} 163 164 165 # Create separate counters for each endpoint using itertools.count 166 flaky_counter = itertools.count(1) 167 rate_limited_counter = itertools.count(1) 168 169 170 @app.post("/flaky-webhook") 171 async def flaky_webhook(request: Request): 172 """Endpoint that fails initially but succeeds after retries""" 173 attempt = next(flaky_counter) 174 175 payload = await request.json() 176 actual_payload = payload.get("data", payload) 177 178 # Log the attempt 179 webhook_data = { 180 "endpoint": "/flaky-webhook", 181 "payload": actual_payload, 182 "headers": dict(request.headers), 183 "attempt": attempt, 184 "error": None, 185 } 186 187 # Fail on first two attempts with 500 error 188 if attempt <= 2: 189 webhook_data["status_code"] = 500 190 webhook_data["error"] = "Server error (will retry)" 191 with LOG_FILE.open("a") as f: 192 f.write(json.dumps(webhook_data) + "\n") 193 raise HTTPException(status_code=500, detail="Internal server error") 194 195 # Succeed on third attempt 196 webhook_data["status_code"] = 200 197 with LOG_FILE.open("a") as f: 198 f.write(json.dumps(webhook_data) + "\n") 199 200 return {"status": "received", "attempt": attempt} 201 202 203 @app.post("/rate-limited-webhook") 204 async def rate_limited_webhook(request: Request): 205 """Endpoint that returns 429 with Retry-After header""" 206 attempt = next(rate_limited_counter) 207 208 payload = await request.json() 209 actual_payload = payload.get("data", payload) 210 211 # Log the attempt 212 webhook_data = { 213 "endpoint": "/rate-limited-webhook", 214 "payload": actual_payload, 215 "headers": dict(request.headers), 216 "attempt": attempt, 217 "error": None, 218 } 219 220 # Return 429 on first attempt 221 if attempt == 1: 222 webhook_data["status_code"] = 429 223 webhook_data["error"] = "Rate limited" 224 with LOG_FILE.open("a") as f: 225 f.write(json.dumps(webhook_data) + "\n") 226 # Return 429 with Retry-After header 227 response = fastapi.Response(content="Rate limited", status_code=429) 228 response.headers["Retry-After"] = "2" 229 return response 230 231 # Succeed on second attempt 232 webhook_data["status_code"] = 200 233 with LOG_FILE.open("a") as f: 234 f.write(json.dumps(webhook_data) + "\n") 235 236 return {"status": "received", "attempt": attempt} 237 238 239 if __name__ == "__main__": 240 port = sys.argv[1] 241 uvicorn.run(app, host="0.0.0.0", port=int(port))