test_project_embeddings.py
1 """Tests for project embeddings endpoints on non-RAG projects.""" 2 3 import random 4 import pytest 5 from fastapi.testclient import TestClient 6 7 from restai.config import RESTAI_DEFAULT_PASSWORD 8 from restai.main import app 9 10 ADMIN = ("admin", RESTAI_DEFAULT_PASSWORD) 11 12 suffix = str(random.randint(0, 999999)) 13 user_name = f"emb_user_{suffix}" 14 user_pass = "emb_pass_123" 15 team_name = f"emb_team_{suffix}" 16 llm_name = f"emb_llm_{suffix}" 17 project_name = f"emb-proj-{suffix}" 18 19 team_id = None 20 project_id = None 21 22 23 @pytest.fixture(scope="module") 24 def client(): 25 with TestClient(app) as c: 26 yield c 27 28 29 def test_setup(client): 30 """Create user, team, LLM, and block project for embeddings tests.""" 31 global team_id, project_id 32 33 r = client.post( 34 "/users", 35 json={"username": user_name, "password": user_pass, "is_admin": False, "is_private": False}, 36 auth=ADMIN, 37 ) 38 assert r.status_code == 201 39 40 r = client.post( 41 "/llms", 42 json={ 43 "name": llm_name, 44 "class_name": "OpenAI", 45 "options": {"model": "gpt-test", "api_key": "sk-fake"}, 46 "privacy": "public", 47 }, 48 auth=ADMIN, 49 ) 50 assert r.status_code in (200, 201) 51 52 r = client.post( 53 "/teams", 54 json={"name": team_name, "users": [user_name], "llms": [llm_name]}, 55 auth=ADMIN, 56 ) 57 assert r.status_code == 201 58 team_id = r.json()["id"] 59 60 r = client.post( 61 "/projects", 62 json={ 63 "name": project_name, 64 "llm": llm_name, 65 "type": "block", 66 "team_id": team_id, 67 }, 68 auth=ADMIN, 69 ) 70 assert r.status_code == 201 71 project_id = r.json()["project"] 72 73 # Assign user to project 74 r = client.patch( 75 f"/projects/{project_id}", 76 json={"users": [user_name]}, 77 auth=ADMIN, 78 ) 79 assert r.status_code == 200 80 81 82 def test_list_embeddings_non_rag(client): 83 """GET /projects/{id}/embeddings on a block project should return 400.""" 84 r = client.get( 85 f"/projects/{project_id}/embeddings", 86 auth=(user_name, user_pass), 87 ) 88 assert r.status_code == 400, f"Expected 400, got {r.status_code}" 89 assert "RAG" in r.json().get("detail", "") 90 91 92 def test_search_non_rag(client): 93 """POST /projects/{id}/embeddings/search on a block project should return 400.""" 94 r = client.post( 95 f"/projects/{project_id}/embeddings/search", 96 json={"text": "test query"}, 97 auth=(user_name, user_pass), 98 ) 99 assert r.status_code == 400, f"Expected 400, got {r.status_code}" 100 assert "RAG" in r.json().get("detail", "") 101 102 103 def test_reset_non_rag(client): 104 """POST /projects/{id}/embeddings/reset on a block project should return 400.""" 105 r = client.post( 106 f"/projects/{project_id}/embeddings/reset", 107 auth=(user_name, user_pass), 108 ) 109 assert r.status_code == 400, f"Expected 400, got {r.status_code}" 110 assert "RAG" in r.json().get("detail", "") 111 112 113 def test_cleanup(client): 114 """Clean up resources created for embeddings tests.""" 115 if project_id: 116 client.delete(f"/projects/{project_id}", auth=ADMIN) 117 if team_id: 118 client.delete(f"/teams/{team_id}", auth=ADMIN) 119 client.delete(f"/users/{user_name}", auth=ADMIN) 120 client.delete(f"/llms/{llm_name}", auth=ADMIN)