/ 02_ORT_Custom_CUDA_Op / src / custom_op.h
custom_op.h
  1  /**
  2   * ONNX Runtime Custom Operator definitions.
  3   * 
  4   * Requirements: 3.2
  5   */
  6  
  7  #ifndef CUSTOM_OP_H
  8  #define CUSTOM_OP_H
  9  
 10  #include <onnxruntime_cxx_api.h>
 11  #include <onnxruntime/core/providers/cuda/cuda_context.h>
 12  #include <cuda_runtime.h>
 13  #include "cuda_kernels.h"
 14  
 15  namespace custom_ops {
 16  
 17  // Domain for custom operators
 18  constexpr const char* CUSTOM_OP_DOMAIN = "custom.ops";
 19  constexpr int CUSTOM_OP_VERSION = 1;
 20  
 21  /**
 22   * GELU Kernel implementation for ONNX Runtime.
 23   * 
 24   * This class handles the actual computation when the operator is invoked.
 25   */
 26  struct GeluKernel {
 27      GeluKernel(const OrtKernelInfo* info) {
 28          // No attributes to parse for GELU
 29          (void)info;
 30      }
 31  
 32      void Compute(OrtKernelContext* context);
 33  
 34    private:
 35      static cudaStream_t GetCudaStream(const OrtKernelContext& context);
 36  };
 37  
 38  /**
 39   * Custom GELU Operator definition.
 40   * 
 41   * Inherits from Ort::CustomOpBase to integrate with ONNX Runtime.
 42   */
 43  struct GeluCustomOp : Ort::CustomOpBase<GeluCustomOp, GeluKernel> {
 44      /**
 45       * Get the operator name.
 46       */
 47      const char* GetName() const {
 48          return "CustomGelu";
 49      }
 50      
 51      /**
 52       * Get the execution provider type.
 53       * Returns "CUDAExecutionProvider" for GPU execution.
 54       */
 55      const char* GetExecutionProviderType() const {
 56          return "CUDAExecutionProvider";
 57      }
 58      
 59      /**
 60       * Get the number of inputs.
 61       */
 62      size_t GetInputTypeCount() const {
 63          return 1;
 64      }
 65      
 66      /**
 67       * Get the type of input at given index.
 68       */
 69      ONNXTensorElementDataType GetInputType(size_t index) const {
 70          (void)index;
 71          return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
 72      }
 73      
 74      /**
 75       * Get input characteristics.
 76       */
 77      OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const {
 78          (void)index;
 79          return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
 80      }
 81      
 82      /**
 83       * Get the number of outputs.
 84       */
 85      size_t GetOutputTypeCount() const {
 86          return 1;
 87      }
 88      
 89      /**
 90       * Get the type of output at given index.
 91       */
 92      ONNXTensorElementDataType GetOutputType(size_t index) const {
 93          (void)index;
 94          return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
 95      }
 96      
 97      /**
 98       * Get output characteristics.
 99       */
100      OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t index) const {
101          (void)index;
102          return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
103      }
104  };
105  
106  /**
107   * Register all custom operators with ONNX Runtime.
108   * 
109   * @param options Session options to register operators with
110   */
111  void RegisterCustomOps(Ort::CustomOpDomain& domain);
112  
113  /**
114   * Get the custom op domain.
115   */
116  Ort::CustomOpDomain& GetCustomOpDomain();
117  
118  }  // namespace custom_ops
119  
120  // C API for dynamic loading
121  extern "C" {
122      OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api);
123  }
124  
125  #endif // CUSTOM_OP_H