test_inference.py
1 """ 2 Inference tests for custom ONNX Runtime operators. 3 4 Requirements: 3.5 5 """ 6 7 import os 8 import sys 9 10 import numpy as np 11 import pytest 12 13 try: 14 import onnx 15 16 ONNX_AVAILABLE = True 17 except ImportError: 18 ONNX_AVAILABLE = False 19 20 try: 21 import onnxruntime as ort 22 23 ORT_AVAILABLE = True 24 except ImportError: 25 ORT_AVAILABLE = False 26 27 _MODULE_DIR = os.path.dirname(os.path.dirname(__file__)) 28 if _MODULE_DIR not in sys.path: 29 sys.path.insert(0, _MODULE_DIR) 30 31 from python.custom_ops import ( 32 CUSTOM_OP_DOMAIN, 33 assert_custom_op_is_active, 34 create_gelu_test_model, 35 create_session, 36 create_temp_gelu_test_model, 37 gelu_reference, 38 get_library_path, 39 get_registered_library_path, 40 has_cuda_execution_provider, 41 is_custom_op_library_available, 42 run_inference, 43 ) 44 45 pytestmark = [pytest.mark.requires_ort] 46 47 CUDA_AVAILABLE = has_cuda_execution_provider() 48 CUSTOM_OP_LIBRARY_AVAILABLE = is_custom_op_library_available() 49 50 51 def _run_and_compare(model_path: str, input_data: np.ndarray) -> None: 52 output = run_inference(model_path, input_data, use_cuda=True) 53 expected = gelu_reference(input_data) 54 np.testing.assert_allclose(output, expected, rtol=1e-5, atol=1e-5) 55 56 57 @pytest.mark.skipif(not (ONNX_AVAILABLE and ORT_AVAILABLE), reason="ONNX and ONNXRuntime required") 58 class TestGeluModelCreation: 59 """Test GELU model creation.""" 60 61 def test_create_model_default_shape(self): 62 model = create_gelu_test_model() 63 64 assert model is not None 65 assert len(model.graph.node) == 1 66 assert model.graph.node[0].op_type == "CustomGelu" 67 assert model.graph.node[0].domain == CUSTOM_OP_DOMAIN 68 69 def test_create_model_custom_shape(self): 70 shape = [2, 128, 64] 71 model = create_gelu_test_model(input_shape=shape) 72 73 input_info = model.graph.input[0] 74 dims = [d.dim_value for d in input_info.type.tensor_type.shape.dim] 75 assert dims == shape 76 77 def test_save_model(self): 78 model_path, cleanup = create_temp_gelu_test_model([1, 256]) 79 80 try: 81 loaded = onnx.load(model_path) 82 assert loaded.graph.node[0].op_type == "CustomGelu" 83 finally: 84 cleanup() 85 86 87 class TestGeluReference: 88 """Test reference GELU implementation.""" 89 90 def test_gelu_zero(self): 91 x = np.array([0.0], dtype=np.float32) 92 y = gelu_reference(x) 93 np.testing.assert_allclose(y, [0.0], atol=1e-7) 94 95 def test_gelu_positive(self): 96 x = np.array([1.0, 2.0, 3.0], dtype=np.float32) 97 y = gelu_reference(x) 98 99 assert np.all(y > 0) 100 assert np.all(y <= x) 101 102 def test_gelu_negative(self): 103 x = np.array([-1.0, -2.0, -3.0], dtype=np.float32) 104 y = gelu_reference(x) 105 106 assert y[-1] > -0.1 107 108 def test_gelu_batch(self): 109 np.random.seed(42) 110 x = np.random.randn(4, 128).astype(np.float32) 111 y = gelu_reference(x) 112 113 assert y.shape == x.shape 114 assert np.all(np.isfinite(y)) 115 116 117 @pytest.mark.requires_cuda 118 @pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDAExecutionProvider not available") 119 @pytest.mark.skipif(not CUSTOM_OP_LIBRARY_AVAILABLE, reason="Custom op library not built") 120 class TestGeluInference: 121 """Test GELU inference with the real custom op library.""" 122 123 def test_registers_expected_library(self): 124 model_path, cleanup = create_temp_gelu_test_model([1, 8]) 125 126 try: 127 library_path = get_registered_library_path(model_path, use_cuda=True) 128 assert os.path.abspath(library_path) == os.path.abspath(get_library_path()) 129 finally: 130 cleanup() 131 132 def test_session_uses_cuda_execution_provider(self): 133 model_path, cleanup = create_temp_gelu_test_model([1, 16]) 134 135 try: 136 session, library_path = create_session(model_path, use_cuda=True) 137 assert_custom_op_is_active(session, use_cuda=True) 138 assert os.path.abspath(library_path) == os.path.abspath(get_library_path()) 139 finally: 140 cleanup() 141 142 def test_inference_basic(self): 143 model_path, cleanup = create_temp_gelu_test_model([1, 64]) 144 145 try: 146 input_data = np.random.randn(1, 64).astype(np.float32) 147 _run_and_compare(model_path, input_data) 148 finally: 149 cleanup() 150 151 def test_inference_large_batch(self): 152 model_path, cleanup = create_temp_gelu_test_model([32, 1024]) 153 154 try: 155 input_data = np.random.randn(32, 1024).astype(np.float32) 156 _run_and_compare(model_path, input_data) 157 finally: 158 cleanup() 159 160 def test_missing_custom_op_registration_breaks_session_creation(self): 161 model_path, cleanup = create_temp_gelu_test_model([1, 4]) 162 163 try: 164 with pytest.raises(Exception): 165 ort.InferenceSession(model_path, providers=["CUDAExecutionProvider"]) 166 finally: 167 cleanup() 168 169 170 if __name__ == "__main__": 171 pytest.main([__file__, "-v"])