/ 02_ORT_Custom_CUDA_Op / tests / test_correctness.py
test_correctness.py
  1  """
  2  Property-based tests for GELU numerical correctness.
  3  
  4  Feature: ai-system-optimization-series, Property 2: GELU Kernel Numerical Correctness
  5  Validates: Requirements 3.3
  6  """
  7  
  8  import math
  9  import os
 10  import sys
 11  import tempfile
 12  
 13  import numpy as np
 14  import pytest
 15  from hypothesis import given, settings
 16  from hypothesis import strategies as st
 17  
 18  try:
 19      import torch
 20  
 21      TORCH_AVAILABLE = True
 22      CUDA_AVAILABLE = torch.cuda.is_available()
 23  except ImportError:
 24      TORCH_AVAILABLE = False
 25      CUDA_AVAILABLE = False
 26  
 27  _MODULE_DIR = os.path.dirname(os.path.dirname(__file__))
 28  if _MODULE_DIR not in sys.path:
 29      sys.path.insert(0, _MODULE_DIR)
 30  
 31  from python.custom_ops import create_gelu_test_model, gelu_reference, run_inference
 32  
 33  # Use shared GELU implementation from common module
 34  from common.utils.gelu import gelu_tanh_approx
 35  
 36  
 37  def gelu_exact_reference(x: np.ndarray) -> np.ndarray:
 38      """Reference GELU implementation using the exact mathematical definition."""
 39      return x * 0.5 * (1.0 + np.vectorize(math.erf)(x / math.sqrt(2.0)))
 40  
 41  
 42  class TestGeluNumericalCorrectness:
 43      @given(
 44          shape=st.tuples(
 45              st.integers(min_value=1, max_value=256), st.integers(min_value=1, max_value=256)
 46          )
 47      )
 48      @settings(max_examples=100)
 49      def test_gelu_tanh_approx_matches_reference(self, shape):
 50          np.random.seed(42)
 51          x = np.random.randn(*shape).astype(np.float32)
 52          y_ref = gelu_exact_reference(x)
 53          y_approx = gelu_tanh_approx(x)
 54  
 55          np.testing.assert_allclose(
 56              y_approx,
 57              y_ref,
 58              rtol=5e-3,
 59              atol=1e-3,
 60              err_msg=f"GELU approximation differs from reference for shape {shape}",
 61          )
 62  
 63      @given(
 64          values=st.lists(
 65              st.floats(min_value=-10.0, max_value=10.0, allow_nan=False, allow_infinity=False),
 66              min_size=1,
 67              max_size=1000,
 68          )
 69      )
 70      @settings(max_examples=100)
 71      def test_gelu_output_range(self, values):
 72          x = np.array(values, dtype=np.float32)
 73          y = gelu_tanh_approx(x)
 74  
 75          zero_mask = np.abs(x) < 1e-6
 76          if np.any(zero_mask):
 77              np.testing.assert_allclose(y[zero_mask], 0.0, atol=1e-5)
 78  
 79          pos_mask = x > 0.1
 80          if np.any(pos_mask):
 81              assert np.all(y[pos_mask] > 0)
 82              assert np.all(y[pos_mask] <= x[pos_mask] + 1e-5)
 83  
 84      @given(x=st.floats(min_value=-10.0, max_value=10.0, allow_nan=False, allow_infinity=False))
 85      @settings(max_examples=100)
 86      def test_gelu_monotonicity(self, x):
 87          if x > 0:
 88              x_arr = np.array([x, x + 0.1], dtype=np.float32)
 89              y_arr = gelu_tanh_approx(x_arr)
 90              assert y_arr[1] >= y_arr[0] - 1e-5
 91  
 92      @given(
 93          shape=st.tuples(
 94              st.integers(min_value=1, max_value=128), st.integers(min_value=1, max_value=128)
 95          )
 96      )
 97      @settings(max_examples=100)
 98      def test_gelu_symmetry_property(self, shape):
 99          np.random.seed(42)
100          x = np.random.randn(*shape).astype(np.float32)
101          y_pos = gelu_tanh_approx(x)
102          y_neg = gelu_tanh_approx(-x)
103  
104          assert np.all(np.isfinite(y_pos))
105          assert np.all(np.isfinite(y_neg))
106  
107  
108  @pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not installed")
109  class TestGeluAgainstPyTorch:
110      @given(
111          shape=st.tuples(
112              st.integers(min_value=1, max_value=256), st.integers(min_value=1, max_value=256)
113          )
114      )
115      @settings(max_examples=100)
116      def test_matches_pytorch_gelu(self, shape):
117          np.random.seed(42)
118          x_np = np.random.randn(*shape).astype(np.float32)
119          x_torch = torch.from_numpy(x_np)
120          y_torch = torch.nn.functional.gelu(x_torch, approximate="tanh").numpy()
121          y_ours = gelu_tanh_approx(x_np)
122  
123          np.testing.assert_allclose(y_ours, y_torch, rtol=1e-5, atol=1e-5)
124  
125  
126  @pytest.mark.requires_ort
127  @pytest.mark.requires_cuda
128  @pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA path not available")
129  class TestGeluCustomOpIntegration:
130      """Validate the real ONNX Runtime custom op path, not a PyTorch proxy."""
131  
132      def test_custom_op_matches_reference(self):
133          shape = [4, 128]
134          with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
135              model_path = f.name
136  
137          try:
138              create_gelu_test_model(input_shape=shape, output_path=model_path)
139              np.random.seed(123)
140              x = np.random.randn(*shape).astype(np.float32)
141              y = run_inference(model_path, x, use_cuda=True)
142              y_ref = gelu_reference(x)
143              np.testing.assert_allclose(y, y_ref, rtol=1e-5, atol=1e-5)
144          finally:
145              if os.path.exists(model_path):
146                  os.remove(model_path)
147  
148  
149  class TestGeluEdgeCases:
150      def test_gelu_zero(self):
151          x = np.array([0.0], dtype=np.float32)
152          y = gelu_tanh_approx(x)
153          np.testing.assert_allclose(y, [0.0], atol=1e-7)
154  
155      def test_gelu_large_positive(self):
156          x = np.array([10.0, 20.0, 50.0], dtype=np.float32)
157          y = gelu_tanh_approx(x)
158          np.testing.assert_allclose(y, x, rtol=1e-3)
159  
160      def test_gelu_large_negative(self):
161          x = np.array([-10.0, -20.0, -50.0], dtype=np.float32)
162          y = gelu_tanh_approx(x)
163          np.testing.assert_allclose(y, [0.0, 0.0, 0.0], atol=1e-3)
164  
165      def test_gelu_small_values(self):
166          x = np.linspace(-1, 1, 100, dtype=np.float32)
167          y = gelu_tanh_approx(x)
168  
169          assert np.all(np.isfinite(y))
170          derivative = np.diff(y) / np.diff(x)
171          assert np.all(np.abs(derivative) < 2.0)
172  
173  
174  if __name__ == "__main__":
175      pytest.main([__file__, "-v"])