/ tests / test_projects.py
test_projects.py
  1  import random
  2  import json
  3  import base64
  4  import pytest
  5  from fastapi.testclient import TestClient
  6  from unittest.mock import patch, MagicMock
  7  
  8  from restai.config import RESTAI_DEFAULT_PASSWORD
  9  from restai.main import app
 10  from restai.models.models import ProjectModelCreate, ProjectModelUpdate, FindModel, TextIngestModel, URLIngestModel, ChatModel, QuestionModel
 11  
 12  project_id = None
 13  test_team_id = None
 14  test_llm = None
 15  test_embedding = None
 16  test_project_name = "test_project_" + str(random.randint(0, 1000000))
 17  test_team_name = "test_team_" + str(random.randint(0, 1000000))
 18  
 19  
 20  @pytest.fixture(scope="module")
 21  def client():
 22      with TestClient(app) as c:
 23          yield c
 24  
 25  
 26  def test_create_project(client):
 27      # Discover available LLM and embedding
 28      llms_resp = client.get("/llms", auth=("admin", RESTAI_DEFAULT_PASSWORD))
 29      assert llms_resp.status_code == 200
 30      global test_llm
 31      test_llm = llms_resp.json()[0]["name"]
 32  
 33      embeddings_resp = client.get("/embeddings", auth=("admin", RESTAI_DEFAULT_PASSWORD))
 34      assert embeddings_resp.status_code == 200
 35      global test_embedding
 36      test_embedding = embeddings_resp.json()[0]["name"]
 37  
 38      # Create a team first
 39      team_response = client.post(
 40          "/teams",
 41          json={"name": test_team_name, "llms": [test_llm], "embeddings": [test_embedding]},
 42          auth=("admin", RESTAI_DEFAULT_PASSWORD)
 43      )
 44      assert team_response.status_code == 201
 45      global test_team_id
 46      test_team_id = team_response.json()["id"]
 47  
 48      response = client.post(
 49          "/projects",
 50          json={
 51              "name": test_project_name,
 52              "llm": test_llm,
 53              "embeddings": test_embedding,
 54              "vectorstore": "chroma",
 55              "type": "rag",
 56              "human_name": "Test Project",
 57              "human_description": "A test project",
 58              "censorship": False,
 59              "guard": False,
 60              "public": False,
 61              "options": {},
 62              "team_id": test_team_id
 63          },
 64          auth=("admin", RESTAI_DEFAULT_PASSWORD)
 65      )
 66      assert response.status_code == 201
 67      global project_id
 68      project_id = response.json()["project"]
 69  
 70  def test_get_projects(client):
 71      # Test getting all projects
 72      response = client.get("/projects", auth=("admin", RESTAI_DEFAULT_PASSWORD))
 73      assert response.status_code == 200
 74      assert len(response.json()["projects"]) > 0
 75  
 76      # Test filtering public projects
 77      response = client.get("/projects?filter=public", auth=("admin", RESTAI_DEFAULT_PASSWORD))
 78      assert response.status_code == 200
 79  
 80      # Test pagination
 81      response = client.get("/projects?start=0&end=5", auth=("admin", RESTAI_DEFAULT_PASSWORD))
 82      assert response.status_code == 200
 83      assert len(response.json()["projects"]) <= 5
 84  
 85  def test_get_project(client):
 86      response = client.get(f"/projects/{project_id}", auth=("admin", RESTAI_DEFAULT_PASSWORD))
 87      assert response.status_code == 200
 88      data = response.json()
 89      assert data["id"] == project_id
 90      assert data["name"] == test_project_name
 91      assert data["type"] == "rag"
 92      assert data["llm"] == test_llm
 93      assert data["embeddings"] == test_embedding
 94      assert data["human_name"] == "Test Project"
 95      assert data["human_description"] == "A test project"
 96      assert data["public"] == False
 97  
 98  def test_edit_project(client):
 99      updated_name = "updated_" + test_project_name
