/ 02_ORT_Custom_CUDA_Op / tests / test_inference.py
test_inference.py
  1  """
  2  Inference tests for custom ONNX Runtime operators.
  3  
  4  Requirements: 3.5
  5  """
  6  
  7  import os
  8  import sys
  9  
 10  import numpy as np
 11  import pytest
 12  
 13  try:
 14      import onnx
 15  
 16      ONNX_AVAILABLE = True
 17  except ImportError:
 18      ONNX_AVAILABLE = False
 19  
 20  try:
 21      import onnxruntime as ort
 22  
 23      ORT_AVAILABLE = True
 24  except ImportError:
 25      ORT_AVAILABLE = False
 26  
 27  _MODULE_DIR = os.path.dirname(os.path.dirname(__file__))
 28  if _MODULE_DIR not in sys.path:
 29      sys.path.insert(0, _MODULE_DIR)
 30  
 31  from python.custom_ops import (
 32      CUSTOM_OP_DOMAIN,
 33      assert_custom_op_is_active,
 34      create_gelu_test_model,
 35      create_session,
 36      create_temp_gelu_test_model,
 37      gelu_reference,
 38      get_library_path,
 39      get_registered_library_path,
 40      has_cuda_execution_provider,
 41      is_custom_op_library_available,
 42      run_inference,
 43  )
 44  
 45  pytestmark = [pytest.mark.requires_ort]
 46  
 47  CUDA_AVAILABLE = has_cuda_execution_provider()
 48  CUSTOM_OP_LIBRARY_AVAILABLE = is_custom_op_library_available()
 49  
 50  
 51  def _run_and_compare(model_path: str, input_data: np.ndarray) -> None:
 52      output = run_inference(model_path, input_data, use_cuda=True)
 53      expected = gelu_reference(input_data)
 54      np.testing.assert_allclose(output, expected, rtol=1e-5, atol=1e-5)
 55  
 56  
 57  @pytest.mark.skipif(not (ONNX_AVAILABLE and ORT_AVAILABLE), reason="ONNX and ONNXRuntime required")
 58  class TestGeluModelCreation:
 59      """Test GELU model creation."""
 60  
 61      def test_create_model_default_shape(self):
 62          model = create_gelu_test_model()
 63  
 64          assert model is not None
 65          assert len(model.graph.node) == 1
 66          assert model.graph.node[0].op_type == "CustomGelu"
 67          assert model.graph.node[0].domain == CUSTOM_OP_DOMAIN
 68  
 69      def test_create_model_custom_shape(self):
 70          shape = [2, 128, 64]
 71          model = create_gelu_test_model(input_shape=shape)
 72  
 73          input_info = model.graph.input[0]
 74          dims = [d.dim_value for d in input_info.type.tensor_type.shape.dim]
 75          assert dims == shape
 76  
 77      def test_save_model(self):
 78          model_path, cleanup = create_temp_gelu_test_model([1, 256])
 79  
 80          try:
 81              loaded = onnx.load(model_path)
 82              assert loaded.graph.node[0].op_type == "CustomGelu"
 83          finally:
 84              cleanup()
 85  
 86  
 87  class TestGeluReference:
 88      """Test reference GELU implementation."""
 89  
 90      def test_gelu_zero(self):
 91          x = np.array([0.0], dtype=np.float32)
 92          y = gelu_reference(x)
 93          np.testing.assert_allclose(y, [0.0], atol=1e-7)
 94  
 95      def test_gelu_positive(self):
 96          x = np.array([1.0, 2.0, 3.0], dtype=np.float32)
 97          y = gelu_reference(x)
 98  
 99          assert np.all(y > 0)
100          assert np.all(y <= x)
101  
102      def test_gelu_negative(self):
103          x = np.array([-1.0, -2.0, -3.0], dtype=np.float32)
104          y = gelu_reference(x)
105  
106          assert y[-1] > -0.1
107  
108      def test_gelu_batch(self):
109          np.random.seed(42)
110          x = np.random.randn(4, 128).astype(np.float32)
111          y = gelu_reference(x)
112  
113          assert y.shape == x.shape
114          assert np.all(np.isfinite(y))
115  
116  
117  @pytest.mark.requires_cuda
118  @pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDAExecutionProvider not available")
119  @pytest.mark.skipif(not CUSTOM_OP_LIBRARY_AVAILABLE, reason="Custom op library not built")
120  class TestGeluInference:
121      """Test GELU inference with the real custom op library."""
122  
123      def test_registers_expected_library(self):
124          model_path, cleanup = create_temp_gelu_test_model([1, 8])
125  
126          try:
127              library_path = get_registered_library_path(model_path, use_cuda=True)
128              assert os.path.abspath(library_path) == os.path.abspath(get_library_path())
129          finally:
130              cleanup()
131  
132      def test_session_uses_cuda_execution_provider(self):
133          model_path, cleanup = create_temp_gelu_test_model([1, 16])
134  
135          try:
136              session, library_path = create_session(model_path, use_cuda=True)
137              assert_custom_op_is_active(session, use_cuda=True)
138              assert os.path.abspath(library_path) == os.path.abspath(get_library_path())
139          finally:
140              cleanup()
141  
142      def test_inference_basic(self):
143          model_path, cleanup = create_temp_gelu_test_model([1, 64])
144  
145          try:
146              input_data = np.random.randn(1, 64).astype(np.float32)
147              _run_and_compare(model_path, input_data)
148          finally:
149              cleanup()
150  
151      def test_inference_large_batch(self):
152          model_path, cleanup = create_temp_gelu_test_model([32, 1024])
153  
154          try:
155              input_data = np.random.randn(32, 1024).astype(np.float32)
156              _run_and_compare(model_path, input_data)
157          finally:
158              cleanup()
159  
160      def test_missing_custom_op_registration_breaks_session_creation(self):
161          model_path, cleanup = create_temp_gelu_test_model([1, 4])
162  
163          try:
164              with pytest.raises(Exception):
165                  ort.InferenceSession(model_path, providers=["CUDAExecutionProvider"])
166          finally:
167              cleanup()
168  
169  
170  if __name__ == "__main__":
171      pytest.main([__file__, "-v"])