mcp.py
1 """Internal MCP server exposing RESTai projects as tools.""" 2 3 import json 4 import logging 5 from typing import Optional 6 7 from fastmcp import FastMCP 8 from fastmcp.server.dependencies import get_http_request 9 10 from restai.database import get_db_wrapper 11 from restai.models.models import QuestionModel, User 12 from restai.models.databasemodels import ProjectDatabase 13 14 logger = logging.getLogger(__name__) 15 16 17 def _authenticate(): 18 """Authenticate the current MCP request via Bearer API key. 19 20 Uses FastMCP's get_http_request() to access the underlying HTTP request, 21 then validates the API key against the database. 22 23 Returns: 24 Tuple of (User, DBWrapper) on success. 25 Raises: 26 PermissionError on authentication failure. 27 """ 28 request = get_http_request() 29 auth_header = request.headers.get("authorization", "") 30 31 if not auth_header.startswith("Bearer "): 32 raise PermissionError("Missing or invalid Authorization header. Use: Bearer <api_key>") 33 34 token = auth_header[7:] # Strip "Bearer " 35 db_wrapper = get_db_wrapper() 36 user_db, api_key_row = db_wrapper.get_user_by_apikey(token) 37 if user_db is None: 38 db_wrapper.db.close() 39 raise PermissionError("Invalid API key") 40 41 user = User.model_validate(user_db) 42 43 if api_key_row is not None: 44 if api_key_row.allowed_projects: 45 try: 46 import json 47 user.api_key_allowed_projects = json.loads(api_key_row.allowed_projects) 48 except (json.JSONDecodeError, TypeError): 49 pass 50 user.api_key_read_only = api_key_row.read_only or False 51 52 return user, db_wrapper 53 54 55 def create_mcp_server(app_ref) -> FastMCP: 56 """Create the MCP server with project tools. 57 58 Args: 59 app_ref: FastAPI app instance (for accessing brain via app.state.brain). 60 """ 61 mcp = FastMCP( 62 name="RESTai", 63 instructions=( 64 "RESTai MCP Server. Use list_projects to discover available AI projects, " 65 "then use query_project to interact with them." 66 ), 67 ) 68 69 @mcp.tool() 70 async def list_projects() -> str: 71 """List all AI projects you have access to. 72 73 Returns a JSON list of projects with name, type, and description. 74 Use the project name with query_project to send questions. 75 """ 76 user, db_wrapper = _authenticate() 77 try: 78 query = db_wrapper.db.query(ProjectDatabase) 79 if not user.is_admin: 80 query = query.filter(ProjectDatabase.id.in_(user.get_project_ids())) 81 projects = query.all() 82 result = [] 83 for p in projects: 84 entry = {"name": p.name, "type": p.type} 85 if p.human_name: 86 entry["human_name"] = p.human_name 87 if p.human_description: 88 entry["description"] = p.human_description 89 result.append(entry) 90 return json.dumps(result, indent=2) 91 finally: 92 db_wrapper.db.close() 93 94 @mcp.tool() 95 async def query_project( 96 project_name: str, 97 question: str = "", 98 image: Optional[str] = None, 99 ) -> str: 100 """Send a question to an AI project and get the response. 101 102 Args: 103 project_name: Name of the project (from list_projects). 104 question: The question or prompt to send. 105 image: Optional base64-encoded image for vision-capable projects. 106 """ 107 from fastapi import BackgroundTasks 108 from restai.helper import question_main 109 from restai.brain import Brain 110 111 user, db_wrapper = _authenticate() 112 try: 113 project_db = db_wrapper.get_project_by_name(project_name) 114 if project_db is None: 115 return f"Error: Project '{project_name}' not found." 116 117 if not user.has_project_access(project_db.id): 118 return f"Error: Access denied to project '{project_name}'." 119 120 brain: Brain = app_ref.state.brain 121 project = brain.find_project(project_db.id, db_wrapper) 122 if project is None: 123 return f"Error: Could not load project '{project_name}'." 124 125 q_input = QuestionModel(question=question, image=image, stream=False) 126 background_tasks = BackgroundTasks() 127 http_request = get_http_request() 128 129 result = await question_main( 130 http_request, 131 brain, 132 project, 133 q_input, 134 user, 135 db_wrapper, 136 background_tasks, 137 ) 138 139 if isinstance(result, dict): 140 return result.get("answer", json.dumps(result)) 141 return str(result) 142 except PermissionError: 143 raise 144 except Exception as e: 145 logger.exception("Error querying project '%s': %s", project_name, e) 146 return f"Error: {e}" 147 finally: 148 db_wrapper.db.close() 149 150 return mcp