/ restai / main.py
main.py
  1  from pathlib import Path
  2  import warnings
  3  warnings.filterwarnings("ignore", message=".*Accessing the 'model_fields' attribute on the instance is deprecated.*")
  4  
  5  from fastapi.exceptions import RequestValidationError
  6  from fastapi.responses import JSONResponse, FileResponse
  7  from fastapi.staticfiles import StaticFiles
  8  from fastapi import FastAPI, HTTPException, Request, Depends, status, Response
  9  from fastapi import Path as PathParam
 10  import logging
 11  import sys
 12  
 13  from fastmcp import FastMCP
 14  from restai import config
 15  import sentry_sdk
 16  from contextlib import asynccontextmanager
 17  from restai.database import get_db_wrapper
 18  from restai.oauth import OAuthManager
 19  from starlette.middleware.sessions import SessionMiddleware
 20  from restai.config import (
 21      OAUTH_PROVIDERS,
 22      SSO_SECRET_KEY,
 23      SESSION_COOKIE_SAME_SITE,
 24      SESSION_COOKIE_SECURE,
 25      RESTAI_AUTH_SECRET,
 26      RESTAI_URL
 27  )
 28  from restai.utils.version import get_version_from_pyproject
 29  
 30  PROJECT_ROOT = Path(__file__).parent.parent
 31  FRONTEND_BUILD_DIR = PROJECT_ROOT / "frontend" / "build"
 32  
 33  # When installed from PyPI, frontend may be bundled inside the package
 34  if not FRONTEND_BUILD_DIR.exists():
 35      _pkg_frontend = Path(__file__).parent.parent / "frontend" / "build"
 36      if _pkg_frontend.exists():
 37          FRONTEND_BUILD_DIR = _pkg_frontend
 38  
 39  @asynccontextmanager
 40  async def lifespan(fs_app: FastAPI):
 41      print(
 42          r"""
 43         ___ ___ ___ _____ _   ___      _.--'"'.
 44        | _ \ __/ __|_   _/_\ |_ _|    (  ( (   )
 45        |   / _|\__ \ | |/ _ \ | |     (o)_    ) )
 46        |_|_\___|___/ |_/_/ \_\___|        (o)_.'
 47                                    
 48          """
 49      )
 50      from restai.brain import Brain
 51      from restai.database import get_db_wrapper, DBWrapper
 52      from restai.auth import get_current_username
 53      from restai.routers import (
 54          llms,
 55          projects,
 56          tools,
 57          users,
 58          image,
 59          audio,
 60          embeddings,
 61          proxy,
 62          statistics,
 63          auth,
 64          teams,
 65          settings,
 66          direct,
 67          widgets,
 68          search,
 69      )
 70      from restai.models.models import User
 71      from restai.models.databasemodels import ProjectDatabase
 72      from restai.multiprocessing import get_manager
 73      from modules.loaders import LOADERS
 74  
 75      try:
 76          fs_app.state.manager = get_manager()
 77      except Exception:
 78          fs_app.state.manager = None
 79  
 80      from restai.settings import ensure_settings_table, seed_defaults
 81      from restai.database import engine as db_engine
 82      ensure_settings_table(db_engine)
 83  
 84      # Auto-create new association tables for generators, eval tables, and migrate output table
 85      from restai.models.databasemodels import TeamImageGeneratorDatabase, TeamAudioGeneratorDatabase, EvalDatasetDatabase, EvalTestCaseDatabase, EvalRunDatabase, EvalResultDatabase, PromptVersionDatabase, GuardEventDatabase, RetrievalEventDatabase, AuditLogDatabase, TeamInvitationDatabase, ImageGeneratorDatabase, SpeechToTextDatabase, ProjectSecretDatabase, ProjectTemplateDatabase, BulkIngestJobDatabase, RoutineExecutionLogDatabase
 86      TeamImageGeneratorDatabase.__table__.create(db_engine, checkfirst=True)
 87      TeamAudioGeneratorDatabase.__table__.create(db_engine, checkfirst=True)
 88      ImageGeneratorDatabase.__table__.create(db_engine, checkfirst=True)
 89      SpeechToTextDatabase.__table__.create(db_engine, checkfirst=True)
 90      ProjectSecretDatabase.__table__.create(db_engine, checkfirst=True)
 91      ProjectTemplateDatabase.__table__.create(db_engine, checkfirst=True)
 92      BulkIngestJobDatabase.__table__.create(db_engine, checkfirst=True)
 93      RoutineExecutionLogDatabase.__table__.create(db_engine, checkfirst=True)
 94      EvalDatasetDatabase.__table__.create(db_engine, checkfirst=True)
 95      EvalTestCaseDatabase.__table__.create(db_engine, checkfirst=True)
 96      EvalRunDatabase.__table__.create(db_engine, checkfirst=True)
 97      EvalResultDatabase.__table__.create(db_engine, checkfirst=True)
 98      PromptVersionDatabase.__table__.create(db_engine, checkfirst=True)
 99      GuardEventDatabase.__table__.create(db_engine, checkfirst=True)
