/ tests / test_evals.py
test_evals.py
  1  import random
  2  import pytest
  3  from fastapi.testclient import TestClient
  4  
  5  from restai.config import RESTAI_DEFAULT_PASSWORD
  6  from restai.main import app
  7  
  8  suffix = str(random.randint(0, 10000000))
  9  team_name = f"evals_team_{suffix}"
 10  llm_name = f"evals_llm_{suffix}"
 11  project_name = f"evals_proj_{suffix}"
 12  
 13  team_id = None
 14  project_id = None
 15  dataset_id = None
 16  case_id = None
 17  
 18  ADMIN = ("admin", RESTAI_DEFAULT_PASSWORD)
 19  
 20  
 21  @pytest.fixture(scope="module")
 22  def client():
 23      with TestClient(app) as c:
 24          yield c
 25  
 26  
 27  def test_setup(client):
 28      """Create team, LLM, and block project for evaluation tests."""
 29      global team_id, project_id
 30      # Create LLM
 31      client.post(
 32          "/llms",
 33          json={
 34              "name": llm_name,
 35              "class_name": "OpenAI",
 36              "options": {"model": "gpt-test", "api_key": "sk-fake"},
 37              "privacy": "public",
 38          },
 39          auth=ADMIN,
 40      )
 41  
 42      # Create team
 43      resp = client.post(
 44          "/teams",
 45          json={"name": team_name, "users": [], "admins": [], "llms": [llm_name]},
 46          auth=ADMIN,
 47      )
 48      assert resp.status_code == 201
 49      team_id = resp.json()["id"]
 50  
 51      # Create block project
 52      resp = client.post(
 53          "/projects",
 54          json={"name": project_name, "type": "block", "team_id": team_id},
 55          auth=ADMIN,
 56      )
 57      assert resp.status_code == 201
 58      project_id = resp.json()["project"]
 59  
 60  
 61  def test_create_dataset(client):
 62      """Create an evaluation dataset."""
 63      global dataset_id
 64      resp = client.post(
 65          f"/projects/{project_id}/evals/datasets",
 66          json={"name": "test-dataset"},
 67          auth=ADMIN,
 68      )
 69      assert resp.status_code == 201
 70      data = resp.json()
 71      assert data["name"] == "test-dataset"
 72      assert data["project_id"] == project_id
 73      assert "id" in data
 74      dataset_id = data["id"]
 75  
 76  
 77  def test_list_datasets(client):
 78      """List datasets for the project includes the created dataset."""
 79      resp = client.get(
 80          f"/projects/{project_id}/evals/datasets", auth=ADMIN
 81      )
 82      assert resp.status_code == 200
 83      datasets = resp.json()
 84      assert any(d["id"] == dataset_id for d in datasets)
 85  
 86  
 87  def test_get_dataset(client):
 88      """Get dataset details by ID."""
 89      resp = client.get(
 90          f"/projects/{project_id}/evals/datasets/{dataset_id}", auth=ADMIN
 91      )
 92      assert resp.status_code == 200
 93      data = resp.json()
 94      assert data["id"] == dataset_id
 95      assert data["name"] == "test-dataset"
 96      assert "test_cases" in data
 97      assert isinstance(data["test_cases"], list)
 98  
 99  
100  def test_add_test_case(client):
101      """Add a test case to the dataset."""
102      global case_id
103      resp = client.post(
104          f"/projects/{project_id}/evals/datasets/{dataset_id}/cases",
105          json={"question": "What is 2+2?", "expected_answer": "4"},
106          auth=ADMIN,
107      )
108      assert resp.status_code == 201
109      data = resp.json()
110      assert data["question"] == "What is 2+2?"
111      assert data["expected_answer"] == "4"
112      assert "id" in data
113      case_id = data["id"]
114  
115  
116  def test_update_dataset(client):
117      """Update the dataset name."""
118      resp = client.patch(
119          f"/projects/{project_id}/evals/datasets/{dataset_id}",
120          json={"name": "updated-dataset"},
121          auth=ADMIN,
122      )
123      assert resp.status_code == 200
124      assert resp.json()["name"] == "updated-dataset"
125  
126  
127  def test_list_runs_empty(client):
128      """No evaluation runs exist initially."""
129      resp = client.get(
130          f"/projects/{project_id}/evals/runs", auth=ADMIN
131      )
132      assert resp.status_code == 200
133      assert resp.json() == []
134  
135  
136  def test_get_nonexistent_dataset(client):
137      """Getting a dataset that doesn't exist returns 404."""
138      resp = client.get(
139          f"/projects/{project_id}/evals/datasets/999999", auth=ADMIN
140      )
141      assert resp.status_code == 404
142  
143  
144  def test_delete_test_case(client):
145      """Delete a test case from the dataset."""
146      resp = client.delete(
147          f"/projects/{project_id}/evals/datasets/{dataset_id}/cases/{case_id}",
148          auth=ADMIN,
149      )
150      assert resp.status_code == 200
151      assert resp.json()["deleted"] is True
152  
153      # Verify test case is gone
154      resp = client.get(
155          f"/projects/{project_id}/evals/datasets/{dataset_id}", auth=ADMIN
156      )
157      assert resp.json()["test_case_count"] == 0
158  
159  
160  def test_delete_dataset(client):
161      """Delete the dataset."""
162      resp = client.delete(
163          f"/projects/{project_id}/evals/datasets/{dataset_id}", auth=ADMIN
164      )
165      assert resp.status_code == 200
166      assert resp.json()["deleted"] is True
167  
168      # Verify dataset is gone
169      resp = client.get(
170          f"/projects/{project_id}/evals/datasets/{dataset_id}", auth=ADMIN
171      )
172      assert resp.status_code == 404
173  
174  
175  def test_cleanup(client):
176      """Remove all test resources."""
177      if project_id:
178          client.delete(f"/projects/{project_id}", auth=ADMIN)
179      if team_id:
180          client.delete(f"/teams/{team_id}", auth=ADMIN)
181      client.delete(f"/llms/{llm_name}", auth=ADMIN)