test_mlflow.py
1 from unittest import mock 2 3 import pytest 4 5 from mlflow.deployments import get_deploy_client 6 from mlflow.deployments.mlflow import MlflowDeploymentClient 7 from mlflow.environment_variables import MLFLOW_DEPLOYMENT_CLIENT_HTTP_REQUEST_TIMEOUT 8 9 10 def test_get_deploy_client(): 11 client = get_deploy_client("http://localhost:5000") 12 assert isinstance(client, MlflowDeploymentClient) 13 14 15 def test_create_endpoint(): 16 client = get_deploy_client("http://localhost:5000") 17 with pytest.raises(NotImplementedError, match=r".*"): 18 client.create_endpoint(name="test") 19 20 21 def test_update_endpoint(): 22 client = get_deploy_client("http://localhost:5000") 23 with pytest.raises(NotImplementedError, match=r".*"): 24 client.update_endpoint(endpoint="test") 25 26 27 def test_delete_endpoint(): 28 client = get_deploy_client("http://localhost:5000") 29 with pytest.raises(NotImplementedError, match=r".*"): 30 client.delete_endpoint(endpoint="test") 31 32 33 def test_get_endpoint(): 34 client = get_deploy_client("http://localhost:5000") 35 mock_resp = mock.Mock() 36 mock_resp.json.return_value = { 37 "model": {"name": "gpt-4", "provider": "openai"}, 38 "name": "completions", 39 "endpoint_type": "llm/v1/completions", 40 "endpoint_url": "http://localhost:5000/endpoints/chat/invocations", 41 "limit": None, 42 } 43 mock_resp.status_code = 200 44 with mock.patch("requests.Session.request", return_value=mock_resp) as mock_request: 45 resp = client.get_endpoint(endpoint="test") 46 mock_request.assert_called_once() 47 assert resp.model_dump() == { 48 "name": "completions", 49 "endpoint_type": "llm/v1/completions", 50 "model": {"name": "gpt-4", "provider": "openai"}, 51 "endpoint_url": "http://localhost:5000/endpoints/chat/invocations", 52 "limit": None, 53 } 54 ((_, url), _) = mock_request.call_args 55 assert url == "http://localhost:5000/api/2.0/endpoints/test" 56 57 58 def test_list_endpoints(): 59 client = get_deploy_client("http://localhost:5000") 60 mock_resp = mock.Mock() 61 mock_resp.json.return_value = { 62 "endpoints": [ 63 { 64 "model": {"name": "gpt-4", "provider": "openai"}, 65 "name": "completions", 66 "endpoint_type": "llm/v1/completions", 67 "endpoint_url": "http://localhost:5000/endpoints/chat/invocations", 68 "limit": None, 69 } 70 ] 71 } 72 mock_resp.status_code = 200 73 with mock.patch("requests.Session.request", return_value=mock_resp) as mock_request: 74 resp = client.list_endpoints() 75 mock_request.assert_called_once() 76 assert [r.model_dump() for r in resp] == [ 77 { 78 "model": {"name": "gpt-4", "provider": "openai"}, 79 "name": "completions", 80 "endpoint_type": "llm/v1/completions", 81 "endpoint_url": "http://localhost:5000/endpoints/chat/invocations", 82 "limit": None, 83 } 84 ] 85 ((_, url), _) = mock_request.call_args 86 assert url == "http://localhost:5000/api/2.0/endpoints/" 87 88 89 def test_list_endpoints_paginated(): 90 client = get_deploy_client("http://localhost:5000") 91 mock_resp = mock.Mock() 92 mock_resp.json.side_effect = [ 93 { 94 "endpoints": [ 95 { 96 "model": {"name": "gpt-4", "provider": "openai"}, 97 "name": "chat", 98 "endpoint_type": "llm/v1/chat", 99 "endpoint_url": "http://localhost:5000/endpoints/chat/invocations", 100 "limit": None, 101 } 102 ], 103 "next_page_token": "token", 104 }, 105 { 106 "endpoints": [ 107 { 108 "model": {"name": "gpt-4", "provider": "openai"}, 109 "name": "completions", 110 "endpoint_type": "llm/v1/completions", 111 "endpoint_url": "http://localhost:5000/endpoints/chat/invocations", 112 "limit": None, 113 } 114 ] 115 }, 116 ] 117 mock_resp.status_code = 200 118 with mock.patch("requests.Session.request", return_value=mock_resp) as mock_request: 119 resp = client.list_endpoints() 120 assert mock_request.call_count == 2 121 assert [r.model_dump() for r in resp] == [ 122 { 123 "model": {"name": "gpt-4", "provider": "openai"}, 124 "name": "chat", 125 "endpoint_type": "llm/v1/chat", 126 "endpoint_url": "http://localhost:5000/endpoints/chat/invocations", 127 "limit": None, 128 }, 129 { 130 "model": {"name": "gpt-4", "provider": "openai"}, 131 "name": "completions", 132 "endpoint_type": "llm/v1/completions", 133 "endpoint_url": "http://localhost:5000/endpoints/chat/invocations", 134 "limit": None, 135 }, 136 ] 137 138 139 def test_predict(): 140 client = get_deploy_client("http://localhost:5000") 141 mock_resp = mock.Mock() 142 mock_resp.json.return_value = { 143 "id": "chatcmpl-123", 144 "object": "chat.completion", 145 "created": 1677652288, 146 "model": "gpt-4o-mini", 147 "choices": [ 148 { 149 "index": 0, 150 "message": { 151 "role": "assistant", 152 "content": "hello", 153 }, 154 "finish_reason": "stop", 155 } 156 ], 157 "usage": { 158 "prompt_tokens": 9, 159 "completion_tokens": 12, 160 "total_tokens": 21, 161 }, 162 } 163 164 mock_resp.status_code = 200 165 with mock.patch("requests.Session.request", return_value=mock_resp) as mock_request: 166 resp = client.predict(endpoint="test", inputs={}) 167 mock_request.assert_called_once() 168 assert resp == { 169 "id": "chatcmpl-123", 170 "object": "chat.completion", 171 "created": 1677652288, 172 "model": "gpt-4o-mini", 173 "choices": [ 174 { 175 "index": 0, 176 "message": {"role": "assistant", "content": "hello"}, 177 "finish_reason": "stop", 178 } 179 ], 180 "usage": { 181 "prompt_tokens": 9, 182 "completion_tokens": 12, 183 "total_tokens": 21, 184 }, 185 } 186 ((_, url), _) = mock_request.call_args 187 assert url == "http://localhost:5000/endpoints/test/invocations" 188 189 190 def test_call_endpoint_uses_default_timeout(): 191 client = get_deploy_client("http://localhost:5000") 192 193 with mock.patch("mlflow.deployments.mlflow.http_request") as mock_http_request: 194 mock_http_request.return_value.json.return_value = {"test": "response"} 195 mock_http_request.return_value.status_code = 200 196 197 client._call_endpoint("GET", "/test") 198 199 mock_http_request.assert_called_once() 200 call_args = mock_http_request.call_args 201 assert call_args.kwargs["timeout"] == MLFLOW_DEPLOYMENT_CLIENT_HTTP_REQUEST_TIMEOUT.get() 202 203 204 def test_call_endpoint_respects_custom_timeout(): 205 client = get_deploy_client("http://localhost:5000") 206 custom_timeout = 600 207 208 with mock.patch("mlflow.deployments.mlflow.http_request") as mock_http_request: 209 mock_http_request.return_value.json.return_value = {"test": "response"} 210 mock_http_request.return_value.status_code = 200 211 212 client._call_endpoint("GET", "/test", timeout=custom_timeout) 213 214 mock_http_request.assert_called_once() 215 call_args = mock_http_request.call_args 216 assert call_args.kwargs["timeout"] == custom_timeout 217 218 219 def test_call_endpoint_timeout_with_environment_variable(monkeypatch): 220 custom_timeout = 420 221 monkeypatch.setenv("MLFLOW_DEPLOYMENT_CLIENT_HTTP_REQUEST_TIMEOUT", str(custom_timeout)) 222 223 client = get_deploy_client("http://localhost:5000") 224 225 with mock.patch("mlflow.deployments.mlflow.http_request") as mock_http_request: 226 mock_http_request.return_value.json.return_value = {"test": "response"} 227 mock_http_request.return_value.status_code = 200 228 229 client._call_endpoint("GET", "/test") 230 231 mock_http_request.assert_called_once() 232 call_args = mock_http_request.call_args 233 assert call_args.kwargs["timeout"] == custom_timeout 234 235 236 def test_get_endpoint_uses_deployment_client_timeout(): 237 client = get_deploy_client("http://localhost:5000") 238 239 with mock.patch("mlflow.deployments.mlflow.http_request") as mock_http_request: 240 mock_http_request.return_value.json.return_value = { 241 "model": {"name": "gpt-4", "provider": "openai"}, 242 "name": "test", 243 "endpoint_type": "llm/v1/chat", 244 "endpoint_url": "http://localhost:5000/endpoints/test/invocations", 245 "limit": None, 246 } 247 mock_http_request.return_value.status_code = 200 248 249 client.get_endpoint("test") 250 251 mock_http_request.assert_called_once() 252 call_args = mock_http_request.call_args 253 assert call_args.kwargs["timeout"] == MLFLOW_DEPLOYMENT_CLIENT_HTTP_REQUEST_TIMEOUT.get() 254 255 256 def test_list_endpoints_uses_deployment_client_timeout(): 257 client = get_deploy_client("http://localhost:5000") 258 259 with mock.patch("mlflow.deployments.mlflow.http_request") as mock_http_request: 260 mock_http_request.return_value.json.return_value = {"endpoints": []} 261 mock_http_request.return_value.status_code = 200 262 263 client.list_endpoints() 264 265 mock_http_request.assert_called_once() 266 call_args = mock_http_request.call_args 267 assert call_args.kwargs["timeout"] == MLFLOW_DEPLOYMENT_CLIENT_HTTP_REQUEST_TIMEOUT.get()