custom_op.h
1 /** 2 * ONNX Runtime Custom Operator definitions. 3 * 4 * Requirements: 3.2 5 */ 6 7 #ifndef CUSTOM_OP_H 8 #define CUSTOM_OP_H 9 10 #include <onnxruntime_cxx_api.h> 11 #include <onnxruntime/core/providers/cuda/cuda_context.h> 12 #include <cuda_runtime.h> 13 #include "cuda_kernels.h" 14 15 namespace custom_ops { 16 17 // Domain for custom operators 18 constexpr const char* CUSTOM_OP_DOMAIN = "custom.ops"; 19 constexpr int CUSTOM_OP_VERSION = 1; 20 21 /** 22 * GELU Kernel implementation for ONNX Runtime. 23 * 24 * This class handles the actual computation when the operator is invoked. 25 */ 26 struct GeluKernel { 27 GeluKernel(const OrtKernelInfo* info) { 28 // No attributes to parse for GELU 29 (void)info; 30 } 31 32 void Compute(OrtKernelContext* context); 33 34 private: 35 static cudaStream_t GetCudaStream(const OrtKernelContext& context); 36 }; 37 38 /** 39 * Custom GELU Operator definition. 40 * 41 * Inherits from Ort::CustomOpBase to integrate with ONNX Runtime. 42 */ 43 struct GeluCustomOp : Ort::CustomOpBase<GeluCustomOp, GeluKernel> { 44 /** 45 * Get the operator name. 46 */ 47 const char* GetName() const { 48 return "CustomGelu"; 49 } 50 51 /** 52 * Get the execution provider type. 53 * Returns "CUDAExecutionProvider" for GPU execution. 54 */ 55 const char* GetExecutionProviderType() const { 56 return "CUDAExecutionProvider"; 57 } 58 59 /** 60 * Get the number of inputs. 61 */ 62 size_t GetInputTypeCount() const { 63 return 1; 64 } 65 66 /** 67 * Get the type of input at given index. 68 */ 69 ONNXTensorElementDataType GetInputType(size_t index) const { 70 (void)index; 71 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; 72 } 73 74 /** 75 * Get input characteristics. 76 */ 77 OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const { 78 (void)index; 79 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; 80 } 81 82 /** 83 * Get the number of outputs. 84 */ 85 size_t GetOutputTypeCount() const { 86 return 1; 87 } 88 89 /** 90 * Get the type of output at given index. 91 */ 92 ONNXTensorElementDataType GetOutputType(size_t index) const { 93 (void)index; 94 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; 95 } 96 97 /** 98 * Get output characteristics. 99 */ 100 OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const { 101 (void)index; 102 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; 103 } 104 }; 105 106 /** 107 * Register all custom operators with ONNX Runtime. 108 * 109 * @param options Session options to register operators with 110 */ 111 void RegisterCustomOps(Ort::CustomOpDomain& domain); 112 113 /** 114 * Get the custom op domain. 115 */ 116 Ort::CustomOpDomain& GetCustomOpDomain(); 117 118 } // namespace custom_ops 119 120 // C API for dynamic loading 121 extern "C" { 122 OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api); 123 } 124 125 #endif // CUSTOM_OP_H