/ tests / test_project_embeddings.py
test_project_embeddings.py
  1  """Tests for project embeddings endpoints on non-RAG projects."""
  2  
  3  import random
  4  import pytest
  5  from fastapi.testclient import TestClient
  6  
  7  from restai.config import RESTAI_DEFAULT_PASSWORD
  8  from restai.main import app
  9  
 10  ADMIN = ("admin", RESTAI_DEFAULT_PASSWORD)
 11  
 12  suffix = str(random.randint(0, 999999))
 13  user_name = f"emb_user_{suffix}"
 14  user_pass = "emb_pass_123"
 15  team_name = f"emb_team_{suffix}"
 16  llm_name = f"emb_llm_{suffix}"
 17  project_name = f"emb-proj-{suffix}"
 18  
 19  team_id = None
 20  project_id = None
 21  
 22  
 23  @pytest.fixture(scope="module")
 24  def client():
 25      with TestClient(app) as c:
 26          yield c
 27  
 28  
 29  def test_setup(client):
 30      """Create user, team, LLM, and block project for embeddings tests."""
 31      global team_id, project_id
 32  
 33      r = client.post(
 34          "/users",
 35          json={"username": user_name, "password": user_pass, "is_admin": False, "is_private": False},
 36          auth=ADMIN,
 37      )
 38      assert r.status_code == 201
 39  
 40      r = client.post(
 41          "/llms",
 42          json={
 43              "name": llm_name,
 44              "class_name": "OpenAI",
 45              "options": {"model": "gpt-test", "api_key": "sk-fake"},
 46              "privacy": "public",
 47          },
 48          auth=ADMIN,
 49      )
 50      assert r.status_code in (200, 201)
 51  
 52      r = client.post(
 53          "/teams",
 54          json={"name": team_name, "users": [user_name], "llms": [llm_name]},
 55          auth=ADMIN,
 56      )
 57      assert r.status_code == 201
 58      team_id = r.json()["id"]
 59  
 60      r = client.post(
 61          "/projects",
 62          json={
 63              "name": project_name,
 64              "llm": llm_name,
 65              "type": "block",
 66              "team_id": team_id,
 67          },
 68          auth=ADMIN,
 69      )
 70      assert r.status_code == 201
 71      project_id = r.json()["project"]
 72  
 73      # Assign user to project
 74      r = client.patch(
 75          f"/projects/{project_id}",
 76          json={"users": [user_name]},
 77          auth=ADMIN,
 78      )
 79      assert r.status_code == 200
 80  
 81  
 82  def test_list_embeddings_non_rag(client):
 83      """GET /projects/{id}/embeddings on a block project should return 400."""
 84      r = client.get(
 85          f"/projects/{project_id}/embeddings",
 86          auth=(user_name, user_pass),
 87      )
 88      assert r.status_code == 400, f"Expected 400, got {r.status_code}"
 89      assert "RAG" in r.json().get("detail", "")
 90  
 91  
 92  def test_search_non_rag(client):
 93      """POST /projects/{id}/embeddings/search on a block project should return 400."""
 94      r = client.post(
 95          f"/projects/{project_id}/embeddings/search",
 96          json={"text": "test query"},
 97          auth=(user_name, user_pass),
 98      )
 99      assert r.status_code == 400, f"Expected 400, got {r.status_code}"
100      assert "RAG" in r.json().get("detail", "")
101  
102  
103  def test_reset_non_rag(client):
104      """POST /projects/{id}/embeddings/reset on a block project should return 400."""
105      r = client.post(
106          f"/projects/{project_id}/embeddings/reset",
107          auth=(user_name, user_pass),
108      )
109      assert r.status_code == 400, f"Expected 400, got {r.status_code}"
110      assert "RAG" in r.json().get("detail", "")
111  
112  
113  def test_cleanup(client):
114      """Clean up resources created for embeddings tests."""
115      if project_id:
116          client.delete(f"/projects/{project_id}", auth=ADMIN)
117      if team_id:
118          client.delete(f"/teams/{team_id}", auth=ADMIN)
119      client.delete(f"/users/{user_name}", auth=ADMIN)
120      client.delete(f"/llms/{llm_name}", auth=ADMIN)