main.py
  1  from __future__ import annotations
  2  
  3  import logging
  4  import os
  5  from pathlib import Path
  6  
  7  import httpx
  8  import sqlalchemy as sa
  9  from alembic import command
 10  from alembic.config import Config
 11  from fastapi import FastAPI, HTTPException
 12  from fastapi import Request as FastAPIRequest
 13  from fastapi import status
 14  from fastapi.exceptions import RequestValidationError
 15  from fastapi.middleware.cors import CORSMiddleware
 16  from fastapi.responses import JSONResponse
 17  from starlette.middleware.sessions import SessionMiddleware
 18  from starlette.staticfiles import StaticFiles
 19  from typing import TYPE_CHECKING
 20  
 21  from a2a.types import InternalError, InvalidRequestError, JSONRPCError
 22  from a2a.types import JSONRPCResponse as A2AJSONRPCResponse
 23  
 24  from ...common import a2a
 25  from ...gateway.http_sse import dependencies
 26  from ...shared.auth.middleware import create_oauth_middleware
 27  from .routers import (
 28      agent_cards,
 29      artifacts,
 30      auth,
 31      config,
 32      document_conversion,
 33      feature_flags,
 34      feedback,
 35      people,
 36      sse,
 37      share,
 38      speech,
 39      version,
 40      visualization,
 41      projects,
 42      prompts,
 43  )
 44  from .routers.sessions import router as session_router
 45  from .routers.tasks import router as task_router
 46  from .routers.users import router as user_router
 47  
 48  
 49  if TYPE_CHECKING:
 50      from .component import WebUIBackendComponent
 51  
 52  log = logging.getLogger(__name__)
 53  
 54  # Import scheduled_tasks separately with error handling
 55  try:
 56      from .routers import scheduled_tasks
 57      _scheduled_tasks_available = True
 58  except Exception as e:
 59      log.warning("Scheduled tasks router not available: %s", e)
 60      scheduled_tasks = None
 61      _scheduled_tasks_available = False
 62  
 63  
 64  # OAuth helper functions - delegate to enterprise package if available
 65  async def _validate_token(
 66      auth_service_url: str,
 67      auth_provider: str,
 68      access_token: str,
 69  ) -> bool:
 70      """
 71      Validate an access token against SAM's OAuth2 service.
 72  
 73      This function delegates to the enterprise package's OAuth utilities.
 74  
 75      Args:
 76          auth_service_url: Base URL of the OAuth2 service
 77          auth_provider: Provider name configured in OAuth2 service
 78          access_token: The access token to validate
 79  
 80      Returns:
 81          True if token is valid, False otherwise
 82      """
 83      try:
 84          from solace_agent_mesh_enterprise.gateway.auth.internal.oauth_utils import (
 85              validate_token_with_oauth_service,
 86          )
 87          return await validate_token_with_oauth_service(
 88              auth_service_url, auth_provider, access_token
 89          )
 90      except ImportError:
 91          log.error("Enterprise package not available for OAuth token validation")
 92          return False
 93  
 94  
 95  async def _get_user_info(
 96      auth_service_url: str,
 97      auth_provider: str,
 98      access_token: str,
 99  ) -> dict | None:
