/ tests / test_settings.py
test_settings.py
  1  import random
  2  import pytest
  3  from fastapi.testclient import TestClient
  4  
  5  from restai import config
  6  from restai.config import RESTAI_DEFAULT_PASSWORD
  7  from restai.main import app
  8  
  9  
 10  @pytest.fixture(scope="module")
 11  def client():
 12      with TestClient(app) as c:
 13          yield c
 14  
 15  
 16  test_user = "test_settings_user_" + str(random.randint(0, 1000000))
 17  
 18  
 19  def test_get_settings(client):
 20      response = client.get("/settings", auth=("admin", RESTAI_DEFAULT_PASSWORD))
 21      assert response.status_code == 200
 22      data = response.json()
 23      for key in (
 24          "app_name",
 25          "hide_branding",
 26          "proxy_enabled",
 27          "max_audio_upload_size",
 28      ):
 29          assert key in data
 30  
 31  
 32  def test_get_settings_non_admin(client):
 33      # Create non-admin user
 34      client.post(
 35          "/users",
 36          json={
 37              "username": test_user,
 38              "password": "testpass",
 39              "admin": False,
 40              "private": False,
 41          },
 42          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 43      )
 44  
 45      response = client.get("/settings", auth=(test_user, "testpass"))
 46      assert response.status_code == 403
 47  
 48      # Clean up
 49      client.delete(
 50          f"/users/{test_user}",
 51          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 52      )
 53  
 54  
 55  def test_update_settings(client):
 56      # Get current value
 57      original = client.get(
 58          "/settings", auth=("admin", RESTAI_DEFAULT_PASSWORD)
 59      ).json()
 60      original_name = original["app_name"]
 61  
 62      # Update
 63      response = client.patch(
 64          "/settings",
 65          json={"app_name": "TestApp"},
 66          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 67      )
 68      assert response.status_code == 200
 69      assert response.json()["app_name"] == "TestApp"
 70  
 71      # Restore
 72      client.patch(
 73          "/settings",
 74          json={"app_name": original_name},
 75          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 76      )
 77  
 78  
 79  def test_update_settings_invalid(client):
 80      response = client.patch(
 81          "/settings",
 82          json={"max_audio_upload_size": 0},
 83          auth=("admin", RESTAI_DEFAULT_PASSWORD),
 84      )
 85      assert response.status_code == 400
 86  
 87  
 88  def test_update_settings_bool(client):
 89      # Get current value
 90      original = client.get(
 91          "/settings", auth=("admin", RESTAI_DEFAULT_PASSWORD)
 92      ).json()
 93      original_val = original["hide_branding"]
 94  
 95      # Update
 96      response = client.patch(
 97          "/settings",
 98          json={"hide_branding": True},
 99          auth=("admin", RESTAI_DEFAULT_PASSWORD),
100      )
101      assert response.status_code == 200
102      assert response.json()["hide_branding"] is True
103  
104      # Restore
105      client.patch(
106          "/settings",
107          json={"hide_branding": original_val},
108          auth=("admin", RESTAI_DEFAULT_PASSWORD),
109      )
110  
111  
112  def test_sso_auto_restricted_default(client):
113      """SSO auto-restricted should default to True."""
114      response = client.get("/settings", auth=("admin", RESTAI_DEFAULT_PASSWORD))
115      assert response.status_code == 200
116      data = response.json()
117      assert "sso_auto_restricted" in data
118      assert data["sso_auto_restricted"] is True
119  
120  
121  def test_sso_auto_team_id_default(client):
122      """SSO auto team ID should default to empty string."""
123      response = client.get("/settings", auth=("admin", RESTAI_DEFAULT_PASSWORD))
124      assert response.status_code == 200
125      data = response.json()
126      assert "sso_auto_team_id" in data
127      assert data["sso_auto_team_id"] == ""
128  
129  
130  def test_update_sso_auto_restricted(client):
131      """Should be able to toggle SSO auto-restricted setting."""
132      # Disable
133      response = client.patch(
134          "/settings",
135          json={"sso_auto_restricted": False},
136          auth=("admin", RESTAI_DEFAULT_PASSWORD),
137      )
138      assert response.status_code == 200
139      assert response.json()["sso_auto_restricted"] is False
140      assert config.SSO_AUTO_RESTRICTED is False
141  
142      # Re-enable
143      response = client.patch(
144          "/settings",
145          json={"sso_auto_restricted": True},
146          auth=("admin", RESTAI_DEFAULT_PASSWORD),
147      )
148      assert response.status_code == 200
149      assert response.json()["sso_auto_restricted"] is True
150      assert config.SSO_AUTO_RESTRICTED is True
151  
152  
153  test_team_name = "test_sso_team_" + str(random.randint(0, 1000000))
154  
155  
156  def test_update_sso_auto_team_id(client):
157      """Should be able to set a default team for SSO users."""
158      # Create a team
159      resp = client.post(
160          "/teams",
161          json={"name": test_team_name},
162          auth=("admin", RESTAI_DEFAULT_PASSWORD),
163      )
164      assert resp.status_code in (200, 201)
165      team_id = str(resp.json()["id"])
166  
167      # Set the team
168      response = client.patch(
169          "/settings",
170          json={"sso_auto_team_id": team_id},
171          auth=("admin", RESTAI_DEFAULT_PASSWORD),
172      )
173      assert response.status_code == 200
174      assert response.json()["sso_auto_team_id"] == team_id
175      assert config.SSO_AUTO_TEAM_ID == team_id
176  
177      # Clear it
178      response = client.patch(
179          "/settings",
180          json={"sso_auto_team_id": ""},
181          auth=("admin", RESTAI_DEFAULT_PASSWORD),
182      )
183      assert response.status_code == 200
184      assert response.json()["sso_auto_team_id"] == ""
185  
186      # Clean up
187      client.delete(
188          f"/teams/{team_id}",
189          auth=("admin", RESTAI_DEFAULT_PASSWORD),
190      )