/ tests / test_teams.py
test_teams.py
  1  import random
  2  import pytest
  3  from fastapi.testclient import TestClient
  4  
  5  from restai.config import RESTAI_DEFAULT_PASSWORD
  6  from restai.main import app
  7  
  8  
  9  @pytest.fixture(scope="module")
 10  def client():
 11      with TestClient(app) as c:
 12          yield c
 13  
 14  
 15  team_id = None
 16  team_name = "test_team_" + str(random.randint(0, 1000000))
 17  test_user1 = "test_team_user1_" + str(random.randint(0, 1000000))
 18  test_user2 = "test_team_user2_" + str(random.randint(0, 1000000))
 19  test_llm_name = "test_team_llm_" + str(random.randint(0, 1000000))
 20  test_llm_id = None
 21  test_embedding_name = "test_team_emb_" + str(random.randint(0, 1000000))
 22  test_embedding_id = None
 23  
 24  
 25  def test_setup_dependencies(client):
 26      global test_llm_id, test_embedding_id
 27      # Create two test users
 28      for username in (test_user1, test_user2):
 29          response = client.post(
 30              "/users",
 31              json={
 32                  "username": username,
 33                  "password": "testpass",
 34                  "admin": False,
 35                  "private": False,
 36              },
 37              auth=("admin", RESTAI_DEFAULT_PASSWORD),
 38          )
 39          assert response.status_code == 201
 40  
 41      # Create test LLM
 42      response = client.post(
 43          "/llms",
 44          json={
 45              "name": test_llm_name,
 46              "class_name": "OpenAI",
 47              "options": {"model": "gpt-test", "api_key": "sk-fake"},
 48              "privacy": "public",
 49          },
 50          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 51      )
 52      assert response.status_code == 201
 53      test_llm_id = response.json()["id"]
 54  
 55      # Create test embedding
 56      response = client.post(
 57          "/embeddings",
 58          json={
 59              "name": test_embedding_name,
 60              "class_name": "Ollama",
 61              "options": "{}",
 62              "privacy": "public",
 63              "dimension": 768,
 64          },
 65          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 66      )
 67      assert response.status_code == 201
 68      test_embedding_id = response.json()["id"]
 69  
 70  
 71  def test_create_team(client):
 72      global team_id
 73      response = client.post(
 74          "/teams",
 75          json={"name": team_name},
 76          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 77      )
 78      assert response.status_code == 201
 79      data = response.json()
 80      assert data["name"] == team_name
 81      assert "id" in data
 82      team_id = data["id"]
 83      assert data["users"] == []
 84      assert data["admins"] == []
 85  
 86  
 87  def test_create_team_non_admin(client):
 88      response = client.post(
 89          "/teams",
 90          json={"name": "should_fail_team"},
 91          auth=(test_user1, "testpass"),
 92      )
 93      assert response.status_code == 403
 94  
 95  
 96  def test_get_teams(client):
 97      response = client.get("/teams", auth=("admin", RESTAI_DEFAULT_PASSWORD))
 98      assert response.status_code == 200
 99      data = response.json()
