custom_ops.py
1 """ 2 Python bindings for custom ONNX Runtime operators. 3 4 Requirements: 3.5 5 """ 6 7 from __future__ import annotations 8 9 import os 10 import tempfile 11 from functools import lru_cache 12 from pathlib import Path 13 14 import numpy as np 15 16 try: 17 import onnx 18 from onnx import TensorProto, helper 19 20 ONNX_AVAILABLE = True 21 except ImportError: 22 ONNX_AVAILABLE = False 23 24 try: 25 import onnxruntime as ort 26 27 ORT_AVAILABLE = True 28 except ImportError: 29 ORT_AVAILABLE = False 30 31 # Import shared GELU implementation 32 from common.utils.gelu import gelu_tanh_approx as gelu_reference 33 34 35 CUSTOM_OP_DOMAIN = "custom.ops" 36 CUSTOM_OP_VERSION = 1 37 38 39 def _candidate_library_paths() -> list[Path]: 40 build_dir = (Path(__file__).resolve().parent / ".." / "build").resolve() 41 lib_names = ["libcustom_gelu_op.so", "custom_gelu_op.dll", "libcustom_gelu_op.dylib"] 42 search_dirs = [build_dir, *(build_dir / sub for sub in ("Release", "Debug", "RelWithDebInfo"))] 43 return [directory / name for directory in search_dirs for name in lib_names] 44 45 46 def get_library_path() -> str: 47 """Get the path to the compiled custom op library.""" 48 for path in _candidate_library_paths(): 49 if path.exists(): 50 return str(path.resolve()) 51 52 raise FileNotFoundError( 53 "Custom op library not found. Please build the project first:\n" 54 " mkdir build && cd build && cmake .. && make" 55 ) 56 57 58 def is_custom_op_library_available() -> bool: 59 try: 60 get_library_path() 61 except FileNotFoundError: 62 return False 63 return True 64 65 66 def register_custom_ops(session_options: ort.SessionOptions) -> str: 67 """Register custom operators with ONNX Runtime session.""" 68 if not ORT_AVAILABLE: 69 raise ImportError("onnxruntime is required") 70 71 library_path = get_library_path() 72 try: 73 session_options.register_custom_ops_library(library_path) 74 except Exception as exc: 75 raise RuntimeError(f"Failed to register custom ops library {library_path}: {exc}") from exc 76 77 return library_path 78 79 80 def get_available_providers() -> list[str]: 81 if not ORT_AVAILABLE: 82 return [] 83 return list(ort.get_available_providers()) 84 85 86 def has_cuda_execution_provider() -> bool: 87 return ORT_AVAILABLE and "CUDAExecutionProvider" in get_available_providers() 88 89 90 def has_cpu_execution_provider() -> bool: 91 return ORT_AVAILABLE and "CPUExecutionProvider" in get_available_providers() 92 93 94 def _resolve_providers(use_cuda: bool) -> list[str]: 95 if not ORT_AVAILABLE: 96 raise RuntimeError("onnxruntime is not installed") 97 98 available = get_available_providers() 99 100 if use_cuda: 101 if "CUDAExecutionProvider" not in available: 102 raise RuntimeError( 103 "CUDAExecutionProvider is not available in this ONNX Runtime installation. " 104 f"Available providers: {available}" 105 ) 106 return ["CUDAExecutionProvider"] 107 108 if "CPUExecutionProvider" not in available: 109 raise RuntimeError( 110 "CPUExecutionProvider is not available in this ONNX Runtime installation. " 111 f"Available providers: {available}" 112 ) 113 return ["CPUExecutionProvider"] 114 115 116 @lru_cache(maxsize=8) 117 def _create_cached_session(model_path: str, use_cuda: bool) -> tuple[ort.InferenceSession, str]: 118 session_options = ort.SessionOptions() 119 library_path = register_custom_ops(session_options) 120 providers = _resolve_providers(use_cuda) 121 session = ort.InferenceSession(model_path, session_options, providers=providers) 122 return session, library_path 123 124 125 def clear_session_cache() -> None: 126 _create_cached_session.cache_clear() 127 128 129 def create_session( 130 model_path: str, 131 use_cuda: bool = True, 132 session_options: ort.SessionOptions | None = None, 133 ) -> tuple[ort.InferenceSession, str]: 134 """Create an ORT inference session with the custom op library registered.""" 135 if not ORT_AVAILABLE: 136 raise ImportError("onnxruntime is required") 137 138 normalized_model_path = os.path.abspath(model_path) 139 if session_options is None: 140 return _create_cached_session(normalized_model_path, use_cuda) 141 142 library_path = register_custom_ops(session_options) 143 providers = _resolve_providers(use_cuda) 144 session = ort.InferenceSession(normalized_model_path, session_options, providers=providers) 145 return session, library_path 146 147 148 def assert_custom_op_is_active(session: ort.InferenceSession, use_cuda: bool = True) -> None: 149 """Validate that the session is configured to execute the custom CUDA op as expected.""" 150 providers = session.get_providers() 151 152 if use_cuda: 153 if providers[:1] != ["CUDAExecutionProvider"]: 154 raise RuntimeError( 155 f"CustomGelu requires the CUDAExecutionProvider. Session providers: {providers}" 156 ) 157 elif not providers or providers[0] != "CPUExecutionProvider": 158 raise RuntimeError(f"Expected CPUExecutionProvider session, got: {providers}") 159 160 inputs = session.get_inputs() 161 outputs = session.get_outputs() 162 if len(inputs) != 1 or len(outputs) != 1: 163 raise RuntimeError("Expected the custom GELU test model to expose one input and one output") 164 165 if inputs[0].name != "X" or outputs[0].name != "Y": 166 raise RuntimeError( 167 "Unexpected model IO names; expected custom GELU test model with X -> Y mapping" 168 ) 169 170 171 def create_gelu_test_model( 172 input_shape: list[int] | None = None, output_path: str | None = None 173 ) -> onnx.ModelProto: 174 """Create an ONNX model with CustomGelu operator for testing.""" 175 if input_shape is None: 176 input_shape = [1, 256] 177 if not ONNX_AVAILABLE: 178 raise ImportError("onnx is required") 179 180 x = helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape) 181 y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, input_shape) 182 183 gelu_node = helper.make_node("CustomGelu", inputs=["X"], outputs=["Y"], domain=CUSTOM_OP_DOMAIN) 184 graph = helper.make_graph([gelu_node], "gelu_test_graph", [x], [y]) 185 opset_imports = [ 186 helper.make_opsetid("", 14), 187 helper.make_opsetid(CUSTOM_OP_DOMAIN, CUSTOM_OP_VERSION), 188 ] 189 model = helper.make_model(graph, opset_imports=opset_imports, producer_name="custom_ops_test") 190 191 if output_path: 192 onnx.save(model, output_path) 193 194 return model 195 196 197 def create_temp_gelu_test_model(input_shape: list[int]) -> tuple[str, callable]: 198 temp_file = tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) 199 temp_file.close() 200 create_gelu_test_model(input_shape=input_shape, output_path=temp_file.name) 201 202 def cleanup() -> None: 203 try: 204 os.remove(temp_file.name) 205 except FileNotFoundError: 206 pass 207 208 return temp_file.name, cleanup 209 210 211 def run_inference(model_path: str, input_data: np.ndarray, use_cuda: bool = True) -> np.ndarray: 212 """Run inference with the registered custom op library.""" 213 if not ORT_AVAILABLE: 214 raise ImportError("onnxruntime is required") 215 216 session, _ = create_session(model_path, use_cuda=use_cuda) 217 assert_custom_op_is_active(session, use_cuda=use_cuda) 218 input_name = session.get_inputs()[0].name 219 outputs = session.run(None, {input_name: input_data}) 220 return outputs[0] 221 222 223 def get_registered_library_path(model_path: str, use_cuda: bool = True) -> str: 224 """Create a session and return the path of the registered custom op library.""" 225 session, library_path = create_session(model_path, use_cuda=use_cuda) 226 assert_custom_op_is_active(session, use_cuda=use_cuda) 227 return library_path 228 229 230 def test_custom_gelu() -> None: 231 """Quick manual smoke test of the custom GELU operator.""" 232 print("Testing Custom GELU Operator") 233 print("=" * 40) 234 235 input_shape = [1, 256] 236 model_path, cleanup = create_temp_gelu_test_model(input_shape) 237 238 np.random.seed(42) 239 input_data = np.random.randn(*input_shape).astype(np.float32) 240 241 try: 242 output = run_inference(model_path, input_data, use_cuda=True) 243 expected = gelu_reference(input_data) 244 max_diff = np.max(np.abs(output - expected)) 245 print(f"Max difference from reference: {max_diff:.2e}") 246 print("✓ Test PASSED" if max_diff < 1e-5 else "✗ Test FAILED") 247 except Exception as exc: 248 print(f"Error: {exc}") 249 print("Make sure the custom op library is built.") 250 finally: 251 cleanup() 252 253 254 if __name__ == "__main__": 255 test_custom_gelu()