/ tests / deployments / mlflow / test_mlflow.py
test_mlflow.py
  1  from unittest import mock
  2  
  3  import pytest
  4  
  5  from mlflow.deployments import get_deploy_client
  6  from mlflow.deployments.mlflow import MlflowDeploymentClient
  7  from mlflow.environment_variables import MLFLOW_DEPLOYMENT_CLIENT_HTTP_REQUEST_TIMEOUT
  8  
  9  
 10  def test_get_deploy_client():
 11      client = get_deploy_client("http://localhost:5000")
 12      assert isinstance(client, MlflowDeploymentClient)
 13  
 14  
 15  def test_create_endpoint():
 16      client = get_deploy_client("http://localhost:5000")
 17      with pytest.raises(NotImplementedError, match=r".*"):
 18          client.create_endpoint(name="test")
 19  
 20  
 21  def test_update_endpoint():
 22      client = get_deploy_client("http://localhost:5000")
 23      with pytest.raises(NotImplementedError, match=r".*"):
 24          client.update_endpoint(endpoint="test")
 25  
 26  
 27  def test_delete_endpoint():
 28      client = get_deploy_client("http://localhost:5000")
 29      with pytest.raises(NotImplementedError, match=r".*"):
 30          client.delete_endpoint(endpoint="test")
 31  
 32  
 33  def test_get_endpoint():
 34      client = get_deploy_client("http://localhost:5000")
 35      mock_resp = mock.Mock()
 36      mock_resp.json.return_value = {
 37          "model": {"name": "gpt-4", "provider": "openai"},
 38          "name": "completions",
 39          "endpoint_type": "llm/v1/completions",
 40          "endpoint_url": "http://localhost:5000/endpoints/chat/invocations",
 41          "limit": None,
 42      }
 43      mock_resp.status_code = 200
 44      with mock.patch("requests.Session.request", return_value=mock_resp) as mock_request:
 45          resp = client.get_endpoint(endpoint="test")
 46          mock_request.assert_called_once()
 47          assert resp.model_dump() == {
 48              "name": "completions",
 49              "endpoint_type": "llm/v1/completions",
 50              "model": {"name": "gpt-4", "provider": "openai"},
 51              "endpoint_url": "http://localhost:5000/endpoints/chat/invocations",
 52              "limit": None,
 53          }
 54          ((_, url), _) = mock_request.call_args
 55          assert url == "http://localhost:5000/api/2.0/endpoints/test"
 56  
 57  
 58  def test_list_endpoints():
 59      client = get_deploy_client("http://localhost:5000")
 60      mock_resp = mock.Mock()
 61      mock_resp.json.return_value = {
 62          "endpoints": [
 63              {
 64                  "model": {"name": "gpt-4", "provider": "openai"},
 65                  "name": "completions",
 66                  "endpoint_type": "llm/v1/completions",
 67                  "endpoint_url": "http://localhost:5000/endpoints/chat/invocations",
 68                  "limit": None,
 69              }
 70          ]
 71      }
 72      mock_resp.status_code = 200
 73      with mock.patch("requests.Session.request", return_value=mock_resp) as mock_request:
 74          resp = client.list_endpoints()
 75          mock_request.assert_called_once()
 76          assert [r.model_dump() for r in resp] == [
 77              {
 78                  "model": {"name": "gpt-4", "provider": "openai"},
 79                  "name": "completions",
 80                  "endpoint_type": "llm/v1/completions",
 81                  "endpoint_url": "http://localhost:5000/endpoints/chat/invocations",
 82                  "limit": None,
 83              }
 84          ]
 85          ((_, url), _) = mock_request.call_args
 86          assert url == "http://localhost:5000/api/2.0/endpoints/"
 87  
 88  
 89  def test_list_endpoints_paginated():
 90      client = get_deploy_client("http://localhost:5000")
 91      mock_resp = mock.Mock()
 92      mock_resp.json.side_effect = [
 93          {
 94              "endpoints": [
 95                  {
 96                      "model": {"name": "gpt-4", "provider": "openai"},
 97                      "name": "chat",
 98                      "endpoint_type": "llm/v1/chat",
 99                      "endpoint_url": "http://localhost:5000/endpoints/chat/invocations",
100                      "limit": None,
101                  }
102              ],
103              "next_page_token": "token",
104          },
105          {
106              "endpoints": [
107                  {
108                      "model": {"name": "gpt-4", "provider": "openai"},
109                      "name": "completions",
110                      "endpoint_type": "llm/v1/completions",
111                      "endpoint_url": "http://localhost:5000/endpoints/chat/invocations",
112                      "limit": None,
113                  }
114              ]
115          },
116      ]
117      mock_resp.status_code = 200
118      with mock.patch("requests.Session.request", return_value=mock_resp) as mock_request:
119          resp = client.list_endpoints()
120          assert mock_request.call_count == 2
121          assert [r.model_dump() for r in resp] == [
122              {
123                  "model": {"name": "gpt-4", "provider": "openai"},
124                  "name": "chat",
125                  "endpoint_type": "llm/v1/chat",
126                  "endpoint_url": "http://localhost:5000/endpoints/chat/invocations",
127                  "limit": None,
128              },
129              {
130                  "model": {"name": "gpt-4", "provider": "openai"},
131                  "name": "completions",
132                  "endpoint_type": "llm/v1/completions",
133                  "endpoint_url": "http://localhost:5000/endpoints/chat/invocations",
134                  "limit": None,
135              },
136          ]
137  
138  
139  def test_predict():
140      client = get_deploy_client("http://localhost:5000")
141      mock_resp = mock.Mock()
142      mock_resp.json.return_value = {
143          "id": "chatcmpl-123",
144          "object": "chat.completion",
145          "created": 1677652288,
146          "model": "gpt-4o-mini",
147          "choices": [
148              {
149                  "index": 0,
150                  "message": {
151                      "role": "assistant",
152                      "content": "hello",
153                  },
154                  "finish_reason": "stop",
155              }
156          ],
157          "usage": {
158              "prompt_tokens": 9,
159              "completion_tokens": 12,
160              "total_tokens": 21,
161          },
162      }
163  
164      mock_resp.status_code = 200
165      with mock.patch("requests.Session.request", return_value=mock_resp) as mock_request:
166          resp = client.predict(endpoint="test", inputs={})
167          mock_request.assert_called_once()
168          assert resp == {
169              "id": "chatcmpl-123",
170              "object": "chat.completion",
171              "created": 1677652288,
172              "model": "gpt-4o-mini",
173              "choices": [
174                  {
175                      "index": 0,
176                      "message": {"role": "assistant", "content": "hello"},
177                      "finish_reason": "stop",
178                  }
179              ],
180              "usage": {
181                  "prompt_tokens": 9,
182                  "completion_tokens": 12,
183                  "total_tokens": 21,
184              },
185          }
186          ((_, url), _) = mock_request.call_args
187          assert url == "http://localhost:5000/endpoints/test/invocations"
188  
189  
190  def test_call_endpoint_uses_default_timeout():
191      client = get_deploy_client("http://localhost:5000")
192  
193      with mock.patch("mlflow.deployments.mlflow.http_request") as mock_http_request:
194          mock_http_request.return_value.json.return_value = {"test": "response"}
195          mock_http_request.return_value.status_code = 200
196  
197          client._call_endpoint("GET", "/test")
198  
199          mock_http_request.assert_called_once()
200          call_args = mock_http_request.call_args
201          assert call_args.kwargs["timeout"] == MLFLOW_DEPLOYMENT_CLIENT_HTTP_REQUEST_TIMEOUT.get()
202  
203  
204  def test_call_endpoint_respects_custom_timeout():
205      client = get_deploy_client("http://localhost:5000")
206      custom_timeout = 600
207  
208      with mock.patch("mlflow.deployments.mlflow.http_request") as mock_http_request:
209          mock_http_request.return_value.json.return_value = {"test": "response"}
210          mock_http_request.return_value.status_code = 200
211  
212          client._call_endpoint("GET", "/test", timeout=custom_timeout)
213  
214          mock_http_request.assert_called_once()
215          call_args = mock_http_request.call_args
216          assert call_args.kwargs["timeout"] == custom_timeout
217  
218  
219  def test_call_endpoint_timeout_with_environment_variable(monkeypatch):
220      custom_timeout = 420
221      monkeypatch.setenv("MLFLOW_DEPLOYMENT_CLIENT_HTTP_REQUEST_TIMEOUT", str(custom_timeout))
222  
223      client = get_deploy_client("http://localhost:5000")
224  
225      with mock.patch("mlflow.deployments.mlflow.http_request") as mock_http_request:
226          mock_http_request.return_value.json.return_value = {"test": "response"}
227          mock_http_request.return_value.status_code = 200
228  
229          client._call_endpoint("GET", "/test")
230  
231          mock_http_request.assert_called_once()
232          call_args = mock_http_request.call_args
233          assert call_args.kwargs["timeout"] == custom_timeout
234  
235  
236  def test_get_endpoint_uses_deployment_client_timeout():
237      client = get_deploy_client("http://localhost:5000")
238  
239      with mock.patch("mlflow.deployments.mlflow.http_request") as mock_http_request:
240          mock_http_request.return_value.json.return_value = {
241              "model": {"name": "gpt-4", "provider": "openai"},
242              "name": "test",
243              "endpoint_type": "llm/v1/chat",
244              "endpoint_url": "http://localhost:5000/endpoints/test/invocations",
245              "limit": None,
246          }
247          mock_http_request.return_value.status_code = 200
248  
249          client.get_endpoint("test")
250  
251          mock_http_request.assert_called_once()
252          call_args = mock_http_request.call_args
253          assert call_args.kwargs["timeout"] == MLFLOW_DEPLOYMENT_CLIENT_HTTP_REQUEST_TIMEOUT.get()
254  
255  
256  def test_list_endpoints_uses_deployment_client_timeout():
257      client = get_deploy_client("http://localhost:5000")
258  
259      with mock.patch("mlflow.deployments.mlflow.http_request") as mock_http_request:
260          mock_http_request.return_value.json.return_value = {"endpoints": []}
261          mock_http_request.return_value.status_code = 200
262  
263          client.list_endpoints()
264  
265          mock_http_request.assert_called_once()
266          call_args = mock_http_request.call_args
267          assert call_args.kwargs["timeout"] == MLFLOW_DEPLOYMENT_CLIENT_HTTP_REQUEST_TIMEOUT.get()