/ mlflow / utils / import_hooks / __init__.py
__init__.py
  1  """
  2  NOTE: The contents of this file have been inlined from the wrapt package's source code
  3  https://github.com/GrahamDumpleton/wrapt/blob/1.12.1/src/wrapt/importer.py.
  4  Some modifications, have been made in order to:
  5      - avoid duplicate registration of import hooks
  6      - inline functions from dependent wrapt submodules rather than importing them.
  7  
  8  This module implements a post import hook mechanism styled after what is described in PEP-369.
  9  Note that it doesn't cope with modules being reloaded.
 10  It also extends the functionality to support custom hooks for import errors
 11  (as opposed to only successful imports).
 12  """
 13  
 14  import importlib.resources
 15  import sys
 16  import threading
 17  
 18  string_types = (str,)
 19  
 20  
 21  # from .decorators import synchronized
 22  # NOTE: Instead of using this import (from wrapt's decorator module, see
 23  # https://github.com/GrahamDumpleton/wrapt/blob/68316bea668fd905a4acb21f37f12596d8c30d80/src/wrapt/decorators.py#L430-L456),
 24  # we define a decorator with similar behavior that acquires a lock while calling the decorated
 25  # function
 26  def synchronized(lock):
 27      def decorator(f):
 28          # See e.g. https://www.python.org/dev/peps/pep-0318/#examples
 29          def new_fn(*args, **kwargs):
 30              with lock:
 31                  return f(*args, **kwargs)
 32  
 33          return new_fn
 34  
 35      return decorator
 36  
 37  
 38  # The dictionary registering any post import hooks to be triggered once
 39  # the target module has been imported. Once a module has been imported
 40  # and the hooks fired, the list of hooks recorded against the target
 41  # module will be truncated but the list left in the dictionary. This
 42  # acts as a flag to indicate that the module had already been imported.
 43  
 44  _post_import_hooks = {}
 45  _post_import_hooks_lock = threading.RLock()
 46  
 47  # A dictionary for any import hook error handlers to be triggered when the
 48  # target module import fails.
 49  
 50  _import_error_hooks = {}
 51  _import_error_hooks_lock = threading.RLock()
 52  
 53  _import_hook_finder_init = False
 54  
 55  # Register a new post import hook for the target module name. This
 56  # differs from the PEP-369 implementation in that it also allows the
 57  # hook function to be specified as a string consisting of the name of
 58  # the callback in the form 'module:function'. This will result in a
 59  # proxy callback being registered which will defer loading of the
 60  # specified module containing the callback function until required.
 61  
 62  
 63  def _create_import_hook_from_string(name):
 64      def import_hook(module):
 65          module_name, function = name.split(":")
 66          attrs = function.split(".")
 67          __import__(module_name)
 68          callback = sys.modules[module_name]
 69          for attr in attrs:
 70              callback = getattr(callback, attr)
 71          return callback(module)
 72  
 73      return import_hook
 74  
 75  
 76  def register_generic_import_hook(hook, name, hook_dict, overwrite):
 77      # Create a deferred import hook if hook is a string name rather than
 78      # a callable function.
 79  
 80      if isinstance(hook, string_types):
 81          hook = _create_import_hook_from_string(hook)
 82  
 83      # Automatically install the import hook finder if it has not already
 84      # been installed.
 85  
 86      global _import_hook_finder_init
 87      if not _import_hook_finder_init:
 88          _import_hook_finder_init = True
 89          sys.meta_path.insert(0, ImportHookFinder())
 90  
 91      # Determine if any prior registration of an import hook for
 92      # the target modules has occurred and act appropriately.
 93  
 94      hooks = hook_dict.get(name, None)
 95  
 96      if hooks is None:
 97          # No prior registration of import hooks for the target
 98          # module. We need to check whether the module has already been
 99          # imported. If it has we fire the hook immediately and add an
