plugin_manager.py
1 import abc 2 import importlib.metadata 3 import inspect 4 5 import importlib_metadata 6 7 from mlflow.deployments.base import BaseDeploymentClient 8 from mlflow.deployments.utils import parse_target_uri 9 from mlflow.exceptions import MlflowException 10 from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, RESOURCE_DOES_NOT_EXIST 11 from mlflow.utils.annotations import developer_stable 12 from mlflow.utils.plugins import get_entry_points 13 14 # TODO: refactor to have a common base class for all the plugin implementation in MLflow 15 # mlflow/tracking/context/registry.py 16 # mlflow/tracking/registry 17 # mlflow/store/artifact/artifact_repository_registry.py 18 19 20 @developer_stable 21 class PluginManager(abc.ABC): 22 """ 23 Abstract class defining a entrypoint based plugin registration. 24 25 This class allows the registration of a function or class to provide an implementation 26 for a given key/name. Implementations declared though the entrypoints can be automatically 27 registered through the `register_entrypoints` method. 28 """ 29 30 def __init__(self, group_name): 31 self._registry = {} 32 self.group_name = group_name 33 self._has_registered = None 34 35 @abc.abstractmethod 36 def __getitem__(self, item): 37 # Letting the child class create this function so that the child 38 # can raise custom exceptions if it needs to 39 pass 40 41 @property 42 def registry(self): 43 """ 44 Registry stores the registered plugin as a key value pair where key is the 45 name of the plugin and value is the plugin object 46 """ 47 return self._registry 48 49 @property 50 def has_registered(self): 51 """ 52 Returns bool representing whether the "register_entrypoints" has run or not. This 53 doesn't return True if `register` method is called outside of `register_entrypoints` 54 to register plugins 55 """ 56 return self._has_registered 57 58 def register(self, target_name, plugin_module): 59 """Register a deployment client given its target name and module 60 Args: 61 target_name: The name of the deployment target. This name will be used by 62 `get_deploy_client()` to retrieve a deployment client from 63 the plugin store. 64 plugin_module: The module that implements the deployment plugin interface. 65 """ 66 self.registry[target_name] = importlib.metadata.EntryPoint( 67 target_name, plugin_module, self.group_name 68 ) 69 70 def register_entrypoints(self): 71 """ 72 Runs through all the packages that has the `group_name` defined as the entrypoint 73 and register that into the registry 74 """ 75 for entrypoint in get_entry_points(self.group_name): 76 self.registry[entrypoint.name] = entrypoint 77 self._has_registered = True 78 79 80 @developer_stable 81 class DeploymentPlugins(PluginManager): 82 def __init__(self): 83 super().__init__("mlflow.deployments") 84 self.register_entrypoints() 85 86 def __getitem__(self, item): 87 """Override __getitem__ so that we can directly look up plugins via dict-like syntax""" 88 try: 89 target_name = parse_target_uri(item) 90 plugin_like = self.registry[target_name] 91 except KeyError: 92 msg = ( 93 f'No plugin found for managing model deployments to "{item}". ' 94 f'In order to deploy models to "{item}", find and install an appropriate ' 95 "plugin from " 96 "https://mlflow.org/docs/latest/plugins.html#community-plugins using " 97 "your package manager (pip, conda etc)." 98 ) 99 raise MlflowException(msg, error_code=RESOURCE_DOES_NOT_EXIST) 100 101 if isinstance(plugin_like, (importlib_metadata.EntryPoint, importlib.metadata.EntryPoint)): 102 try: 103 plugin_obj = plugin_like.load() 104 except (AttributeError, ImportError) as exc: 105 raise RuntimeError(f'Failed to load the plugin "{item}": {exc}') 106 self.registry[item] = plugin_obj 107 else: 108 plugin_obj = plugin_like 109 110 # Testing whether the plugin is valid or not 111 expected = {"target_help", "run_local"} 112 deployment_classes = [] 113 for name, obj in inspect.getmembers(plugin_obj): 114 if name in expected: 115 expected.remove(name) 116 elif ( 117 inspect.isclass(obj) 118 and issubclass(obj, BaseDeploymentClient) 119 and not obj == BaseDeploymentClient 120 ): 121 deployment_classes.append(name) 122 if len(expected) > 0: 123 raise MlflowException( 124 f"Plugin registered for the target {item} does not have all " 125 "the required interfaces. Raise an issue with the " 126 "plugin developers.\n" 127 f"Missing interfaces: {expected}", 128 error_code=INTERNAL_ERROR, 129 ) 130 if len(deployment_classes) > 1: 131 raise MlflowException( 132 f"Plugin registered for the target {item} has more than one " 133 "child class of BaseDeploymentClient. Raise an issue with" 134 " the plugin developers. " 135 f"Classes found are {deployment_classes}" 136 ) 137 elif len(deployment_classes) == 0: 138 raise MlflowException( 139 f"Plugin registered for the target {item} has no child class" 140 " of BaseDeploymentClient. Raise an issue with the " 141 "plugin developers" 142 ) 143 return plugin_obj