100      response = client.patch(
101          f"/projects/{project_id}",
102          json={
103              "name": updated_name,
104              "human_name": "Updated Test Project",
105              "human_description": "An updated test project",
106              "public": True
107          },
108          auth=("admin", RESTAI_DEFAULT_PASSWORD)
109      )
110      assert response.status_code == 200
111      assert response.json()["project"] == project_id
112  
113      # Verify changes
114      response = client.get(f"/projects/{project_id}", auth=("admin", RESTAI_DEFAULT_PASSWORD))
115      assert response.status_code == 200
116      data = response.json()
117      assert data["name"] == updated_name
118      assert data["human_name"] == "Updated Test Project"
119      assert data["human_description"] == "An updated test project"
120      assert data["public"] == True
121  
122  def test_embeddings_endpoints(client):
123      # Test reset embeddings
124      response = client.post(
125          f"/projects/{project_id}/embeddings/reset",
126          auth=("admin", RESTAI_DEFAULT_PASSWORD)
127      )
128      assert response.status_code == 200
129  
130      # Test search embeddings
131      response = client.post(
132          f"/projects/{project_id}/embeddings/search",
133          json={"query": "test query", "k": 5},
134          auth=("admin", RESTAI_DEFAULT_PASSWORD)
135      )
136      assert response.status_code == 200
137  
138      # Test ingest text
139      response = client.post(
140          f"/projects/{project_id}/embeddings/ingest/text",
141          json={"text": "This is a test document for embedding.", "source": "test_doc"},
142          auth=("admin", RESTAI_DEFAULT_PASSWORD)
143      )
144      assert response.status_code == 200
145  
146      # Test ingest URL
147      response = client.post(
148          f"/projects/{project_id}/embeddings/ingest/url",
149          json={"url": "http://info.cern.ch/", "source": "example"},
150          auth=("admin", RESTAI_DEFAULT_PASSWORD)
151      )
152      assert response.status_code == 200
153  
154      # Test get embeddings
155      response = client.get(
156          f"/projects/{project_id}/embeddings",
157          auth=("admin", RESTAI_DEFAULT_PASSWORD)
158      )
159      assert response.status_code == 200
160  
161      # Test get embedding by source
162      response = client.get(
163          f"/projects/{project_id}/embeddings/source/" + base64.b64encode(b"test_doc").decode("utf-8"),
164          auth=("admin", RESTAI_DEFAULT_PASSWORD)
165      )
166      assert response.status_code == 200
167  
168      # Test delete embedding
169      response = client.delete(
170          f"/projects/{project_id}/embeddings/" + base64.b64encode(b"test_doc").decode("utf-8"),
171          auth=("admin", RESTAI_DEFAULT_PASSWORD)
172      )
173      assert response.status_code == 200
174  
175  def test_chat_and_question_endpoints(client):
176      # Test chat endpoint
177      response = client.post(
178          f"/projects/{project_id}/chat",
179          json={"question": "Hello, how are you?"},
180          auth=("admin", RESTAI_DEFAULT_PASSWORD)
181      )
182      assert response.status_code == 200
183  
184      # Test question endpoint
185      response = client.post(
186          f"/projects/{project_id}/question",
187          json={"question": "What is this project about?"},
188          auth=("admin", RESTAI_DEFAULT_PASSWORD)
189      )
190      assert response.status_code == 200
191  
192  def test_logs_endpoints(client):
193      # Test get logs
194      response = client.get(
195          f"/projects/{project_id}/logs",
196          auth=("admin", RESTAI_DEFAULT_PASSWORD)
197      )
198      assert response.status_code == 200
199  
200      # Test get daily token consumption
201      response = client.get(
202          f"/projects/{project_id}/tokens/daily",
203          auth=("admin", RESTAI_DEFAULT_PASSWORD)
204      )
205      assert response.status_code == 200
206  
207      # Test get monthly token consumption with specific month
208      response = client.get(
209          f"/projects/{project_id}/tokens/daily?year=2023&month=12",
210          auth=("admin", RESTAI_DEFAULT_PASSWORD)
211      )
212      assert response.status_code == 200
213  
214  def test_delete_project(client):
215      response = client.delete(f"/projects/{project_id}", auth=("admin", RESTAI_DEFAULT_PASSWORD))
216      assert response.status_code == 200
217      assert response.json()["project"] == project_id
218  
219      # Verify project is deleted
220      response = client.get(f"/projects/{project_id}", auth=("admin", RESTAI_DEFAULT_PASSWORD))
221      assert response.status_code == 404
222  
223      # Cleanup: delete the team
224      client.delete(f"/teams/{test_team_id}", auth=("admin", RESTAI_DEFAULT_PASSWORD))