base.py
1 from inspect import Parameter 2 from inspect import Signature 3 from typing import Any 4 from typing import ClassVar 5 from typing import Dict 6 from typing import Type 7 8 from litestar.params import Dependency 9 from typing_extensions import Annotated 10 from typing_inspect import is_classvar 11 12 13 def replace_signature(annotations: Dict[str, Any], return_annotation=..., is_method=False): 14 """Decorator to trick Litestar DI into providing arguments needed""" 15 16 def dec(f): 17 parameters = [Parameter(n, kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=t) for n, t in annotations.items()] 18 if is_method: 19 parameters.insert(0, Parameter("self", kind=Parameter.POSITIONAL_OR_KEYWORD)) 20 f.__signature__ = Signature(parameters=parameters, return_annotation=return_annotation) 21 f.__annotations__ = annotations 22 f.__annotations__["return"] = return_annotation 23 return f 24 25 return dec 26 27 28 class ProviderGetter: 29 def __get__(self, instance, owner: Type["BaseDependant"]): 30 deps = _get_manager_deps(owner) 31 32 @replace_signature({name: Annotated[cls, Dependency(skip_validation=True)] for name, cls in deps.items()}, None) 33 async def provide(**kwargs): 34 obj = owner(**kwargs) 35 await obj.post_provide() 36 return obj 37 38 provide.__name__ = f"{owner.__name__}.provide" 39 return provide 40 41 42 class BaseDependant: 43 """Base class that allows to define dependencies as class fields""" 44 45 __dependencies__: ClassVar 46 47 provide = ProviderGetter() 48 49 def __init__(self, **dependencies): 50 self._validate(dependencies) 51 for k, v in dependencies.items(): 52 setattr(self, k, v) 53 54 @classmethod 55 def _validate(cls, dependencies: Dict[str, Any]): 56 deps = set(dependencies.keys()) 57 required = set(_get_manager_deps(cls)) 58 if deps == required: 59 return 60 if deps - required: 61 raise ValueError(f"Extra dependencies {deps - required}") 62 if required - deps: 63 raise ValueError(f"Missing dependencies {required - deps}") 64 65 async def post_provide(self): 66 pass 67 68 69 def _get_manager_deps(dependant_type: Type[BaseDependant]) -> Dict[str, Type]: 70 if not hasattr(dependant_type, "__dependencies__"): 71 dependant_type.__dependencies__ = { 72 name: cls 73 for bt in dependant_type.mro() 74 if issubclass(bt, BaseDependant) 75 for name, cls in getattr(bt, "__annotations__", {}).items() 76 if not is_classvar(cls) and not name.startswith("_") 77 } 78 return dependant_type.__dependencies__ 79 80 81 class BaseManager(BaseDependant): 82 pass