/ src / python / txtai / models / models.py
models.py
  1  """
  2  Models module
  3  """
  4  
  5  import os
  6  
  7  import torch
  8  
  9  from transformers import (
 10      AutoConfig,
 11      AutoModel,
 12      AutoModelForQuestionAnswering,
 13      AutoModelForSeq2SeqLM,
 14      AutoModelForSequenceClassification,
 15      AutoTokenizer,
 16  )
 17  from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
 18  
 19  from .onnx import OnnxModel
 20  
 21  
 22  class Models:
 23      """
 24      Utility methods for working with machine learning models
 25      """
 26  
 27      @staticmethod
 28      def checklength(config, tokenizer):
 29          """
 30          Checks the length for a Hugging Face Transformers tokenizer using a Hugging Face Transformers config. Copies the
 31          max_position_embeddings parameter if the tokenizer has no max_length set. This helps with backwards compatibility
 32          with older tokenizers.
 33  
 34          Args:
 35              config: transformers config
 36              tokenizer: transformers tokenizer
 37          """
 38  
 39          # Unpack nested config, handles passing model directly
 40          if hasattr(config, "config"):
 41              config = config.config
 42  
 43          if (
 44              hasattr(config, "max_position_embeddings")
 45              and tokenizer
 46              and hasattr(tokenizer, "model_max_length")
 47              and tokenizer.model_max_length == int(1e30)
 48          ):
 49              tokenizer.model_max_length = config.max_position_embeddings
 50  
 51      @staticmethod
 52      def maxlength(config, tokenizer):
 53          """
 54          Gets the best max length to use for generate calls. This method will return config.max_length if it's set. Otherwise, it will return
 55          tokenizer.model_max_length.
 56  
 57          Args:
 58              config: transformers config
 59              tokenizer: transformers tokenizer
 60          """
 61  
 62          # Unpack nested config, handles passing model directly
 63          if hasattr(config, "config"):
 64              config = config.config
 65  
 66          # Get non-defaulted fields
 67          keys = config.to_diff_dict()
 68  
 69          # Use config.max_length if not set to default value, else use tokenizer.model_max_length if available
 70          return config.max_length if "max_length" in keys or not hasattr(tokenizer, "model_max_length") else tokenizer.model_max_length
 71  
 72      @staticmethod
 73      def deviceid(gpu):
 74          """
 75          Translates input gpu argument into a device id.
 76  
 77          Args:
 78              gpu: True/False if GPU should be enabled, also supports a device id/string/instance
 79  
 80          Returns:
 81              device id
 82          """
 83  
 84          # Return if this is already a torch device
 85          # pylint: disable=E1101
 86          if isinstance(gpu, torch.device):
 87              return gpu
 88  
 89          # Always return -1 if gpu is None or an accelerator device is unavailable
 90          if gpu is None or not Models.hasaccelerator():
 91              return -1
 92  
 93          # Default to device 0 if gpu is True and not otherwise specified
 94          if isinstance(gpu, bool):
 95              return 0 if gpu else -1
 96  
 97          # Return gpu as device id if gpu flag is an int
 98          return int(gpu)
 99  
