/ tests / test_project_prompts.py
test_project_prompts.py
  1  import random
  2  import pytest
  3  from fastapi.testclient import TestClient
  4  
  5  from restai.config import RESTAI_DEFAULT_PASSWORD
  6  from restai.main import app
  7  
  8  suffix = str(random.randint(0, 10000000))
  9  team_name = f"prompts_team_{suffix}"
 10  llm_name = f"prompts_llm_{suffix}"
 11  project_name = f"prompts_proj_{suffix}"
 12  
 13  team_id = None
 14  project_id = None
 15  version_id = None
 16  
 17  ADMIN = ("admin", RESTAI_DEFAULT_PASSWORD)
 18  
 19  
 20  @pytest.fixture(scope="module")
 21  def client():
 22      with TestClient(app) as c:
 23          yield c
 24  
 25  
 26  def test_setup(client):
 27      """Create team, LLM, and block project for prompt version tests."""
 28      global team_id, project_id
 29      # Create LLM
 30      client.post(
 31          "/llms",
 32          json={
 33              "name": llm_name,
 34              "class_name": "OpenAI",
 35              "options": {"model": "gpt-test", "api_key": "sk-fake"},
 36              "privacy": "public",
 37          },
 38          auth=ADMIN,
 39      )
 40  
 41      # Create team
 42      resp = client.post(
 43          "/teams",
 44          json={"name": team_name, "users": [], "admins": [], "llms": [llm_name]},
 45          auth=ADMIN,
 46      )
 47      assert resp.status_code == 201
 48      team_id = resp.json()["id"]
 49  
 50      # Create block project
 51      resp = client.post(
 52          "/projects",
 53          json={"name": project_name, "type": "block", "team_id": team_id},
 54          auth=ADMIN,
 55      )
 56      assert resp.status_code == 201
 57      project_id = resp.json()["project"]
 58  
 59  
 60  def test_list_prompts_initial(client):
 61      """A new project with a system prompt has an initial version."""
 62      resp = client.get(f"/projects/{project_id}/prompts", auth=ADMIN)
 63      assert resp.status_code == 200
 64      assert isinstance(resp.json(), list)
 65  
 66  
 67  def test_auto_version_on_edit(client):
 68      """Editing the system prompt creates a prompt version automatically."""
 69      global version_id
 70      # Set a system prompt
 71      resp = client.patch(
 72          f"/projects/{project_id}",
 73          json={"system": "First version of the prompt."},
 74          auth=ADMIN,
 75      )
 76      assert resp.status_code == 200
 77  
 78      # List prompt versions
 79      resp = client.get(f"/projects/{project_id}/prompts", auth=ADMIN)
 80      assert resp.status_code == 200
 81      versions = resp.json()
 82      assert len(versions) >= 1
 83      version_id = versions[0]["id"]
 84      assert versions[0]["system_prompt"] == "First version of the prompt."
 85  
 86  
 87  def test_get_version(client):
 88      """Retrieve a specific prompt version by ID."""
 89      resp = client.get(
 90          f"/projects/{project_id}/prompts/{version_id}", auth=ADMIN
 91      )
 92      assert resp.status_code == 200
 93      data = resp.json()
 94      assert data["id"] == version_id
 95      assert data["project_id"] == project_id
 96      assert data["system_prompt"] == "First version of the prompt."
 97      assert "version" in data
 98      assert "created_at" in data
 99  
100  
101  def test_activate_version(client):
102      """Update prompt, then activate the old version to restore it."""
103      # Change the prompt to something new
104      resp = client.patch(
105          f"/projects/{project_id}",
106          json={"system": "Second version of the prompt."},
107          auth=ADMIN,
108      )
109      assert resp.status_code == 200
110  
111      # Verify the project now has the new prompt
112      resp = client.get(f"/projects/{project_id}", auth=ADMIN)
113      assert resp.json()["system"] == "Second version of the prompt."
114  
115      # Activate the first version
116      resp = client.post(
117          f"/projects/{project_id}/prompts/{version_id}/activate",
118          auth=ADMIN,
119      )
120      assert resp.status_code == 200
121  
122      # Verify the prompt was restored
123      resp = client.get(f"/projects/{project_id}", auth=ADMIN)
124      assert resp.json()["system"] == "First version of the prompt."
125  
126  
127  def test_cleanup(client):
128      """Remove all test resources."""
129      if project_id:
130          client.delete(f"/projects/{project_id}", auth=ADMIN)
131      if team_id:
132          client.delete(f"/teams/{team_id}", auth=ADMIN)
133      client.delete(f"/llms/{llm_name}", auth=ADMIN)