factory.py
1 from __future__ import annotations 2 3 """Retriever factory for creating retriever implementations.""" 4 5 from collections.abc import Callable 6 from typing import Any, ClassVar 7 8 from .protocol import Retriever 9 from .similarity_retriever import create_similarity_retriever 10 from .types import RetrieverType 11 12 13 class RetrieverFactory: 14 """Factory for creating retriever implementations. Easily extensible.""" 15 16 _registry: ClassVar[dict[RetrieverType, Callable[[dict[str, Any]], Retriever]]] = {} 17 18 @classmethod 19 def register(cls, retriever_type: RetrieverType): 20 """Register a new retriever factory. 21 22 Parameters 23 ---------- 24 retriever_type 25 Type to register the retriever under. 26 """ 27 def decorator(factory_func: Callable[[dict[str, Any]], Retriever]): 28 cls._registry[retriever_type] = factory_func 29 return factory_func 30 return decorator 31 32 @classmethod 33 def create(cls, retriever_type: RetrieverType, **kwargs) -> Retriever: 34 """Create a retriever by type. 35 36 Parameters 37 ---------- 38 retriever_type 39 Type of the retriever to create. 40 **kwargs 41 Additional arguments passed to the retriever factory. 42 43 Returns 44 ------- 45 Retriever instance. 46 """ 47 if retriever_type not in cls._registry: 48 available = ", ".join(t.value for t in cls._registry) 49 raise ValueError( 50 f"Unknown retriever: {retriever_type}. " 51 f"Available retrievers: {available}" 52 ) 53 return cls._registry[retriever_type](kwargs) 54 55 RetrieverFactory.register(RetrieverType.SIMILARITY)(create_similarity_retriever)