/ src / api / main.py
main.py
  1  """
  2  FastAPI application for Ag3ntum API.
  3  
  4  Main entry point that configures the FastAPI app with:
  5  - Security headers middleware
  6  - CORS middleware (origins derived from server.hostname)
  7  - Host header validation
  8  - Database initialization
  9  - Route registration
 10  - Lifespan management
 11  - Dual logging (console with colors + file)
 12  """
 13  import logging
 14  import os
 15  import re
 16  from contextlib import asynccontextmanager
 17  from pathlib import Path
 18  from typing import Any, AsyncGenerator
 19  
 20  from sqlalchemy import text
 21  
 22  import yaml
 23  from fastapi import FastAPI, Request
 24  from fastapi.middleware.cors import CORSMiddleware
 25  from fastapi.responses import JSONResponse
 26  
 27  from ..config import CONFIG_DIR, ConfigNotFoundError, ConfigValidationError
 28  from ..services.session_service import InvalidSessionIdError, SessionNotFoundError
 29  from ..core.logging_config import setup_backend_logging
 30  from ..core.subagent_manager import get_subagent_manager
 31  from ..db.database import engine, init_db, DATABASE_PATH
 32  from .routes import admin_router, auth_router, config_router, files_router, health_router, llm_proxy_router, llm_proxy_session_router, queue_router, reseller_router, sessions_router, skills_router, ssh_profiles_router
 33  from .metrics import setup_metrics
 34  from .waf_filter import WAFMiddleware
 35  from .security_middleware import (
 36      build_allowed_origins,
 37      build_allowed_hosts,
 38      SecurityHeadersMiddleware,
 39      HostValidationMiddleware,
 40      TrustedProxyMiddleware,
 41  )
 42  
 43  logger = logging.getLogger(__name__)
 44  
 45  # API configuration file
 46  API_CONFIG_FILE: Path = CONFIG_DIR / "api.yaml"
 47  
 48  # Required fields in api.yaml (cors_origins now derived from server.hostname)
 49  REQUIRED_API_FIELDS = ["host", "port"]
 50  
 51  # Patterns for sensitive field names (case-insensitive)
 52  SENSITIVE_PATTERNS = re.compile(
 53      r"(secret|key|password|token|credential|auth)", re.IGNORECASE
 54  )
 55  
 56  
 57  # =============================================================================
 58  # Configuration Utilities
 59  # =============================================================================
 60  
 61  def mask_sensitive_value(value: str, visible_chars: int = 4) -> str:
 62      """
 63      Mask a sensitive value, showing only first and last few characters.
 64  
 65      Args:
 66          value: The sensitive string to mask.
 67          visible_chars: Number of characters to show at start and end.
 68  
 69      Returns:
 70          Masked string like "sk-a...xyz" or "****" if too short.
 71      """
 72      if not isinstance(value, str):
 73          return "****"
 74      if len(value) <= visible_chars * 2:
 75          return "*" * len(value)
 76      return f"{value[:visible_chars]}...{value[-visible_chars:]}"
 77  
 78  
 79  def format_config_value(key: str, value: Any, indent: int = 0) -> list[str]:
 80      """
 81      Format a configuration value for logging, masking sensitive values.
 82  
 83      Args:
 84          key: The configuration key name.
 85          value: The configuration value.
 86          indent: Current indentation level.
 87  
 88      Returns:
 89          List of formatted log lines.
 90      """
 91      prefix = "  " * indent
 92      lines = []
 93  
 94      if isinstance(value, dict):
 95          lines.append(f"{prefix}{key}:")
 96          for k, v in value.items():
 97              lines.extend(format_config_value(k, v, indent + 1))
 98      elif isinstance(value, list):
 99          lines.append(f"{prefix}{key}:")
