test_deployments.py
1 from unittest import mock 2 3 import pytest 4 5 from mlflow import deployments 6 from mlflow.deployments.plugin_manager import DeploymentPlugins 7 from mlflow.exceptions import MlflowException 8 9 f_model_uri = "fake_model_uri" 10 f_endpoint_name = "fake_endpoint_name" 11 f_deployment_id = "fake_deployment_name" 12 f_flavor = "fake_flavor" 13 f_target = "faketarget" 14 15 16 def test_create_success(): 17 client = deployments.get_deploy_client(f_target) 18 ret = client.create_deployment(f_deployment_id, f_model_uri, f_flavor, config={}) 19 assert isinstance(ret, dict) 20 assert ret["name"] == f_deployment_id 21 assert ret["flavor"] == f_flavor 22 23 ret2 = client.create_deployment(f_deployment_id, f_model_uri) 24 assert ret2["flavor"] is None 25 26 27 def test_delete_success(): 28 client = deployments.get_deploy_client(f_target) 29 assert client.delete_deployment(f_deployment_id) is None 30 31 32 def test_update_success(): 33 client = deployments.get_deploy_client(f_target) 34 res = client.update_deployment(f_deployment_id, f_model_uri, f_flavor) 35 assert res["flavor"] == f_flavor 36 37 38 def test_list_success(): 39 client = deployments.get_deploy_client(f_target) 40 ret = client.list_deployments() 41 assert ret[0]["name"] == f_deployment_id 42 43 44 def test_get_success(): 45 client = deployments.get_deploy_client(f_target) 46 ret = client.get_deployment(f_deployment_id) 47 assert ret["key1"] == "val1" 48 49 50 def test_endpoint_create_success(): 51 client = deployments.get_deploy_client(f_target) 52 endpoint = client.create_endpoint(f_endpoint_name) 53 assert isinstance(endpoint, dict) 54 assert endpoint["name"] == f_endpoint_name 55 56 57 def test_endpoint_delete_success(): 58 client = deployments.get_deploy_client(f_target) 59 assert client.delete_endpoint(f_endpoint_name) is None 60 61 62 def test_endpoint_update_success(): 63 client = deployments.get_deploy_client(f_target) 64 assert client.update_endpoint(f_endpoint_name) is None 65 66 67 def test_endpoint_list_success(): 68 client = deployments.get_deploy_client(f_target) 69 endpoints = client.list_endpoints() 70 assert endpoints[0]["name"] == f_endpoint_name 71 72 73 def test_endpoint_get_success(): 74 client = deployments.get_deploy_client(f_target) 75 endpoint = client.get_endpoint(f_endpoint_name) 76 assert endpoint["name"] == f_endpoint_name 77 78 79 def test_wrong_target_name(): 80 with pytest.raises( 81 MlflowException, match='No plugin found for managing model deployments to "wrong_target"' 82 ): 83 deployments.get_deploy_client("wrong_target") 84 85 86 def test_plugin_doesnot_have_required_attrib(): 87 class DummyPlugin: 88 pass 89 90 dummy_plugin = DummyPlugin() 91 plugin_manager = DeploymentPlugins() 92 plugin_manager.registry["dummy"] = dummy_plugin 93 with pytest.raises(MlflowException, match="Plugin registered for the target dummy"): 94 plugin_manager["dummy"] 95 96 97 def test_plugin_raising_error(monkeypatch): 98 client = deployments.get_deploy_client(f_target) 99 # special case to raise error 100 monkeypatch.setenv("raiseError", "True") 101 with pytest.raises(RuntimeError, match="Error requested"): 102 client.list_deployments() 103 104 105 def test_target_uri_parsing(): 106 deployments.get_deploy_client(f_target) 107 deployments.get_deploy_client(f"{f_target}:/somesuffix") 108 deployments.get_deploy_client(f"{f_target}://somesuffix") 109 110 111 def test_explain_with_no_target_implementation(): 112 from mlflow_test_plugin import fake_deployment_plugin 113 114 mock_error = MlflowException("MOCK ERROR") 115 target_client = deployments.get_deploy_client(f_target) 116 plugin = fake_deployment_plugin.PluginDeploymentClient 117 with mock.patch.object(plugin, "explain", return_value=mock_error) as mock_explain: 118 res = target_client.explain(f_target, "test") 119 assert type(res) == MlflowException 120 mock_explain.assert_called_once() 121 122 123 def test_explain_with_target_implementation(): 124 target_client = deployments.get_deploy_client(f_target) 125 res = target_client.explain(f_target, "test") 126 assert res == "1"