100          # empty list to the registry to indicate that the module has
101          # already been imported and hooks have fired. Otherwise add
102          # the post import hook to the registry.
103  
104          module = sys.modules.get(name, None)
105  
106          if module is not None:
107              hook_dict[name] = []
108              hook(module)
109  
110          else:
111              hook_dict[name] = [hook]
112  
113      elif hooks == []:
114          # A prior registration of import hooks for the target
115          # module was done and the hooks already fired. Fire the hook
116          # immediately.
117  
118          module = sys.modules[name]
119          hook(module)
120  
121      else:
122          # A prior registration of import hooks for the target
123          # module was done but the module has not yet been imported.
124  
125          def hooks_equal(existing_hook, hook):
126              if hasattr(existing_hook, "__name__") and hasattr(hook, "__name__"):
127                  return existing_hook.__name__ == hook.__name__
128              else:
129                  return False
130  
131          if overwrite:
132              hook_dict[name] = [
133                  existing_hook
134                  for existing_hook in hook_dict[name]
135                  if not hooks_equal(existing_hook, hook)
136              ]
137  
138          hook_dict[name].append(hook)
139  
140  
141  @synchronized(_import_error_hooks_lock)
142  def register_import_error_hook(hook, name, overwrite=True):
143      """
144      Args:
145          hook: A function or string entrypoint to invoke when the specified module is imported
146              and an error occurs.
147          name: The name of the module for which to fire the hook at import error detection time.
148          overwrite: Specifies the desired behavior when a preexisting hook for the same
149              function / entrypoint already exists for the specified module. If `True`,
150              all preexisting hooks matching the specified function / entrypoint will be
151              removed and replaced with a single instance of the specified `hook`.
152      """
153      register_generic_import_hook(hook, name, _import_error_hooks, overwrite)
154  
155  
156  @synchronized(_post_import_hooks_lock)
157  def register_post_import_hook(hook, name, overwrite=True):
158      """
159      Args:
160          hook: A function or string entrypoint to invoke when the specified module is imported.
161          name: The name of the module for which to fire the hook at import time.
162          overwrite: Specifies the desired behavior when a preexisting hook for the same
163              function / entrypoint already exists for the specified module. If `True`,
164              all preexisting hooks matching the specified function / entrypoint will be
165              removed and replaced with a single instance of the specified `hook`.
166      """
167      register_generic_import_hook(hook, name, _post_import_hooks, overwrite)
168  
169  
170  @synchronized(_post_import_hooks_lock)
171  def get_post_import_hooks(name):
172      return _post_import_hooks.get(name)
173  
174  
175  # Register post import hooks defined as package entry points.
176  
177  
178  def _create_import_hook_from_entrypoint(entrypoint):
179      def import_hook(module):
180          __import__(entrypoint.module_name)
181          callback = sys.modules[entrypoint.module_name]
182          for attr in entrypoint.attrs:
183              callback = getattr(callback, attr)
184          return callback(module)
185  
186      return import_hook
187  
188  
189  def discover_post_import_hooks(group):
190      for entrypoint in (
191          resource.name
192          for resource in importlib.resources.files(group).iterdir()
193          if resource.is_file()
194      ):
195          callback = _create_import_hook_from_entrypoint(entrypoint)
196          register_post_import_hook(callback, entrypoint.name)
197  
198  
199  # Indicate that a module has been loaded. Any post import hooks which
200  # were registered against the target module will be invoked. If an
201  # exception is raised in any of the post import hooks, that will cause
202  # the import of the target module to fail.
203  
204  
205  @synchronized(_post_import_hooks_lock)
206  def notify_module_loaded(module):
207      name = getattr(module, "__name__", None)
208      if hooks := _post_import_hooks.get(name):
209          _post_import_hooks[name] = []
210  
211          for hook in hooks:
212              hook(module)
213  
214  
215  @synchronized(_import_error_hooks_lock)
216  def notify_module_import_error(module_name):
217      if hooks := _import_error_hooks.get(module_name):
218          # Error hooks differ from post import hooks, in that we don't clear the
219          # hook as soon as it fires.
220          for hook in hooks:
221              hook(module_name)
222  
223  
224  # A custom module import finder. This intercepts attempts to import
225  # modules and watches out for attempts to import target modules of
226  # interest. When a module of interest is imported, then any post import
227  # hooks which are registered will be invoked.
228  
229  
230  class _ImportHookChainedLoader:
231      def __init__(self, loader):
232          self.loader = loader
233  
234      def load_module(self, fullname):
235          try:
236              module = self.loader.load_module(fullname)
237              notify_module_loaded(module)
238          except (ImportError, AttributeError):
239              notify_module_import_error(fullname)
240              raise
241  
242          return module
243  
244  
245  class ImportHookFinder:
246      def __init__(self):
247          self.in_progress = {}
248  
249      @synchronized(_post_import_hooks_lock)
250      @synchronized(_import_error_hooks_lock)
251      def find_module(self, fullname, path=None):
252          # If the module being imported is not one we have registered
253          # import hooks for, we can return immediately. We will
254          # take no further part in the importing of this module.
255  
256          if fullname not in _post_import_hooks and fullname not in _import_error_hooks:
257              return None
258  
259          # When we are interested in a specific module, we will call back
260          # into the import system a second time to defer to the import
261          # finder that is supposed to handle the importing of the module.
262          # We set an in progress flag for the target module so that on
263          # the second time through we don't trigger another call back
264          # into the import system and cause a infinite loop.
265  
266          if fullname in self.in_progress:
267              return None
268  
269          self.in_progress[fullname] = True
270  
271          # Now call back into the import system again.
272  
273          try:
274              # For Python 3 we need to use find_spec().loader
275              # from the importlib.util module. It doesn't actually
276              # import the target module and only finds the
277              # loader. If a loader is found, we need to return
278              # our own loader which will then in turn call the
279              # real loader to import the module and invoke the
280              # post import hooks.
281              try:
282                  import importlib.util  # clint: disable=lazy-import
283  
284                  loader = importlib.util.find_spec(fullname).loader
285              # If an ImportError (or AttributeError) is encountered while finding the module,
286              # notify the hooks for import errors
287              except (ImportError, AttributeError):
288                  notify_module_import_error(fullname)
289                  loader = importlib.find_loader(fullname, path)
290              if loader:
291                  return _ImportHookChainedLoader(loader)
292          finally:
293              del self.in_progress[fullname]
294  
295      @synchronized(_post_import_hooks_lock)
296      @synchronized(_import_error_hooks_lock)
297      def find_spec(self, fullname, path, target=None):
298          # If the module being imported is not one we have registered
299          # import hooks for, we can return immediately. We will
300          # take no further part in the importing of this module.
301  
302          if fullname not in _post_import_hooks and fullname not in _import_error_hooks:
303              return None
304  
305          # When we are interested in a specific module, we will call back
306          # into the import system a second time to defer to the import
307          # finder that is supposed to handle the importing of the module.
308          # We set an in progress flag for the target module so that on
309          # the second time through we don't trigger another call back
310          # into the import system and cause a infinite loop.
311  
312          if fullname in self.in_progress:
313              return None
314  
315          self.in_progress[fullname] = True
316  
317          # Now call back into the import system again.
318  
319          try:
320              import importlib.util  # clint: disable=lazy-import
321  
322              spec = importlib.util.find_spec(fullname)
323              # Replace the module spec's loader with a wrapped version that executes import
324              # hooks when the module is loaded
325              spec.loader = _ImportHookChainedLoader(spec.loader)
326              return spec
327          except (ImportError, AttributeError):
328              notify_module_import_error(fullname)
329          finally:
330              del self.in_progress[fullname]
331  
332  
333  # Decorator for marking that a function should be called as a post
334  # import hook when the target module is imported.
335  # If error_handler is True, then apply the marked function as an import hook
336  # for import errors (instead of successful imports).
337  # It is assumed that all error hooks are added during driver start-up,
338  # and thus added prior to any import calls. If an error hook is added
339  # after a module has already failed the import, there's no guarantee
340  # that the hook will fire.
341  
342  
343  def when_imported(name, error_handler=False):
344      def register(hook):
345          if error_handler:
346              register_import_error_hook(hook, name)
347          else:
348              register_post_import_hook(hook, name)
349          return hook
350  
351      return register