/ middleware / security_middleware.py
security_middleware.py
  1  import re
  2  import logging
  3  import secrets
  4  from typing import Dict, List, Callable, Optional
  5  from fastapi import Request, Response
  6  from fastapi.responses import JSONResponse
  7  from starlette.middleware.base import BaseHTTPMiddleware
  8  from starlette.types import ASGIApp
  9  
 10  # Logging configuration
 11  logger = logging.getLogger("security_middleware")
 12  
 13  class SecurityMiddleware(BaseHTTPMiddleware):
 14      """Middleware implementing various security protections."""
 15      
 16      def __init__(
 17          self,
 18          app: ASGIApp,
 19          enable_xss_protection: bool = True,
 20          enable_hsts: bool = True,
 21          enable_content_type_options: bool = True,
 22          enable_frame_options: bool = True,
 23          enable_referrer_policy: bool = True,
 24          enable_csp: bool = True,
 25          enable_cors_protection: bool = True,
 26          csp_directives: Optional[Dict[str, str]] = None,
 27          allowed_origins: List[str] = None,
 28          allowed_methods: List[str] = None,
 29      ):
 30          """
 31          Initializes the security middleware.
 32          
 33          Args:
 34              app: ASGI Application
 35              enable_xss_protection: Enable XSS protection
 36              enable_hsts: Enable HSTS (HTTP Strict Transport Security)
 37              enable_content_type_options: Enable X-Content-Type-Options
 38              enable_frame_options: Enable X-Frame-Options
 39              enable_referrer_policy: Enable Referrer-Policy
 40              enable_csp: Enable Content-Security-Policy
 41              enable_cors_protection: Enable advanced CORS protection
 42              csp_directives: Custom CSP directives
 43              allowed_origins: Allowed origins for CORS
 44              allowed_methods: Allowed HTTP methods for CORS
 45          """
 46          super().__init__(app)
 47          self.enable_xss_protection = enable_xss_protection
 48          self.enable_hsts = enable_hsts
 49          self.enable_content_type_options = enable_content_type_options
 50          self.enable_frame_options = enable_frame_options
 51          self.enable_referrer_policy = enable_referrer_policy
 52          self.enable_csp = enable_csp
 53          self.enable_cors_protection = enable_cors_protection
 54          
 55          # CSP parameters
 56          self.csp_directives = csp_directives or {
 57              "default-src": "'self'",
 58              "script-src": "'self'",
 59              "style-src": "'self'",
 60              "img-src": "'self' data:",
 61              "font-src": "'self'",
 62              "connect-src": "'self'",
 63              "frame-ancestors": "'none'",
 64              "form-action": "'self'",
 65              "base-uri": "'self'",
 66              "object-src": "'none'"
 67          }
 68          
 69          # CORS parameters
 70          self.allowed_origins = allowed_origins or ["*"]
 71          self.allowed_methods = allowed_methods or ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"]
 72          
 73          # Generate a nonce for CSP (could be regenerated for each request)
 74          self.csp_nonce = secrets.token_urlsafe(16)
 75      
 76      async def dispatch(self, request: Request, call_next: Callable) -> Response:
 77          """Adds security headers to the response."""
 78          # Check origin for advanced CORS protection
 79          if self.enable_cors_protection and request.method != "OPTIONS":
 80              origin = request.headers.get("Origin")
 81              if origin and not self._is_origin_allowed(origin):
 82                  return JSONResponse(
 83                      status_code=403,
 84                      content={"detail": "Origin not allowed"}
 85                  )
 86          
 87          # Process the request normally
 88          response = await call_next(request)
 89          
 90          # Add security headers
 91          self._add_security_headers(response, request)
 92          
 93          return response
 94      
 95      def _add_security_headers(self, response: Response, request: Request) -> None:
 96          """Adds security headers to the response."""
 97          # X-XSS-Protection
 98          if self.enable_xss_protection:
 99              response.headers["X-XSS-Protection"] = "1; mode=block"
100          
101          # Strict-Transport-Security (HSTS)
102          if self.enable_hsts:
103              response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains; preload"
104          
105          # X-Content-Type-Options
106          if self.enable_content_type_options:
107              response.headers["X-Content-Type-Options"] = "nosniff"
108          
109          # X-Frame-Options
110          if self.enable_frame_options:
111              response.headers["X-Frame-Options"] = "DENY"
112          
113          # Referrer-Policy
114          if self.enable_referrer_policy:
115              response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
116          
117          # Content-Security-Policy
118          if self.enable_csp and not self._is_api_endpoint(request.url.path):
119              # Generate CSP policy
120              csp_parts = []
121              for directive, value in self.csp_directives.items():
122                  # Add nonce for script-src and style-src directives
123                  if directive in ["script-src", "style-src"] and "'unsafe-inline'" not in value:
124                      value = f"{value} 'nonce-{self.csp_nonce}'"
125                  csp_parts.append(f"{directive} {value}")
126              
127              response.headers["Content-Security-Policy"] = "; ".join(csp_parts)
128              # Add nonce to response context for use in templates
129              request.state.csp_nonce = self.csp_nonce
130      
131      def _is_origin_allowed(self, origin: str) -> bool:
132          """Checks if the origin is allowed."""
133          if "*" in self.allowed_origins:
134              return True
135          
136          return origin in self.allowed_origins or self._matches_wildcard_origin(origin)
137      
138      def _matches_wildcard_origin(self, origin: str) -> bool:
139          """Checks if the origin matches a wildcard pattern."""
140          for allowed_origin in self.allowed_origins:
141              if allowed_origin.startswith("*."):
142                  pattern = allowed_origin.replace("*.", ".*\\.")
143                  if re.match(pattern, origin):
144                      return True
145          return False
146      
147      def _is_api_endpoint(self, path: str) -> bool:
148          """Checks if the path is an API endpoint (to avoid CSP on JSON APIs)."""
149          return path.startswith("/api/") or path.startswith("/auth/")