/ tests / webhooks / app.py
app.py
  1  import base64
  2  import hashlib
  3  import hmac
  4  import itertools
  5  import json
  6  import sys
  7  from pathlib import Path
  8  
  9  import fastapi
 10  import uvicorn
 11  from fastapi import HTTPException, Request
 12  
 13  from mlflow.webhooks.constants import (
 14      WEBHOOK_DELIVERY_ID_HEADER,
 15      WEBHOOK_SIGNATURE_HEADER,
 16      WEBHOOK_SIGNATURE_VERSION,
 17      WEBHOOK_TIMESTAMP_HEADER,
 18  )
 19  
 20  LOG_FILE = Path("logs.jsonl")
 21  
 22  app = fastapi.FastAPI()
 23  
 24  
 25  @app.get("/health")
 26  async def health_check():
 27      return {"status": "ok"}
 28  
 29  
 30  @app.post("/insecure-webhook")
 31  async def insecure_webhook(request: Request):
 32      payload = await request.json()
 33      # Extract the data field from webhook payload
 34      actual_payload = payload.get("data", payload)
 35      webhook_data = {
 36          "endpoint": "/insecure-webhook",
 37          "payload": actual_payload,
 38          "headers": dict(request.headers),
 39          "status_code": 200,
 40          "error": None,
 41      }
 42      with LOG_FILE.open("a") as f:
 43          f.write(json.dumps(webhook_data) + "\n")
 44  
 45      return {"status": "received"}
 46  
 47  
 48  @app.post("/reset")
 49  async def reset():
 50      """Reset both logs and counters for testing"""
 51      global flaky_counter, rate_limited_counter
 52  
 53      # Clear logs
 54      if LOG_FILE.exists():
 55          LOG_FILE.open("w").close()
 56  
 57      # Reset all counters
 58      flaky_counter = itertools.count(1)
 59      rate_limited_counter = itertools.count(1)
 60  
 61      return {"status": "reset complete", "logs": "cleared", "counters": "reset"}
 62  
 63  
 64  @app.get("/logs")
 65  async def get_logs():
 66      if not LOG_FILE.exists():
 67          return {"logs": []}
 68  
 69      with LOG_FILE.open("r") as f:
 70          logs = [json.loads(s) for line in f if (s := line.strip())]
 71          return {"logs": logs}
 72  
 73  
 74  # Secret key for HMAC verification (in real world, this would be stored securely)
 75  WEBHOOK_SECRET = "test-secret-key"
 76  
 77  
 78  def verify_webhook_signature(
 79      payload: str, signature: str, delivery_id: str, timestamp: str
 80  ) -> bool:
 81      if not signature or not signature.startswith(f"{WEBHOOK_SIGNATURE_VERSION},"):
 82          return False
 83  
 84      # Signature format: delivery_id.timestamp.payload
 85      signed_content = f"{delivery_id}.{timestamp}.{payload}"
 86      expected_signature = hmac.new(
 87          WEBHOOK_SECRET.encode("utf-8"), signed_content.encode("utf-8"), hashlib.sha256
 88      ).digest()
 89      expected_signature_b64 = base64.b64encode(expected_signature).decode("utf-8")
 90  
 91      provided_signature = signature.removeprefix(f"{WEBHOOK_SIGNATURE_VERSION},")
 92      return hmac.compare_digest(expected_signature_b64, provided_signature)
 93  
 94  
 95  @app.post("/secure-webhook")
 96  async def secure_webhook(request: Request):
 97      body = await request.body()
 98      signature = request.headers.get(WEBHOOK_SIGNATURE_HEADER)
 99      timestamp = request.headers.get(WEBHOOK_TIMESTAMP_HEADER)
