test_projects.py
1 import random 2 import json 3 import base64 4 import pytest 5 from fastapi.testclient import TestClient 6 from unittest.mock import patch, MagicMock 7 8 from restai.config import RESTAI_DEFAULT_PASSWORD 9 from restai.main import app 10 from restai.models.models import ProjectModelCreate, ProjectModelUpdate, FindModel, TextIngestModel, URLIngestModel, ChatModel, QuestionModel 11 12 project_id = None 13 test_team_id = None 14 test_llm = None 15 test_embedding = None 16 test_project_name = "test_project_" + str(random.randint(0, 1000000)) 17 test_team_name = "test_team_" + str(random.randint(0, 1000000)) 18 19 20 @pytest.fixture(scope="module") 21 def client(): 22 with TestClient(app) as c: 23 yield c 24 25 26 def test_create_project(client): 27 # Discover available LLM and embedding 28 llms_resp = client.get("/llms", auth=("admin", RESTAI_DEFAULT_PASSWORD)) 29 assert llms_resp.status_code == 200 30 global test_llm 31 test_llm = llms_resp.json()[0]["name"] 32 33 embeddings_resp = client.get("/embeddings", auth=("admin", RESTAI_DEFAULT_PASSWORD)) 34 assert embeddings_resp.status_code == 200 35 global test_embedding 36 test_embedding = embeddings_resp.json()[0]["name"] 37 38 # Create a team first 39 team_response = client.post( 40 "/teams", 41 json={"name": test_team_name, "llms": [test_llm], "embeddings": [test_embedding]}, 42 auth=("admin", RESTAI_DEFAULT_PASSWORD) 43 ) 44 assert team_response.status_code == 201 45 global test_team_id 46 test_team_id = team_response.json()["id"] 47 48 response = client.post( 49 "/projects", 50 json={ 51 "name": test_project_name, 52 "llm": test_llm, 53 "embeddings": test_embedding, 54 "vectorstore": "chroma", 55 "type": "rag", 56 "human_name": "Test Project", 57 "human_description": "A test project", 58 "censorship": False, 59 "guard": False, 60 "public": False, 61 "options": {}, 62 "team_id": test_team_id 63 }, 64 auth=("admin", RESTAI_DEFAULT_PASSWORD) 65 ) 66 assert response.status_code == 201 67 global project_id 68 project_id = response.json()["project"] 69 70 def test_get_projects(client): 71 # Test getting all projects 72 response = client.get("/projects", auth=("admin", RESTAI_DEFAULT_PASSWORD)) 73 assert response.status_code == 200 74 assert len(response.json()["projects"]) > 0 75 76 # Test filtering public projects 77 response = client.get("/projects?filter=public", auth=("admin", RESTAI_DEFAULT_PASSWORD)) 78 assert response.status_code == 200 79 80 # Test pagination 81 response = client.get("/projects?start=0&end=5", auth=("admin", RESTAI_DEFAULT_PASSWORD)) 82 assert response.status_code == 200 83 assert len(response.json()["projects"]) <= 5 84 85 def test_get_project(client): 86 response = client.get(f"/projects/{project_id}", auth=("admin", RESTAI_DEFAULT_PASSWORD)) 87 assert response.status_code == 200 88 data = response.json() 89 assert data["id"] == project_id 90 assert data["name"] == test_project_name 91 assert data["type"] == "rag" 92 assert data["llm"] == test_llm 93 assert data["embeddings"] == test_embedding 94 assert data["human_name"] == "Test Project" 95 assert data["human_description"] == "A test project" 96 assert data["public"] == False 97 98 def test_edit_project(client): 99 updated_name = "updated_" + test_project_name 100 response = client.patch( 101 f"/projects/{project_id}", 102 json={ 103 "name": updated_name, 104 "human_name": "Updated Test Project", 105 "human_description": "An updated test project", 106 "public": True 107 }, 108 auth=("admin", RESTAI_DEFAULT_PASSWORD) 109 ) 110 assert response.status_code == 200 111 assert response.json()["project"] == project_id 112 113 # Verify changes 114 response = client.get(f"/projects/{project_id}", auth=("admin", RESTAI_DEFAULT_PASSWORD)) 115 assert response.status_code == 200 116 data = response.json() 117 assert data["name"] == updated_name 118 assert data["human_name"] == "Updated Test Project" 119 assert data["human_description"] == "An updated test project" 120 assert data["public"] == True 121 122 def test_embeddings_endpoints(client): 123 # Test reset embeddings 124 response = client.post( 125 f"/projects/{project_id}/embeddings/reset", 126 auth=("admin", RESTAI_DEFAULT_PASSWORD) 127 ) 128 assert response.status_code == 200 129 130 # Test search embeddings 131 response = client.post( 132 f"/projects/{project_id}/embeddings/search", 133 json={"query": "test query", "k": 5}, 134 auth=("admin", RESTAI_DEFAULT_PASSWORD) 135 ) 136 assert response.status_code == 200 137 138 # Test ingest text 139 response = client.post( 140 f"/projects/{project_id}/embeddings/ingest/text", 141 json={"text": "This is a test document for embedding.", "source": "test_doc"}, 142 auth=("admin", RESTAI_DEFAULT_PASSWORD) 143 ) 144 assert response.status_code == 200 145 146 # Test ingest URL 147 response = client.post( 148 f"/projects/{project_id}/embeddings/ingest/url", 149 json={"url": "http://info.cern.ch/", "source": "example"}, 150 auth=("admin", RESTAI_DEFAULT_PASSWORD) 151 ) 152 assert response.status_code == 200 153 154 # Test get embeddings 155 response = client.get( 156 f"/projects/{project_id}/embeddings", 157 auth=("admin", RESTAI_DEFAULT_PASSWORD) 158 ) 159 assert response.status_code == 200 160 161 # Test get embedding by source 162 response = client.get( 163 f"/projects/{project_id}/embeddings/source/" + base64.b64encode(b"test_doc").decode("utf-8"), 164 auth=("admin", RESTAI_DEFAULT_PASSWORD) 165 ) 166 assert response.status_code == 200 167 168 # Test delete embedding 169 response = client.delete( 170 f"/projects/{project_id}/embeddings/" + base64.b64encode(b"test_doc").decode("utf-8"), 171 auth=("admin", RESTAI_DEFAULT_PASSWORD) 172 ) 173 assert response.status_code == 200 174 175 def test_chat_and_question_endpoints(client): 176 # Test chat endpoint 177 response = client.post( 178 f"/projects/{project_id}/chat", 179 json={"question": "Hello, how are you?"}, 180 auth=("admin", RESTAI_DEFAULT_PASSWORD) 181 ) 182 assert response.status_code == 200 183 184 # Test question endpoint 185 response = client.post( 186 f"/projects/{project_id}/question", 187 json={"question": "What is this project about?"}, 188 auth=("admin", RESTAI_DEFAULT_PASSWORD) 189 ) 190 assert response.status_code == 200 191 192 def test_logs_endpoints(client): 193 # Test get logs 194 response = client.get( 195 f"/projects/{project_id}/logs", 196 auth=("admin", RESTAI_DEFAULT_PASSWORD) 197 ) 198 assert response.status_code == 200 199 200 # Test get daily token consumption 201 response = client.get( 202 f"/projects/{project_id}/tokens/daily", 203 auth=("admin", RESTAI_DEFAULT_PASSWORD) 204 ) 205 assert response.status_code == 200 206 207 # Test get monthly token consumption with specific month 208 response = client.get( 209 f"/projects/{project_id}/tokens/daily?year=2023&month=12", 210 auth=("admin", RESTAI_DEFAULT_PASSWORD) 211 ) 212 assert response.status_code == 200 213 214 def test_delete_project(client): 215 response = client.delete(f"/projects/{project_id}", auth=("admin", RESTAI_DEFAULT_PASSWORD)) 216 assert response.status_code == 200 217 assert response.json()["project"] == project_id 218 219 # Verify project is deleted 220 response = client.get(f"/projects/{project_id}", auth=("admin", RESTAI_DEFAULT_PASSWORD)) 221 assert response.status_code == 404 222 223 # Cleanup: delete the team 224 client.delete(f"/teams/{test_team_id}", auth=("admin", RESTAI_DEFAULT_PASSWORD))