registry.py
1 """ 2 Registry module 3 """ 4 5 from transformers import AutoModel, AutoModelForQuestionAnswering, AutoModelForSequenceClassification 6 from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING 7 8 9 class Registry: 10 """ 11 Methods to register models and fully support pipelines. 12 """ 13 14 @staticmethod 15 def register(model, config=None): 16 """ 17 Registers a model with auto model and tokenizer configuration to fully support pipelines. 18 19 Args: 20 model: model to register 21 config: config class name 22 """ 23 24 # Default config class to model class if not provided 25 config = config if config else model.__class__ 26 27 # Default model config_class if empty 28 if hasattr(model.__class__, "config_class") and not model.__class__.config_class: 29 model.__class__.config_class = config 30 31 # Add references for this class to supported AutoModel classes 32 for mapping in [AutoModel, AutoModelForQuestionAnswering, AutoModelForSequenceClassification]: 33 mapping.register(config, model.__class__) 34 35 # Add references for this class to support pipeline AutoTokenizers 36 if hasattr(model, "config") and type(model.config) not in TOKENIZER_MAPPING: 37 TOKENIZER_MAPPING.register(type(model.config), type(model.config).__name__)