main.py
1 from __future__ import annotations 2 3 import logging 4 import os 5 from pathlib import Path 6 7 import httpx 8 import sqlalchemy as sa 9 from alembic import command 10 from alembic.config import Config 11 from fastapi import FastAPI, HTTPException 12 from fastapi import Request as FastAPIRequest 13 from fastapi import status 14 from fastapi.exceptions import RequestValidationError 15 from fastapi.middleware.cors import CORSMiddleware 16 from fastapi.responses import JSONResponse 17 from starlette.middleware.sessions import SessionMiddleware 18 from starlette.staticfiles import StaticFiles 19 from typing import TYPE_CHECKING 20 21 from a2a.types import InternalError, InvalidRequestError, JSONRPCError 22 from a2a.types import JSONRPCResponse as A2AJSONRPCResponse 23 24 from ...common import a2a 25 from ...gateway.http_sse import dependencies 26 from ...shared.auth.middleware import create_oauth_middleware 27 from .routers import ( 28 agent_cards, 29 artifacts, 30 auth, 31 config, 32 document_conversion, 33 feature_flags, 34 feedback, 35 people, 36 sse, 37 share, 38 speech, 39 version, 40 visualization, 41 projects, 42 prompts, 43 ) 44 from .routers.sessions import router as session_router 45 from .routers.tasks import router as task_router 46 from .routers.users import router as user_router 47 48 49 if TYPE_CHECKING: 50 from .component import WebUIBackendComponent 51 52 log = logging.getLogger(__name__) 53 54 # Import scheduled_tasks separately with error handling 55 try: 56 from .routers import scheduled_tasks 57 _scheduled_tasks_available = True 58 except Exception as e: 59 log.warning("Scheduled tasks router not available: %s", e) 60 scheduled_tasks = None 61 _scheduled_tasks_available = False 62 63 64 # OAuth helper functions - delegate to enterprise package if available 65 async def _validate_token( 66 auth_service_url: str, 67 auth_provider: str, 68 access_token: str, 69 ) -> bool: 70 """ 71 Validate an access token against SAM's OAuth2 service. 72 73 This function delegates to the enterprise package's OAuth utilities. 74 75 Args: 76 auth_service_url: Base URL of the OAuth2 service 77 auth_provider: Provider name configured in OAuth2 service 78 access_token: The access token to validate 79 80 Returns: 81 True if token is valid, False otherwise 82 """ 83 try: 84 from solace_agent_mesh_enterprise.gateway.auth.internal.oauth_utils import ( 85 validate_token_with_oauth_service, 86 ) 87 return await validate_token_with_oauth_service( 88 auth_service_url, auth_provider, access_token 89 ) 90 except ImportError: 91 log.error("Enterprise package not available for OAuth token validation") 92 return False 93 94 95 async def _get_user_info( 96 auth_service_url: str, 97 auth_provider: str, 98 access_token: str, 99 ) -> dict | None: 100 """ 101 Retrieve user information from SAM's OAuth2 service. 102 103 This function delegates to the enterprise package's OAuth utilities. 104 105 Args: 106 auth_service_url: Base URL of the OAuth2 service 107 auth_provider: Provider name configured in OAuth2 service 108 access_token: The validated access token 109 110 Returns: 111 Dictionary containing user claims, or None if request failed 112 """ 113 try: 114 from solace_agent_mesh_enterprise.gateway.auth.internal.oauth_utils import ( 115 get_user_info_from_oauth_service, 116 ) 117 return await get_user_info_from_oauth_service( 118 auth_service_url, auth_provider, access_token 119 ) 120 except ImportError: 121 log.error("Enterprise package not available for OAuth user info retrieval") 122 return None 123 124 125 def _extract_user_identifier(user_info: dict, preferred_claim: str | None = None) -> str | None: 126 """ 127 Extract the primary user identifier from OAuth user info. 128 129 This function delegates to the enterprise package's OAuth utilities, 130 with a fallback to "sam_dev_user" for development when identifier is invalid. 131 132 Args: 133 user_info: Dictionary of user claims from OAuth provider 134 preferred_claim: OAuth claim to prioritize as user ID 135 136 Returns: 137 The user's primary identifier, or "sam_dev_user" if not found/invalid 138 """ 139 try: 140 from solace_agent_mesh_enterprise.gateway.auth.internal.oauth_utils import ( 141 extract_user_identifier, 142 ) 143 # Only pass preferred_claim if it's not None to match test expectations 144 if preferred_claim is not None: 145 result = extract_user_identifier(user_info, preferred_claim) 146 else: 147 result = extract_user_identifier(user_info) 148 # Fallback to sam_dev_user if enterprise returns None (invalid/unknown identifier) 149 if result is None: 150 return "sam_dev_user" 151 return result 152 except ImportError: 153 log.error("Enterprise package not available for user identifier extraction") 154 return "sam_dev_user" 155 156 157 app = FastAPI( 158 title="A2A Web UI Backend", 159 version="1.0.0", # Updated to reflect simplified architecture 160 description="Backend API and SSE server for the A2A Web UI, hosted by Solace AI Connector.", 161 ) 162 163 164 165 166 def _setup_alembic_config(database_url: str) -> Config: 167 alembic_cfg = Config() 168 alembic_cfg.set_main_option( 169 "script_location", 170 os.path.join(os.path.dirname(__file__), "alembic"), 171 ) 172 alembic_cfg.set_main_option("sqlalchemy.url", database_url) 173 return alembic_cfg 174 175 176 def _run_community_migrations(database_url: str) -> None: 177 """ 178 Run Alembic migrations for the community database schema. 179 This includes sessions, chat_messages tables and their indexes. 180 """ 181 from solace_agent_mesh.shared.database.sqlite_version_check import check_sqlite_version 182 183 # Verify SQLite version before running migrations 184 # This will raise RuntimeError if version is incompatible 185 check_sqlite_version(database_url, "WebUI Gateway") 186 187 try: 188 from sqlalchemy import create_engine 189 190 log.info("[WebUI Gateway] Starting community migrations...") 191 engine = create_engine(database_url) 192 inspector = sa.inspect(engine) 193 existing_tables = inspector.get_table_names() 194 195 alembic_cfg = _setup_alembic_config(database_url) 196 if not existing_tables or "sessions" not in existing_tables: 197 log.info("[WebUI Gateway] Running initial database setup") 198 else: 199 log.info("[WebUI Gateway] Checking for schema updates") 200 201 command.upgrade(alembic_cfg, "head") 202 log.info("[WebUI Gateway] Community migrations completed") 203 except Exception as e: 204 log.warning("[WebUI Gateway] Migration check failed: %s - attempting to run migrations", e) 205 try: 206 alembic_cfg = _setup_alembic_config(database_url) 207 command.upgrade(alembic_cfg, "head") 208 log.info("[WebUI Gateway] Community migrations completed") 209 except Exception as migration_error: 210 log.error("[WebUI Gateway] Migration failed: %s", migration_error) 211 log.error("[WebUI Gateway] Check database connectivity and permissions") 212 raise RuntimeError( 213 f"Community database migration failed: {migration_error}" 214 ) from migration_error 215 216 217 218 219 def _setup_database(database_url: str) -> None: 220 """Initialize database and run migrations.""" 221 from ...common.middleware.registry import MiddlewareRegistry 222 223 dependencies.init_database(database_url) 224 log.info("[WebUI Gateway] Running community database migrations...") 225 _run_community_migrations(database_url) 226 log.info("[WebUI Gateway] Community migrations completed") 227 228 # Run any registered post-migration hooks (e.g., enterprise migrations) 229 MiddlewareRegistry.run_post_migration_hooks(database_url) 230 log.info("[WebUI Gateway] Database setup complete") 231 232 233 def _get_app_config(component: "WebUIBackendComponent") -> dict: 234 webui_app = component.get_app() 235 app_config = {} 236 if webui_app: 237 app_config = getattr(webui_app, "app_config", {}) 238 if app_config is None: 239 log.warning("webui_app.app_config is None, using empty dict.") 240 app_config = {} 241 else: 242 log.warning("Could not get webui_app from component. Using empty app_config.") 243 return app_config 244 245 246 def _create_api_config(app_config: dict, database_url: str) -> dict: 247 return { 248 "external_auth_service_url": app_config.get( 249 "external_auth_service_url", "http://localhost:8080" 250 ), 251 "external_auth_callback_uri": app_config.get( 252 "external_auth_callback_uri", "http://localhost:8000/api/v1/auth/callback" 253 ), 254 "external_auth_provider": app_config.get("external_auth_provider", "azure"), 255 "frontend_use_authorization": app_config.get( 256 "frontend_use_authorization", False 257 ), 258 "frontend_redirect_url": app_config.get( 259 "frontend_redirect_url", "http://localhost:3000" 260 ), 261 "persistence_enabled": database_url is not None, 262 } 263 264 265 def setup_dependencies(component: "WebUIBackendComponent"): 266 """ 267 Initialize FastAPI dependencies (middleware, routers, static files). 268 Database migrations are handled in component.__init__(). 269 270 Args: 271 component: WebUIBackendComponent instance 272 """ 273 dependencies.set_component_instance(component) 274 275 app_config = _get_app_config(component) 276 api_config_dict = _create_api_config(app_config, component.database_url) 277 dependencies.set_api_config(api_config_dict) 278 279 _setup_middleware(component) 280 _setup_routers() 281 _setup_static_files() 282 283 284 def _setup_middleware(component: "WebUIBackendComponent") -> None: 285 allowed_origins = component.get_cors_origins() 286 cors_origin_regex = component.get_cors_origin_regex() 287 app.add_middleware( 288 CORSMiddleware, 289 allow_origins=allowed_origins, 290 allow_origin_regex=cors_origin_regex if cors_origin_regex else None, 291 allow_credentials=True, 292 allow_methods=["*"], 293 allow_headers=["*"], 294 ) 295 log.info("CORSMiddleware added with origins: %s", allowed_origins) 296 if cors_origin_regex: 297 log.info("CORS origin regex pattern: %s", cors_origin_regex) 298 299 session_manager = component.get_session_manager() 300 app.add_middleware(SessionMiddleware, secret_key=session_manager.secret_key) 301 log.info("SessionMiddleware added.") 302 303 auth_middleware_class = create_oauth_middleware(component) 304 app.add_middleware(auth_middleware_class, component=component) 305 306 api_config = dependencies.get_api_config() 307 use_auth = api_config.get("frontend_use_authorization", False) if api_config else False 308 if use_auth: 309 log.info("OAuth middleware added (real token validation enabled)") 310 else: 311 log.info("OAuth middleware added (development mode - community/dev user)") 312 313 from .middleware.observability import GatewayObservabilityMiddleware 314 app.add_middleware(GatewayObservabilityMiddleware) 315 log.info("Gateway observability middleware added (monitoring: tasks, sessions, sse, artifacts, messages)") 316 317 def _setup_routers() -> None: 318 api_prefix = "/api/v1" 319 320 app.include_router(session_router, prefix=api_prefix, tags=["Sessions"]) 321 app.include_router(user_router, prefix=f"{api_prefix}/users", tags=["Users"]) 322 app.include_router(config.router, prefix=api_prefix, tags=["Config"]) 323 app.include_router(version.router, prefix=api_prefix, tags=["Version"]) 324 app.include_router(feature_flags.router, prefix=api_prefix, tags=["Config"]) 325 app.include_router(agent_cards.router, prefix=api_prefix, tags=["Agent Cards"]) 326 app.include_router(task_router, prefix=api_prefix, tags=["Tasks"]) 327 app.include_router(sse.router, prefix=f"{api_prefix}/sse", tags=["SSE"]) 328 app.include_router( 329 artifacts.router, prefix=f"{api_prefix}/artifacts", tags=["Artifacts"] 330 ) 331 app.include_router( 332 visualization.router, 333 prefix=f"{api_prefix}/visualization", 334 tags=["Visualization"], 335 ) 336 app.include_router(people.router, prefix=api_prefix, tags=["People"]) 337 app.include_router(auth.router, prefix=api_prefix, tags=["Auth"]) 338 app.include_router(projects.router, prefix=api_prefix, tags=["Projects"]) 339 app.include_router(feedback.router, prefix=api_prefix, tags=["Feedback"]) 340 app.include_router(prompts.router, prefix=f"{api_prefix}/prompts", tags=["Prompts"]) 341 app.include_router(speech.router, prefix=f"{api_prefix}/speech", tags=["Speech"]) 342 app.include_router( 343 document_conversion.router, 344 prefix=f"{api_prefix}/document-conversion", 345 tags=["Document Conversion"], 346 ) 347 app.include_router(share.router, prefix=api_prefix, tags=["Share"]) 348 349 # Mount scheduled tasks router if available 350 if _scheduled_tasks_available and scheduled_tasks: 351 try: 352 app.include_router(scheduled_tasks.router, prefix=api_prefix, tags=["Scheduled Tasks"]) 353 log.info("Scheduled tasks router mounted successfully") 354 except Exception as e: 355 log.error("Failed to mount scheduled tasks router: %s", e, exc_info=True) 356 357 log.info("Legacy routers mounted for endpoints not yet migrated") 358 359 # Register shared exception handlers 360 from solace_agent_mesh.shared.exceptions.exception_handlers import register_exception_handlers 361 362 register_exception_handlers(app) 363 log.info("Registered shared exception handlers") 364 365 366 def _setup_static_files() -> None: 367 current_dir = os.path.dirname(os.path.abspath(__file__)) 368 root_dir = Path(os.path.normpath(os.path.join(current_dir, "..", ".."))) 369 static_files_dir = Path.joinpath(root_dir, "client", "webui", "frontend", "static") 370 371 if not os.path.isdir(static_files_dir): 372 log.warning( 373 "Static files directory '%s' not found. Frontend may not be served.", 374 static_files_dir, 375 ) 376 # try to mount static files directory anyways, might work for enterprise 377 try: 378 app.mount( 379 "/", StaticFiles(directory=static_files_dir, html=True), name="static" 380 ) 381 log.info("Mounted static files directory '%s' at '/'", static_files_dir) 382 except Exception as static_mount_err: 383 log.error( 384 "Failed to mount static files directory '%s': %s", 385 static_files_dir, 386 static_mount_err, 387 ) 388 389 390 @app.exception_handler(HTTPException) 391 async def http_exception_handler(request: FastAPIRequest, exc: HTTPException): 392 """ 393 HTTP exception handler with automatic format detection. 394 Returns JSON-RPC format for tasks/SSE endpoints, REST format for others. 395 """ 396 log.warning( 397 "HTTP Exception Handler triggered: Status=%s, Detail=%s, Request: %s %s", 398 exc.status_code, 399 exc.detail, 400 request.method, 401 request.url, 402 ) 403 404 # Check if this is a JSON-RPC endpoint (tasks and SSE endpoints use JSON-RPC) 405 is_jsonrpc_endpoint = request.url.path.startswith( 406 "/api/v1/tasks" 407 ) or request.url.path.startswith("/api/v1/sse") 408 409 if is_jsonrpc_endpoint: 410 # Use JSON-RPC format for tasks and SSE endpoints 411 error_data = None 412 error_code = InternalError().code 413 error_message = str(exc.detail) 414 415 if isinstance(exc.detail, dict): 416 if "code" in exc.detail and "message" in exc.detail: 417 error_code = exc.detail["code"] 418 error_message = exc.detail["message"] 419 error_data = exc.detail.get("data") 420 else: 421 error_data = exc.detail 422 elif isinstance(exc.detail, str): 423 if exc.status_code == status.HTTP_400_BAD_REQUEST: 424 error_code = -32600 425 elif exc.status_code == status.HTTP_404_NOT_FOUND: 426 error_code = -32601 427 error_message = "Resource not found" 428 429 error_obj = JSONRPCError( 430 code=error_code, message=error_message, data=error_data 431 ) 432 response = A2AJSONRPCResponse(error=error_obj) 433 return JSONResponse( 434 status_code=exc.status_code, content=response.model_dump(exclude_none=True) 435 ) 436 else: 437 # Use standard REST format for sessions and other REST endpoints 438 if isinstance(exc.detail, dict): 439 error_response = exc.detail 440 elif isinstance(exc.detail, str): 441 error_response = {"detail": exc.detail} 442 else: 443 error_response = {"detail": str(exc.detail)} 444 445 return JSONResponse(status_code=exc.status_code, content=error_response) 446 447 448 @app.exception_handler(RequestValidationError) 449 async def validation_exception_handler( 450 request: FastAPIRequest, exc: RequestValidationError 451 ): 452 """ 453 Handles Pydantic validation errors with format detection. 454 """ 455 log.warning( 456 "Validation Exception Handler triggered: %s, Request: %s %s", 457 exc.errors(), 458 request.method, 459 request.url, 460 ) 461 response = a2a.create_invalid_request_error_response( 462 message="Invalid request parameters", data=exc.errors(), request_id=None 463 ) 464 return JSONResponse( 465 status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, 466 content=response.model_dump(exclude_none=True), 467 ) 468 469 470 @app.exception_handler(Exception) 471 async def generic_exception_handler(request: FastAPIRequest, exc: Exception): 472 """ 473 Handles any other unexpected exceptions with format detection. 474 """ 475 log.exception( 476 "Generic Exception Handler triggered: %s, Request: %s %s", 477 exc, 478 request.method, 479 request.url, 480 ) 481 error_obj = a2a.create_internal_error( 482 message="An unexpected server error occurred: %s" % type(exc).__name__ 483 ) 484 response = a2a.create_error_response(error=error_obj, request_id=None) 485 return JSONResponse( 486 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 487 content=response.model_dump(exclude_none=True), 488 ) 489 490 491 @app.get("/health", tags=["Health"]) 492 async def read_root(): 493 """Basic health check endpoint.""" 494 log.debug("Health check endpoint '/health' called") 495 return {"status": "A2A Web UI Backend is running"}