/ 01_TVM_End2End_Optimization / tests / test_tvm_optimization.py
test_tvm_optimization.py
  1  """
  2  Property-based tests for TensorIR schedule transformations.
  3  
  4  Feature: ai-system-optimization-series, Property 1: TensorIR Schedule Transformation Correctness
  5  Validates: Requirements 2.4
  6  
  7  For any valid TensorIR schedule with applied transformations (split, reorder, bind),
  8  the generated code SHALL preserve the original computation semantics—the output tensor
  9  values must be identical to the unoptimized baseline within floating-point tolerance.
 10  """
 11  
 12  import importlib.util
 13  from pathlib import Path
 14  
 15  import numpy as np
 16  import pytest
 17  from hypothesis import assume, given, settings
 18  from hypothesis import strategies as st
 19  
 20  # TVM imports
 21  try:
 22      import tvm
 23      from tvm import tir
 24  
 25      TVM_AVAILABLE = True
 26  except ImportError:
 27      TVM_AVAILABLE = False
 28  
 29  pytestmark = pytest.mark.skipif(not TVM_AVAILABLE, reason="TVM not installed")
 30  
 31  
 32  if TVM_AVAILABLE:
 33      MANUAL_SCHEDULE_PATH = Path(__file__).resolve().parents[1] / "3_tensorir_manual_schedule.py"
 34      _spec = importlib.util.spec_from_file_location("tensorir_manual_schedule", MANUAL_SCHEDULE_PATH)
 35      assert _spec is not None and _spec.loader is not None
 36      _tensorir_manual_schedule = importlib.util.module_from_spec(_spec)
 37      _spec.loader.exec_module(_tensorir_manual_schedule)
 38  
 39      create_matmul_module = _tensorir_manual_schedule.create_matmul_module
 40      apply_schedule_primitives = _tensorir_manual_schedule.apply_schedule_primitives
 41  else:
 42      create_matmul_module = None
 43      apply_schedule_primitives = None
 44  
 45  
 46  def run_and_get_output(mod, a_np, b_np, target: str = "llvm"):
 47      """Compile and run the module, return output."""
 48      m, k = a_np.shape
 49      _, n = b_np.shape
 50  
 51      target = tvm.target.Target(target)
 52      with tvm.transform.PassContext(opt_level=3):
 53          lib = tvm.build(mod, target=target)
 54  
 55      dev = tvm.device(str(target), 0)
 56      a_tvm = tvm.nd.array(a_np, dev)
 57      b_tvm = tvm.nd.array(b_np, dev)
 58      c_tvm = tvm.nd.array(np.zeros((m, n), dtype="float32"), dev)
 59  
 60      func = lib["main"]
 61      func(a_tvm, b_tvm, c_tvm)
 62  
 63      return c_tvm.numpy()
 64  
 65  
 66  class TestTensorIRScheduleCorrectness:
 67      """
 68      Property 1: TensorIR Schedule Transformation Correctness
 69  
 70      For any valid TensorIR schedule with applied transformations,
 71      the output must match the unoptimized baseline within floating-point tolerance.
 72      """
 73  
 74      @given(
 75          m=st.integers(min_value=64, max_value=512),
 76          n=st.integers(min_value=64, max_value=512),
 77          k=st.integers(min_value=64, max_value=512),
 78      )
 79      @settings(max_examples=100, deadline=60000)  # 60s deadline for compilation
 80      def test_schedule_preserves_semantics(self, m, n, k):
 81          """
 82          Property: For any matrix dimensions, scheduled computation equals baseline.
 83          """
 84          # Ensure dimensions are multiples of 32 for clean tiling
 85          m = (m // 32) * 32
 86          n = (n // 32) * 32
 87          k = (k // 32) * 32
 88          assume(m >= 32 and n >= 32 and k >= 32)
 89  
 90          # Create module
 91          mod = create_matmul_module(m, n, k)
 92  
 93          # Generate random input
 94          np.random.seed(42)  # Reproducibility
 95          a_np = np.random.randn(m, k).astype("float32")
 96          b_np = np.random.randn(k, n).astype("float32")
 97  
 98          # Run baseline (unoptimized)
 99          baseline_sch = tir.Schedule(mod)
100          baseline_output = run_and_get_output(baseline_sch.mod, a_np, b_np)
101  
102          # Run optimized (with transformations)
103          optimized_sch = apply_schedule_primitives(mod)
104          optimized_output = run_and_get_output(optimized_sch.mod, a_np, b_np)
105  
106          # Verify outputs match
107          np.testing.assert_allclose(
108              optimized_output,
109              baseline_output,
110              rtol=1e-5,
111              atol=1e-5,
112              err_msg=f"Schedule transformation changed output for M={m}, N={n}, K={k}",
113          )
114  
115      @given(
116          m=st.integers(min_value=64, max_value=512),
117          n=st.integers(min_value=64, max_value=512),
118          k=st.integers(min_value=64, max_value=512),
119      )
120      @settings(max_examples=100, deadline=60000)
121      def test_schedule_matches_numpy(self, m, n, k):
122          """
123          Property: Scheduled computation matches numpy reference.
124          """
125          m = (m // 32) * 32
126          n = (n // 32) * 32
127          k = (k // 32) * 32
128          assume(m >= 32 and n >= 32 and k >= 32)
129  
130          mod = create_matmul_module(m, n, k)
131  
132          np.random.seed(42)
133          a_np = np.random.randn(m, k).astype("float32")
134          b_np = np.random.randn(k, n).astype("float32")
135  
136          # Numpy reference
137          c_ref = np.matmul(a_np, b_np)
138  
139          # TVM optimized
140          optimized_sch = apply_schedule_primitives(mod)
141          c_tvm = run_and_get_output(optimized_sch.mod, a_np, b_np)
142  
143          np.testing.assert_allclose(
144              c_tvm,
145              c_ref,
146              rtol=1e-5,
147              atol=1e-5,
148              err_msg=f"TVM output differs from numpy for M={m}, N={n}, K={k}",
149          )
150  
151      @given(block_size=st.sampled_from([16, 32, 64, 128]))
152      @settings(max_examples=100, deadline=60000)
153      def test_different_block_sizes(self, block_size):
154          """
155          Property: Different block sizes produce same results.
156          """
157          m, n, k = 256, 256, 256
158          mod = create_matmul_module(m, n, k)
159  
160          np.random.seed(42)
161          a_np = np.random.randn(m, k).astype("float32")
162          b_np = np.random.randn(k, n).astype("float32")
163  
164          # Reference
165          c_ref = np.matmul(a_np, b_np)
166  
167          # TVM with specific block size
168          actual_block = min(block_size, m, n)
169          sch = apply_schedule_primitives(mod, block_size=actual_block)
170  
171          c_tvm = run_and_get_output(sch.mod, a_np, b_np)
172  
173          np.testing.assert_allclose(
174              c_tvm,
175              c_ref,
176              rtol=1e-5,
177              atol=1e-5,
178              err_msg=f"Block size {block_size} produces incorrect results",
179          )
180  
181  
182  class TestSchedulePrimitives:
183      """Unit tests for individual schedule primitives."""
184  
185      def test_split_preserves_computation(self):
186          """Test that split alone preserves computation."""
187          m, n, k = 128, 128, 128
188          mod = create_matmul_module(m, n, k)
189  
190          np.random.seed(42)
191          a_np = np.random.randn(m, k).astype("float32")
192          b_np = np.random.randn(k, n).astype("float32")
193          c_ref = np.matmul(a_np, b_np)
194  
195          sch = tir.Schedule(mod)
196          block = sch.get_block("matmul")
197          i, j, k_loop = sch.get_loops(block)
198          sch.split(i, factors=[None, 32])
199  
200          c_tvm = run_and_get_output(sch.mod, a_np, b_np)
201          np.testing.assert_allclose(c_tvm, c_ref, rtol=1e-5)
202  
203      def test_reorder_preserves_computation(self):
204          """Test that reorder alone preserves computation."""
205          m, n, k = 128, 128, 128
206          mod = create_matmul_module(m, n, k)
207  
208          np.random.seed(42)
209          a_np = np.random.randn(m, k).astype("float32")
210          b_np = np.random.randn(k, n).astype("float32")
211          c_ref = np.matmul(a_np, b_np)
212  
213          sch = tir.Schedule(mod)
214          block = sch.get_block("matmul")
215          i, j, k_loop = sch.get_loops(block)
216          # Reorder to j, i, k
217          sch.reorder(j, i, k_loop)
218  
219          c_tvm = run_and_get_output(sch.mod, a_np, b_np)
220          np.testing.assert_allclose(c_tvm, c_ref, rtol=1e-5)
221  
222  
223  if __name__ == "__main__":
224      pytest.main([__file__, "-v"])