cuda_kernels.cu
1 /** 2 * CUDA Kernel implementations for custom ONNX Runtime operators. 3 * 4 * Requirements: 3.3 5 */ 6 7 #include "cuda_kernels.h" 8 #include <cuda_fp16.h> 9 #include <cmath> 10 11 // Constants for GELU approximation 12 // GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))) 13 constexpr float SQRT_2_OVER_PI = 0.7978845608028654f; // sqrt(2/π) 14 constexpr float GELU_COEF = 0.044715f; 15 16 // Block size for GELU kernel 17 // 256 threads provides good occupancy for element-wise GELU: 18 // - Sufficient parallelism for typical tensor sizes 19 // - Balances register pressure vs shared memory usage 20 constexpr int GELU_BLOCK_SIZE = 256; 21 22 /** 23 * GELU activation kernel (float32). 24 * 25 * Uses tanh approximation: 26 * GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))) 27 */ 28 __global__ void gelu_kernel_f32( 29 const float* __restrict__ input, 30 float* __restrict__ output, 31 int64_t count 32 ) { 33 int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x; 34 35 if (idx < count) { 36 float x = input[idx]; 37 38 // Compute GELU using tanh approximation 39 float x_cubed = x * x * x; 40 float inner = SQRT_2_OVER_PI * (x + GELU_COEF * x_cubed); 41 float cdf = 0.5f * (1.0f + tanhf(inner)); 42 43 output[idx] = x * cdf; 44 } 45 } 46 47 /** 48 * GELU activation kernel (float16). 49 */ 50 __global__ void gelu_kernel_f16( 51 const __half* __restrict__ input, 52 __half* __restrict__ output, 53 int64_t count 54 ) { 55 int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x; 56 57 if (idx < count) { 58 // Convert to float for computation 59 float x = __half2float(input[idx]); 60 61 // Compute GELU 62 float x_cubed = x * x * x; 63 float inner = SQRT_2_OVER_PI * (x + GELU_COEF * x_cubed); 64 float cdf = 0.5f * (1.0f + tanhf(inner)); 65 float result = x * cdf; 66 67 // Convert back to half 68 output[idx] = __float2half(result); 69 } 70 } 71 72 /** 73 * Vectorized GELU kernel using float4 for better memory throughput. 74 */ 75 __global__ void gelu_kernel_f32_vec4( 76 const float4* __restrict__ input, 77 float4* __restrict__ output, 78 int64_t count_vec4 79 ) { 80 int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x; 81 82 if (idx < count_vec4) { 83 float4 in = input[idx]; 84 float4 out; 85 86 // Process 4 elements 87 #pragma unroll 88 for (int i = 0; i < 4; ++i) { 89 float x = reinterpret_cast<float*>(&in)[i]; 90 float x_cubed = x * x * x; 91 float inner = SQRT_2_OVER_PI * (x + GELU_COEF * x_cubed); 92 float cdf = 0.5f * (1.0f + tanhf(inner)); 93 reinterpret_cast<float*>(&out)[i] = x * cdf; 94 } 95 96 output[idx] = out; 97 } 98 } 99 100 // ============================================================================ 101 // Public API 102 // ============================================================================ 103 104 cudaError_t LaunchGeluKernel( 105 const float* input, 106 float* output, 107 int64_t count, 108 cudaStream_t stream 109 ) { 110 if (count == 0) return cudaSuccess; 111 112 // Try vectorized version if aligned 113 if (count % 4 == 0 && 114 reinterpret_cast<uintptr_t>(input) % 16 == 0 && 115 reinterpret_cast<uintptr_t>(output) % 16 == 0) { 116 117 int64_t count_vec4 = count / 4; 118 int grid_size = (count_vec4 + GELU_BLOCK_SIZE - 1) / GELU_BLOCK_SIZE; 119 120 gelu_kernel_f32_vec4<<<grid_size, GELU_BLOCK_SIZE, 0, stream>>>( 121 reinterpret_cast<const float4*>(input), 122 reinterpret_cast<float4*>(output), 123 count_vec4 124 ); 125 } else { 126 int grid_size = (count + GELU_BLOCK_SIZE - 1) / GELU_BLOCK_SIZE; 127 128 gelu_kernel_f32<<<grid_size, GELU_BLOCK_SIZE, 0, stream>>>( 129 input, output, count 130 ); 131 } 132 133 return cudaGetLastError(); 134 } 135 136 cudaError_t LaunchGeluKernelHalf( 137 const __half* input, 138 __half* output, 139 int64_t count, 140 cudaStream_t stream 141 ) { 142 if (count == 0) return cudaSuccess; 143 144 int grid_size = (count + GELU_BLOCK_SIZE - 1) / GELU_BLOCK_SIZE; 145 146 gelu_kernel_f16<<<grid_size, GELU_BLOCK_SIZE, 0, stream>>>( 147 input, output, count 148 ); 149 150 return cudaGetLastError(); 151 } 152 153 int GetGeluBlockSize() { 154 return GELU_BLOCK_SIZE; 155 } 156 157 int GetGeluGridSize(int64_t count) { 158 return (count + GELU_BLOCK_SIZE - 1) / GELU_BLOCK_SIZE; 159 }