/ middleware / security_middleware.py
security_middleware.py
1 import re 2 import logging 3 import secrets 4 from typing import Dict, List, Callable, Optional 5 from fastapi import Request, Response 6 from fastapi.responses import JSONResponse 7 from starlette.middleware.base import BaseHTTPMiddleware 8 from starlette.types import ASGIApp 9 10 # Logging configuration 11 logger = logging.getLogger("security_middleware") 12 13 class SecurityMiddleware(BaseHTTPMiddleware): 14 """Middleware implementing various security protections.""" 15 16 def __init__( 17 self, 18 app: ASGIApp, 19 enable_xss_protection: bool = True, 20 enable_hsts: bool = True, 21 enable_content_type_options: bool = True, 22 enable_frame_options: bool = True, 23 enable_referrer_policy: bool = True, 24 enable_csp: bool = True, 25 enable_cors_protection: bool = True, 26 csp_directives: Optional[Dict[str, str]] = None, 27 allowed_origins: List[str] = None, 28 allowed_methods: List[str] = None, 29 ): 30 """ 31 Initializes the security middleware. 32 33 Args: 34 app: ASGI Application 35 enable_xss_protection: Enable XSS protection 36 enable_hsts: Enable HSTS (HTTP Strict Transport Security) 37 enable_content_type_options: Enable X-Content-Type-Options 38 enable_frame_options: Enable X-Frame-Options 39 enable_referrer_policy: Enable Referrer-Policy 40 enable_csp: Enable Content-Security-Policy 41 enable_cors_protection: Enable advanced CORS protection 42 csp_directives: Custom CSP directives 43 allowed_origins: Allowed origins for CORS 44 allowed_methods: Allowed HTTP methods for CORS 45 """ 46 super().__init__(app) 47 self.enable_xss_protection = enable_xss_protection 48 self.enable_hsts = enable_hsts 49 self.enable_content_type_options = enable_content_type_options 50 self.enable_frame_options = enable_frame_options 51 self.enable_referrer_policy = enable_referrer_policy 52 self.enable_csp = enable_csp 53 self.enable_cors_protection = enable_cors_protection 54 55 # CSP parameters 56 self.csp_directives = csp_directives or { 57 "default-src": "'self'", 58 "script-src": "'self'", 59 "style-src": "'self'", 60 "img-src": "'self' data:", 61 "font-src": "'self'", 62 "connect-src": "'self'", 63 "frame-ancestors": "'none'", 64 "form-action": "'self'", 65 "base-uri": "'self'", 66 "object-src": "'none'" 67 } 68 69 # CORS parameters 70 self.allowed_origins = allowed_origins or ["*"] 71 self.allowed_methods = allowed_methods or ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"] 72 73 # Generate a nonce for CSP (could be regenerated for each request) 74 self.csp_nonce = secrets.token_urlsafe(16) 75 76 async def dispatch(self, request: Request, call_next: Callable) -> Response: 77 """Adds security headers to the response.""" 78 # Check origin for advanced CORS protection 79 if self.enable_cors_protection and request.method != "OPTIONS": 80 origin = request.headers.get("Origin") 81 if origin and not self._is_origin_allowed(origin): 82 return JSONResponse( 83 status_code=403, 84 content={"detail": "Origin not allowed"} 85 ) 86 87 # Process the request normally 88 response = await call_next(request) 89 90 # Add security headers 91 self._add_security_headers(response, request) 92 93 return response 94 95 def _add_security_headers(self, response: Response, request: Request) -> None: 96 """Adds security headers to the response.""" 97 # X-XSS-Protection 98 if self.enable_xss_protection: 99 response.headers["X-XSS-Protection"] = "1; mode=block" 100 101 # Strict-Transport-Security (HSTS) 102 if self.enable_hsts: 103 response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains; preload" 104 105 # X-Content-Type-Options 106 if self.enable_content_type_options: 107 response.headers["X-Content-Type-Options"] = "nosniff" 108 109 # X-Frame-Options 110 if self.enable_frame_options: 111 response.headers["X-Frame-Options"] = "DENY" 112 113 # Referrer-Policy 114 if self.enable_referrer_policy: 115 response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" 116 117 # Content-Security-Policy 118 if self.enable_csp and not self._is_api_endpoint(request.url.path): 119 # Generate CSP policy 120 csp_parts = [] 121 for directive, value in self.csp_directives.items(): 122 # Add nonce for script-src and style-src directives 123 if directive in ["script-src", "style-src"] and "'unsafe-inline'" not in value: 124 value = f"{value} 'nonce-{self.csp_nonce}'" 125 csp_parts.append(f"{directive} {value}") 126 127 response.headers["Content-Security-Policy"] = "; ".join(csp_parts) 128 # Add nonce to response context for use in templates 129 request.state.csp_nonce = self.csp_nonce 130 131 def _is_origin_allowed(self, origin: str) -> bool: 132 """Checks if the origin is allowed.""" 133 if "*" in self.allowed_origins: 134 return True 135 136 return origin in self.allowed_origins or self._matches_wildcard_origin(origin) 137 138 def _matches_wildcard_origin(self, origin: str) -> bool: 139 """Checks if the origin matches a wildcard pattern.""" 140 for allowed_origin in self.allowed_origins: 141 if allowed_origin.startswith("*."): 142 pattern = allowed_origin.replace("*.", ".*\\.") 143 if re.match(pattern, origin): 144 return True 145 return False 146 147 def _is_api_endpoint(self, path: str) -> bool: 148 """Checks if the path is an API endpoint (to avoid CSP on JSON APIs).""" 149 return path.startswith("/api/") or path.startswith("/auth/")