/ middleware / rate_limit_middleware.py
rate_limit_middleware.py
  1  """
  2  Rate limiting middleware.
  3  Protects the API against abuse by limiting the number of requests per IP or API key.
  4  """
  5  
  6  import time
  7  import logging
  8  from typing import Dict, Optional, List, Callable, Tuple
  9  from collections import defaultdict
 10  from fastapi import Request, Response
 11  from fastapi.responses import JSONResponse
 12  from starlette.middleware.base import BaseHTTPMiddleware
 13  from starlette.types import ASGIApp
 14  
 15  # Logging configuration
 16  logger = logging.getLogger("rate_limit_middleware")
 17  
 18  class RateLimiter:
 19      """Rate limiter manager based on a sliding window algorithm."""
 20      
 21      def __init__(self, window_size: int = 60, max_requests: int = 100):
 22          """
 23          Initializes the rate limiter.
 24          
 25          Args:
 26              window_size: Window size in seconds
 27              max_requests: Maximum number of requests allowed in the window
 28          """
 29          self.window_size = window_size
 30          self.max_requests = max_requests
 31          # Structure: {identifier: [(timestamp1, 1), (timestamp2, 1), ...]}
 32          self.request_records = defaultdict(list)
 33      
 34      def is_rate_limited(self, identifier: str) -> Tuple[bool, int, int]:
 35          """
 36          Checks if the identifier has exceeded its request limit.
 37          
 38          Args:
 39              identifier: Unique identifier (IP or API key)
 40              
 41          Returns:
 42              (is_limited, remaining_requests, wait_time_in_seconds)
 43          """
 44          now = time.time()
 45          records = self.request_records[identifier]
 46          
 47          # Remove records that are too old
 48          cutoff = now - self.window_size
 49          records = [(ts, count) for ts, count in records if ts > cutoff]
 50          self.request_records[identifier] = records
 51          
 52          # Calculate the total number of requests in the window
 53          total_requests = sum(count for _, count in records)
 54          
 55          # Check if the limit is reached
 56          if total_requests >= self.max_requests:
 57              # Calculate the wait time
 58              if records:
 59                  oldest = records[0][0]
 60                  wait_time = int(self.window_size - (now - oldest)) + 1
 61                  return True, 0, wait_time
 62              return True, 0, self.window_size
 63          
 64          # Add this request
 65          records.append((now, 1))
 66          self.request_records[identifier] = records
 67          
 68          # Return the number of remaining requests
 69          remaining = self.max_requests - total_requests - 1
 70          return False, remaining, 0
 71  
 72  class RateLimitMiddleware(BaseHTTPMiddleware):
 73      """Middleware to limit request rate."""
 74      
 75      def __init__(
 76          self,
 77          app: ASGIApp,
 78          global_rate_limit: int = 1000,
 79          ip_rate_limit: int = 100,
 80          api_key_rate_limit: int = 200,
 81          window_size: int = 60,
 82          exclude_paths: List[str] = None,
 83          exclude_prefixes: List[str] = None,
 84      ):
 85          """
 86          Initializes the rate limiting middleware.
 87          
 88          Args:
 89              app: ASGI Application
 90              global_rate_limit: Global limit of requests per minute
 91              ip_rate_limit: Limit of requests per IP per minute
 92              api_key_rate_limit: Limit of requests per API key per minute
 93              window_size: Window size in seconds
 94              exclude_paths: Paths excluded from rate limiting
 95              exclude_prefixes: Path prefixes excluded from rate limiting
 96          """
 97          super().__init__(app)
 98          self.global_limiter = RateLimiter(window_size, global_rate_limit)
 99          self.ip_limiter = RateLimiter(window_size, ip_rate_limit)
100          self.api_key_limiter = RateLimiter(window_size, api_key_rate_limit)
101          self.exclude_paths = exclude_paths or ["/api/health", "/api/docs", "/api/redoc", "/api/openapi.json"]
102          self.exclude_prefixes = exclude_prefixes or ["/static/", "/docs/"]
103      
104      async def dispatch(self, request: Request, call_next: Callable) -> Response:
105          """Processes requests with rate limiting."""
106          # Check if this path is excluded
107          if self._is_excluded_path(request.url.path):
108              return await call_next(request)
109          
110          # Check global limits
111          global_limited, global_remaining, global_wait = self.global_limiter.is_rate_limited("global")
112          if global_limited:
113              return self._create_rate_limit_response(global_wait, 0, "global")
114          
115          # Get the request identifier (IP or API key)
116          client_ip = request.client.host if request.client else "unknown"
117          api_key = request.headers.get("X-API-Key")
118          
119          # Check IP-based limits
120          ip_limited, ip_remaining, ip_wait = self.ip_limiter.is_rate_limited(client_ip)
121          if ip_limited:
122              return self._create_rate_limit_response(ip_wait, 0, "ip", client_ip)
123          
124          # Check API key-based limits
125          if api_key:
126              key_limited, key_remaining, key_wait = self.api_key_limiter.is_rate_limited(api_key)
127              if key_limited:
128                  return self._create_rate_limit_response(key_wait, 0, "api_key")
129              
130              # Add the number of remaining requests to the response header
131              response = await call_next(request)
132              response.headers["X-RateLimit-Remaining"] = str(key_remaining)
133              response.headers["X-RateLimit-Limit"] = str(self.api_key_limiter.max_requests)
134              response.headers["X-RateLimit-Reset"] = str(int(time.time()) + self.api_key_limiter.window_size)
135              return response
136          
137          # Add the number of remaining requests to the response header (based on IP)
138          response = await call_next(request)
139          response.headers["X-RateLimit-Remaining"] = str(ip_remaining)
140          response.headers["X-RateLimit-Limit"] = str(self.ip_limiter.max_requests)
141          response.headers["X-RateLimit-Reset"] = str(int(time.time()) + self.ip_limiter.window_size)
142          return response
143      
144      def _is_excluded_path(self, path: str) -> bool:
145          """Checks if the path is excluded from rate limiting."""
146          if path in self.exclude_paths:
147              return True
148          
149          for prefix in self.exclude_prefixes:
150              if path.startswith(prefix):
151                  return True
152          
153          return False
154      
155      def _create_rate_limit_response(self, wait_time: int, remaining: int, limiter_type: str, identifier: str = None) -> Response:
156          """Creates a response for a rate-limited request."""
157          detail = f"Rate limit exceeded. Please try again in {wait_time} seconds."
158          if identifier and limiter_type == "ip":
159              logger.warning(f"Rate limit exceeded for IP {identifier}")
160              
161          return JSONResponse(
162              status_code=429,
163              content={
164                  "detail": detail,
165                  "type": "rate_limit_exceeded",
166                  "limiter": limiter_type
167              },
168              headers={
169                  "Retry-After": str(wait_time),
170                  "X-RateLimit-Remaining": "0",
171                  "X-RateLimit-Reset": str(int(time.time()) + wait_time)
172              }
173          )