test_mlserver.py
1 import os 2 from typing import Any 3 4 import pytest 5 6 from mlflow.pyfunc.mlserver import MLServerDefaultModelName, MLServerMLflowRuntime, get_cmd 7 8 9 @pytest.mark.parametrize( 10 ("params", "expected"), 11 [ 12 ( 13 {"port": 5000, "host": "0.0.0.0", "nworkers": 4}, 14 { 15 "MLSERVER_HTTP_PORT": "5000", 16 "MLSERVER_HOST": "0.0.0.0", 17 "MLSERVER_PARALLEL_WORKERS": "4", 18 "MLSERVER_MODEL_NAME": MLServerDefaultModelName, 19 }, 20 ), 21 ( 22 {"host": "0.0.0.0", "nworkers": 4}, 23 { 24 "MLSERVER_HOST": "0.0.0.0", 25 "MLSERVER_PARALLEL_WORKERS": "4", 26 "MLSERVER_MODEL_NAME": MLServerDefaultModelName, 27 }, 28 ), 29 ( 30 {"port": 5000, "nworkers": 4}, 31 { 32 "MLSERVER_HTTP_PORT": "5000", 33 "MLSERVER_PARALLEL_WORKERS": "4", 34 "MLSERVER_MODEL_NAME": MLServerDefaultModelName, 35 }, 36 ), 37 ( 38 {"port": 5000}, 39 { 40 "MLSERVER_HTTP_PORT": "5000", 41 "MLSERVER_MODEL_NAME": MLServerDefaultModelName, 42 }, 43 ), 44 ( 45 {"model_name": "mymodel", "model_version": "12"}, 46 {"MLSERVER_MODEL_NAME": "mymodel", "MLSERVER_MODEL_VERSION": "12"}, 47 ), 48 ({}, {"MLSERVER_MODEL_NAME": MLServerDefaultModelName}), 49 ], 50 ) 51 def test_get_cmd(params: dict[str, Any], expected: dict[str, Any]): 52 model_uri = "/foo/bar" 53 cmd, cmd_env = get_cmd(model_uri=model_uri, **params) 54 55 assert cmd == f"mlserver start {model_uri}" 56 57 assert cmd_env == { 58 "MLSERVER_MODEL_URI": model_uri, 59 "MLSERVER_MODEL_IMPLEMENTATION": MLServerMLflowRuntime, 60 **expected, 61 **os.environ.copy(), 62 }