/ 02_ORT_Custom_CUDA_Op / python / custom_ops.py
custom_ops.py
  1  """
  2  Python bindings for custom ONNX Runtime operators.
  3  
  4  Requirements: 3.5
  5  """
  6  
  7  from __future__ import annotations
  8  
  9  import os
 10  import tempfile
 11  from functools import lru_cache
 12  from pathlib import Path
 13  
 14  import numpy as np
 15  
 16  try:
 17      import onnx
 18      from onnx import TensorProto, helper
 19  
 20      ONNX_AVAILABLE = True
 21  except ImportError:
 22      ONNX_AVAILABLE = False
 23  
 24  try:
 25      import onnxruntime as ort
 26  
 27      ORT_AVAILABLE = True
 28  except ImportError:
 29      ORT_AVAILABLE = False
 30  
 31  # Import shared GELU implementation
 32  from common.utils.gelu import gelu_tanh_approx as gelu_reference
 33  
 34  
 35  CUSTOM_OP_DOMAIN = "custom.ops"
 36  CUSTOM_OP_VERSION = 1
 37  
 38  
 39  def _candidate_library_paths() -> list[Path]:
 40      build_dir = (Path(__file__).resolve().parent / ".." / "build").resolve()
 41      lib_names = ["libcustom_gelu_op.so", "custom_gelu_op.dll", "libcustom_gelu_op.dylib"]
 42      search_dirs = [build_dir, *(build_dir / sub for sub in ("Release", "Debug", "RelWithDebInfo"))]
 43      return [directory / name for directory in search_dirs for name in lib_names]
 44  
 45  
 46  def get_library_path() -> str:
 47      """Get the path to the compiled custom op library."""
 48      for path in _candidate_library_paths():
 49          if path.exists():
 50              return str(path.resolve())
 51  
 52      raise FileNotFoundError(
 53          "Custom op library not found. Please build the project first:\n"
 54          "  mkdir build && cd build && cmake .. && make"
 55      )
 56  
 57  
 58  def is_custom_op_library_available() -> bool:
 59      try:
 60          get_library_path()
 61      except FileNotFoundError:
 62          return False
 63      return True
 64  
 65  
 66  def register_custom_ops(session_options: ort.SessionOptions) -> str:
 67      """Register custom operators with ONNX Runtime session."""
 68      if not ORT_AVAILABLE:
 69          raise ImportError("onnxruntime is required")
 70  
 71      library_path = get_library_path()
 72      try:
 73          session_options.register_custom_ops_library(library_path)
 74      except Exception as exc:
 75          raise RuntimeError(f"Failed to register custom ops library {library_path}: {exc}") from exc
 76  
 77      return library_path
 78  
 79  
 80  def get_available_providers() -> list[str]:
 81      if not ORT_AVAILABLE:
 82          return []
 83      return list(ort.get_available_providers())
 84  
 85  
 86  def has_cuda_execution_provider() -> bool:
 87      return ORT_AVAILABLE and "CUDAExecutionProvider" in get_available_providers()
 88  
 89  
 90  def has_cpu_execution_provider() -> bool:
 91      return ORT_AVAILABLE and "CPUExecutionProvider" in get_available_providers()
 92  
 93  
 94  def _resolve_providers(use_cuda: bool) -> list[str]:
 95      if not ORT_AVAILABLE:
 96          raise RuntimeError("onnxruntime is not installed")
 97  
 98      available = get_available_providers()
 99  
