test_project_prompts.py
1 import random 2 import pytest 3 from fastapi.testclient import TestClient 4 5 from restai.config import RESTAI_DEFAULT_PASSWORD 6 from restai.main import app 7 8 suffix = str(random.randint(0, 10000000)) 9 team_name = f"prompts_team_{suffix}" 10 llm_name = f"prompts_llm_{suffix}" 11 project_name = f"prompts_proj_{suffix}" 12 13 team_id = None 14 project_id = None 15 version_id = None 16 17 ADMIN = ("admin", RESTAI_DEFAULT_PASSWORD) 18 19 20 @pytest.fixture(scope="module") 21 def client(): 22 with TestClient(app) as c: 23 yield c 24 25 26 def test_setup(client): 27 """Create team, LLM, and block project for prompt version tests.""" 28 global team_id, project_id 29 # Create LLM 30 client.post( 31 "/llms", 32 json={ 33 "name": llm_name, 34 "class_name": "OpenAI", 35 "options": {"model": "gpt-test", "api_key": "sk-fake"}, 36 "privacy": "public", 37 }, 38 auth=ADMIN, 39 ) 40 41 # Create team 42 resp = client.post( 43 "/teams", 44 json={"name": team_name, "users": [], "admins": [], "llms": [llm_name]}, 45 auth=ADMIN, 46 ) 47 assert resp.status_code == 201 48 team_id = resp.json()["id"] 49 50 # Create block project 51 resp = client.post( 52 "/projects", 53 json={"name": project_name, "type": "block", "team_id": team_id}, 54 auth=ADMIN, 55 ) 56 assert resp.status_code == 201 57 project_id = resp.json()["project"] 58 59 60 def test_list_prompts_initial(client): 61 """A new project with a system prompt has an initial version.""" 62 resp = client.get(f"/projects/{project_id}/prompts", auth=ADMIN) 63 assert resp.status_code == 200 64 assert isinstance(resp.json(), list) 65 66 67 def test_auto_version_on_edit(client): 68 """Editing the system prompt creates a prompt version automatically.""" 69 global version_id 70 # Set a system prompt 71 resp = client.patch( 72 f"/projects/{project_id}", 73 json={"system": "First version of the prompt."}, 74 auth=ADMIN, 75 ) 76 assert resp.status_code == 200 77 78 # List prompt versions 79 resp = client.get(f"/projects/{project_id}/prompts", auth=ADMIN) 80 assert resp.status_code == 200 81 versions = resp.json() 82 assert len(versions) >= 1 83 version_id = versions[0]["id"] 84 assert versions[0]["system_prompt"] == "First version of the prompt." 85 86 87 def test_get_version(client): 88 """Retrieve a specific prompt version by ID.""" 89 resp = client.get( 90 f"/projects/{project_id}/prompts/{version_id}", auth=ADMIN 91 ) 92 assert resp.status_code == 200 93 data = resp.json() 94 assert data["id"] == version_id 95 assert data["project_id"] == project_id 96 assert data["system_prompt"] == "First version of the prompt." 97 assert "version" in data 98 assert "created_at" in data 99 100 101 def test_activate_version(client): 102 """Update prompt, then activate the old version to restore it.""" 103 # Change the prompt to something new 104 resp = client.patch( 105 f"/projects/{project_id}", 106 json={"system": "Second version of the prompt."}, 107 auth=ADMIN, 108 ) 109 assert resp.status_code == 200 110 111 # Verify the project now has the new prompt 112 resp = client.get(f"/projects/{project_id}", auth=ADMIN) 113 assert resp.json()["system"] == "Second version of the prompt." 114 115 # Activate the first version 116 resp = client.post( 117 f"/projects/{project_id}/prompts/{version_id}/activate", 118 auth=ADMIN, 119 ) 120 assert resp.status_code == 200 121 122 # Verify the prompt was restored 123 resp = client.get(f"/projects/{project_id}", auth=ADMIN) 124 assert resp.json()["system"] == "First version of the prompt." 125 126 127 def test_cleanup(client): 128 """Remove all test resources.""" 129 if project_id: 130 client.delete(f"/projects/{project_id}", auth=ADMIN) 131 if team_id: 132 client.delete(f"/teams/{team_id}", auth=ADMIN) 133 client.delete(f"/llms/{llm_name}", auth=ADMIN)