onnx.py
1 """ 2 ONNX module 3 """ 4 5 # Conditional import 6 try: 7 import onnxruntime as ort 8 9 ONNX_RUNTIME = True 10 except ImportError: 11 ONNX_RUNTIME = False 12 13 import numpy as np 14 import torch 15 16 from transformers import AutoConfig 17 from transformers.configuration_utils import PretrainedConfig 18 from transformers.modeling_outputs import SequenceClassifierOutput 19 from transformers.modeling_utils import PreTrainedModel 20 21 from .registry import Registry 22 23 24 # pylint: disable=W0223 25 class OnnxModel(PreTrainedModel): 26 """ 27 Provides a Transformers/PyTorch compatible interface for ONNX models. Handles casting inputs 28 and outputs with minimal to no copying of data. 29 """ 30 31 def __init__(self, model, config=None): 32 """ 33 Creates a new OnnxModel. 34 35 Args: 36 model: path to model or InferenceSession 37 config: path to model configuration 38 """ 39 40 if not ONNX_RUNTIME: 41 raise ImportError('onnxruntime is not available - install "model" extra to enable') 42 43 super().__init__(AutoConfig.from_pretrained(config) if config else OnnxConfig()) 44 45 # Create ONNX session 46 self.model = ort.InferenceSession(model, ort.SessionOptions(), self.providers()) 47 48 # Add references for this class to supported AutoModel classes 49 Registry.register(self) 50 51 @property 52 def device(self): 53 """ 54 Returns model device id. 55 56 Returns: 57 model device id 58 """ 59 60 return -1 61 62 def providers(self): 63 """ 64 Returns a list of available and usable providers. 65 66 Returns: 67 list of available and usable providers 68 """ 69 70 # Create list of providers, prefer CUDA provider if available 71 # CUDA provider only available if GPU is available and onnxruntime-gpu installed 72 if torch.cuda.is_available() and "CUDAExecutionProvider" in ort.get_available_providers(): 73 return ["CUDAExecutionProvider", "CPUExecutionProvider"] 74 75 # Default when CUDA provider isn't available 76 return ["CPUExecutionProvider"] 77 78 def forward(self, **inputs): 79 """ 80 Runs inputs through an ONNX model and returns outputs. This method handles casting inputs 81 and outputs between torch tensors and numpy arrays as shared memory (no copy). 82 83 Args: 84 inputs: model inputs 85 86 Returns: 87 model outputs 88 """ 89 90 inputs = self.parse(inputs) 91 92 # Run inputs through ONNX model 93 results = self.model.run(None, inputs) 94 95 # pylint: disable=E1101 96 # Detect if logits is an output and return classifier output in that case 97 if any(x.name for x in self.model.get_outputs() if x.name == "logits"): 98 return SequenceClassifierOutput(logits=torch.from_numpy(np.array(results[0]))) 99 100 return torch.from_numpy(np.array(results)) 101 102 def parse(self, inputs): 103 """ 104 Parse model inputs and handle converting to ONNX compatible inputs. 105 106 Args: 107 inputs: model inputs 108 109 Returns: 110 ONNX compatible model inputs 111 """ 112 113 features = {} 114 115 # Select features from inputs 116 for key in ["input_ids", "attention_mask", "token_type_ids"]: 117 if key in inputs: 118 value = inputs[key] 119 120 # Cast torch tensors to numpy 121 if hasattr(value, "cpu"): 122 value = value.cpu().numpy() 123 124 # Cast to numpy array if not already one 125 features[key] = np.asarray(value) 126 127 return features 128 129 130 class OnnxConfig(PretrainedConfig): 131 """ 132 Configuration for ONNX models. 133 """