test_correctness.py
1 """ 2 Property-based tests for GELU numerical correctness. 3 4 Feature: ai-system-optimization-series, Property 2: GELU Kernel Numerical Correctness 5 Validates: Requirements 3.3 6 """ 7 8 import math 9 import os 10 import sys 11 import tempfile 12 13 import numpy as np 14 import pytest 15 from hypothesis import given, settings 16 from hypothesis import strategies as st 17 18 try: 19 import torch 20 21 TORCH_AVAILABLE = True 22 CUDA_AVAILABLE = torch.cuda.is_available() 23 except ImportError: 24 TORCH_AVAILABLE = False 25 CUDA_AVAILABLE = False 26 27 _MODULE_DIR = os.path.dirname(os.path.dirname(__file__)) 28 if _MODULE_DIR not in sys.path: 29 sys.path.insert(0, _MODULE_DIR) 30 31 from python.custom_ops import create_gelu_test_model, gelu_reference, run_inference 32 33 # Use shared GELU implementation from common module 34 from common.utils.gelu import gelu_tanh_approx 35 36 37 def gelu_exact_reference(x: np.ndarray) -> np.ndarray: 38 """Reference GELU implementation using the exact mathematical definition.""" 39 return x * 0.5 * (1.0 + np.vectorize(math.erf)(x / math.sqrt(2.0))) 40 41 42 class TestGeluNumericalCorrectness: 43 @given( 44 shape=st.tuples( 45 st.integers(min_value=1, max_value=256), st.integers(min_value=1, max_value=256) 46 ) 47 ) 48 @settings(max_examples=100) 49 def test_gelu_tanh_approx_matches_reference(self, shape): 50 np.random.seed(42) 51 x = np.random.randn(*shape).astype(np.float32) 52 y_ref = gelu_exact_reference(x) 53 y_approx = gelu_tanh_approx(x) 54 55 np.testing.assert_allclose( 56 y_approx, 57 y_ref, 58 rtol=5e-3, 59 atol=1e-3, 60 err_msg=f"GELU approximation differs from reference for shape {shape}", 61 ) 62 63 @given( 64 values=st.lists( 65 st.floats(min_value=-10.0, max_value=10.0, allow_nan=False, allow_infinity=False), 66 min_size=1, 67 max_size=1000, 68 ) 69 ) 70 @settings(max_examples=100) 71 def test_gelu_output_range(self, values): 72 x = np.array(values, dtype=np.float32) 73 y = gelu_tanh_approx(x) 74 75 zero_mask = np.abs(x) < 1e-6 76 if np.any(zero_mask): 77 np.testing.assert_allclose(y[zero_mask], 0.0, atol=1e-5) 78 79 pos_mask = x > 0.1 80 if np.any(pos_mask): 81 assert np.all(y[pos_mask] > 0) 82 assert np.all(y[pos_mask] <= x[pos_mask] + 1e-5) 83 84 @given(x=st.floats(min_value=-10.0, max_value=10.0, allow_nan=False, allow_infinity=False)) 85 @settings(max_examples=100) 86 def test_gelu_monotonicity(self, x): 87 if x > 0: 88 x_arr = np.array([x, x + 0.1], dtype=np.float32) 89 y_arr = gelu_tanh_approx(x_arr) 90 assert y_arr[1] >= y_arr[0] - 1e-5 91 92 @given( 93 shape=st.tuples( 94 st.integers(min_value=1, max_value=128), st.integers(min_value=1, max_value=128) 95 ) 96 ) 97 @settings(max_examples=100) 98 def test_gelu_symmetry_property(self, shape): 99 np.random.seed(42) 100 x = np.random.randn(*shape).astype(np.float32) 101 y_pos = gelu_tanh_approx(x) 102 y_neg = gelu_tanh_approx(-x) 103 104 assert np.all(np.isfinite(y_pos)) 105 assert np.all(np.isfinite(y_neg)) 106 107 108 @pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not installed") 109 class TestGeluAgainstPyTorch: 110 @given( 111 shape=st.tuples( 112 st.integers(min_value=1, max_value=256), st.integers(min_value=1, max_value=256) 113 ) 114 ) 115 @settings(max_examples=100) 116 def test_matches_pytorch_gelu(self, shape): 117 np.random.seed(42) 118 x_np = np.random.randn(*shape).astype(np.float32) 119 x_torch = torch.from_numpy(x_np) 120 y_torch = torch.nn.functional.gelu(x_torch, approximate="tanh").numpy() 121 y_ours = gelu_tanh_approx(x_np) 122 123 np.testing.assert_allclose(y_ours, y_torch, rtol=1e-5, atol=1e-5) 124 125 126 @pytest.mark.requires_ort 127 @pytest.mark.requires_cuda 128 @pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA path not available") 129 class TestGeluCustomOpIntegration: 130 """Validate the real ONNX Runtime custom op path, not a PyTorch proxy.""" 131 132 def test_custom_op_matches_reference(self): 133 shape = [4, 128] 134 with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f: 135 model_path = f.name 136 137 try: 138 create_gelu_test_model(input_shape=shape, output_path=model_path) 139 np.random.seed(123) 140 x = np.random.randn(*shape).astype(np.float32) 141 y = run_inference(model_path, x, use_cuda=True) 142 y_ref = gelu_reference(x) 143 np.testing.assert_allclose(y, y_ref, rtol=1e-5, atol=1e-5) 144 finally: 145 if os.path.exists(model_path): 146 os.remove(model_path) 147 148 149 class TestGeluEdgeCases: 150 def test_gelu_zero(self): 151 x = np.array([0.0], dtype=np.float32) 152 y = gelu_tanh_approx(x) 153 np.testing.assert_allclose(y, [0.0], atol=1e-7) 154 155 def test_gelu_large_positive(self): 156 x = np.array([10.0, 20.0, 50.0], dtype=np.float32) 157 y = gelu_tanh_approx(x) 158 np.testing.assert_allclose(y, x, rtol=1e-3) 159 160 def test_gelu_large_negative(self): 161 x = np.array([-10.0, -20.0, -50.0], dtype=np.float32) 162 y = gelu_tanh_approx(x) 163 np.testing.assert_allclose(y, [0.0, 0.0, 0.0], atol=1e-3) 164 165 def test_gelu_small_values(self): 166 x = np.linspace(-1, 1, 100, dtype=np.float32) 167 y = gelu_tanh_approx(x) 168 169 assert np.all(np.isfinite(y)) 170 derivative = np.diff(y) / np.diff(x) 171 assert np.all(np.abs(derivative) < 2.0) 172 173 174 if __name__ == "__main__": 175 pytest.main([__file__, "-v"])