/ tests / gateway / test_runner.py
test_runner.py
  1  from pathlib import Path
  2  
  3  import pytest
  4  
  5  from tests.gateway.tools import Gateway, save_yaml
  6  
  7  BASE_ROUTE = "/api/2.0/endpoints/"
  8  
  9  
 10  @pytest.fixture
 11  def basic_config_dict():
 12      return {
 13          "endpoints": [
 14              {
 15                  "name": "completions-gpt4",
 16                  "endpoint_type": "llm/v1/completions",
 17                  "model": {
 18                      "name": "gpt-4",
 19                      "provider": "openai",
 20                      "config": {
 21                          "openai_api_key": "mykey",
 22                          "openai_api_base": "https://api.openai.com/v1",
 23                          "openai_api_version": "2023-05-15",
 24                          "openai_api_type": "openai",
 25                      },
 26                  },
 27              },
 28              {
 29                  "name": "embeddings-gpt4",
 30                  "endpoint_type": "llm/v1/embeddings",
 31                  "model": {
 32                      "name": "gpt-4",
 33                      "provider": "openai",
 34                      "config": {
 35                          "openai_api_key": "mykey",
 36                          "openai_api_base": "https://api.openai.com/v1",
 37                          "openai_api_version": "2023-05-15",
 38                          "openai_api_type": "openai",
 39                      },
 40                  },
 41              },
 42          ]
 43      }
 44  
 45  
 46  @pytest.fixture
 47  def basic_routes():
 48      return [
 49          {
 50              "name": "completions-gpt4",
 51              "endpoint_type": "llm/v1/completions",
 52              "endpoint_url": "/gateway/completions-gpt4/invocations",
 53              "model": {
 54                  "name": "gpt-4",
 55                  "provider": "openai",
 56              },
 57              "limit": None,
 58          },
 59          {
 60              "name": "embeddings-gpt4",
 61              "endpoint_type": "llm/v1/embeddings",
 62              "endpoint_url": "/gateway/embeddings-gpt4/invocations",
 63              "model": {
 64                  "name": "gpt-4",
 65                  "provider": "openai",
 66              },
 67              "limit": None,
 68          },
 69      ]
 70  
 71  
 72  @pytest.fixture
 73  def update_config_dict():
 74      return {
 75          "endpoints": [
 76              {
 77                  "name": "chat-gpt4",
 78                  "endpoint_type": "llm/v1/chat",
 79                  "model": {
 80                      "name": "gpt-4",
 81                      "provider": "openai",
 82                      "config": {
 83                          "openai_api_key": "mykey",
 84                          "openai_api_base": "https://api.openai.com/v1",
 85                          "openai_api_version": "2023-05-15",
 86                          "openai_api_type": "openai",
 87                      },
 88                  },
 89                  "limit": None,
 90              },
 91          ]
 92      }
 93  
 94  
 95  @pytest.fixture
 96  def update_routes():
 97      return [
 98          {
 99              "name": "chat-gpt4",
100              "endpoint_type": "llm/v1/chat",
101              "endpoint_url": "/gateway/chat-gpt4/invocations",
102              "model": {
103                  "name": "gpt-4",
104                  "provider": "openai",
105              },
106              "limit": None,
107          },
108      ]
109  
110  
111  @pytest.fixture
112  def invalid_config_dict():
113      return {
114          "endpoints": [
115              {
116                  "invalid_name": "invalid",
117                  "endpoint_type": "llm/v1/chat",
118                  "model": {"invalidkey": "invalid", "invalid_provider": "invalid"},
119              }
120          ]
121      }
122  
123  
124  def test_server_update(
125      tmp_path: Path, basic_config_dict, update_config_dict, basic_routes, update_routes
126  ):
127      config = tmp_path / "config.yaml"
128      save_yaml(config, basic_config_dict)
129  
130      with Gateway(config) as gateway:
131          response = gateway.get(BASE_ROUTE)
132          assert response.json()["endpoints"] == basic_routes
133  
134          # push an update to the config file
135          save_yaml(config, update_config_dict)
136  
137          # Ensure there is no server downtime
138          gateway.assert_health()
139  
140          # Wait for the app to restart
141          gateway.wait_reload()
142          response = gateway.get(BASE_ROUTE)
143  
144          assert response.json()["endpoints"] == update_routes
145  
146          # push the original file back
147          save_yaml(config, basic_config_dict)
148          gateway.assert_health()
149          gateway.wait_reload()
150          response = gateway.get(BASE_ROUTE)
151          assert response.json()["endpoints"] == basic_routes
152  
153  
154  def test_server_update_with_invalid_config(
155      tmp_path: Path, basic_config_dict, invalid_config_dict, basic_routes
156  ):
157      config = tmp_path / "config.yaml"
158      save_yaml(config, basic_config_dict)
159  
160      with Gateway(config) as gateway:
161          response = gateway.get(BASE_ROUTE)
162          assert response.json()["endpoints"] == basic_routes
163          # Give filewatch a moment to cycle
164          gateway.wait_reload()
165          # push an invalid config
166          save_yaml(config, invalid_config_dict)
167          gateway.assert_health()
168          # ensure that filewatch has run through the aborted config change logic
169          gateway.wait_reload()
170          gateway.assert_health()
171          response = gateway.get(BASE_ROUTE)
172          assert response.json()["endpoints"] == basic_routes
173  
174  
175  def test_server_update_config_removed_then_recreated(
176      tmp_path: Path, basic_config_dict, basic_routes
177  ):
178      config = tmp_path / "config.yaml"
179      save_yaml(config, basic_config_dict)
180  
181      with Gateway(config) as gateway:
182          response = gateway.get(BASE_ROUTE)
183          assert response.json()["endpoints"] == basic_routes
184          # Give filewatch a moment to cycle
185          gateway.wait_reload()
186          # remove config
187          config.unlink()
188          gateway.wait_reload()
189          gateway.assert_health()
190  
191          save_yaml(config, {"endpoints": basic_config_dict["endpoints"][1:]})
192          gateway.wait_reload()
193          response = gateway.get(BASE_ROUTE)
194          assert response.json()["endpoints"] == basic_routes[1:]
195  
196  
197  def test_server_static_endpoints(tmp_path, basic_config_dict, basic_routes):
198      config = tmp_path / "config.yaml"
199      save_yaml(config, basic_config_dict)
200  
201      with Gateway(config) as gateway:
202          response = gateway.get(BASE_ROUTE)
203          assert response.json()["endpoints"] == basic_routes
204  
205          for route in ["docs", "redoc"]:
206              response = gateway.get(route)
207              assert response.status_code == 200
208  
209          for index, route in enumerate(basic_config_dict["endpoints"]):
210              response = gateway.get(f"{BASE_ROUTE}{route['name']}")
211              assert response.json() == basic_routes[index]
212  
213  
214  def test_request_invalid_route(tmp_path, basic_config_dict):
215      config = tmp_path / "config.yaml"
216      save_yaml(config, basic_config_dict)
217  
218      with Gateway(config) as gateway:
219          # Test get
220          response = gateway.get(f"{BASE_ROUTE}invalid/")
221          assert response.status_code == 404
222          assert response.json() == {
223              "detail": "The endpoint 'invalid' is not present or active on the server. Please "
224              "verify the endpoint name."
225          }
226  
227          # Test post
228          response = gateway.post(f"{BASE_ROUTE}invalid", json={"input": "should fail"})
229          assert response.status_code == 405
230          assert response.json() == {"detail": "Method Not Allowed"}