main.py
1 from pathlib import Path 2 import warnings 3 warnings.filterwarnings("ignore", message=".*Accessing the 'model_fields' attribute on the instance is deprecated.*") 4 5 from fastapi.exceptions import RequestValidationError 6 from fastapi.responses import JSONResponse, FileResponse 7 from fastapi.staticfiles import StaticFiles 8 from fastapi import FastAPI, HTTPException, Request, Depends, status, Response 9 from fastapi import Path as PathParam 10 import logging 11 import sys 12 13 from fastmcp import FastMCP 14 from restai import config 15 import sentry_sdk 16 from contextlib import asynccontextmanager 17 from restai.database import get_db_wrapper 18 from restai.oauth import OAuthManager 19 from starlette.middleware.sessions import SessionMiddleware 20 from restai.config import ( 21 OAUTH_PROVIDERS, 22 SSO_SECRET_KEY, 23 SESSION_COOKIE_SAME_SITE, 24 SESSION_COOKIE_SECURE, 25 RESTAI_AUTH_SECRET, 26 RESTAI_URL 27 ) 28 from restai.utils.version import get_version_from_pyproject 29 30 PROJECT_ROOT = Path(__file__).parent.parent 31 FRONTEND_BUILD_DIR = PROJECT_ROOT / "frontend" / "build" 32 33 # When installed from PyPI, frontend may be bundled inside the package 34 if not FRONTEND_BUILD_DIR.exists(): 35 _pkg_frontend = Path(__file__).parent.parent / "frontend" / "build" 36 if _pkg_frontend.exists(): 37 FRONTEND_BUILD_DIR = _pkg_frontend 38 39 @asynccontextmanager 40 async def lifespan(fs_app: FastAPI): 41 print( 42 r""" 43 ___ ___ ___ _____ _ ___ _.--'"'. 44 | _ \ __/ __|_ _/_\ |_ _| ( ( ( ) 45 | / _|\__ \ | |/ _ \ | | (o)_ ) ) 46 |_|_\___|___/ |_/_/ \_\___| (o)_.' 47 48 """ 49 ) 50 from restai.brain import Brain 51 from restai.database import get_db_wrapper, DBWrapper 52 from restai.auth import get_current_username 53 from restai.routers import ( 54 llms, 55 projects, 56 tools, 57 users, 58 image, 59 audio, 60 embeddings, 61 proxy, 62 statistics, 63 auth, 64 teams, 65 settings, 66 direct, 67 widgets, 68 search, 69 ) 70 from restai.models.models import User 71 from restai.models.databasemodels import ProjectDatabase 72 from restai.multiprocessing import get_manager 73 from modules.loaders import LOADERS 74 75 try: 76 fs_app.state.manager = get_manager() 77 except Exception: 78 fs_app.state.manager = None 79 80 from restai.settings import ensure_settings_table, seed_defaults 81 from restai.database import engine as db_engine 82 ensure_settings_table(db_engine) 83 84 # Auto-create new association tables for generators, eval tables, and migrate output table 85 from restai.models.databasemodels import TeamImageGeneratorDatabase, TeamAudioGeneratorDatabase, EvalDatasetDatabase, EvalTestCaseDatabase, EvalRunDatabase, EvalResultDatabase, PromptVersionDatabase, GuardEventDatabase, RetrievalEventDatabase, AuditLogDatabase, TeamInvitationDatabase, ImageGeneratorDatabase, SpeechToTextDatabase, ProjectSecretDatabase, ProjectTemplateDatabase, BulkIngestJobDatabase, RoutineExecutionLogDatabase 86 TeamImageGeneratorDatabase.__table__.create(db_engine, checkfirst=True) 87 TeamAudioGeneratorDatabase.__table__.create(db_engine, checkfirst=True) 88 ImageGeneratorDatabase.__table__.create(db_engine, checkfirst=True) 89 SpeechToTextDatabase.__table__.create(db_engine, checkfirst=True) 90 ProjectSecretDatabase.__table__.create(db_engine, checkfirst=True) 91 ProjectTemplateDatabase.__table__.create(db_engine, checkfirst=True) 92 BulkIngestJobDatabase.__table__.create(db_engine, checkfirst=True) 93 RoutineExecutionLogDatabase.__table__.create(db_engine, checkfirst=True) 94 EvalDatasetDatabase.__table__.create(db_engine, checkfirst=True) 95 EvalTestCaseDatabase.__table__.create(db_engine, checkfirst=True) 96 EvalRunDatabase.__table__.create(db_engine, checkfirst=True) 97 EvalResultDatabase.__table__.create(db_engine, checkfirst=True) 98 PromptVersionDatabase.__table__.create(db_engine, checkfirst=True) 99 GuardEventDatabase.__table__.create(db_engine, checkfirst=True) 100 RetrievalEventDatabase.__table__.create(db_engine, checkfirst=True) 101 AuditLogDatabase.__table__.create(db_engine, checkfirst=True) 102 TeamInvitationDatabase.__table__.create(db_engine, checkfirst=True) 103 settings_db_wrapper = get_db_wrapper() 104 seed_defaults(settings_db_wrapper) 105 106 fs_app.state.brain = Brain() 107 108 # Auto-seed image-generator registry rows for every worker module under 109 # restai/image/workers/*.py. Idempotent — existing rows keep their 110 # admin-applied state (enabled flag, description, team grants). 111 try: 112 from restai.image.registry import seed_local_generators 113 seeded = seed_local_generators(settings_db_wrapper) 114 if seeded: 115 logging.info("Seeded %d local image generator(s)", seeded) 116 except Exception as e: 117 logging.warning("Failed to seed local image generators: %s", e) 118 119 # Same pattern for speech-to-text models — auto-seed from 120 # restai/audio/workers/*.py so admins can manage them via the new page. 121 try: 122 from restai.speech_to_text.registry import seed_local_stt_models 123 seeded = seed_local_stt_models(settings_db_wrapper) 124 if seeded: 125 logging.info("Seeded %d local speech-to-text model(s)", seeded) 126 except Exception as e: 127 logging.warning("Failed to seed local speech-to-text models: %s", e) 128 129 from restai.oauth import OAuthManager 130 config.load_oauth_providers() 131 fs_app.state.oauth_manager = OAuthManager(fs_app, db_wrapper=get_db_wrapper()) 132 133 # Run data retention cleanup on startup 134 from restai.retention import run_retention_cleanup 135 run_retention_cleanup(settings_db_wrapper) 136 137 # Anonymized telemetry 138 import os as _os 139 if _os.environ.get("ANONYMIZED_TELEMETRY", "True").lower() == "true": 140 print("Anonymized telemetry is enabled. To opt out, set ANONYMIZED_TELEMETRY=false.") 141 import asyncio 142 from restai.telemetry import telemetry_loop 143 asyncio.create_task(telemetry_loop()) 144 145 if not RESTAI_URL: 146 logging.warning("RESTAI_URL env var missing. OAUTH auth schemes may not work properly.") 147 148 # JWT signing secret strength check. Loud warning so a legacy install 149 # or a copy-pasted dev .env doesn't silently sign tokens with a 150 # guessable secret. Defaults written by `_ensure_env_secret` are 151 # 64 url-safe base64 chars, so a healthy install will pass. 152 _weak_secrets = {"secret", "changeme", "change-me", "default", "password", "restai", "dev"} 153 secret_val = (RESTAI_AUTH_SECRET or "").strip() 154 fs_app.state.auth_secret_weak = False 155 if not secret_val: 156 logging.error( 157 "SECURITY: RESTAI_AUTH_SECRET is empty. JWTs will fail to sign. " 158 "Set a long random value in .env (at least 32 bytes)." 159 ) 160 fs_app.state.auth_secret_weak = True 161 elif len(secret_val) < 32: 162 logging.warning( 163 "SECURITY: RESTAI_AUTH_SECRET is %d chars (recommended ≥32). " 164 "Generate a stronger value with `python -c \"import secrets; print(secrets.token_urlsafe(64))\"` " 165 "and rotate it.", len(secret_val), 166 ) 167 fs_app.state.auth_secret_weak = True 168 elif secret_val.lower() in _weak_secrets: 169 logging.warning( 170 "SECURITY: RESTAI_AUTH_SECRET matches a known-weak default (%r). " 171 "Rotate to a long random value before going to production.", secret_val, 172 ) 173 fs_app.state.auth_secret_weak = True 174 175 @fs_app.get("/", tags=["Health"]) 176 async def get(): 177 """Root endpoint — redirect to admin UI.""" 178 from starlette.responses import RedirectResponse 179 return RedirectResponse(url="/admin") 180 181 @fs_app.get("/version", tags=["Health"]) 182 async def get_version(_: User = Depends(get_current_username)): 183 """Get the current RESTai version.""" 184 return { 185 "version": fs_app.version, 186 "telemetry": _os.environ.get("ANONYMIZED_TELEMETRY", "True").lower() == "true", 187 } 188 189 _update_cache = {"data": None, "ts": 0} 190 191 @fs_app.get("/version/check", tags=["Health"]) 192 async def check_for_update(_: User = Depends(get_current_username)): 193 """Check GitHub for a newer release. Cached for 1 hour.""" 194 import time as _time 195 import httpx 196 from packaging.version import parse as parse_version 197 198 current = fs_app.version 199 now = _time.time() 200 201 # Return cached result if fresh (1 hour) 202 if _update_cache["data"] and (now - _update_cache["ts"]) < 3600: 203 return _update_cache["data"] 204 205 result = { 206 "current": current, 207 "latest": current, 208 "update_available": False, 209 "latest_url": "https://github.com/apocas/restai/releases", 210 } 211 212 try: 213 async with httpx.AsyncClient(timeout=10) as client: 214 resp = await client.get( 215 "https://api.github.com/repos/apocas/restai/releases/latest", 216 headers={"Accept": "application/vnd.github+json"}, 217 ) 218 if resp.status_code == 200: 219 data = resp.json() 220 tag = data.get("tag_name", "").lstrip("v") 221 if tag: 222 result["latest"] = tag 223 result["latest_url"] = data.get("html_url", result["latest_url"]) 224 result["update_available"] = parse_version(tag) > parse_version(current) 225 except Exception: 226 pass 227 228 _update_cache["data"] = result 229 _update_cache["ts"] = now 230 return result 231 232 @fs_app.get("/health/live", tags=["Health"]) 233 async def health_live(): 234 """Liveness probe. Returns 200 if the service is running.""" 235 return {"status": "ok"} 236 237 @fs_app.get("/health/ready", tags=["Health"]) 238 async def health_ready(): 239 """Readiness probe. Checks database and Redis connectivity.""" 240 health = {"status": "ok"} 241 try: 242 from sqlalchemy import text 243 db_check = get_db_wrapper() 244 db_check.db.execute(text("SELECT 1")) 245 db_check.db.close() 246 health["database"] = "ok" 247 except Exception: 248 health["database"] = "error" 249 health["status"] = "degraded" 250 251 if config.REDIS_HOST: 252 try: 253 import redis 254 r = redis.Redis( 255 host=config.REDIS_HOST, 256 port=int(config.REDIS_PORT or 6379), 257 socket_connect_timeout=2, 258 ) 259 r.ping() 260 r.close() 261 health["redis"] = "ok" 262 except Exception: 263 health["redis"] = "error" 264 health["status"] = "degraded" 265 266 if health["status"] != "ok": 267 return JSONResponse(content=health, status_code=503) 268 return health 269 270 @fs_app.get("/setup", tags=["Health"]) 271 async def get_setup( 272 db_wrapper: DBWrapper = Depends(get_db_wrapper), 273 ): 274 """Get platform setup information including SSO providers and feature flags.""" 275 sso_list = [] 276 if isinstance(config.OAUTH_PROVIDERS, dict): 277 sso_list = list(config.OAUTH_PROVIDERS.keys()) 278 elif isinstance(config.OAUTH_PROVIDERS, (list, tuple)): 279 sso_list = list(config.OAUTH_PROVIDERS) 280 else: 281 sso_list = [] 282 sso_provider_names = {} 283 for provider in sso_list: 284 if provider == "oidc": 285 sso_provider_names[provider] = config.OAUTH_PROVIDER_NAME or "SSO" 286 else: 287 sso_provider_names[provider] = provider.capitalize() 288 289 _sv = db_wrapper.get_setting_value 290 return { 291 "sso": sso_list, 292 "sso_provider_names": sso_provider_names, 293 "proxy": bool(_sv("proxy_url")), 294 "gpu": config.RESTAI_GPU, 295 "app_name": _sv("app_name", "RESTai"), 296 "hide_branding": _sv("hide_branding", "false").lower() in ("true", "1"), 297 "proxy_url": _sv("proxy_url", ""), 298 "currency": _sv("currency", "EUR"), 299 "auth_disable_local": _sv("auth_disable_local", "false").lower() in ("true", "1"), 300 "mcp": config.RESTAI_MCP, 301 "enforce_2fa": _sv("enforce_2fa", "false").lower() in ("true", "1"), 302 # Intentionally NOT exposing `auth_secret_weak` here — this 303 # endpoint is unauthenticated (used by the pre-login UI to 304 # show SSO providers) and the weak-secret signal is a 305 # reconnaissance aid for an attacker. Admins read it from 306 # the authenticated /info endpoint instead. 307 } 308 309 @fs_app.get("/info", tags=["Health"]) 310 async def get_info( 311 user: User = Depends(get_current_username), 312 db_wrapper: DBWrapper = Depends(get_db_wrapper), 313 ): 314 """Get platform information including available LLMs, embeddings, and loaders.""" 315 from restai.vectordb.tools import get_available_vectorstores 316 317 output = { 318 "version": fs_app.version, 319 "loaders": list(LOADERS.keys()), 320 "embeddings": [], 321 "llms": [], 322 "vectorstores": get_available_vectorstores(), 323 "system_llm_configured": bool(getattr(db_wrapper.get_setting("system_llm"), "value", None)), 324 # Admin-only security signal. Non-admins always see False here 325 # — they shouldn't know whether the instance is misconfigured. 326 "auth_secret_weak": bool(getattr(fs_app.state, "auth_secret_weak", False)) if user.is_admin else False, 327 } 328 329 # Filter LLMs and embeddings by team access for non-admin users 330 allowed_llm_names = None 331 allowed_emb_names = None 332 if not user.is_admin: 333 allowed_llm_names = set() 334 allowed_emb_names = set() 335 for team in user.teams: 336 for llm in (team.llms if hasattr(team, 'llms') and team.llms else []): 337 allowed_llm_names.add(llm.name if hasattr(llm, 'name') else llm) 338 for emb in (team.embeddings if hasattr(team, 'embeddings') and team.embeddings else []): 339 allowed_emb_names.add(emb.name if hasattr(emb, 'name') else emb) 340 341 db_llms = db_wrapper.get_llms() 342 for llm in db_llms: 343 if allowed_llm_names is not None and llm.name not in allowed_llm_names: 344 continue 345 output["llms"].append( 346 { 347 "id": llm.id, 348 "name": llm.name, 349 "privacy": llm.privacy, 350 "description": llm.description, 351 } 352 ) 353 354 db_embeddings = db_wrapper.get_embeddings() 355 for embedding in db_embeddings: 356 if allowed_emb_names is not None and embedding.name not in allowed_emb_names: 357 continue 358 output["embeddings"].append( 359 { 360 "id": embedding.id, 361 "name": embedding.name, 362 "privacy": embedding.privacy, 363 "description": embedding.description, 364 } 365 ) 366 return output 367 368 try: 369 fs_app.mount( 370 "/admin/static", 371 StaticFiles(directory=str(FRONTEND_BUILD_DIR / "static")), 372 name="static_assets", 373 ) 374 fs_app.mount( 375 "/admin/assets", 376 StaticFiles(directory=str(FRONTEND_BUILD_DIR / "assets")), 377 name="static_images", 378 ) 379 380 # SPA catch-all route for /admin/* - must be defined after static mounts 381 @fs_app.get("/admin/{full_path:path}") 382 async def serve_spa(full_path: str): 383 """Serve static files if they exist, otherwise index.html for SPA routing.""" 384 # Serve actual files (manifest.json, favicon.ico, etc.) 385 file_path = (FRONTEND_BUILD_DIR / full_path).resolve() 386 build_root = FRONTEND_BUILD_DIR.resolve() 387 # Prevent directory traversal — resolved path must stay inside build dir 388 if full_path and str(file_path).startswith(str(build_root) + "/") and file_path.is_file(): 389 return FileResponse(str(file_path)) 390 index_file = build_root / "index.html" 391 if index_file.exists(): 392 return FileResponse(str(index_file)) 393 return JSONResponse(status_code=404, content={"detail": "Frontend not found"}) 394 except Exception as e: 395 print(e) 396 print("Admin frontend not available.") 397 398 # Widget JS endpoint 399 WIDGET_DIR = Path(__file__).parent / "widget" 400 if WIDGET_DIR.exists(): 401 @fs_app.get("/widget/chat.js") 402 async def serve_widget_js(): 403 widget_file = WIDGET_DIR / "chat.js" 404 if widget_file.exists(): 405 return FileResponse(str(widget_file), media_type="application/javascript") 406 return JSONResponse(status_code=404, content={"detail": "Widget not found"}) 407 408 fs_app.include_router(llms.router, tags=["LLMs"]) 409 fs_app.include_router(embeddings.router, tags=["Embeddings"]) 410 from restai.routers import image_generators, speech_to_text, secrets, whatsapp_webhook, webhooks as webhooks_router, templates as templates_router, bulk_ingest as bulk_ingest_router 411 fs_app.include_router(image_generators.router, tags=["Image Generators"]) 412 fs_app.include_router(speech_to_text.router, tags=["Speech-to-Text"]) 413 fs_app.include_router(secrets.router, tags=["Project Secrets"]) 414 fs_app.include_router(whatsapp_webhook.router, tags=["WhatsApp"]) 415 fs_app.include_router(webhooks_router.router, tags=["Webhooks"]) 416 fs_app.include_router(templates_router.router) 417 fs_app.include_router(bulk_ingest_router.router) 418 fs_app.include_router(projects.router) 419 fs_app.include_router(tools.router, tags=["Tools"]) 420 fs_app.include_router(users.router, tags=["Users"]) 421 fs_app.include_router(proxy.router, tags=["Proxy"]) 422 fs_app.include_router(statistics.router, tags=["Statistics"]) 423 fs_app.include_router(auth.router, tags=["Auth"]) 424 fs_app.include_router(teams.router, tags=["Teams"]) 425 fs_app.include_router(settings.router, tags=["Settings"]) 426 fs_app.include_router(direct.router, tags=["Direct Access"]) 427 fs_app.include_router(widgets.router, tags=["Widget"]) 428 fs_app.include_router(search.router, tags=["Search"]) 429 430 from restai.routers import evals 431 fs_app.include_router(evals.router) 432 433 # Image cache endpoint is always mounted (used by the `draw_image` builtin 434 # tool, which works on non-GPU deployments via OpenAI/Google generators). 435 from restai.routers import image_cache 436 fs_app.include_router(image_cache.router, tags=["Image"]) 437 438 # Image + audio routers are always mounted now that external providers 439 # (OpenAI, Google, Deepgram, AssemblyAI) live in the registry alongside 440 # local workers — non-GPU instances can still dispatch to remote 441 # generators / STT models. Local-worker calls fail cleanly with a 442 # "GPU required" error from the dispatch helper. 443 fs_app.include_router(image.router, tags=["Image"]) 444 fs_app.include_router(audio.router, tags=["Audio"]) 445 446 if config.RESTAI_MCP: 447 from restai.mcp import create_mcp_server 448 mcp_server = create_mcp_server(fs_app) 449 fs_app.mount("/mcp", mcp_server.http_app(transport="sse")) 450 logging.info("MCP server enabled at /mcp/sse") 451 452 yield 453 454 # Shutdown: clean up Docker containers 455 fs_app.state.brain.shutdown_docker_manager() 456 fs_app.state.brain.shutdown_browser_manager() 457 458 459 logging.basicConfig(level=config.LOG_LEVEL) 460 461 if config.SENTRY_DSN: 462 sentry_sdk.init( 463 dsn=config.SENTRY_DSN, 464 traces_sample_rate=1.0, 465 enable_tracing=True, 466 profiles_sample_rate=1.0, 467 ) 468 469 OPENAPI_TAGS = [ 470 {"name": "Projects", "description": "Create and manage AI projects (RAG, agent, block)"}, 471 {"name": "Knowledge", "description": "Manage embeddings and knowledge base for RAG projects"}, 472 {"name": "Chat", "description": "Chat and question endpoints for interacting with projects"}, 473 {"name": "Teams", "description": "Manage teams, members, admins, and resource access"}, 474 {"name": "Users", "description": "User management and API key management"}, 475 {"name": "LLMs", "description": "Register and configure Large Language Model providers"}, 476 {"name": "Embeddings", "description": "Register and configure embedding model providers"}, 477 {"name": "Tools", "description": "Text classification, MCP server probing, Ollama model management"}, 478 {"name": "Proxy", "description": "LiteLLM proxy key management"}, 479 {"name": "Statistics", "description": "Platform usage statistics and analytics"}, 480 {"name": "Auth", "description": "Authentication, login, logout, and session management"}, 481 {"name": "Settings", "description": "Platform settings (admin only)"}, 482 {"name": "Image", "description": "GPU-accelerated image generation"}, 483 {"name": "Audio", "description": "GPU-accelerated audio transcription"}, 484 {"name": "Direct Access", "description": "OpenAI-compatible direct access to LLMs, image and audio generators"}, 485 {"name": "Health", "description": "Health checks and system information"}, 486 ] 487 488 app = FastAPI( 489 title=config.RESTAI_NAME, 490 description="""RESTai is an AIaaS (AI as a Service) platform. Create AI projects and consume them via REST API. 491 492 Supports multiple project types: **RAG**, **Agent**, and **Block**. 493 494 ## Authentication 495 496 All endpoints require authentication via one of: 497 - **JWT Cookie** (`restai_token`) 498 - **Bearer API Key** (`Authorization: Bearer <key>`) 499 - **Basic Auth** (`Authorization: Basic <credentials>`) 500 """, 501 version=get_version_from_pyproject(), 502 lifespan=lifespan, 503 openapi_tags=OPENAPI_TAGS, 504 contact={"name": "RESTai", "url": "https://github.com/apocas/restai"}, 505 license_info={"name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0"}, 506 ) 507 508 # Always add SessionMiddleware so SSO can be enabled at runtime via settings 509 app.add_middleware( 510 SessionMiddleware, 511 secret_key=SSO_SECRET_KEY, 512 session_cookie="oui-session", 513 same_site=SESSION_COOKIE_SAME_SITE, 514 https_only=SESSION_COOKIE_SECURE, 515 ) 516 517 # Audit log middleware — records all mutation requests 518 from restai.audit import AuditMiddleware 519 app.add_middleware(AuditMiddleware) 520 521 522 @app.get("/oauth/{provider}/login", tags=["Auth"]) 523 async def oauth_login(provider: str = PathParam(description="OAuth provider name"), request: Request = ...): 524 """Initiate OAuth login flow for the specified provider.""" 525 return await request.app.state.oauth_manager.handle_login(request, provider) 526 527 528 @app.get("/oauth/{provider}/callback", tags=["Auth"]) 529 async def oauth_callback(provider: str = PathParam(description="OAuth provider name"), request: Request = ..., response: Response = ...): 530 """Handle OAuth callback from the specified provider.""" 531 return await request.app.state.oauth_manager.handle_callback(request, provider, response) 532 533 534 535 @app.exception_handler(HTTPException) 536 async def http_exception_handler(request: Request, exc: HTTPException): 537 response = JSONResponse(content={"detail": exc.detail}, status_code=exc.status_code) 538 if exc.status_code == 401: 539 response.delete_cookie(key="restai_token") 540 return response 541 542 543 @app.exception_handler(RequestValidationError) 544 async def validation_exception_handler(request: Request, exc: RequestValidationError): 545 exc_str = f"{exc}".replace("\n", " ").replace(" ", " ") 546 logging.error(f"{request}: {exc_str}") 547 # Extract clean user-facing messages from validation errors 548 messages = [] 549 for err in exc.errors(): 550 msg = err.get("msg", "") 551 if msg.startswith("Value error, "): 552 msg = msg[len("Value error, "):] 553 messages.append(msg) 554 detail = "; ".join(messages) if messages else exc_str 555 return JSONResponse( 556 content={"detail": detail}, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY 557 ) 558 559 560 @app.exception_handler(Exception) 561 async def unhandled_exception_handler(request: Request, exc: Exception): 562 logging.exception(f"Unhandled exception on {request.method} {request.url}: {exc}") 563 return JSONResponse( 564 content={"detail": "Internal server error"}, 565 status_code=500 566 ) 567 568 if config.RESTAI_DEV == True: 569 print("Running in development mode!") 570 571 # CORS — only allow wildcard origins for widget endpoints (cross-origin 572 # chat API calls from embedded widgets on third-party sites). All other 573 # endpoints use same-origin only. 574 _WIDGET_CORS_PATHS = ("/widget/",) 575 _CORS_HEADERS = { 576 "Access-Control-Allow-Methods": "GET, POST, OPTIONS", 577 "Access-Control-Allow-Headers": "Content-Type, Accept, X-Widget-Key, X-Widget-Context", 578 "Access-Control-Max-Age": "86400", 579 } 580 581 582 _SECURITY_HEADERS = { 583 "X-Content-Type-Options": "nosniff", 584 "Referrer-Policy": "strict-origin-when-cross-origin", 585 } 586 _ADMIN_SECURITY_HEADERS = { 587 "X-Frame-Options": "DENY", 588 "Content-Security-Policy": ( 589 "default-src 'self'; " 590 "img-src 'self' data: blob: https:; " 591 "script-src 'self' 'unsafe-inline' 'unsafe-eval'; " 592 "style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; " 593 "font-src 'self' data: https://fonts.gstatic.com; " 594 "connect-src 'self'; " 595 "frame-ancestors 'none'" 596 ), 597 } 598 599 600 @app.middleware("http") 601 async def cors_middleware(request: Request, call_next): 602 path = request.url.path 603 is_widget = any(path.startswith(p) for p in _WIDGET_CORS_PATHS) 604 origin = request.headers.get("origin") 605 606 # Handle preflight OPTIONS 607 if request.method == "OPTIONS" and is_widget and origin: 608 return Response( 609 status_code=204, 610 headers={ 611 "Access-Control-Allow-Origin": origin, 612 **_CORS_HEADERS, 613 }, 614 ) 615 616 response = await call_next(request) 617 618 # Add CORS headers only for widget endpoints 619 if is_widget and origin: 620 response.headers["Access-Control-Allow-Origin"] = origin 621 for k, v in _CORS_HEADERS.items(): 622 response.headers[k] = v 623 624 # Security headers — applied to all responses, with stricter rules for non-widget paths 625 for k, v in _SECURITY_HEADERS.items(): 626 response.headers.setdefault(k, v) 627 if not is_widget: 628 for k, v in _ADMIN_SECURITY_HEADERS.items(): 629 response.headers.setdefault(k, v) 630 631 return response