100      assert "teams" in data
101      team_names = [t["name"] for t in data["teams"]]
102      assert team_name in team_names
103  
104  
105  def test_get_team(client):
106      response = client.get(
107          f"/teams/{team_id}",
108          auth=("admin", RESTAI_DEFAULT_PASSWORD),
109      )
110      assert response.status_code == 200
111      data = response.json()
112      assert data["id"] == team_id
113      assert data["name"] == team_name
114  
115  
116  def test_update_team(client):
117      response = client.patch(
118          f"/teams/{team_id}",
119          json={"description": "Updated team description"},
120          auth=("admin", RESTAI_DEFAULT_PASSWORD),
121      )
122      assert response.status_code == 200
123  
124      # Verify
125      response = client.get(
126          f"/teams/{team_id}",
127          auth=("admin", RESTAI_DEFAULT_PASSWORD),
128      )
129      assert response.status_code == 200
130      assert response.json()["description"] == "Updated team description"
131  
132  
133  def test_add_user_to_team(client):
134      response = client.post(
135          f"/teams/{team_id}/users/{test_user1}",
136          auth=("admin", RESTAI_DEFAULT_PASSWORD),
137      )
138      assert response.status_code == 200
139      assert response.json()["added"] == test_user1
140  
141  
142  def test_add_admin_to_team(client):
143      response = client.post(
144          f"/teams/{team_id}/admins/{test_user2}",
145          auth=("admin", RESTAI_DEFAULT_PASSWORD),
146      )
147      assert response.status_code == 200
148      assert response.json()["added_admin"] == test_user2
149  
150  
151  def test_verify_members(client):
152      response = client.get(
153          f"/teams/{team_id}",
154          auth=("admin", RESTAI_DEFAULT_PASSWORD),
155      )
156      assert response.status_code == 200
157      data = response.json()
158      user_names = [u["username"] for u in data["users"]]
159      admin_names = [a["username"] for a in data["admins"]]
160      assert test_user1 in user_names
161      assert test_user2 in admin_names
162  
163  
164  def test_get_team_as_member(client):
165      response = client.get(
166          f"/teams/{team_id}",
167          auth=(test_user1, "testpass"),
168      )
169      assert response.status_code == 200
170      assert response.json()["id"] == team_id
171  
172  
173  def test_get_team_as_non_member(client):
174      # Create a user who is not a team member
175      outsider = "test_outsider_" + str(random.randint(0, 1000000))
176      client.post(
177          "/users",
178          json={
179              "username": outsider,
180              "password": "testpass",
181              "admin": False,
182              "private": False,
183          },
184          auth=("admin", RESTAI_DEFAULT_PASSWORD),
185      )
186  
187      response = client.get(
188          f"/teams/{team_id}",
189          auth=(outsider, "testpass"),
190      )
191      assert response.status_code == 403
192  
193      # Clean up
194      client.delete(
195          f"/users/{outsider}",
196          auth=("admin", RESTAI_DEFAULT_PASSWORD),
197      )
198  
199  
200  def test_add_llm_to_team(client):
201      response = client.post(
202          f"/teams/{team_id}/llms/{test_llm_id}",
203          auth=("admin", RESTAI_DEFAULT_PASSWORD),
204      )
205      assert response.status_code == 200
206      assert response.json()["added_llm"] == test_llm_name
207  
208  
209  def test_add_embedding_to_team(client):
210      response = client.post(
211          f"/teams/{team_id}/embeddings/{test_embedding_id}",
212          auth=("admin", RESTAI_DEFAULT_PASSWORD),
213      )
214      assert response.status_code == 200
215      assert response.json()["added_embedding"] == test_embedding_name
216  
217  
218  def test_remove_embedding_from_team(client):
219      response = client.delete(
220          f"/teams/{team_id}/embeddings/{test_embedding_id}",
221          auth=("admin", RESTAI_DEFAULT_PASSWORD),
222      )
223      assert response.status_code == 200
224      assert response.json()["removed_embedding"] == test_embedding_name
225  
226  
227  def test_remove_llm_from_team(client):
228      response = client.delete(
229          f"/teams/{team_id}/llms/{test_llm_id}",
230          auth=("admin", RESTAI_DEFAULT_PASSWORD),
231      )
232      assert response.status_code == 200
233      assert response.json()["removed_llm"] == test_llm_name
234  
235  
236  def test_remove_user_from_team(client):
237      response = client.delete(
238          f"/teams/{team_id}/users/{test_user1}",
239          auth=("admin", RESTAI_DEFAULT_PASSWORD),
240      )
241      assert response.status_code == 200
242      assert response.json()["removed"] == test_user1
243  
244  
245  def test_remove_admin_from_team(client):
246      response = client.delete(
247          f"/teams/{team_id}/admins/{test_user2}",
248          auth=("admin", RESTAI_DEFAULT_PASSWORD),
249      )
250      assert response.status_code == 200
251      assert response.json()["removed_admin"] == test_user2
252  
253  
254  def test_delete_team(client):
255      response = client.delete(
256          f"/teams/{team_id}",
257          auth=("admin", RESTAI_DEFAULT_PASSWORD),
258      )
259      assert response.status_code == 200
260  
261      # Verify deleted
262      response = client.get(
263          f"/teams/{team_id}",
264          auth=("admin", RESTAI_DEFAULT_PASSWORD),
265      )
266      assert response.status_code == 404
267  
268  
269  def test_cleanup_dependencies(client):
270      # Delete test users
271      for username in (test_user1, test_user2):
272          response = client.delete(
273              f"/users/{username}",
274              auth=("admin", RESTAI_DEFAULT_PASSWORD),
275          )
276          assert response.status_code == 200
277  
278      # Delete test LLM
279      response = client.delete(
280          f"/llms/{test_llm_id}",
281          auth=("admin", RESTAI_DEFAULT_PASSWORD),
282      )
283      assert response.status_code == 200
284  
285      # Delete test embedding
286      response = client.delete(
287          f"/embeddings/{test_embedding_id}",
288          auth=("admin", RESTAI_DEFAULT_PASSWORD),
289      )
290      assert response.status_code == 200