/ middleware / rate_limit_middleware.py
rate_limit_middleware.py
1 """ 2 Rate limiting middleware. 3 Protects the API against abuse by limiting the number of requests per IP or API key. 4 """ 5 6 import time 7 import logging 8 from typing import Dict, Optional, List, Callable, Tuple 9 from collections import defaultdict 10 from fastapi import Request, Response 11 from fastapi.responses import JSONResponse 12 from starlette.middleware.base import BaseHTTPMiddleware 13 from starlette.types import ASGIApp 14 15 # Logging configuration 16 logger = logging.getLogger("rate_limit_middleware") 17 18 class RateLimiter: 19 """Rate limiter manager based on a sliding window algorithm.""" 20 21 def __init__(self, window_size: int = 60, max_requests: int = 100): 22 """ 23 Initializes the rate limiter. 24 25 Args: 26 window_size: Window size in seconds 27 max_requests: Maximum number of requests allowed in the window 28 """ 29 self.window_size = window_size 30 self.max_requests = max_requests 31 # Structure: {identifier: [(timestamp1, 1), (timestamp2, 1), ...]} 32 self.request_records = defaultdict(list) 33 34 def is_rate_limited(self, identifier: str) -> Tuple[bool, int, int]: 35 """ 36 Checks if the identifier has exceeded its request limit. 37 38 Args: 39 identifier: Unique identifier (IP or API key) 40 41 Returns: 42 (is_limited, remaining_requests, wait_time_in_seconds) 43 """ 44 now = time.time() 45 records = self.request_records[identifier] 46 47 # Remove records that are too old 48 cutoff = now - self.window_size 49 records = [(ts, count) for ts, count in records if ts > cutoff] 50 self.request_records[identifier] = records 51 52 # Calculate the total number of requests in the window 53 total_requests = sum(count for _, count in records) 54 55 # Check if the limit is reached 56 if total_requests >= self.max_requests: 57 # Calculate the wait time 58 if records: 59 oldest = records[0][0] 60 wait_time = int(self.window_size - (now - oldest)) + 1 61 return True, 0, wait_time 62 return True, 0, self.window_size 63 64 # Add this request 65 records.append((now, 1)) 66 self.request_records[identifier] = records 67 68 # Return the number of remaining requests 69 remaining = self.max_requests - total_requests - 1 70 return False, remaining, 0 71 72 class RateLimitMiddleware(BaseHTTPMiddleware): 73 """Middleware to limit request rate.""" 74 75 def __init__( 76 self, 77 app: ASGIApp, 78 global_rate_limit: int = 1000, 79 ip_rate_limit: int = 100, 80 api_key_rate_limit: int = 200, 81 window_size: int = 60, 82 exclude_paths: List[str] = None, 83 exclude_prefixes: List[str] = None, 84 ): 85 """ 86 Initializes the rate limiting middleware. 87 88 Args: 89 app: ASGI Application 90 global_rate_limit: Global limit of requests per minute 91 ip_rate_limit: Limit of requests per IP per minute 92 api_key_rate_limit: Limit of requests per API key per minute 93 window_size: Window size in seconds 94 exclude_paths: Paths excluded from rate limiting 95 exclude_prefixes: Path prefixes excluded from rate limiting 96 """ 97 super().__init__(app) 98 self.global_limiter = RateLimiter(window_size, global_rate_limit) 99 self.ip_limiter = RateLimiter(window_size, ip_rate_limit) 100 self.api_key_limiter = RateLimiter(window_size, api_key_rate_limit) 101 self.exclude_paths = exclude_paths or ["/api/health", "/api/docs", "/api/redoc", "/api/openapi.json"] 102 self.exclude_prefixes = exclude_prefixes or ["/static/", "/docs/"] 103 104 async def dispatch(self, request: Request, call_next: Callable) -> Response: 105 """Processes requests with rate limiting.""" 106 # Check if this path is excluded 107 if self._is_excluded_path(request.url.path): 108 return await call_next(request) 109 110 # Check global limits 111 global_limited, global_remaining, global_wait = self.global_limiter.is_rate_limited("global") 112 if global_limited: 113 return self._create_rate_limit_response(global_wait, 0, "global") 114 115 # Get the request identifier (IP or API key) 116 client_ip = request.client.host if request.client else "unknown" 117 api_key = request.headers.get("X-API-Key") 118 119 # Check IP-based limits 120 ip_limited, ip_remaining, ip_wait = self.ip_limiter.is_rate_limited(client_ip) 121 if ip_limited: 122 return self._create_rate_limit_response(ip_wait, 0, "ip", client_ip) 123 124 # Check API key-based limits 125 if api_key: 126 key_limited, key_remaining, key_wait = self.api_key_limiter.is_rate_limited(api_key) 127 if key_limited: 128 return self._create_rate_limit_response(key_wait, 0, "api_key") 129 130 # Add the number of remaining requests to the response header 131 response = await call_next(request) 132 response.headers["X-RateLimit-Remaining"] = str(key_remaining) 133 response.headers["X-RateLimit-Limit"] = str(self.api_key_limiter.max_requests) 134 response.headers["X-RateLimit-Reset"] = str(int(time.time()) + self.api_key_limiter.window_size) 135 return response 136 137 # Add the number of remaining requests to the response header (based on IP) 138 response = await call_next(request) 139 response.headers["X-RateLimit-Remaining"] = str(ip_remaining) 140 response.headers["X-RateLimit-Limit"] = str(self.ip_limiter.max_requests) 141 response.headers["X-RateLimit-Reset"] = str(int(time.time()) + self.ip_limiter.window_size) 142 return response 143 144 def _is_excluded_path(self, path: str) -> bool: 145 """Checks if the path is excluded from rate limiting.""" 146 if path in self.exclude_paths: 147 return True 148 149 for prefix in self.exclude_prefixes: 150 if path.startswith(prefix): 151 return True 152 153 return False 154 155 def _create_rate_limit_response(self, wait_time: int, remaining: int, limiter_type: str, identifier: str = None) -> Response: 156 """Creates a response for a rate-limited request.""" 157 detail = f"Rate limit exceeded. Please try again in {wait_time} seconds." 158 if identifier and limiter_type == "ip": 159 logger.warning(f"Rate limit exceeded for IP {identifier}") 160 161 return JSONResponse( 162 status_code=429, 163 content={ 164 "detail": detail, 165 "type": "rate_limit_exceeded", 166 "limiter": limiter_type 167 }, 168 headers={ 169 "Retry-After": str(wait_time), 170 "X-RateLimit-Remaining": "0", 171 "X-RateLimit-Reset": str(int(time.time()) + wait_time) 172 } 173 )