models.py
1 """ 2 Models module 3 """ 4 5 import os 6 7 import torch 8 9 from transformers import ( 10 AutoConfig, 11 AutoModel, 12 AutoModelForQuestionAnswering, 13 AutoModelForSeq2SeqLM, 14 AutoModelForSequenceClassification, 15 AutoTokenizer, 16 ) 17 from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES 18 19 from .onnx import OnnxModel 20 21 22 class Models: 23 """ 24 Utility methods for working with machine learning models 25 """ 26 27 @staticmethod 28 def checklength(config, tokenizer): 29 """ 30 Checks the length for a Hugging Face Transformers tokenizer using a Hugging Face Transformers config. Copies the 31 max_position_embeddings parameter if the tokenizer has no max_length set. This helps with backwards compatibility 32 with older tokenizers. 33 34 Args: 35 config: transformers config 36 tokenizer: transformers tokenizer 37 """ 38 39 # Unpack nested config, handles passing model directly 40 if hasattr(config, "config"): 41 config = config.config 42 43 if ( 44 hasattr(config, "max_position_embeddings") 45 and tokenizer 46 and hasattr(tokenizer, "model_max_length") 47 and tokenizer.model_max_length == int(1e30) 48 ): 49 tokenizer.model_max_length = config.max_position_embeddings 50 51 @staticmethod 52 def maxlength(config, tokenizer): 53 """ 54 Gets the best max length to use for generate calls. This method will return config.max_length if it's set. Otherwise, it will return 55 tokenizer.model_max_length. 56 57 Args: 58 config: transformers config 59 tokenizer: transformers tokenizer 60 """ 61 62 # Unpack nested config, handles passing model directly 63 if hasattr(config, "config"): 64 config = config.config 65 66 # Get non-defaulted fields 67 keys = config.to_diff_dict() 68 69 # Use config.max_length if not set to default value, else use tokenizer.model_max_length if available 70 return config.max_length if "max_length" in keys or not hasattr(tokenizer, "model_max_length") else tokenizer.model_max_length 71 72 @staticmethod 73 def deviceid(gpu): 74 """ 75 Translates input gpu argument into a device id. 76 77 Args: 78 gpu: True/False if GPU should be enabled, also supports a device id/string/instance 79 80 Returns: 81 device id 82 """ 83 84 # Return if this is already a torch device 85 # pylint: disable=E1101 86 if isinstance(gpu, torch.device): 87 return gpu 88 89 # Always return -1 if gpu is None or an accelerator device is unavailable 90 if gpu is None or not Models.hasaccelerator(): 91 return -1 92 93 # Default to device 0 if gpu is True and not otherwise specified 94 if isinstance(gpu, bool): 95 return 0 if gpu else -1 96 97 # Return gpu as device id if gpu flag is an int 98 return int(gpu) 99 100 @staticmethod 101 def device(deviceid): 102 """ 103 Gets a tensor device. 104 105 Args: 106 deviceid: device id 107 108 Returns: 109 tensor device 110 """ 111 112 # Torch device 113 # pylint: disable=E1101 114 return deviceid if isinstance(deviceid, torch.device) else torch.device(Models.reference(deviceid)) 115 116 @staticmethod 117 def reference(deviceid): 118 """ 119 Gets a tensor device reference. 120 121 Args: 122 deviceid: device id 123 124 Returns: 125 device reference 126 """ 127 128 return ( 129 deviceid 130 if isinstance(deviceid, str) 131 else ( 132 "cpu" 133 if deviceid < 0 134 else f"cuda:{deviceid}" if torch.cuda.is_available() else "mps" if Models.hasmpsdevice() else Models.finddevice() 135 ) 136 ) 137 138 @staticmethod 139 def acceleratorcount(): 140 """ 141 Gets the number of accelerator devices available. 142 143 Returns: 144 number of accelerators available 145 """ 146 147 return max(torch.cuda.device_count(), int(Models.hasaccelerator())) 148 149 @staticmethod 150 def hasaccelerator(): 151 """ 152 Checks if there is an accelerator device available. 153 154 Returns: 155 True if an accelerator device is available, False otherwise 156 """ 157 158 return torch.cuda.is_available() or Models.hasmpsdevice() or bool(Models.finddevice()) 159 160 @staticmethod 161 def hasmpsdevice(): 162 """ 163 Checks if there is a MPS device available. 164 165 Returns: 166 True if a MPS device is available, False otherwise 167 """ 168 169 return os.environ.get("PYTORCH_MPS_DISABLE") != "1" and torch.backends.mps.is_available() 170 171 @staticmethod 172 def finddevice(): 173 """ 174 Attempts to find an alternative accelerator device. 175 176 Returns: 177 name of first alternative accelerator available or None if not found 178 """ 179 180 return next((device for device in ["xpu"] if hasattr(torch, device) and getattr(torch, device).is_available()), None) 181 182 @staticmethod 183 def load(path, config=None, task="default", modelargs=None): 184 """ 185 Loads a machine learning model. Handles multiple model frameworks (ONNX, Transformers). 186 187 Args: 188 path: path to model 189 config: path to model configuration 190 task: task name used to lookup model type 191 192 Returns: 193 machine learning model 194 """ 195 196 # Detect ONNX models 197 if isinstance(path, bytes) or (isinstance(path, str) and os.path.isfile(path)): 198 return OnnxModel(path, config) 199 200 # Return path, if path isn't a string 201 if not isinstance(path, str): 202 return path 203 204 # Transformer models 205 models = { 206 "default": AutoModel.from_pretrained, 207 "question-answering": AutoModelForQuestionAnswering.from_pretrained, 208 "summarization": AutoModelForSeq2SeqLM.from_pretrained, 209 "text-classification": AutoModelForSequenceClassification.from_pretrained, 210 "zero-shot-classification": AutoModelForSequenceClassification.from_pretrained, 211 } 212 213 # Pass modelargs as keyword arguments 214 modelargs = modelargs if modelargs else {} 215 216 # Load model for supported tasks. Return path for unsupported tasks. 217 return models[task](path, **modelargs) if task in models else path 218 219 @staticmethod 220 def tokenizer(path, **kwargs): 221 """ 222 Loads a tokenizer from path. 223 224 Args: 225 path: path to tokenizer 226 kwargs: optional additional keyword arguments 227 228 Returns: 229 tokenizer 230 """ 231 232 return AutoTokenizer.from_pretrained(path, **kwargs) if isinstance(path, str) else path 233 234 @staticmethod 235 def task(path, **kwargs): 236 """ 237 Attempts to detect the model task from path. 238 239 Args: 240 path: path to model 241 kwargs: optional additional keyword arguments 242 243 Returns: 244 inferred model task 245 """ 246 247 # Get model configuration 248 config = None 249 if isinstance(path, (list, tuple)) and hasattr(path[0], "config"): 250 config = path[0].config 251 elif isinstance(path, str): 252 config = AutoConfig.from_pretrained(path, **kwargs) 253 254 # Attempt to resolve task using configuration 255 task = None 256 if config: 257 architecture = config.architectures[0] if config.architectures else None 258 if architecture: 259 if architecture in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values(): 260 task = "vision" 261 elif any(x for x in ["LMHead", "CausalLM"] if x in architecture): 262 task = "language-generation" 263 elif "QuestionAnswering" in architecture: 264 task = "question-answering" 265 elif "ConditionalGeneration" in architecture: 266 task = "sequence-sequence" 267 268 return task