/ src / retrievers / factory.py
factory.py
 1  from __future__ import annotations
 2  
 3  """Retriever factory for creating retriever implementations."""
 4  
 5  from collections.abc import Callable
 6  from typing import Any, ClassVar
 7  
 8  from .protocol import Retriever
 9  from .similarity_retriever import create_similarity_retriever
10  from .types import RetrieverType
11  
12  
13  class RetrieverFactory:
14      """Factory for creating retriever implementations. Easily extensible."""
15  
16      _registry: ClassVar[dict[RetrieverType, Callable[[dict[str, Any]], Retriever]]] = {}
17  
18      @classmethod
19      def register(cls, retriever_type: RetrieverType):
20          """Register a new retriever factory.
21  
22          Parameters
23          ----------
24          retriever_type
25              Type to register the retriever under.
26          """
27          def decorator(factory_func: Callable[[dict[str, Any]], Retriever]):
28              cls._registry[retriever_type] = factory_func
29              return factory_func
30          return decorator
31  
32      @classmethod
33      def create(cls, retriever_type: RetrieverType, **kwargs) -> Retriever:
34          """Create a retriever by type.
35  
36          Parameters
37          ----------
38          retriever_type
39              Type of the retriever to create.
40          **kwargs
41              Additional arguments passed to the retriever factory.
42  
43          Returns
44          -------
45          Retriever instance.
46          """
47          if retriever_type not in cls._registry:
48              available = ", ".join(t.value for t in cls._registry)
49              raise ValueError(
50                  f"Unknown retriever: {retriever_type}. "
51                  f"Available retrievers: {available}"
52              )
53          return cls._registry[retriever_type](kwargs)
54  
55  RetrieverFactory.register(RetrieverType.SIMILARITY)(create_similarity_retriever)