100          for item in value:
101              if isinstance(item, dict):
102                  lines.append(f"{prefix}  -")
103                  for k, v in item.items():
104                      lines.extend(format_config_value(k, v, indent + 2))
105              else:
106                  lines.append(f"{prefix}  - {item}")
107      else:
108          # Check if key matches sensitive patterns
109          if SENSITIVE_PATTERNS.search(key) and value:
110              display_value = mask_sensitive_value(str(value))
111          else:
112              display_value = value
113          lines.append(f"{prefix}{key}: {display_value}")
114  
115      return lines
116  
117  
118  def log_configuration(config: dict[str, Any]) -> None:
119      """
120      Log all loaded configuration with sensitive values masked.
121  
122      Args:
123          config: The full configuration dictionary.
124      """
125      logger.info("=" * 60)
126      logger.info("AG3NTUM API CONFIGURATION")
127      logger.info("=" * 60)
128  
129      # Log config file path
130      logger.info(f"Config file: {API_CONFIG_FILE}")
131      logger.info(f"Database: {DATABASE_PATH}")
132      logger.info("-" * 60)
133  
134      # Format and log all config values
135      for key, value in config.items():
136          for line in format_config_value(key, value):
137              logger.info(line)
138  
139      logger.info("=" * 60)
140  
141  
142  def load_api_config() -> dict[str, Any]:
143      """
144      Load API configuration from api.yaml.
145  
146      Raises:
147          ConfigNotFoundError: If api.yaml doesn't exist.
148          ConfigValidationError: If required fields are missing or invalid.
149      """
150      if not API_CONFIG_FILE.exists():
151          raise ConfigNotFoundError(
152              f"API configuration not found: {API_CONFIG_FILE}\n"
153              f"Create config/api.yaml with required fields: {', '.join(REQUIRED_API_FIELDS)}"
154          )
155  
156      try:
157          with API_CONFIG_FILE.open("r", encoding="utf-8") as f:
158              config = yaml.safe_load(f)
159      except yaml.YAMLError as e:
160          raise ConfigValidationError(
161              f"Failed to parse api.yaml: {e}"
162          )
163  
164      if config is None:
165          raise ConfigValidationError(
166              f"API configuration file is empty: {API_CONFIG_FILE}"
167          )
168  
169      api_config = config.get("api")
170      if not api_config:
171          raise ConfigValidationError(
172              f"No 'api' section found in {API_CONFIG_FILE}"
173          )
174  
175      missing = [field for field in REQUIRED_API_FIELDS if field not in api_config]
176      if missing:
177          raise ConfigValidationError(
178              f"Missing required fields in {API_CONFIG_FILE}:\n"
179              f"  {', '.join(missing)}\n"
180              f"All fields must be explicitly defined - no default values."
181          )
182  
183      return config
184  
185  
186  # =============================================================================
187  # Application Lifespan
188  # =============================================================================
189  
190  @asynccontextmanager
191  async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
192      """
193      FastAPI lifespan context manager.
194  
195      Handles startup and shutdown events:
196      - Startup: Initialize database, load subagent configurations, start queue processor
197      - Shutdown: Stop queue processor, cleanup resources
198      """
199      # Startup
200      logger.info("Starting Ag3ntum API...")
201      # In production, Alembic migrations run in entrypoint-api.sh before this
202      # process starts. init_db() is kept as an idempotent fallback for cases
203      # where the entrypoint doesn't run (E2E tests, direct uvicorn invocation).
204      await init_db()
205      logger.info("Database initialized")
206  
207      # Verify database is writable (fail fast instead of silent 500 on every write)
208      try:
209          async with engine.begin() as conn:
210              await conn.execute(text(
211                  "CREATE TABLE IF NOT EXISTS _db_write_check (id INTEGER PRIMARY KEY)"
212              ))
213              await conn.execute(text("DROP TABLE IF EXISTS _db_write_check"))
214          logger.info("Database write check passed")
215      except Exception as e:
216          logger.critical(
217              f"DATABASE IS NOT WRITABLE: {e}\n"
218              f"  Database path: {DATABASE_PATH}\n"
219              f"  Running as UID: {os.getuid()}\n"
220              f"  Fix: ensure ./data/ is owned by UID 45045 "
221              f"(run: sudo chown -R 45045:45045 ./data/)"
222          )
223          raise SystemExit(
224              f"Fatal: database at {DATABASE_PATH} is not writable by UID "
225              f"{os.getuid()}. Run: sudo chown -R 45045:45045 ./data/"
226          )
227  
228      # NOTE: Linux user sync happens in entrypoint-api.sh BEFORE this process starts.
229      # This ensures the process inherits correct supplementary groups (shared GID model).
230  
231      # Load platform defaults from DB into FeatureFlagService
232      try:
233          from ..services.feature_flag_service import feature_flag_service
234          from ..db.database import AsyncSessionLocal
235          async with AsyncSessionLocal() as db:
236              await feature_flag_service.load_platform_defaults(db)
237      except Exception as e:
238          logger.warning("Could not load platform defaults: %s", e)
239  
240      # Start WebhookProcessor for retry deliveries
241      webhook_processor = None
242      try:
243          from ..services.webhook_processor import WebhookProcessor
244          webhook_processor = WebhookProcessor(interval_seconds=30)
245          await webhook_processor.start()
246      except Exception as e:
247          logger.warning("Could not start WebhookProcessor: %s", e)
248  
249      # Start RetentionProcessor for daily data purge
250      retention_processor = None
251      try:
252          from ..services.retention_processor import RetentionProcessor
253          retention_processor = RetentionProcessor()
254          await retention_processor.start()
255      except Exception as e:
256          logger.warning("Could not start RetentionProcessor: %s", e)
257  
258      # Initialize SSH Service Manager (agent integration)
259      try:
260          from ..services.ssh_service_manager import ssh_service_manager
261          await ssh_service_manager.initialize()
262          app.state.ssh_manager = ssh_service_manager
263      except Exception as e:
264          logger.warning("Could not initialize SSH service manager: %s", e)
265  
266      # Initialize SubagentManager singleton
267      # This loads config/subagents.yaml and renders all prompt templates ONCE.
268      # The same subagent definitions are shared across ALL users and sessions.
269      # See src/core/subagent_manager.py for architecture details.
270      subagent_manager = get_subagent_manager()
271      logger.info(
272          f"SubagentManager initialized: {subagent_manager.agent_count} subagents "
273          f"({subagent_manager.enabled_count} enabled, "
274          f"{subagent_manager.disabled_count} disabled)"
275      )
276  
277      # Initialize task queue system
278      queue_processor = None
279      try:
280          config = load_api_config()
281          queue_config = config.get("task_queue", {})
282  
283          # Check if queue system is enabled
284          if queue_config.get("queue", {}).get("enabled", True):
285              from ..services.task_queue import TaskQueue
286              from ..services.quota_manager import QuotaManager
287              from ..services.queue_processor import QueueProcessor
288              from ..services.auto_resume import AutoResumeService
289              from ..services.queue_config import load_queue_config
290              from ..services.agent_runner import agent_runner
291              from ..db.database import AsyncSessionLocal
292  
293              redis_url = config.get("redis", {}).get("url", "redis://redis:6379/0")
294              qc = load_queue_config(queue_config)
295  
296              # Initialize queue components
297              task_queue = TaskQueue(redis_url, max_queue_size=qc.queue.max_queue_size)
298              quota_manager = QuotaManager(task_queue, qc.quotas)
299              queue_processor = QueueProcessor(
300                  task_queue,
301                  quota_manager,
302                  qc.queue.processing_interval_ms,
303                  redis_url,
304                  qc.queue.task_timeout_minutes,
305              )
306              auto_resume_service = AutoResumeService(task_queue, qc.auto_resume)
307  
308              # Register completion callback with AgentRunner
309              agent_runner.register_completion_callback(queue_processor.on_task_complete)
310  
311              # Recover interrupted sessions (auto-resume)
312              async with AsyncSessionLocal() as db:
313                  stats = await auto_resume_service.recover_on_startup(db)
314                  logger.info(f"Auto-resume recovery: {stats}")
315  
316              # Start queue processor background task
317              await queue_processor.start()
318  
319              # Store in app.state for route access
320              app.state.task_queue = task_queue
321              app.state.quota_manager = quota_manager
322              app.state.queue_processor = queue_processor
323  
324              logger.info("Task queue system initialized")
325          else:
326              logger.info("Task queue system disabled in configuration")
327  
328      except Exception as e:
329          logger.warning(f"Failed to initialize task queue system: {e}")
330          # Continue without queue system - tasks will start immediately
331  
332      yield
333  
334      # Shutdown
335      # Close SSH connection pool
336      ssh_mgr = getattr(app.state, "ssh_manager", None)
337      if ssh_mgr:
338          await ssh_mgr.shutdown()
339  
340      if retention_processor:
341          await retention_processor.stop()
342          logger.info("RetentionProcessor stopped")
343  
344      if webhook_processor:
345          await webhook_processor.stop()
346          logger.info("WebhookProcessor stopped")
347  
348      # Close the httpx connection pool used by webhook delivery
349      try:
350          from ..services.webhook_service import webhook_service
351          await webhook_service.close()
352      except Exception:
353          pass
354  
355      if queue_processor:
356          await queue_processor.stop()
357          logger.info("Queue processor stopped")
358  
359      logger.info("Shutting down Ag3ntum API...")
360  
361  
362  # =============================================================================
363  # Application Factory
364  # =============================================================================
365  
366  def create_app() -> FastAPI:
367      """
368      Create and configure the FastAPI application.
369  
370      Returns:
371          Configured FastAPI app instance.
372  
373      Raises:
374          ConfigNotFoundError: If api.yaml doesn't exist.
375          ConfigValidationError: If required fields are missing.
376      """
377      # Configure dual logging (console with colors + file)
378      setup_backend_logging()
379  
380      config = load_api_config()
381      api_config = config["api"]
382  
383      # Log all loaded configuration
384      log_configuration(config)
385  
386      app = FastAPI(
387          title="Ag3ntum API",
388          description="REST API for Ag3ntum - Self-Improving Agent",
389          version="1.0.0",
390          lifespan=lifespan,
391          docs_url="/api/docs",
392          redoc_url="/api/redoc",
393          openapi_url="/api/openapi.json",
394      )
395  
396      # Get security config
397      security_config = config.get("security", {})
398      server_config = config.get("server", {})
399  
400      # Build CORS origins from server.hostname configuration
401      cors_origins = build_allowed_origins(config)
402  
403      # CORS middleware - origins derived from server.hostname
404      app.add_middleware(
405          CORSMiddleware,
406          allow_origins=cors_origins,
407          allow_credentials=True,
408          allow_methods=["*"],
409          allow_headers=["*"],
410      )
411  
412      # Security headers middleware (X-Content-Type-Options, X-Frame-Options, etc.)
413      if security_config.get("enable_security_headers", True):
414          app.add_middleware(SecurityHeadersMiddleware, config=config)
415          logger.info("Security headers middleware enabled")
416  
417      # Host header validation middleware (prevents host header injection)
418      if security_config.get("validate_host_header", True):
419          allowed_hosts = build_allowed_hosts(config)
420          app.add_middleware(HostValidationMiddleware, allowed_hosts=allowed_hosts)
421          logger.info("Host header validation enabled")
422  
423      # Trusted proxy middleware (for X-Forwarded-* headers)
424      trusted_proxies = server_config.get("trusted_proxies", [])
425      if trusted_proxies:
426          app.add_middleware(TrustedProxyMiddleware, trusted_proxies=trusted_proxies)
427          logger.info(f"Trusted proxy middleware enabled: {trusted_proxies}")
428  
429      # WAF Filter Middleware - validates request sizes before processing.
430      # Implemented as a class middleware (not @app.middleware) to avoid the
431      # BaseHTTPMiddleware bug where exceptions re-raised after the exception
432      # handler responds, stripping CORS headers from error responses.
433      app.add_middleware(WAFMiddleware)
434  
435      # Register routes under /api/v1 prefix
436      app.include_router(admin_router, prefix="/api/v1")
437      app.include_router(health_router, prefix="/api/v1")
438      app.include_router(auth_router, prefix="/api/v1")
439      app.include_router(sessions_router, prefix="/api/v1")
440      app.include_router(files_router, prefix="/api/v1")
441      app.include_router(queue_router, prefix="/api/v1")
442      app.include_router(llm_proxy_router, prefix="/api")
443      app.include_router(llm_proxy_session_router, prefix="/api")
444      app.include_router(skills_router, prefix="/api/v1")
445      app.include_router(config_router, prefix="/api/v1")
446      app.include_router(reseller_router, prefix="/api/v1")
447      app.include_router(ssh_profiles_router, prefix="/api/v1")
448  
449      # Prometheus metrics (optional — requires prometheus-fastapi-instrumentator)
450      setup_metrics(app)
451  
452      # Exception handlers for session-related errors
453      @app.exception_handler(InvalidSessionIdError)
454      async def invalid_session_id_handler(
455          request: Request, exc: InvalidSessionIdError
456      ) -> JSONResponse:
457          """Convert InvalidSessionIdError to 404 response."""
458          # Extract session ID from the error message for the response
459          return JSONResponse(
460              status_code=404,
461              content={"detail": str(exc)},
462          )
463  
464      @app.exception_handler(SessionNotFoundError)
465      async def session_not_found_handler(
466          request: Request, exc: SessionNotFoundError
467      ) -> JSONResponse:
468          """Convert SessionNotFoundError to 404 response."""
469          return JSONResponse(
470              status_code=404,
471              content={"detail": str(exc)},
472          )
473  
474      @app.exception_handler(PermissionError)
475      async def permission_error_handler(
476          request: Request, exc: PermissionError
477      ) -> JSONResponse:
478          """
479          Handle PermissionError explicitly to prevent 500 errors without CORS headers.
480  
481          This typically happens when the API cannot access user directories due to
482          misconfigured permissions. The error is logged but the user gets a
483          generic message without internal path details.
484          """
485          logger.error(f"PermissionError during request: {exc}")
486          return JSONResponse(
487              status_code=500,
488              content={
489                  "detail": "Server configuration error: insufficient permissions. "
490                  "Please contact administrator."
491              },
492          )
493  
494      @app.exception_handler(Exception)
495      async def unhandled_exception_handler(
496          request: Request, exc: Exception
497      ) -> JSONResponse:
498          """
499          Catch-all handler for unhandled exceptions.
500  
501          Ensures all errors return proper JSON responses (which will have CORS headers
502          added by the middleware) instead of bare 500 errors.
503          """
504          logger.error(f"Unhandled exception during {request.method} {request.url}: {exc}")
505          return JSONResponse(
506              status_code=500,
507              content={"detail": "Internal server error"},
508          )
509  
510      return app
511  
512  
513  # Create the app instance
514  app = create_app()
515  
516  
517  if __name__ == "__main__":
518      import uvicorn
519  
520      config = load_api_config()
521      api_config = config["api"]
522  
523      uvicorn.run(
524          "src.api.main:app",
525          host=api_config["host"],
526          port=api_config["port"],
527          reload=api_config.get("reload", False),
528          reload_dirs=["src"],
529      )