factory.py
1 from __future__ import annotations 2 3 """Factory for creating embedding models.""" 4 5 from collections.abc import Callable 6 from typing import Any, ClassVar 7 8 from .gemini import create_gemini_embedding 9 from .huggingface import create_huggingface_embedding 10 from .openai import create_openai_embedding 11 from .protocol import Embeddings 12 from .types import EmbeddingModelType 13 14 15 class EmbeddingModelFactory: 16 """Factory for creating embedding models. Easily extensible.""" 17 18 _registry: ClassVar[dict[EmbeddingModelType, Callable[[dict[str, Any]], Embeddings]]] = {} 19 20 @classmethod 21 def register(cls, model_type: EmbeddingModelType): 22 """Register a new embedding model factory. 23 24 Parameters 25 ---------- 26 model_type 27 Type to register the model under. 28 """ 29 def decorator(factory_func: Callable[[dict[str, Any]], Embeddings]): 30 cls._registry[model_type] = factory_func 31 return factory_func 32 return decorator 33 34 @classmethod 35 def create(cls, model_type: EmbeddingModelType, **kwargs) -> Embeddings: 36 """Create an embedding model by type. 37 38 Parameters 39 ---------- 40 model_type 41 Type of the model to create. 42 **kwargs 43 Additional arguments passed to the model factory. 44 45 Returns 46 ------- 47 Embeddings instance. 48 """ 49 if model_type not in cls._registry: 50 available = ", ".join(t.value for t in cls._registry) 51 raise ValueError( 52 f"Unknown model: {model_type}. " 53 f"Available models: {available}" 54 ) 55 return cls._registry[model_type](kwargs) 56 57 EmbeddingModelFactory.register(EmbeddingModelType.HUGGINGFACE)(create_huggingface_embedding) 58 EmbeddingModelFactory.register(EmbeddingModelType.OPENAI)(create_openai_embedding) 59 EmbeddingModelFactory.register(EmbeddingModelType.GEMINI)(create_gemini_embedding)