100      RetrievalEventDatabase.__table__.create(db_engine, checkfirst=True)
101      AuditLogDatabase.__table__.create(db_engine, checkfirst=True)
102      TeamInvitationDatabase.__table__.create(db_engine, checkfirst=True)
103      settings_db_wrapper = get_db_wrapper()
104      seed_defaults(settings_db_wrapper)
105  
106      fs_app.state.brain = Brain()
107  
108      # Auto-seed image-generator registry rows for every worker module under
109      # restai/image/workers/*.py. Idempotent — existing rows keep their
110      # admin-applied state (enabled flag, description, team grants).
111      try:
112          from restai.image.registry import seed_local_generators
113          seeded = seed_local_generators(settings_db_wrapper)
114          if seeded:
115              logging.info("Seeded %d local image generator(s)", seeded)
116      except Exception as e:
117          logging.warning("Failed to seed local image generators: %s", e)
118  
119      # Same pattern for speech-to-text models — auto-seed from
120      # restai/audio/workers/*.py so admins can manage them via the new page.
121      try:
122          from restai.speech_to_text.registry import seed_local_stt_models
123          seeded = seed_local_stt_models(settings_db_wrapper)
124          if seeded:
125              logging.info("Seeded %d local speech-to-text model(s)", seeded)
126      except Exception as e:
127          logging.warning("Failed to seed local speech-to-text models: %s", e)
128  
129      from restai.oauth import OAuthManager
130      config.load_oauth_providers()
131      fs_app.state.oauth_manager = OAuthManager(fs_app, db_wrapper=get_db_wrapper())
132  
133      # Run data retention cleanup on startup
134      from restai.retention import run_retention_cleanup
135      run_retention_cleanup(settings_db_wrapper)
136  
137      # Anonymized telemetry
138      import os as _os
139      if _os.environ.get("ANONYMIZED_TELEMETRY", "True").lower() == "true":
140          print("Anonymized telemetry is enabled. To opt out, set ANONYMIZED_TELEMETRY=false.")
141          import asyncio
142          from restai.telemetry import telemetry_loop
143          asyncio.create_task(telemetry_loop())
144  
145      if not RESTAI_URL:
146        logging.warning("RESTAI_URL env var missing. OAUTH auth schemes may not work properly.")
147  
148      # JWT signing secret strength check. Loud warning so a legacy install
149      # or a copy-pasted dev .env doesn't silently sign tokens with a
150      # guessable secret. Defaults written by `_ensure_env_secret` are
151      # 64 url-safe base64 chars, so a healthy install will pass.
152      _weak_secrets = {"secret", "changeme", "change-me", "default", "password", "restai", "dev"}
153      secret_val = (RESTAI_AUTH_SECRET or "").strip()
154      fs_app.state.auth_secret_weak = False
155      if not secret_val:
156          logging.error(
157              "SECURITY: RESTAI_AUTH_SECRET is empty. JWTs will fail to sign. "
158              "Set a long random value in .env (at least 32 bytes)."
159          )
160          fs_app.state.auth_secret_weak = True
161      elif len(secret_val) < 32:
162          logging.warning(
163              "SECURITY: RESTAI_AUTH_SECRET is %d chars (recommended ≥32). "
164              "Generate a stronger value with `python -c \"import secrets; print(secrets.token_urlsafe(64))\"` "
165              "and rotate it.", len(secret_val),
166          )
167          fs_app.state.auth_secret_weak = True
168      elif secret_val.lower() in _weak_secrets:
169          logging.warning(
170              "SECURITY: RESTAI_AUTH_SECRET matches a known-weak default (%r). "
171              "Rotate to a long random value before going to production.", secret_val,
172          )
173          fs_app.state.auth_secret_weak = True
174  
175      @fs_app.get("/", tags=["Health"])
176      async def get():
177          """Root endpoint — redirect to admin UI."""
178          from starlette.responses import RedirectResponse
179          return RedirectResponse(url="/admin")
180  
181      @fs_app.get("/version", tags=["Health"])
182      async def get_version(_: User = Depends(get_current_username)):
183          """Get the current RESTai version."""
184          return {
185              "version": fs_app.version,
186              "telemetry": _os.environ.get("ANONYMIZED_TELEMETRY", "True").lower() == "true",
187          }
188  
189      _update_cache = {"data": None, "ts": 0}
190  
191      @fs_app.get("/version/check", tags=["Health"])
192      async def check_for_update(_: User = Depends(get_current_username)):
193          """Check GitHub for a newer release. Cached for 1 hour."""
194          import time as _time
195          import httpx
196          from packaging.version import parse as parse_version
197  
198          current = fs_app.version
199          now = _time.time()
200  
201          # Return cached result if fresh (1 hour)
202          if _update_cache["data"] and (now - _update_cache["ts"]) < 3600:
203              return _update_cache["data"]
204  
205          result = {
206              "current": current,
207              "latest": current,
208              "update_available": False,
209              "latest_url": "https://github.com/apocas/restai/releases",
210          }
211  
212          try:
213              async with httpx.AsyncClient(timeout=10) as client:
214                  resp = await client.get(
215                      "https://api.github.com/repos/apocas/restai/releases/latest",
216                      headers={"Accept": "application/vnd.github+json"},
217                  )
218                  if resp.status_code == 200:
219                      data = resp.json()
220                      tag = data.get("tag_name", "").lstrip("v")
221                      if tag:
222                          result["latest"] = tag
223                          result["latest_url"] = data.get("html_url", result["latest_url"])
224                          result["update_available"] = parse_version(tag) > parse_version(current)
225          except Exception:
226              pass
227  
228          _update_cache["data"] = result
229          _update_cache["ts"] = now
230          return result
231  
232      @fs_app.get("/health/live", tags=["Health"])
233      async def health_live():
234          """Liveness probe. Returns 200 if the service is running."""
235          return {"status": "ok"}
236  
237      @fs_app.get("/health/ready", tags=["Health"])
238      async def health_ready():
239          """Readiness probe. Checks database and Redis connectivity."""
240          health = {"status": "ok"}
241          try:
242              from sqlalchemy import text
243              db_check = get_db_wrapper()
244              db_check.db.execute(text("SELECT 1"))
245              db_check.db.close()
246              health["database"] = "ok"
247          except Exception:
248              health["database"] = "error"
249              health["status"] = "degraded"
250  
251          if config.REDIS_HOST:
252              try:
253                  import redis
254                  r = redis.Redis(
255                      host=config.REDIS_HOST,
256                      port=int(config.REDIS_PORT or 6379),
257                      socket_connect_timeout=2,
258                  )
259                  r.ping()
260                  r.close()
261                  health["redis"] = "ok"
262              except Exception:
263                  health["redis"] = "error"
264                  health["status"] = "degraded"
265  
266          if health["status"] != "ok":
267              return JSONResponse(content=health, status_code=503)
268          return health
269  
270      @fs_app.get("/setup", tags=["Health"])
271      async def get_setup(
272          db_wrapper: DBWrapper = Depends(get_db_wrapper),
273      ):
274          """Get platform setup information including SSO providers and feature flags."""
275          sso_list = []
276          if isinstance(config.OAUTH_PROVIDERS, dict):
277              sso_list = list(config.OAUTH_PROVIDERS.keys())
278          elif isinstance(config.OAUTH_PROVIDERS, (list, tuple)):
279              sso_list = list(config.OAUTH_PROVIDERS)
280          else:
281              sso_list = []
282          sso_provider_names = {}
283          for provider in sso_list:
284              if provider == "oidc":
285                  sso_provider_names[provider] = config.OAUTH_PROVIDER_NAME or "SSO"
286              else:
287                  sso_provider_names[provider] = provider.capitalize()
288  
289          _sv = db_wrapper.get_setting_value
290          return {
291              "sso": sso_list,
292              "sso_provider_names": sso_provider_names,
293              "proxy": bool(_sv("proxy_url")),
294              "gpu": config.RESTAI_GPU,
295              "app_name": _sv("app_name", "RESTai"),
296              "hide_branding": _sv("hide_branding", "false").lower() in ("true", "1"),
297              "proxy_url": _sv("proxy_url", ""),
298              "currency": _sv("currency", "EUR"),
299              "auth_disable_local": _sv("auth_disable_local", "false").lower() in ("true", "1"),
300              "mcp": config.RESTAI_MCP,
301              "enforce_2fa": _sv("enforce_2fa", "false").lower() in ("true", "1"),
302              # Intentionally NOT exposing `auth_secret_weak` here — this
303              # endpoint is unauthenticated (used by the pre-login UI to
304              # show SSO providers) and the weak-secret signal is a
305              # reconnaissance aid for an attacker. Admins read it from
306              # the authenticated /info endpoint instead.
307          }
308  
309      @fs_app.get("/info", tags=["Health"])
310      async def get_info(
311          user: User = Depends(get_current_username),
312          db_wrapper: DBWrapper = Depends(get_db_wrapper),
313      ):
314          """Get platform information including available LLMs, embeddings, and loaders."""
315          from restai.vectordb.tools import get_available_vectorstores
316  
317          output = {
318              "version": fs_app.version,
319              "loaders": list(LOADERS.keys()),
320              "embeddings": [],
321              "llms": [],
322              "vectorstores": get_available_vectorstores(),
323              "system_llm_configured": bool(getattr(db_wrapper.get_setting("system_llm"), "value", None)),
324              # Admin-only security signal. Non-admins always see False here
325              # — they shouldn't know whether the instance is misconfigured.
326              "auth_secret_weak": bool(getattr(fs_app.state, "auth_secret_weak", False)) if user.is_admin else False,
327          }
328  
329          # Filter LLMs and embeddings by team access for non-admin users
330          allowed_llm_names = None
331          allowed_emb_names = None
332          if not user.is_admin:
333              allowed_llm_names = set()
334              allowed_emb_names = set()
335              for team in user.teams:
336                  for llm in (team.llms if hasattr(team, 'llms') and team.llms else []):
337                      allowed_llm_names.add(llm.name if hasattr(llm, 'name') else llm)
338                  for emb in (team.embeddings if hasattr(team, 'embeddings') and team.embeddings else []):
339                      allowed_emb_names.add(emb.name if hasattr(emb, 'name') else emb)
340  
341          db_llms = db_wrapper.get_llms()
342          for llm in db_llms:
343              if allowed_llm_names is not None and llm.name not in allowed_llm_names:
344                  continue
345              output["llms"].append(
346                  {
347                      "id": llm.id,
348                      "name": llm.name,
349                      "privacy": llm.privacy,
350                      "description": llm.description,
351                  }
352              )
353  
354          db_embeddings = db_wrapper.get_embeddings()
355          for embedding in db_embeddings:
356              if allowed_emb_names is not None and embedding.name not in allowed_emb_names:
357                  continue
358              output["embeddings"].append(
359                  {
360                      "id": embedding.id,
361                      "name": embedding.name,
362                      "privacy": embedding.privacy,
363                      "description": embedding.description,
364                  }
365              )
366          return output
367  
368      try:
369              fs_app.mount(
370                  "/admin/static",
371                  StaticFiles(directory=str(FRONTEND_BUILD_DIR / "static")),
372                  name="static_assets",
373              )
374              fs_app.mount(
375                  "/admin/assets",
376                  StaticFiles(directory=str(FRONTEND_BUILD_DIR / "assets")),
377                  name="static_images",
378              )
379  
380              # SPA catch-all route for /admin/* - must be defined after static mounts
381              @fs_app.get("/admin/{full_path:path}")
382              async def serve_spa(full_path: str):
383                  """Serve static files if they exist, otherwise index.html for SPA routing."""
384                  # Serve actual files (manifest.json, favicon.ico, etc.)
385                  file_path = (FRONTEND_BUILD_DIR / full_path).resolve()
386                  build_root = FRONTEND_BUILD_DIR.resolve()
387                  # Prevent directory traversal — resolved path must stay inside build dir
388                  if full_path and str(file_path).startswith(str(build_root) + "/") and file_path.is_file():
389                      return FileResponse(str(file_path))
390                  index_file = build_root / "index.html"
391                  if index_file.exists():
392                      return FileResponse(str(index_file))
393                  return JSONResponse(status_code=404, content={"detail": "Frontend not found"})
394      except Exception as e:
395          print(e)
396          print("Admin frontend not available.")
397  
398      # Widget JS endpoint
399      WIDGET_DIR = Path(__file__).parent / "widget"
400      if WIDGET_DIR.exists():
401          @fs_app.get("/widget/chat.js")
402          async def serve_widget_js():
403              widget_file = WIDGET_DIR / "chat.js"
404              if widget_file.exists():
405                  return FileResponse(str(widget_file), media_type="application/javascript")
406              return JSONResponse(status_code=404, content={"detail": "Widget not found"})
407  
408      fs_app.include_router(llms.router, tags=["LLMs"])
409      fs_app.include_router(embeddings.router, tags=["Embeddings"])
410      from restai.routers import image_generators, speech_to_text, secrets, whatsapp_webhook, webhooks as webhooks_router, templates as templates_router, bulk_ingest as bulk_ingest_router
411      fs_app.include_router(image_generators.router, tags=["Image Generators"])
412      fs_app.include_router(speech_to_text.router, tags=["Speech-to-Text"])
413      fs_app.include_router(secrets.router, tags=["Project Secrets"])
414      fs_app.include_router(whatsapp_webhook.router, tags=["WhatsApp"])
415      fs_app.include_router(webhooks_router.router, tags=["Webhooks"])
416      fs_app.include_router(templates_router.router)
417      fs_app.include_router(bulk_ingest_router.router)
418      fs_app.include_router(projects.router)
419      fs_app.include_router(tools.router, tags=["Tools"])
420      fs_app.include_router(users.router, tags=["Users"])
421      fs_app.include_router(proxy.router, tags=["Proxy"])
422      fs_app.include_router(statistics.router, tags=["Statistics"])
423      fs_app.include_router(auth.router, tags=["Auth"])
424      fs_app.include_router(teams.router, tags=["Teams"])
425      fs_app.include_router(settings.router, tags=["Settings"])
426      fs_app.include_router(direct.router, tags=["Direct Access"])
427      fs_app.include_router(widgets.router, tags=["Widget"])
428      fs_app.include_router(search.router, tags=["Search"])
429  
430      from restai.routers import evals
431      fs_app.include_router(evals.router)
432  
433      # Image cache endpoint is always mounted (used by the `draw_image` builtin
434      # tool, which works on non-GPU deployments via OpenAI/Google generators).
435      from restai.routers import image_cache
436      fs_app.include_router(image_cache.router, tags=["Image"])
437  
438      # Image + audio routers are always mounted now that external providers
439      # (OpenAI, Google, Deepgram, AssemblyAI) live in the registry alongside
440      # local workers — non-GPU instances can still dispatch to remote
441      # generators / STT models. Local-worker calls fail cleanly with a
442      # "GPU required" error from the dispatch helper.
443      fs_app.include_router(image.router, tags=["Image"])
444      fs_app.include_router(audio.router, tags=["Audio"])
445  
446      if config.RESTAI_MCP:
447          from restai.mcp import create_mcp_server
448          mcp_server = create_mcp_server(fs_app)
449          fs_app.mount("/mcp", mcp_server.http_app(transport="sse"))
450          logging.info("MCP server enabled at /mcp/sse")
451  
452      yield
453  
454      # Shutdown: clean up Docker containers
455      fs_app.state.brain.shutdown_docker_manager()
456      fs_app.state.brain.shutdown_browser_manager()
457  
458  
459  logging.basicConfig(level=config.LOG_LEVEL)
460  
461  if config.SENTRY_DSN:
462      sentry_sdk.init(
463          dsn=config.SENTRY_DSN,
464          traces_sample_rate=1.0,
465          enable_tracing=True,
466          profiles_sample_rate=1.0,
467      )
468  
469  OPENAPI_TAGS = [
470      {"name": "Projects", "description": "Create and manage AI projects (RAG, agent, block)"},
471      {"name": "Knowledge", "description": "Manage embeddings and knowledge base for RAG projects"},
472      {"name": "Chat", "description": "Chat and question endpoints for interacting with projects"},
473      {"name": "Teams", "description": "Manage teams, members, admins, and resource access"},
474      {"name": "Users", "description": "User management and API key management"},
475      {"name": "LLMs", "description": "Register and configure Large Language Model providers"},
476      {"name": "Embeddings", "description": "Register and configure embedding model providers"},
477      {"name": "Tools", "description": "Text classification, MCP server probing, Ollama model management"},
478      {"name": "Proxy", "description": "LiteLLM proxy key management"},
479      {"name": "Statistics", "description": "Platform usage statistics and analytics"},
480      {"name": "Auth", "description": "Authentication, login, logout, and session management"},
481      {"name": "Settings", "description": "Platform settings (admin only)"},
482      {"name": "Image", "description": "GPU-accelerated image generation"},
483      {"name": "Audio", "description": "GPU-accelerated audio transcription"},
484      {"name": "Direct Access", "description": "OpenAI-compatible direct access to LLMs, image and audio generators"},
485      {"name": "Health", "description": "Health checks and system information"},
486  ]
487  
488  app = FastAPI(
489      title=config.RESTAI_NAME,
490      description="""RESTai is an AIaaS (AI as a Service) platform. Create AI projects and consume them via REST API.
491  
492  Supports multiple project types: **RAG**, **Agent**, and **Block**.
493  
494  ## Authentication
495  
496  All endpoints require authentication via one of:
497  - **JWT Cookie** (`restai_token`)
498  - **Bearer API Key** (`Authorization: Bearer <key>`)
499  - **Basic Auth** (`Authorization: Basic <credentials>`)
500  """,
501      version=get_version_from_pyproject(),
502      lifespan=lifespan,
503      openapi_tags=OPENAPI_TAGS,
504      contact={"name": "RESTai", "url": "https://github.com/apocas/restai"},
505      license_info={"name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0"},
506  )
507  
508  # Always add SessionMiddleware so SSO can be enabled at runtime via settings
509  app.add_middleware(
510      SessionMiddleware,
511      secret_key=SSO_SECRET_KEY,
512      session_cookie="oui-session",
513      same_site=SESSION_COOKIE_SAME_SITE,
514      https_only=SESSION_COOKIE_SECURE,
515  )
516  
517  # Audit log middleware — records all mutation requests
518  from restai.audit import AuditMiddleware
519  app.add_middleware(AuditMiddleware)
520  
521  
522  @app.get("/oauth/{provider}/login", tags=["Auth"])
523  async def oauth_login(provider: str = PathParam(description="OAuth provider name"), request: Request = ...):
524      """Initiate OAuth login flow for the specified provider."""
525      return await request.app.state.oauth_manager.handle_login(request, provider)
526  
527  
528  @app.get("/oauth/{provider}/callback", tags=["Auth"])
529  async def oauth_callback(provider: str = PathParam(description="OAuth provider name"), request: Request = ..., response: Response = ...):
530      """Handle OAuth callback from the specified provider."""
531      return await request.app.state.oauth_manager.handle_callback(request, provider, response)
532  
533      
534  
535  @app.exception_handler(HTTPException)
536  async def http_exception_handler(request: Request, exc: HTTPException):
537      response = JSONResponse(content={"detail": exc.detail}, status_code=exc.status_code)
538      if exc.status_code == 401:
539          response.delete_cookie(key="restai_token")
540      return response
541  
542  
543  @app.exception_handler(RequestValidationError)
544  async def validation_exception_handler(request: Request, exc: RequestValidationError):
545      exc_str = f"{exc}".replace("\n", " ").replace("   ", " ")
546      logging.error(f"{request}: {exc_str}")
547      # Extract clean user-facing messages from validation errors
548      messages = []
549      for err in exc.errors():
550          msg = err.get("msg", "")
551          if msg.startswith("Value error, "):
552              msg = msg[len("Value error, "):]
553          messages.append(msg)
554      detail = "; ".join(messages) if messages else exc_str
555      return JSONResponse(
556          content={"detail": detail}, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
557      )
558  
559  
560  @app.exception_handler(Exception)
561  async def unhandled_exception_handler(request: Request, exc: Exception):
562      logging.exception(f"Unhandled exception on {request.method} {request.url}: {exc}")
563      return JSONResponse(
564          content={"detail": "Internal server error"},
565          status_code=500
566      )
567  
568  if config.RESTAI_DEV == True:
569      print("Running in development mode!")
570  
571  # CORS — only allow wildcard origins for widget endpoints (cross-origin
572  # chat API calls from embedded widgets on third-party sites). All other
573  # endpoints use same-origin only.
574  _WIDGET_CORS_PATHS = ("/widget/",)
575  _CORS_HEADERS = {
576      "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
577      "Access-Control-Allow-Headers": "Content-Type, Accept, X-Widget-Key, X-Widget-Context",
578      "Access-Control-Max-Age": "86400",
579  }
580  
581  
582  _SECURITY_HEADERS = {
583      "X-Content-Type-Options": "nosniff",
584      "Referrer-Policy": "strict-origin-when-cross-origin",
585  }
586  _ADMIN_SECURITY_HEADERS = {
587      "X-Frame-Options": "DENY",
588      "Content-Security-Policy": (
589          "default-src 'self'; "
590          "img-src 'self' data: blob: https:; "
591          "script-src 'self' 'unsafe-inline' 'unsafe-eval'; "
592          "style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; "
593          "font-src 'self' data: https://fonts.gstatic.com; "
594          "connect-src 'self'; "
595          "frame-ancestors 'none'"
596      ),
597  }
598  
599  
600  @app.middleware("http")
601  async def cors_middleware(request: Request, call_next):
602      path = request.url.path
603      is_widget = any(path.startswith(p) for p in _WIDGET_CORS_PATHS)
604      origin = request.headers.get("origin")
605  
606      # Handle preflight OPTIONS
607      if request.method == "OPTIONS" and is_widget and origin:
608          return Response(
609              status_code=204,
610              headers={
611                  "Access-Control-Allow-Origin": origin,
612                  **_CORS_HEADERS,
613              },
614          )
615  
616      response = await call_next(request)
617  
618      # Add CORS headers only for widget endpoints
619      if is_widget and origin:
620          response.headers["Access-Control-Allow-Origin"] = origin
621          for k, v in _CORS_HEADERS.items():
622              response.headers[k] = v
623  
624      # Security headers — applied to all responses, with stricter rules for non-widget paths
625      for k, v in _SECURITY_HEADERS.items():
626          response.headers.setdefault(k, v)
627      if not is_widget:
628          for k, v in _ADMIN_SECURITY_HEADERS.items():
629              response.headers.setdefault(k, v)
630  
631      return response