100      @staticmethod
101      def device(deviceid):
102          """
103          Gets a tensor device.
104  
105          Args:
106              deviceid: device id
107  
108          Returns:
109              tensor device
110          """
111  
112          # Torch device
113          # pylint: disable=E1101
114          return deviceid if isinstance(deviceid, torch.device) else torch.device(Models.reference(deviceid))
115  
116      @staticmethod
117      def reference(deviceid):
118          """
119          Gets a tensor device reference.
120  
121          Args:
122              deviceid: device id
123  
124          Returns:
125              device reference
126          """
127  
128          return (
129              deviceid
130              if isinstance(deviceid, str)
131              else (
132                  "cpu"
133                  if deviceid < 0
134                  else f"cuda:{deviceid}" if torch.cuda.is_available() else "mps" if Models.hasmpsdevice() else Models.finddevice()
135              )
136          )
137  
138      @staticmethod
139      def acceleratorcount():
140          """
141          Gets the number of accelerator devices available.
142  
143          Returns:
144              number of accelerators available
145          """
146  
147          return max(torch.cuda.device_count(), int(Models.hasaccelerator()))
148  
149      @staticmethod
150      def hasaccelerator():
151          """
152          Checks if there is an accelerator device available.
153  
154          Returns:
155              True if an accelerator device is available, False otherwise
156          """
157  
158          return torch.cuda.is_available() or Models.hasmpsdevice() or bool(Models.finddevice())
159  
160      @staticmethod
161      def hasmpsdevice():
162          """
163          Checks if there is a MPS device available.
164  
165          Returns:
166              True if a MPS device is available, False otherwise
167          """
168  
169          return os.environ.get("PYTORCH_MPS_DISABLE") != "1" and torch.backends.mps.is_available()
170  
171      @staticmethod
172      def finddevice():
173          """
174          Attempts to find an alternative accelerator device.
175  
176          Returns:
177              name of first alternative accelerator available or None if not found
178          """
179  
180          return next((device for device in ["xpu"] if hasattr(torch, device) and getattr(torch, device).is_available()), None)
181  
182      @staticmethod
183      def load(path, config=None, task="default", modelargs=None):
184          """
185          Loads a machine learning model. Handles multiple model frameworks (ONNX, Transformers).
186  
187          Args:
188              path: path to model
189              config: path to model configuration
190              task: task name used to lookup model type
191  
192          Returns:
193              machine learning model
194          """
195  
196          # Detect ONNX models
197          if isinstance(path, bytes) or (isinstance(path, str) and os.path.isfile(path)):
198              return OnnxModel(path, config)
199  
200          # Return path, if path isn't a string
201          if not isinstance(path, str):
202              return path
203  
204          # Transformer models
205          models = {
206              "default": AutoModel.from_pretrained,
207              "question-answering": AutoModelForQuestionAnswering.from_pretrained,
208              "summarization": AutoModelForSeq2SeqLM.from_pretrained,
209              "text-classification": AutoModelForSequenceClassification.from_pretrained,
210              "zero-shot-classification": AutoModelForSequenceClassification.from_pretrained,
211          }
212  
213          # Pass modelargs as keyword arguments
214          modelargs = modelargs if modelargs else {}
215  
216          # Load model for supported tasks. Return path for unsupported tasks.
217          return models[task](path, **modelargs) if task in models else path
218  
219      @staticmethod
220      def tokenizer(path, **kwargs):
221          """
222          Loads a tokenizer from path.
223  
224          Args:
225              path: path to tokenizer
226              kwargs: optional additional keyword arguments
227  
228          Returns:
229              tokenizer
230          """
231  
232          return AutoTokenizer.from_pretrained(path, **kwargs) if isinstance(path, str) else path
233  
234      @staticmethod
235      def task(path, **kwargs):
236          """
237          Attempts to detect the model task from path.
238  
239          Args:
240              path: path to model
241              kwargs: optional additional keyword arguments
242  
243          Returns:
244              inferred model task
245          """
246  
247          # Get model configuration
248          config = None
249          if isinstance(path, (list, tuple)) and hasattr(path[0], "config"):
250              config = path[0].config
251          elif isinstance(path, str):
252              config = AutoConfig.from_pretrained(path, **kwargs)
253  
254          # Attempt to resolve task using configuration
255          task = None
256          if config:
257              architecture = config.architectures[0] if config.architectures else None
258              if architecture:
259                  if architecture in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
260                      task = "vision"
261                  elif any(x for x in ["LMHead", "CausalLM"] if x in architecture):
262                      task = "language-generation"
263                  elif "QuestionAnswering" in architecture:
264                      task = "question-answering"
265                  elif "ConditionalGeneration" in architecture:
266                      task = "sequence-sequence"
267  
268          return task