/ tests / test_mcp.py
test_mcp.py
  1  """Tests for the internal MCP server (restai/mcp.py)."""
  2  
  3  import asyncio
  4  import json
  5  import random
  6  import pytest
  7  from unittest.mock import patch, MagicMock
  8  from fastapi.testclient import TestClient
  9  
 10  from restai.config import RESTAI_DEFAULT_PASSWORD
 11  from restai.main import app
 12  
 13  ADMIN = ("admin", RESTAI_DEFAULT_PASSWORD)
 14  
 15  suffix = str(random.randint(0, 10000000))
 16  user_name = f"mcp_user_{suffix}"
 17  user_pass = "mcp_pass_123"
 18  team_name = f"mcp_team_{suffix}"
 19  llm_name = f"mcp_llm_{suffix}"
 20  project_name = f"mcp-proj-{suffix}"
 21  
 22  team_id = None
 23  project_id = None
 24  api_key = None
 25  admin_api_key = None
 26  
 27  
 28  @pytest.fixture(scope="module")
 29  def client():
 30      with TestClient(app) as c:
 31          yield c
 32  
 33  
 34  def test_mcp_setup(client):
 35      """Create user, team, LLM, project, and API keys for MCP tests."""
 36      global team_id, project_id, api_key, admin_api_key
 37  
 38      # Create user
 39      r = client.post(
 40          "/users",
 41          json={"username": user_name, "password": user_pass, "is_admin": False, "is_private": False},
 42          auth=ADMIN,
 43      )
 44      assert r.status_code == 201, f"Failed to create user: {r.text}"
 45  
 46      # Create LLM
 47      r = client.post(
 48          "/llms",
 49          json={
 50              "name": llm_name,
 51              "class_name": "OpenAI",
 52              "options": {"model": "gpt-test", "api_key": "sk-fake"},
 53              "privacy": "public",
 54          },
 55          auth=ADMIN,
 56      )
 57      assert r.status_code in (200, 201), f"Failed to create LLM: {r.text}"
 58  
 59      # Create team with user and LLM
 60      r = client.post(
 61          "/teams",
 62          json={
 63              "name": team_name,
 64              "users": [user_name],
 65              "admins": [],
 66              "llms": [llm_name],
 67          },
 68          auth=ADMIN,
 69      )
 70      assert r.status_code == 201, f"Failed to create team: {r.text}"
 71      team_id = r.json()["id"]
 72  
 73      # Create project
 74      r = client.post(
 75          "/projects",
 76          json={
 77              "name": project_name,
 78              "llm": llm_name,
 79              "type": "agent",
 80              "team_id": team_id,
 81              "human_description": "Test project for MCP",
 82          },
 83          auth=ADMIN,
 84      )
 85      assert r.status_code == 201, f"Failed to create project: {r.text}"
 86      project_id = r.json()["project"]
 87  
 88      # Assign project to user
 89      r = client.patch(
 90          f"/projects/{project_id}",
 91          json={"users": [user_name]},
 92          auth=ADMIN,
 93      )
 94      assert r.status_code == 200, f"Failed to assign project: {r.text}"
 95  
 96      # Create API key for user
 97      r = client.post(
 98          f"/users/{user_name}/apikeys",
 99          json={"description": "mcp test key"},
100          auth=(user_name, user_pass),
101      )
102      assert r.status_code == 201, f"Failed to create API key: {r.text}"
103      api_key = r.json()["api_key"]
104  
105      # Create API key for admin
106      r = client.post(
107          "/users/admin/apikeys",
108          json={"description": "mcp admin key"},
109          auth=ADMIN,
110      )
111      assert r.status_code == 201, f"Failed to create admin API key: {r.text}"
112      admin_api_key = r.json()["api_key"]
113  
114  
115  # ── Authentication tests ─────────────────────────────────────────────────
116  
117  
118  def test_mcp_auth_missing_header():
119      """MCP auth with no Authorization header should fail."""
120      from restai.mcp import _authenticate
121  
122      mock_request = MagicMock()
123      mock_request.headers = {}
124  
125      with patch("restai.mcp.get_http_request", return_value=mock_request):
126          try:
127              _authenticate()
128              assert False, "Should have raised PermissionError"
129          except PermissionError as e:
130              assert "Bearer" in str(e)
131  
132  
133  def test_mcp_auth_basic_rejected():
134      """MCP auth with Basic auth should be rejected."""
135      from restai.mcp import _authenticate
136  
137      mock_request = MagicMock()
138      mock_request.headers = {"authorization": "Basic dXNlcjpwYXNz"}
139  
140      with patch("restai.mcp.get_http_request", return_value=mock_request):
141          try:
142              _authenticate()
143              assert False, "Should have raised PermissionError"
144          except PermissionError as e:
145              assert "Bearer" in str(e)
146  
147  
148  def test_mcp_auth_invalid_key():
149      """MCP auth with invalid Bearer key should fail."""
150      from restai.mcp import _authenticate
151  
152      mock_request = MagicMock()
153      mock_request.headers = {"authorization": "Bearer invalid-key-12345"}
154  
155      with patch("restai.mcp.get_http_request", return_value=mock_request):
156          try:
157              _authenticate()
158              assert False, "Should have raised PermissionError"
159          except PermissionError as e:
160              assert "Invalid" in str(e)
161  
162  
163  def test_mcp_auth_valid_user_key():
164      """MCP auth with valid user API key should return the user."""
165      from restai.mcp import _authenticate
166  
167      mock_request = MagicMock()
168      mock_request.headers = {"authorization": f"Bearer {api_key}"}
169  
170      with patch("restai.mcp.get_http_request", return_value=mock_request):
171          user, db_wrapper = _authenticate()
172          try:
173              assert user.username == user_name
174              assert not user.is_admin
175          finally:
176              db_wrapper.db.close()
177  
178  
179  def test_mcp_auth_valid_admin_key():
180      """MCP auth with valid admin API key should return admin user."""
181      from restai.mcp import _authenticate
182  
183      mock_request = MagicMock()
184      mock_request.headers = {"authorization": f"Bearer {admin_api_key}"}
185  
186      with patch("restai.mcp.get_http_request", return_value=mock_request):
187          user, db_wrapper = _authenticate()
188          try:
189              assert user.username == "admin"
190              assert user.is_admin
191          finally:
192              db_wrapper.db.close()
193  
194  
195  # ── Access control tests ─────────────────────────────────────────────────
196  
197  
198  def test_mcp_user_has_project_access():
199      """User should have access to their assigned project."""
200      from restai.mcp import _authenticate
201  
202      mock_request = MagicMock()
203      mock_request.headers = {"authorization": f"Bearer {api_key}"}
204  
205      with patch("restai.mcp.get_http_request", return_value=mock_request):
206          user, db_wrapper = _authenticate()
207          try:
208              assert user.has_project_access(project_id)
209          finally:
210              db_wrapper.db.close()
211  
212  
213  def test_mcp_user_no_access_to_unassigned():
214      """User should not have access to unassigned projects."""
215      from restai.mcp import _authenticate
216  
217      mock_request = MagicMock()
218      mock_request.headers = {"authorization": f"Bearer {api_key}"}
219  
220      with patch("restai.mcp.get_http_request", return_value=mock_request):
221          user, db_wrapper = _authenticate()
222          try:
223              assert not user.has_project_access(999999)
224          finally:
225              db_wrapper.db.close()
226  
227  
228  def test_mcp_admin_has_access_to_all():
229      """Admin should have access to any project."""
230      from restai.mcp import _authenticate
231  
232      mock_request = MagicMock()
233      mock_request.headers = {"authorization": f"Bearer {admin_api_key}"}
234  
235      with patch("restai.mcp.get_http_request", return_value=mock_request):
236          user, db_wrapper = _authenticate()
237          try:
238              assert user.has_project_access(project_id)
239              assert user.has_project_access(999999)  # Admin bypasses check
240          finally:
241              db_wrapper.db.close()
242  
243  
244  # ── Server creation tests ────────────────────────────────────────────────
245  
246  
247  def test_mcp_server_has_tools():
248      """MCP server should have list_projects and query_project tools."""
249      from restai.mcp import create_mcp_server
250  
251      mcp = create_mcp_server(MagicMock())
252      tools = asyncio.run(mcp.list_tools())
253      tool_names = {t.name for t in tools}
254      assert "list_projects" in tool_names
255      assert "query_project" in tool_names
256  
257  
258  def test_mcp_server_name():
259      """MCP server should be named RESTai."""
260      from restai.mcp import create_mcp_server
261  
262      mcp = create_mcp_server(MagicMock())
263      assert mcp.name == "RESTai"
264  
265  
266  def test_mcp_server_produces_sse_app():
267      """MCP server should produce a valid SSE ASGI app."""
268      from restai.mcp import create_mcp_server
269  
270      mcp = create_mcp_server(MagicMock())
271      sse_app = mcp.http_app(transport="sse")
272      assert sse_app is not None
273  
274  
275  # ── Teardown ─────────────────────────────────────────────────────────────
276  
277  
278  def test_mcp_teardown(client):
279      """Clean up resources created for MCP tests."""
280      # Delete project before team to avoid orphaned records
281      if project_id:
282          client.delete(f"/projects/{project_id}", auth=ADMIN)
283      if team_id:
284          client.delete(f"/teams/{team_id}", auth=ADMIN)
285      client.delete(f"/users/{user_name}", auth=ADMIN)
286      client.delete(f"/llms/{llm_name}", auth=ADMIN)