/ src / python / txtai / models / onnx.py
onnx.py
  1  """
  2  ONNX module
  3  """
  4  
  5  # Conditional import
  6  try:
  7      import onnxruntime as ort
  8  
  9      ONNX_RUNTIME = True
 10  except ImportError:
 11      ONNX_RUNTIME = False
 12  
 13  import numpy as np
 14  import torch
 15  
 16  from transformers import AutoConfig
 17  from transformers.configuration_utils import PretrainedConfig
 18  from transformers.modeling_outputs import SequenceClassifierOutput
 19  from transformers.modeling_utils import PreTrainedModel
 20  
 21  from .registry import Registry
 22  
 23  
 24  # pylint: disable=W0223
 25  class OnnxModel(PreTrainedModel):
 26      """
 27      Provides a Transformers/PyTorch compatible interface for ONNX models. Handles casting inputs
 28      and outputs with minimal to no copying of data.
 29      """
 30  
 31      def __init__(self, model, config=None):
 32          """
 33          Creates a new OnnxModel.
 34  
 35          Args:
 36              model: path to model or InferenceSession
 37              config: path to model configuration
 38          """
 39  
 40          if not ONNX_RUNTIME:
 41              raise ImportError('onnxruntime is not available - install "model" extra to enable')
 42  
 43          super().__init__(AutoConfig.from_pretrained(config) if config else OnnxConfig())
 44  
 45          # Create ONNX session
 46          self.model = ort.InferenceSession(model, ort.SessionOptions(), self.providers())
 47  
 48          # Add references for this class to supported AutoModel classes
 49          Registry.register(self)
 50  
 51      @property
 52      def device(self):
 53          """
 54          Returns model device id.
 55  
 56          Returns:
 57              model device id
 58          """
 59  
 60          return -1
 61  
 62      def providers(self):
 63          """
 64          Returns a list of available and usable providers.
 65  
 66          Returns:
 67              list of available and usable providers
 68          """
 69  
 70          # Create list of providers, prefer CUDA provider if available
 71          # CUDA provider only available if GPU is available and onnxruntime-gpu installed
 72          if torch.cuda.is_available() and "CUDAExecutionProvider" in ort.get_available_providers():
 73              return ["CUDAExecutionProvider", "CPUExecutionProvider"]
 74  
 75          # Default when CUDA provider isn't available
 76          return ["CPUExecutionProvider"]
 77  
 78      def forward(self, **inputs):
 79          """
 80          Runs inputs through an ONNX model and returns outputs. This method handles casting inputs
 81          and outputs between torch tensors and numpy arrays as shared memory (no copy).
 82  
 83          Args:
 84              inputs: model inputs
 85  
 86          Returns:
 87              model outputs
 88          """
 89  
 90          inputs = self.parse(inputs)
 91  
 92          # Run inputs through ONNX model
 93          results = self.model.run(None, inputs)
 94  
 95          # pylint: disable=E1101
 96          # Detect if logits is an output and return classifier output in that case
 97          if any(x.name for x in self.model.get_outputs() if x.name == "logits"):
 98              return SequenceClassifierOutput(logits=torch.from_numpy(np.array(results[0])))
 99  
100          return torch.from_numpy(np.array(results))
101  
102      def parse(self, inputs):
103          """
104          Parse model inputs and handle converting to ONNX compatible inputs.
105  
106          Args:
107              inputs: model inputs
108  
109          Returns:
110              ONNX compatible model inputs
111          """
112  
113          features = {}
114  
115          # Select features from inputs
116          for key in ["input_ids", "attention_mask", "token_type_ids"]:
117              if key in inputs:
118                  value = inputs[key]
119  
120                  # Cast torch tensors to numpy
121                  if hasattr(value, "cpu"):
122                      value = value.cpu().numpy()
123  
124                  # Cast to numpy array if not already one
125                  features[key] = np.asarray(value)
126  
127          return features
128  
129  
130  class OnnxConfig(PretrainedConfig):
131      """
132      Configuration for ONNX models.
133      """