/ src / python / txtai / models / registry.py
registry.py
 1  """
 2  Registry module
 3  """
 4  
 5  from transformers import AutoModel, AutoModelForQuestionAnswering, AutoModelForSequenceClassification
 6  from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING
 7  
 8  
 9  class Registry:
10      """
11      Methods to register models and fully support pipelines.
12      """
13  
14      @staticmethod
15      def register(model, config=None):
16          """
17          Registers a model with auto model and tokenizer configuration to fully support pipelines.
18  
19          Args:
20              model: model to register
21              config: config class name
22          """
23  
24          # Default config class to model class if not provided
25          config = config if config else model.__class__
26  
27          # Default model config_class if empty
28          if hasattr(model.__class__, "config_class") and not model.__class__.config_class:
29              model.__class__.config_class = config
30  
31          # Add references for this class to supported AutoModel classes
32          for mapping in [AutoModel, AutoModelForQuestionAnswering, AutoModelForSequenceClassification]:
33              mapping.register(config, model.__class__)
34  
35          # Add references for this class to support pipeline AutoTokenizers
36          if hasattr(model, "config") and type(model.config) not in TOKENIZER_MAPPING:
37              TOKENIZER_MAPPING.register(type(model.config), type(model.config).__name__)