test_mcp.py
1 """Tests for the internal MCP server (restai/mcp.py).""" 2 3 import asyncio 4 import json 5 import random 6 import pytest 7 from unittest.mock import patch, MagicMock 8 from fastapi.testclient import TestClient 9 10 from restai.config import RESTAI_DEFAULT_PASSWORD 11 from restai.main import app 12 13 ADMIN = ("admin", RESTAI_DEFAULT_PASSWORD) 14 15 suffix = str(random.randint(0, 10000000)) 16 user_name = f"mcp_user_{suffix}" 17 user_pass = "mcp_pass_123" 18 team_name = f"mcp_team_{suffix}" 19 llm_name = f"mcp_llm_{suffix}" 20 project_name = f"mcp-proj-{suffix}" 21 22 team_id = None 23 project_id = None 24 api_key = None 25 admin_api_key = None 26 27 28 @pytest.fixture(scope="module") 29 def client(): 30 with TestClient(app) as c: 31 yield c 32 33 34 def test_mcp_setup(client): 35 """Create user, team, LLM, project, and API keys for MCP tests.""" 36 global team_id, project_id, api_key, admin_api_key 37 38 # Create user 39 r = client.post( 40 "/users", 41 json={"username": user_name, "password": user_pass, "is_admin": False, "is_private": False}, 42 auth=ADMIN, 43 ) 44 assert r.status_code == 201, f"Failed to create user: {r.text}" 45 46 # Create LLM 47 r = client.post( 48 "/llms", 49 json={ 50 "name": llm_name, 51 "class_name": "OpenAI", 52 "options": {"model": "gpt-test", "api_key": "sk-fake"}, 53 "privacy": "public", 54 }, 55 auth=ADMIN, 56 ) 57 assert r.status_code in (200, 201), f"Failed to create LLM: {r.text}" 58 59 # Create team with user and LLM 60 r = client.post( 61 "/teams", 62 json={ 63 "name": team_name, 64 "users": [user_name], 65 "admins": [], 66 "llms": [llm_name], 67 }, 68 auth=ADMIN, 69 ) 70 assert r.status_code == 201, f"Failed to create team: {r.text}" 71 team_id = r.json()["id"] 72 73 # Create project 74 r = client.post( 75 "/projects", 76 json={ 77 "name": project_name, 78 "llm": llm_name, 79 "type": "agent", 80 "team_id": team_id, 81 "human_description": "Test project for MCP", 82 }, 83 auth=ADMIN, 84 ) 85 assert r.status_code == 201, f"Failed to create project: {r.text}" 86 project_id = r.json()["project"] 87 88 # Assign project to user 89 r = client.patch( 90 f"/projects/{project_id}", 91 json={"users": [user_name]}, 92 auth=ADMIN, 93 ) 94 assert r.status_code == 200, f"Failed to assign project: {r.text}" 95 96 # Create API key for user 97 r = client.post( 98 f"/users/{user_name}/apikeys", 99 json={"description": "mcp test key"}, 100 auth=(user_name, user_pass), 101 ) 102 assert r.status_code == 201, f"Failed to create API key: {r.text}" 103 api_key = r.json()["api_key"] 104 105 # Create API key for admin 106 r = client.post( 107 "/users/admin/apikeys", 108 json={"description": "mcp admin key"}, 109 auth=ADMIN, 110 ) 111 assert r.status_code == 201, f"Failed to create admin API key: {r.text}" 112 admin_api_key = r.json()["api_key"] 113 114 115 # ── Authentication tests ───────────────────────────────────────────────── 116 117 118 def test_mcp_auth_missing_header(): 119 """MCP auth with no Authorization header should fail.""" 120 from restai.mcp import _authenticate 121 122 mock_request = MagicMock() 123 mock_request.headers = {} 124 125 with patch("restai.mcp.get_http_request", return_value=mock_request): 126 try: 127 _authenticate() 128 assert False, "Should have raised PermissionError" 129 except PermissionError as e: 130 assert "Bearer" in str(e) 131 132 133 def test_mcp_auth_basic_rejected(): 134 """MCP auth with Basic auth should be rejected.""" 135 from restai.mcp import _authenticate 136 137 mock_request = MagicMock() 138 mock_request.headers = {"authorization": "Basic dXNlcjpwYXNz"} 139 140 with patch("restai.mcp.get_http_request", return_value=mock_request): 141 try: 142 _authenticate() 143 assert False, "Should have raised PermissionError" 144 except PermissionError as e: 145 assert "Bearer" in str(e) 146 147 148 def test_mcp_auth_invalid_key(): 149 """MCP auth with invalid Bearer key should fail.""" 150 from restai.mcp import _authenticate 151 152 mock_request = MagicMock() 153 mock_request.headers = {"authorization": "Bearer invalid-key-12345"} 154 155 with patch("restai.mcp.get_http_request", return_value=mock_request): 156 try: 157 _authenticate() 158 assert False, "Should have raised PermissionError" 159 except PermissionError as e: 160 assert "Invalid" in str(e) 161 162 163 def test_mcp_auth_valid_user_key(): 164 """MCP auth with valid user API key should return the user.""" 165 from restai.mcp import _authenticate 166 167 mock_request = MagicMock() 168 mock_request.headers = {"authorization": f"Bearer {api_key}"} 169 170 with patch("restai.mcp.get_http_request", return_value=mock_request): 171 user, db_wrapper = _authenticate() 172 try: 173 assert user.username == user_name 174 assert not user.is_admin 175 finally: 176 db_wrapper.db.close() 177 178 179 def test_mcp_auth_valid_admin_key(): 180 """MCP auth with valid admin API key should return admin user.""" 181 from restai.mcp import _authenticate 182 183 mock_request = MagicMock() 184 mock_request.headers = {"authorization": f"Bearer {admin_api_key}"} 185 186 with patch("restai.mcp.get_http_request", return_value=mock_request): 187 user, db_wrapper = _authenticate() 188 try: 189 assert user.username == "admin" 190 assert user.is_admin 191 finally: 192 db_wrapper.db.close() 193 194 195 # ── Access control tests ───────────────────────────────────────────────── 196 197 198 def test_mcp_user_has_project_access(): 199 """User should have access to their assigned project.""" 200 from restai.mcp import _authenticate 201 202 mock_request = MagicMock() 203 mock_request.headers = {"authorization": f"Bearer {api_key}"} 204 205 with patch("restai.mcp.get_http_request", return_value=mock_request): 206 user, db_wrapper = _authenticate() 207 try: 208 assert user.has_project_access(project_id) 209 finally: 210 db_wrapper.db.close() 211 212 213 def test_mcp_user_no_access_to_unassigned(): 214 """User should not have access to unassigned projects.""" 215 from restai.mcp import _authenticate 216 217 mock_request = MagicMock() 218 mock_request.headers = {"authorization": f"Bearer {api_key}"} 219 220 with patch("restai.mcp.get_http_request", return_value=mock_request): 221 user, db_wrapper = _authenticate() 222 try: 223 assert not user.has_project_access(999999) 224 finally: 225 db_wrapper.db.close() 226 227 228 def test_mcp_admin_has_access_to_all(): 229 """Admin should have access to any project.""" 230 from restai.mcp import _authenticate 231 232 mock_request = MagicMock() 233 mock_request.headers = {"authorization": f"Bearer {admin_api_key}"} 234 235 with patch("restai.mcp.get_http_request", return_value=mock_request): 236 user, db_wrapper = _authenticate() 237 try: 238 assert user.has_project_access(project_id) 239 assert user.has_project_access(999999) # Admin bypasses check 240 finally: 241 db_wrapper.db.close() 242 243 244 # ── Server creation tests ──────────────────────────────────────────────── 245 246 247 def test_mcp_server_has_tools(): 248 """MCP server should have list_projects and query_project tools.""" 249 from restai.mcp import create_mcp_server 250 251 mcp = create_mcp_server(MagicMock()) 252 tools = asyncio.run(mcp.list_tools()) 253 tool_names = {t.name for t in tools} 254 assert "list_projects" in tool_names 255 assert "query_project" in tool_names 256 257 258 def test_mcp_server_name(): 259 """MCP server should be named RESTai.""" 260 from restai.mcp import create_mcp_server 261 262 mcp = create_mcp_server(MagicMock()) 263 assert mcp.name == "RESTai" 264 265 266 def test_mcp_server_produces_sse_app(): 267 """MCP server should produce a valid SSE ASGI app.""" 268 from restai.mcp import create_mcp_server 269 270 mcp = create_mcp_server(MagicMock()) 271 sse_app = mcp.http_app(transport="sse") 272 assert sse_app is not None 273 274 275 # ── Teardown ───────────────────────────────────────────────────────────── 276 277 278 def test_mcp_teardown(client): 279 """Clean up resources created for MCP tests.""" 280 # Delete project before team to avoid orphaned records 281 if project_id: 282 client.delete(f"/projects/{project_id}", auth=ADMIN) 283 if team_id: 284 client.delete(f"/teams/{team_id}", auth=ADMIN) 285 client.delete(f"/users/{user_name}", auth=ADMIN) 286 client.delete(f"/llms/{llm_name}", auth=ADMIN)