/ tests / pyfunc / test_mlserver.py
test_mlserver.py
 1  import os
 2  from typing import Any
 3  
 4  import pytest
 5  
 6  from mlflow.pyfunc.mlserver import MLServerDefaultModelName, MLServerMLflowRuntime, get_cmd
 7  
 8  
 9  @pytest.mark.parametrize(
10      ("params", "expected"),
11      [
12          (
13              {"port": 5000, "host": "0.0.0.0", "nworkers": 4},
14              {
15                  "MLSERVER_HTTP_PORT": "5000",
16                  "MLSERVER_HOST": "0.0.0.0",
17                  "MLSERVER_PARALLEL_WORKERS": "4",
18                  "MLSERVER_MODEL_NAME": MLServerDefaultModelName,
19              },
20          ),
21          (
22              {"host": "0.0.0.0", "nworkers": 4},
23              {
24                  "MLSERVER_HOST": "0.0.0.0",
25                  "MLSERVER_PARALLEL_WORKERS": "4",
26                  "MLSERVER_MODEL_NAME": MLServerDefaultModelName,
27              },
28          ),
29          (
30              {"port": 5000, "nworkers": 4},
31              {
32                  "MLSERVER_HTTP_PORT": "5000",
33                  "MLSERVER_PARALLEL_WORKERS": "4",
34                  "MLSERVER_MODEL_NAME": MLServerDefaultModelName,
35              },
36          ),
37          (
38              {"port": 5000},
39              {
40                  "MLSERVER_HTTP_PORT": "5000",
41                  "MLSERVER_MODEL_NAME": MLServerDefaultModelName,
42              },
43          ),
44          (
45              {"model_name": "mymodel", "model_version": "12"},
46              {"MLSERVER_MODEL_NAME": "mymodel", "MLSERVER_MODEL_VERSION": "12"},
47          ),
48          ({}, {"MLSERVER_MODEL_NAME": MLServerDefaultModelName}),
49      ],
50  )
51  def test_get_cmd(params: dict[str, Any], expected: dict[str, Any]):
52      model_uri = "/foo/bar"
53      cmd, cmd_env = get_cmd(model_uri=model_uri, **params)
54  
55      assert cmd == f"mlserver start {model_uri}"
56  
57      assert cmd_env == {
58          "MLSERVER_MODEL_URI": model_uri,
59          "MLSERVER_MODEL_IMPLEMENTATION": MLServerMLflowRuntime,
60          **expected,
61          **os.environ.copy(),
62      }