/ src / embeddings / factory.py
factory.py
 1  from __future__ import annotations
 2  
 3  """Factory for creating embedding models."""
 4  
 5  from collections.abc import Callable
 6  from typing import Any, ClassVar
 7  
 8  from .gemini import create_gemini_embedding
 9  from .huggingface import create_huggingface_embedding
10  from .openai import create_openai_embedding
11  from .protocol import Embeddings
12  from .types import EmbeddingModelType
13  
14  
15  class EmbeddingModelFactory:
16      """Factory for creating embedding models. Easily extensible."""
17  
18      _registry: ClassVar[dict[EmbeddingModelType, Callable[[dict[str, Any]], Embeddings]]] = {}
19  
20      @classmethod
21      def register(cls, model_type: EmbeddingModelType):
22          """Register a new embedding model factory.
23  
24          Parameters
25          ----------
26          model_type
27              Type to register the model under.
28          """
29          def decorator(factory_func: Callable[[dict[str, Any]], Embeddings]):
30              cls._registry[model_type] = factory_func
31              return factory_func
32          return decorator
33  
34      @classmethod
35      def create(cls, model_type: EmbeddingModelType, **kwargs) -> Embeddings:
36          """Create an embedding model by type.
37  
38          Parameters
39          ----------
40          model_type
41              Type of the model to create.
42          **kwargs
43              Additional arguments passed to the model factory.
44  
45          Returns
46          -------
47          Embeddings instance.
48          """
49          if model_type not in cls._registry:
50              available = ", ".join(t.value for t in cls._registry)
51              raise ValueError(
52                  f"Unknown model: {model_type}. "
53                  f"Available models: {available}"
54              )
55          return cls._registry[model_type](kwargs)
56  
57  EmbeddingModelFactory.register(EmbeddingModelType.HUGGINGFACE)(create_huggingface_embedding)
58  EmbeddingModelFactory.register(EmbeddingModelType.OPENAI)(create_openai_embedding)
59  EmbeddingModelFactory.register(EmbeddingModelType.GEMINI)(create_gemini_embedding)