100      if use_cuda:
101          if "CUDAExecutionProvider" not in available:
102              raise RuntimeError(
103                  "CUDAExecutionProvider is not available in this ONNX Runtime installation. "
104                  f"Available providers: {available}"
105              )
106          return ["CUDAExecutionProvider"]
107  
108      if "CPUExecutionProvider" not in available:
109          raise RuntimeError(
110              "CPUExecutionProvider is not available in this ONNX Runtime installation. "
111              f"Available providers: {available}"
112          )
113      return ["CPUExecutionProvider"]
114  
115  
116  @lru_cache(maxsize=8)
117  def _create_cached_session(model_path: str, use_cuda: bool) -> tuple[ort.InferenceSession, str]:
118      session_options = ort.SessionOptions()
119      library_path = register_custom_ops(session_options)
120      providers = _resolve_providers(use_cuda)
121      session = ort.InferenceSession(model_path, session_options, providers=providers)
122      return session, library_path
123  
124  
125  def clear_session_cache() -> None:
126      _create_cached_session.cache_clear()
127  
128  
129  def create_session(
130      model_path: str,
131      use_cuda: bool = True,
132      session_options: ort.SessionOptions | None = None,
133  ) -> tuple[ort.InferenceSession, str]:
134      """Create an ORT inference session with the custom op library registered."""
135      if not ORT_AVAILABLE:
136          raise ImportError("onnxruntime is required")
137  
138      normalized_model_path = os.path.abspath(model_path)
139      if session_options is None:
140          return _create_cached_session(normalized_model_path, use_cuda)
141  
142      library_path = register_custom_ops(session_options)
143      providers = _resolve_providers(use_cuda)
144      session = ort.InferenceSession(normalized_model_path, session_options, providers=providers)
145      return session, library_path
146  
147  
148  def assert_custom_op_is_active(session: ort.InferenceSession, use_cuda: bool = True) -> None:
149      """Validate that the session is configured to execute the custom CUDA op as expected."""
150      providers = session.get_providers()
151  
152      if use_cuda:
153          if providers[:1] != ["CUDAExecutionProvider"]:
154              raise RuntimeError(
155                  f"CustomGelu requires the CUDAExecutionProvider. Session providers: {providers}"
156              )
157      elif not providers or providers[0] != "CPUExecutionProvider":
158          raise RuntimeError(f"Expected CPUExecutionProvider session, got: {providers}")
159  
160      inputs = session.get_inputs()
161      outputs = session.get_outputs()
162      if len(inputs) != 1 or len(outputs) != 1:
163          raise RuntimeError("Expected the custom GELU test model to expose one input and one output")
164  
165      if inputs[0].name != "X" or outputs[0].name != "Y":
166          raise RuntimeError(
167              "Unexpected model IO names; expected custom GELU test model with X -> Y mapping"
168          )
169  
170  
171  def create_gelu_test_model(
172      input_shape: list[int] | None = None, output_path: str | None = None
173  ) -> onnx.ModelProto:
174      """Create an ONNX model with CustomGelu operator for testing."""
175      if input_shape is None:
176          input_shape = [1, 256]
177      if not ONNX_AVAILABLE:
178          raise ImportError("onnx is required")
179  
180      x = helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)
181      y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, input_shape)
182  
183      gelu_node = helper.make_node("CustomGelu", inputs=["X"], outputs=["Y"], domain=CUSTOM_OP_DOMAIN)
184      graph = helper.make_graph([gelu_node], "gelu_test_graph", [x], [y])
185      opset_imports = [
186          helper.make_opsetid("", 14),
187          helper.make_opsetid(CUSTOM_OP_DOMAIN, CUSTOM_OP_VERSION),
188      ]
189      model = helper.make_model(graph, opset_imports=opset_imports, producer_name="custom_ops_test")
190  
191      if output_path:
192          onnx.save(model, output_path)
193  
194      return model
195  
196  
197  def create_temp_gelu_test_model(input_shape: list[int]) -> tuple[str, callable]:
198      temp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False)
199      temp_file.close()
200      create_gelu_test_model(input_shape=input_shape, output_path=temp_file.name)
201  
202      def cleanup() -> None:
203          try:
204              os.remove(temp_file.name)
205          except FileNotFoundError:
206              pass
207  
208      return temp_file.name, cleanup
209  
210  
211  def run_inference(model_path: str, input_data: np.ndarray, use_cuda: bool = True) -> np.ndarray:
212      """Run inference with the registered custom op library."""
213      if not ORT_AVAILABLE:
214          raise ImportError("onnxruntime is required")
215  
216      session, _ = create_session(model_path, use_cuda=use_cuda)
217      assert_custom_op_is_active(session, use_cuda=use_cuda)
218      input_name = session.get_inputs()[0].name
219      outputs = session.run(None, {input_name: input_data})
220      return outputs[0]
221  
222  
223  def get_registered_library_path(model_path: str, use_cuda: bool = True) -> str:
224      """Create a session and return the path of the registered custom op library."""
225      session, library_path = create_session(model_path, use_cuda=use_cuda)
226      assert_custom_op_is_active(session, use_cuda=use_cuda)
227      return library_path
228  
229  
230  def test_custom_gelu() -> None:
231      """Quick manual smoke test of the custom GELU operator."""
232      print("Testing Custom GELU Operator")
233      print("=" * 40)
234  
235      input_shape = [1, 256]
236      model_path, cleanup = create_temp_gelu_test_model(input_shape)
237  
238      np.random.seed(42)
239      input_data = np.random.randn(*input_shape).astype(np.float32)
240  
241      try:
242          output = run_inference(model_path, input_data, use_cuda=True)
243          expected = gelu_reference(input_data)
244          max_diff = np.max(np.abs(output - expected))
245          print(f"Max difference from reference: {max_diff:.2e}")
246          print("✓ Test PASSED" if max_diff < 1e-5 else "✗ Test FAILED")
247      except Exception as exc:
248          print(f"Error: {exc}")
249          print("Make sure the custom op library is built.")
250      finally:
251          cleanup()
252  
253  
254  if __name__ == "__main__":
255      test_custom_gelu()