/ src / evidently / legacy / ui / managers / base.py
base.py
 1  from inspect import Parameter
 2  from inspect import Signature
 3  from typing import Any
 4  from typing import ClassVar
 5  from typing import Dict
 6  from typing import Type
 7  
 8  from litestar.params import Dependency
 9  from typing_extensions import Annotated
10  from typing_inspect import is_classvar
11  
12  
13  def replace_signature(annotations: Dict[str, Any], return_annotation=..., is_method=False):
14      """Decorator to trick Litestar DI into providing arguments needed"""
15  
16      def dec(f):
17          parameters = [Parameter(n, kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=t) for n, t in annotations.items()]
18          if is_method:
19              parameters.insert(0, Parameter("self", kind=Parameter.POSITIONAL_OR_KEYWORD))
20          f.__signature__ = Signature(parameters=parameters, return_annotation=return_annotation)
21          f.__annotations__ = annotations
22          f.__annotations__["return"] = return_annotation
23          return f
24  
25      return dec
26  
27  
28  class ProviderGetter:
29      def __get__(self, instance, owner: Type["BaseDependant"]):
30          deps = _get_manager_deps(owner)
31  
32          @replace_signature({name: Annotated[cls, Dependency(skip_validation=True)] for name, cls in deps.items()}, None)
33          async def provide(**kwargs):
34              obj = owner(**kwargs)
35              await obj.post_provide()
36              return obj
37  
38          provide.__name__ = f"{owner.__name__}.provide"
39          return provide
40  
41  
42  class BaseDependant:
43      """Base class that allows to define dependencies as class fields"""
44  
45      __dependencies__: ClassVar
46  
47      provide = ProviderGetter()
48  
49      def __init__(self, **dependencies):
50          self._validate(dependencies)
51          for k, v in dependencies.items():
52              setattr(self, k, v)
53  
54      @classmethod
55      def _validate(cls, dependencies: Dict[str, Any]):
56          deps = set(dependencies.keys())
57          required = set(_get_manager_deps(cls))
58          if deps == required:
59              return
60          if deps - required:
61              raise ValueError(f"Extra dependencies {deps - required}")
62          if required - deps:
63              raise ValueError(f"Missing dependencies {required - deps}")
64  
65      async def post_provide(self):
66          pass
67  
68  
69  def _get_manager_deps(dependant_type: Type[BaseDependant]) -> Dict[str, Type]:
70      if not hasattr(dependant_type, "__dependencies__"):
71          dependant_type.__dependencies__ = {
72              name: cls
73              for bt in dependant_type.mro()
74              if issubclass(bt, BaseDependant)
75              for name, cls in getattr(bt, "__annotations__", {}).items()
76              if not is_classvar(cls) and not name.startswith("_")
77          }
78      return dependant_type.__dependencies__
79  
80  
81  class BaseManager(BaseDependant):
82      pass