100      delivery_id = request.headers.get(WEBHOOK_DELIVERY_ID_HEADER)
101  
102      if not signature:
103          error_data = {
104              "endpoint": "/secure-webhook",
105              "headers": dict(request.headers),
106              "status_code": 400,
107              "payload": None,
108              "error": "Missing signature header",
109          }
110          with LOG_FILE.open("a") as f:
111              f.write(json.dumps(error_data) + "\n")
112          raise HTTPException(status_code=400, detail="Missing signature header")
113  
114      if not timestamp:
115          error_data = {
116              "endpoint": "/secure-webhook",
117              "error": "Missing timestamp header",
118              "status_code": 400,
119              "headers": dict(request.headers),
120          }
121          with LOG_FILE.open("a") as f:
122              f.write(json.dumps(error_data) + "\n")
123          raise HTTPException(status_code=400, detail="Missing timestamp header")
124  
125      if not delivery_id:
126          error_data = {
127              "endpoint": "/secure-webhook",
128              "error": "Missing delivery ID header",
129              "status_code": 400,
130              "headers": dict(request.headers),
131          }
132          with LOG_FILE.open("a") as f:
133              f.write(json.dumps(error_data) + "\n")
134          raise HTTPException(status_code=400, detail="Missing delivery ID header")
135  
136      if not verify_webhook_signature(body.decode("utf-8"), signature, delivery_id, timestamp):
137          error_data = {
138              "endpoint": "/secure-webhook",
139              "headers": dict(request.headers),
140              "status_code": 401,
141              "payload": None,
142              "error": "Invalid signature",
143          }
144          with LOG_FILE.open("a") as f:
145              f.write(json.dumps(error_data) + "\n")
146          raise HTTPException(status_code=401, detail="Invalid signature")
147  
148      payload = json.loads(body)
149      # Extract the data field from webhook payload
150      actual_payload = payload.get("data", payload)
151      webhook_data = {
152          "endpoint": "/secure-webhook",
153          "payload": actual_payload,
154          "headers": dict(request.headers),
155          "status_code": 200,
156          "error": None,
157      }
158  
159      with LOG_FILE.open("a") as f:
160          f.write(json.dumps(webhook_data) + "\n")
161  
162      return {"status": "received", "signature": "verified"}
163  
164  
165  # Create separate counters for each endpoint using itertools.count
166  flaky_counter = itertools.count(1)
167  rate_limited_counter = itertools.count(1)
168  
169  
170  @app.post("/flaky-webhook")
171  async def flaky_webhook(request: Request):
172      """Endpoint that fails initially but succeeds after retries"""
173      attempt = next(flaky_counter)
174  
175      payload = await request.json()
176      actual_payload = payload.get("data", payload)
177  
178      # Log the attempt
179      webhook_data = {
180          "endpoint": "/flaky-webhook",
181          "payload": actual_payload,
182          "headers": dict(request.headers),
183          "attempt": attempt,
184          "error": None,
185      }
186  
187      # Fail on first two attempts with 500 error
188      if attempt <= 2:
189          webhook_data["status_code"] = 500
190          webhook_data["error"] = "Server error (will retry)"
191          with LOG_FILE.open("a") as f:
192              f.write(json.dumps(webhook_data) + "\n")
193          raise HTTPException(status_code=500, detail="Internal server error")
194  
195      # Succeed on third attempt
196      webhook_data["status_code"] = 200
197      with LOG_FILE.open("a") as f:
198          f.write(json.dumps(webhook_data) + "\n")
199  
200      return {"status": "received", "attempt": attempt}
201  
202  
203  @app.post("/rate-limited-webhook")
204  async def rate_limited_webhook(request: Request):
205      """Endpoint that returns 429 with Retry-After header"""
206      attempt = next(rate_limited_counter)
207  
208      payload = await request.json()
209      actual_payload = payload.get("data", payload)
210  
211      # Log the attempt
212      webhook_data = {
213          "endpoint": "/rate-limited-webhook",
214          "payload": actual_payload,
215          "headers": dict(request.headers),
216          "attempt": attempt,
217          "error": None,
218      }
219  
220      # Return 429 on first attempt
221      if attempt == 1:
222          webhook_data["status_code"] = 429
223          webhook_data["error"] = "Rate limited"
224          with LOG_FILE.open("a") as f:
225              f.write(json.dumps(webhook_data) + "\n")
226          # Return 429 with Retry-After header
227          response = fastapi.Response(content="Rate limited", status_code=429)
228          response.headers["Retry-After"] = "2"
229          return response
230  
231      # Succeed on second attempt
232      webhook_data["status_code"] = 200
233      with LOG_FILE.open("a") as f:
234          f.write(json.dumps(webhook_data) + "\n")
235  
236      return {"status": "received", "attempt": attempt}
237  
238  
239  if __name__ == "__main__":
240      port = sys.argv[1]
241      uvicorn.run(app, host="0.0.0.0", port=int(port))