100      """
101      Retrieve user information from SAM's OAuth2 service.
102  
103      This function delegates to the enterprise package's OAuth utilities.
104  
105      Args:
106          auth_service_url: Base URL of the OAuth2 service
107          auth_provider: Provider name configured in OAuth2 service
108          access_token: The validated access token
109  
110      Returns:
111          Dictionary containing user claims, or None if request failed
112      """
113      try:
114          from solace_agent_mesh_enterprise.gateway.auth.internal.oauth_utils import (
115              get_user_info_from_oauth_service,
116          )
117          return await get_user_info_from_oauth_service(
118              auth_service_url, auth_provider, access_token
119          )
120      except ImportError:
121          log.error("Enterprise package not available for OAuth user info retrieval")
122          return None
123  
124  
125  def _extract_user_identifier(user_info: dict, preferred_claim: str | None = None) -> str | None:
126      """
127      Extract the primary user identifier from OAuth user info.
128  
129      This function delegates to the enterprise package's OAuth utilities,
130      with a fallback to "sam_dev_user" for development when identifier is invalid.
131  
132      Args:
133          user_info: Dictionary of user claims from OAuth provider
134          preferred_claim: OAuth claim to prioritize as user ID
135  
136      Returns:
137          The user's primary identifier, or "sam_dev_user" if not found/invalid
138      """
139      try:
140          from solace_agent_mesh_enterprise.gateway.auth.internal.oauth_utils import (
141              extract_user_identifier,
142          )
143          # Only pass preferred_claim if it's not None to match test expectations
144          if preferred_claim is not None:
145              result = extract_user_identifier(user_info, preferred_claim)
146          else:
147              result = extract_user_identifier(user_info)
148          # Fallback to sam_dev_user if enterprise returns None (invalid/unknown identifier)
149          if result is None:
150              return "sam_dev_user"
151          return result
152      except ImportError:
153          log.error("Enterprise package not available for user identifier extraction")
154          return "sam_dev_user"
155  
156  
157  app = FastAPI(
158      title="A2A Web UI Backend",
159      version="1.0.0",  # Updated to reflect simplified architecture
160      description="Backend API and SSE server for the A2A Web UI, hosted by Solace AI Connector.",
161  )
162  
163  
164  
165  
166  def _setup_alembic_config(database_url: str) -> Config:
167      alembic_cfg = Config()
168      alembic_cfg.set_main_option(
169          "script_location",
170          os.path.join(os.path.dirname(__file__), "alembic"),
171      )
172      alembic_cfg.set_main_option("sqlalchemy.url", database_url)
173      return alembic_cfg
174  
175  
176  def _run_community_migrations(database_url: str) -> None:
177      """
178      Run Alembic migrations for the community database schema.
179      This includes sessions, chat_messages tables and their indexes.
180      """
181      from solace_agent_mesh.shared.database.sqlite_version_check import check_sqlite_version
182  
183      # Verify SQLite version before running migrations
184      # This will raise RuntimeError if version is incompatible
185      check_sqlite_version(database_url, "WebUI Gateway")
186  
187      try:
188          from sqlalchemy import create_engine
189  
190          log.info("[WebUI Gateway] Starting community migrations...")
191          engine = create_engine(database_url)
192          inspector = sa.inspect(engine)
193          existing_tables = inspector.get_table_names()
194  
195          alembic_cfg = _setup_alembic_config(database_url)
196          if not existing_tables or "sessions" not in existing_tables:
197              log.info("[WebUI Gateway] Running initial database setup")
198          else:
199              log.info("[WebUI Gateway] Checking for schema updates")
200  
201          command.upgrade(alembic_cfg, "head")
202          log.info("[WebUI Gateway] Community migrations completed")
203      except Exception as e:
204          log.warning("[WebUI Gateway] Migration check failed: %s - attempting to run migrations", e)
205          try:
206              alembic_cfg = _setup_alembic_config(database_url)
207              command.upgrade(alembic_cfg, "head")
208              log.info("[WebUI Gateway] Community migrations completed")
209          except Exception as migration_error:
210              log.error("[WebUI Gateway] Migration failed: %s", migration_error)
211              log.error("[WebUI Gateway] Check database connectivity and permissions")
212              raise RuntimeError(
213                  f"Community database migration failed: {migration_error}"
214              ) from migration_error
215  
216  
217  
218  
219  def _setup_database(database_url: str) -> None:
220      """Initialize database and run migrations."""
221      from ...common.middleware.registry import MiddlewareRegistry
222  
223      dependencies.init_database(database_url)
224      log.info("[WebUI Gateway] Running community database migrations...")
225      _run_community_migrations(database_url)
226      log.info("[WebUI Gateway] Community migrations completed")
227  
228      # Run any registered post-migration hooks (e.g., enterprise migrations)
229      MiddlewareRegistry.run_post_migration_hooks(database_url)
230      log.info("[WebUI Gateway] Database setup complete")
231  
232  
233  def _get_app_config(component: "WebUIBackendComponent") -> dict:
234      webui_app = component.get_app()
235      app_config = {}
236      if webui_app:
237          app_config = getattr(webui_app, "app_config", {})
238          if app_config is None:
239              log.warning("webui_app.app_config is None, using empty dict.")
240              app_config = {}
241      else:
242          log.warning("Could not get webui_app from component. Using empty app_config.")
243      return app_config
244  
245  
246  def _create_api_config(app_config: dict, database_url: str) -> dict:
247      return {
248          "external_auth_service_url": app_config.get(
249              "external_auth_service_url", "http://localhost:8080"
250          ),
251          "external_auth_callback_uri": app_config.get(
252              "external_auth_callback_uri", "http://localhost:8000/api/v1/auth/callback"
253          ),
254          "external_auth_provider": app_config.get("external_auth_provider", "azure"),
255          "frontend_use_authorization": app_config.get(
256              "frontend_use_authorization", False
257          ),
258          "frontend_redirect_url": app_config.get(
259              "frontend_redirect_url", "http://localhost:3000"
260          ),
261          "persistence_enabled": database_url is not None,
262      }
263  
264  
265  def setup_dependencies(component: "WebUIBackendComponent"):
266      """
267      Initialize FastAPI dependencies (middleware, routers, static files).
268      Database migrations are handled in component.__init__().
269  
270      Args:
271          component: WebUIBackendComponent instance
272      """
273      dependencies.set_component_instance(component)
274  
275      app_config = _get_app_config(component)
276      api_config_dict = _create_api_config(app_config, component.database_url)
277      dependencies.set_api_config(api_config_dict)
278  
279      _setup_middleware(component)
280      _setup_routers()
281      _setup_static_files()
282  
283  
284  def _setup_middleware(component: "WebUIBackendComponent") -> None:
285      allowed_origins = component.get_cors_origins()
286      cors_origin_regex = component.get_cors_origin_regex()
287      app.add_middleware(
288          CORSMiddleware,
289          allow_origins=allowed_origins,
290          allow_origin_regex=cors_origin_regex if cors_origin_regex else None,
291          allow_credentials=True,
292          allow_methods=["*"],
293          allow_headers=["*"],
294      )
295      log.info("CORSMiddleware added with origins: %s", allowed_origins)
296      if cors_origin_regex:
297          log.info("CORS origin regex pattern: %s", cors_origin_regex)
298  
299      session_manager = component.get_session_manager()
300      app.add_middleware(SessionMiddleware, secret_key=session_manager.secret_key)
301      log.info("SessionMiddleware added.")
302  
303      auth_middleware_class = create_oauth_middleware(component)
304      app.add_middleware(auth_middleware_class, component=component)
305  
306      api_config = dependencies.get_api_config()
307      use_auth = api_config.get("frontend_use_authorization", False) if api_config else False
308      if use_auth:
309          log.info("OAuth middleware added (real token validation enabled)")
310      else:
311          log.info("OAuth middleware added (development mode - community/dev user)")
312  
313      from .middleware.observability import GatewayObservabilityMiddleware
314      app.add_middleware(GatewayObservabilityMiddleware)
315      log.info("Gateway observability middleware added (monitoring: tasks, sessions, sse, artifacts, messages)")
316  
317  def _setup_routers() -> None:
318      api_prefix = "/api/v1"
319  
320      app.include_router(session_router, prefix=api_prefix, tags=["Sessions"])
321      app.include_router(user_router, prefix=f"{api_prefix}/users", tags=["Users"])
322      app.include_router(config.router, prefix=api_prefix, tags=["Config"])
323      app.include_router(version.router, prefix=api_prefix, tags=["Version"])
324      app.include_router(feature_flags.router, prefix=api_prefix, tags=["Config"])
325      app.include_router(agent_cards.router, prefix=api_prefix, tags=["Agent Cards"])
326      app.include_router(task_router, prefix=api_prefix, tags=["Tasks"])
327      app.include_router(sse.router, prefix=f"{api_prefix}/sse", tags=["SSE"])
328      app.include_router(
329          artifacts.router, prefix=f"{api_prefix}/artifacts", tags=["Artifacts"]
330      )
331      app.include_router(
332          visualization.router,
333          prefix=f"{api_prefix}/visualization",
334          tags=["Visualization"],
335      )
336      app.include_router(people.router, prefix=api_prefix, tags=["People"])
337      app.include_router(auth.router, prefix=api_prefix, tags=["Auth"])
338      app.include_router(projects.router, prefix=api_prefix, tags=["Projects"])
339      app.include_router(feedback.router, prefix=api_prefix, tags=["Feedback"])
340      app.include_router(prompts.router, prefix=f"{api_prefix}/prompts", tags=["Prompts"])
341      app.include_router(speech.router, prefix=f"{api_prefix}/speech", tags=["Speech"])
342      app.include_router(
343          document_conversion.router,
344          prefix=f"{api_prefix}/document-conversion",
345          tags=["Document Conversion"],
346      )
347      app.include_router(share.router, prefix=api_prefix, tags=["Share"])
348  
349      # Mount scheduled tasks router if available
350      if _scheduled_tasks_available and scheduled_tasks:
351          try:
352              app.include_router(scheduled_tasks.router, prefix=api_prefix, tags=["Scheduled Tasks"])
353              log.info("Scheduled tasks router mounted successfully")
354          except Exception as e:
355              log.error("Failed to mount scheduled tasks router: %s", e, exc_info=True)
356  
357      log.info("Legacy routers mounted for endpoints not yet migrated")
358  
359      # Register shared exception handlers
360      from solace_agent_mesh.shared.exceptions.exception_handlers import register_exception_handlers
361  
362      register_exception_handlers(app)
363      log.info("Registered shared exception handlers")
364  
365  
366  def _setup_static_files() -> None:
367      current_dir = os.path.dirname(os.path.abspath(__file__))
368      root_dir = Path(os.path.normpath(os.path.join(current_dir, "..", "..")))
369      static_files_dir = Path.joinpath(root_dir, "client", "webui", "frontend", "static")
370  
371      if not os.path.isdir(static_files_dir):
372          log.warning(
373              "Static files directory '%s' not found. Frontend may not be served.",
374              static_files_dir,
375          )
376      # try to mount static files directory anyways, might work for enterprise
377      try:
378          app.mount(
379              "/", StaticFiles(directory=static_files_dir, html=True), name="static"
380          )
381          log.info("Mounted static files directory '%s' at '/'", static_files_dir)
382      except Exception as static_mount_err:
383          log.error(
384              "Failed to mount static files directory '%s': %s",
385              static_files_dir,
386              static_mount_err,
387          )
388  
389  
390  @app.exception_handler(HTTPException)
391  async def http_exception_handler(request: FastAPIRequest, exc: HTTPException):
392      """
393      HTTP exception handler with automatic format detection.
394      Returns JSON-RPC format for tasks/SSE endpoints, REST format for others.
395      """
396      log.warning(
397          "HTTP Exception Handler triggered: Status=%s, Detail=%s, Request: %s %s",
398          exc.status_code,
399          exc.detail,
400          request.method,
401          request.url,
402      )
403  
404      # Check if this is a JSON-RPC endpoint (tasks and SSE endpoints use JSON-RPC)
405      is_jsonrpc_endpoint = request.url.path.startswith(
406          "/api/v1/tasks"
407      ) or request.url.path.startswith("/api/v1/sse")
408  
409      if is_jsonrpc_endpoint:
410          # Use JSON-RPC format for tasks and SSE endpoints
411          error_data = None
412          error_code = InternalError().code
413          error_message = str(exc.detail)
414  
415          if isinstance(exc.detail, dict):
416              if "code" in exc.detail and "message" in exc.detail:
417                  error_code = exc.detail["code"]
418                  error_message = exc.detail["message"]
419                  error_data = exc.detail.get("data")
420              else:
421                  error_data = exc.detail
422          elif isinstance(exc.detail, str):
423              if exc.status_code == status.HTTP_400_BAD_REQUEST:
424                  error_code = -32600
425              elif exc.status_code == status.HTTP_404_NOT_FOUND:
426                  error_code = -32601
427                  error_message = "Resource not found"
428  
429          error_obj = JSONRPCError(
430              code=error_code, message=error_message, data=error_data
431          )
432          response = A2AJSONRPCResponse(error=error_obj)
433          return JSONResponse(
434              status_code=exc.status_code, content=response.model_dump(exclude_none=True)
435          )
436      else:
437          # Use standard REST format for sessions and other REST endpoints
438          if isinstance(exc.detail, dict):
439              error_response = exc.detail
440          elif isinstance(exc.detail, str):
441              error_response = {"detail": exc.detail}
442          else:
443              error_response = {"detail": str(exc.detail)}
444  
445          return JSONResponse(status_code=exc.status_code, content=error_response)
446  
447  
448  @app.exception_handler(RequestValidationError)
449  async def validation_exception_handler(
450      request: FastAPIRequest, exc: RequestValidationError
451  ):
452      """
453      Handles Pydantic validation errors with format detection.
454      """
455      log.warning(
456          "Validation Exception Handler triggered: %s, Request: %s %s",
457          exc.errors(),
458          request.method,
459          request.url,
460      )
461      response = a2a.create_invalid_request_error_response(
462          message="Invalid request parameters", data=exc.errors(), request_id=None
463      )
464      return JSONResponse(
465          status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
466          content=response.model_dump(exclude_none=True),
467      )
468  
469  
470  @app.exception_handler(Exception)
471  async def generic_exception_handler(request: FastAPIRequest, exc: Exception):
472      """
473      Handles any other unexpected exceptions with format detection.
474      """
475      log.exception(
476          "Generic Exception Handler triggered: %s, Request: %s %s",
477          exc,
478          request.method,
479          request.url,
480      )
481      error_obj = a2a.create_internal_error(
482          message="An unexpected server error occurred: %s" % type(exc).__name__
483      )
484      response = a2a.create_error_response(error=error_obj, request_id=None)
485      return JSONResponse(
486          status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
487          content=response.model_dump(exclude_none=True),
488      )
489  
490  
491  @app.get("/health", tags=["Health"])
492  async def read_root():
493      """Basic health check endpoint."""
494      log.debug("Health check endpoint '/health' called")
495      return {"status": "A2A Web UI Backend is running"}