/ tests / deployments / test_deployments.py
test_deployments.py
  1  from unittest import mock
  2  
  3  import pytest
  4  
  5  from mlflow import deployments
  6  from mlflow.deployments.plugin_manager import DeploymentPlugins
  7  from mlflow.exceptions import MlflowException
  8  
  9  f_model_uri = "fake_model_uri"
 10  f_endpoint_name = "fake_endpoint_name"
 11  f_deployment_id = "fake_deployment_name"
 12  f_flavor = "fake_flavor"
 13  f_target = "faketarget"
 14  
 15  
 16  def test_create_success():
 17      client = deployments.get_deploy_client(f_target)
 18      ret = client.create_deployment(f_deployment_id, f_model_uri, f_flavor, config={})
 19      assert isinstance(ret, dict)
 20      assert ret["name"] == f_deployment_id
 21      assert ret["flavor"] == f_flavor
 22  
 23      ret2 = client.create_deployment(f_deployment_id, f_model_uri)
 24      assert ret2["flavor"] is None
 25  
 26  
 27  def test_delete_success():
 28      client = deployments.get_deploy_client(f_target)
 29      assert client.delete_deployment(f_deployment_id) is None
 30  
 31  
 32  def test_update_success():
 33      client = deployments.get_deploy_client(f_target)
 34      res = client.update_deployment(f_deployment_id, f_model_uri, f_flavor)
 35      assert res["flavor"] == f_flavor
 36  
 37  
 38  def test_list_success():
 39      client = deployments.get_deploy_client(f_target)
 40      ret = client.list_deployments()
 41      assert ret[0]["name"] == f_deployment_id
 42  
 43  
 44  def test_get_success():
 45      client = deployments.get_deploy_client(f_target)
 46      ret = client.get_deployment(f_deployment_id)
 47      assert ret["key1"] == "val1"
 48  
 49  
 50  def test_endpoint_create_success():
 51      client = deployments.get_deploy_client(f_target)
 52      endpoint = client.create_endpoint(f_endpoint_name)
 53      assert isinstance(endpoint, dict)
 54      assert endpoint["name"] == f_endpoint_name
 55  
 56  
 57  def test_endpoint_delete_success():
 58      client = deployments.get_deploy_client(f_target)
 59      assert client.delete_endpoint(f_endpoint_name) is None
 60  
 61  
 62  def test_endpoint_update_success():
 63      client = deployments.get_deploy_client(f_target)
 64      assert client.update_endpoint(f_endpoint_name) is None
 65  
 66  
 67  def test_endpoint_list_success():
 68      client = deployments.get_deploy_client(f_target)
 69      endpoints = client.list_endpoints()
 70      assert endpoints[0]["name"] == f_endpoint_name
 71  
 72  
 73  def test_endpoint_get_success():
 74      client = deployments.get_deploy_client(f_target)
 75      endpoint = client.get_endpoint(f_endpoint_name)
 76      assert endpoint["name"] == f_endpoint_name
 77  
 78  
 79  def test_wrong_target_name():
 80      with pytest.raises(
 81          MlflowException, match='No plugin found for managing model deployments to "wrong_target"'
 82      ):
 83          deployments.get_deploy_client("wrong_target")
 84  
 85  
 86  def test_plugin_doesnot_have_required_attrib():
 87      class DummyPlugin:
 88          pass
 89  
 90      dummy_plugin = DummyPlugin()
 91      plugin_manager = DeploymentPlugins()
 92      plugin_manager.registry["dummy"] = dummy_plugin
 93      with pytest.raises(MlflowException, match="Plugin registered for the target dummy"):
 94          plugin_manager["dummy"]
 95  
 96  
 97  def test_plugin_raising_error(monkeypatch):
 98      client = deployments.get_deploy_client(f_target)
 99      # special case to raise error
100      monkeypatch.setenv("raiseError", "True")
101      with pytest.raises(RuntimeError, match="Error requested"):
102          client.list_deployments()
103  
104  
105  def test_target_uri_parsing():
106      deployments.get_deploy_client(f_target)
107      deployments.get_deploy_client(f"{f_target}:/somesuffix")
108      deployments.get_deploy_client(f"{f_target}://somesuffix")
109  
110  
111  def test_explain_with_no_target_implementation():
112      from mlflow_test_plugin import fake_deployment_plugin
113  
114      mock_error = MlflowException("MOCK ERROR")
115      target_client = deployments.get_deploy_client(f_target)
116      plugin = fake_deployment_plugin.PluginDeploymentClient
117      with mock.patch.object(plugin, "explain", return_value=mock_error) as mock_explain:
118          res = target_client.explain(f_target, "test")
119          assert type(res) == MlflowException
120          mock_explain.assert_called_once()
121  
122  
123  def test_explain_with_target_implementation():
124      target_client = deployments.get_deploy_client(f_target)
125      res = target_client.explain(f_target, "test")
126      assert res == "1"