test_tvm_optimization.py
1 """ 2 Property-based tests for TensorIR schedule transformations. 3 4 Feature: ai-system-optimization-series, Property 1: TensorIR Schedule Transformation Correctness 5 Validates: Requirements 2.4 6 7 For any valid TensorIR schedule with applied transformations (split, reorder, bind), 8 the generated code SHALL preserve the original computation semantics—the output tensor 9 values must be identical to the unoptimized baseline within floating-point tolerance. 10 """ 11 12 import importlib.util 13 from pathlib import Path 14 15 import numpy as np 16 import pytest 17 from hypothesis import assume, given, settings 18 from hypothesis import strategies as st 19 20 # TVM imports 21 try: 22 import tvm 23 from tvm import tir 24 25 TVM_AVAILABLE = True 26 except ImportError: 27 TVM_AVAILABLE = False 28 29 pytestmark = pytest.mark.skipif(not TVM_AVAILABLE, reason="TVM not installed") 30 31 32 if TVM_AVAILABLE: 33 MANUAL_SCHEDULE_PATH = Path(__file__).resolve().parents[1] / "3_tensorir_manual_schedule.py" 34 _spec = importlib.util.spec_from_file_location("tensorir_manual_schedule", MANUAL_SCHEDULE_PATH) 35 assert _spec is not None and _spec.loader is not None 36 _tensorir_manual_schedule = importlib.util.module_from_spec(_spec) 37 _spec.loader.exec_module(_tensorir_manual_schedule) 38 39 create_matmul_module = _tensorir_manual_schedule.create_matmul_module 40 apply_schedule_primitives = _tensorir_manual_schedule.apply_schedule_primitives 41 else: 42 create_matmul_module = None 43 apply_schedule_primitives = None 44 45 46 def run_and_get_output(mod, a_np, b_np, target: str = "llvm"): 47 """Compile and run the module, return output.""" 48 m, k = a_np.shape 49 _, n = b_np.shape 50 51 target = tvm.target.Target(target) 52 with tvm.transform.PassContext(opt_level=3): 53 lib = tvm.build(mod, target=target) 54 55 dev = tvm.device(str(target), 0) 56 a_tvm = tvm.nd.array(a_np, dev) 57 b_tvm = tvm.nd.array(b_np, dev) 58 c_tvm = tvm.nd.array(np.zeros((m, n), dtype="float32"), dev) 59 60 func = lib["main"] 61 func(a_tvm, b_tvm, c_tvm) 62 63 return c_tvm.numpy() 64 65 66 class TestTensorIRScheduleCorrectness: 67 """ 68 Property 1: TensorIR Schedule Transformation Correctness 69 70 For any valid TensorIR schedule with applied transformations, 71 the output must match the unoptimized baseline within floating-point tolerance. 72 """ 73 74 @given( 75 m=st.integers(min_value=64, max_value=512), 76 n=st.integers(min_value=64, max_value=512), 77 k=st.integers(min_value=64, max_value=512), 78 ) 79 @settings(max_examples=100, deadline=60000) # 60s deadline for compilation 80 def test_schedule_preserves_semantics(self, m, n, k): 81 """ 82 Property: For any matrix dimensions, scheduled computation equals baseline. 83 """ 84 # Ensure dimensions are multiples of 32 for clean tiling 85 m = (m // 32) * 32 86 n = (n // 32) * 32 87 k = (k // 32) * 32 88 assume(m >= 32 and n >= 32 and k >= 32) 89 90 # Create module 91 mod = create_matmul_module(m, n, k) 92 93 # Generate random input 94 np.random.seed(42) # Reproducibility 95 a_np = np.random.randn(m, k).astype("float32") 96 b_np = np.random.randn(k, n).astype("float32") 97 98 # Run baseline (unoptimized) 99 baseline_sch = tir.Schedule(mod) 100 baseline_output = run_and_get_output(baseline_sch.mod, a_np, b_np) 101 102 # Run optimized (with transformations) 103 optimized_sch = apply_schedule_primitives(mod) 104 optimized_output = run_and_get_output(optimized_sch.mod, a_np, b_np) 105 106 # Verify outputs match 107 np.testing.assert_allclose( 108 optimized_output, 109 baseline_output, 110 rtol=1e-5, 111 atol=1e-5, 112 err_msg=f"Schedule transformation changed output for M={m}, N={n}, K={k}", 113 ) 114 115 @given( 116 m=st.integers(min_value=64, max_value=512), 117 n=st.integers(min_value=64, max_value=512), 118 k=st.integers(min_value=64, max_value=512), 119 ) 120 @settings(max_examples=100, deadline=60000) 121 def test_schedule_matches_numpy(self, m, n, k): 122 """ 123 Property: Scheduled computation matches numpy reference. 124 """ 125 m = (m // 32) * 32 126 n = (n // 32) * 32 127 k = (k // 32) * 32 128 assume(m >= 32 and n >= 32 and k >= 32) 129 130 mod = create_matmul_module(m, n, k) 131 132 np.random.seed(42) 133 a_np = np.random.randn(m, k).astype("float32") 134 b_np = np.random.randn(k, n).astype("float32") 135 136 # Numpy reference 137 c_ref = np.matmul(a_np, b_np) 138 139 # TVM optimized 140 optimized_sch = apply_schedule_primitives(mod) 141 c_tvm = run_and_get_output(optimized_sch.mod, a_np, b_np) 142 143 np.testing.assert_allclose( 144 c_tvm, 145 c_ref, 146 rtol=1e-5, 147 atol=1e-5, 148 err_msg=f"TVM output differs from numpy for M={m}, N={n}, K={k}", 149 ) 150 151 @given(block_size=st.sampled_from([16, 32, 64, 128])) 152 @settings(max_examples=100, deadline=60000) 153 def test_different_block_sizes(self, block_size): 154 """ 155 Property: Different block sizes produce same results. 156 """ 157 m, n, k = 256, 256, 256 158 mod = create_matmul_module(m, n, k) 159 160 np.random.seed(42) 161 a_np = np.random.randn(m, k).astype("float32") 162 b_np = np.random.randn(k, n).astype("float32") 163 164 # Reference 165 c_ref = np.matmul(a_np, b_np) 166 167 # TVM with specific block size 168 actual_block = min(block_size, m, n) 169 sch = apply_schedule_primitives(mod, block_size=actual_block) 170 171 c_tvm = run_and_get_output(sch.mod, a_np, b_np) 172 173 np.testing.assert_allclose( 174 c_tvm, 175 c_ref, 176 rtol=1e-5, 177 atol=1e-5, 178 err_msg=f"Block size {block_size} produces incorrect results", 179 ) 180 181 182 class TestSchedulePrimitives: 183 """Unit tests for individual schedule primitives.""" 184 185 def test_split_preserves_computation(self): 186 """Test that split alone preserves computation.""" 187 m, n, k = 128, 128, 128 188 mod = create_matmul_module(m, n, k) 189 190 np.random.seed(42) 191 a_np = np.random.randn(m, k).astype("float32") 192 b_np = np.random.randn(k, n).astype("float32") 193 c_ref = np.matmul(a_np, b_np) 194 195 sch = tir.Schedule(mod) 196 block = sch.get_block("matmul") 197 i, j, k_loop = sch.get_loops(block) 198 sch.split(i, factors=[None, 32]) 199 200 c_tvm = run_and_get_output(sch.mod, a_np, b_np) 201 np.testing.assert_allclose(c_tvm, c_ref, rtol=1e-5) 202 203 def test_reorder_preserves_computation(self): 204 """Test that reorder alone preserves computation.""" 205 m, n, k = 128, 128, 128 206 mod = create_matmul_module(m, n, k) 207 208 np.random.seed(42) 209 a_np = np.random.randn(m, k).astype("float32") 210 b_np = np.random.randn(k, n).astype("float32") 211 c_ref = np.matmul(a_np, b_np) 212 213 sch = tir.Schedule(mod) 214 block = sch.get_block("matmul") 215 i, j, k_loop = sch.get_loops(block) 216 # Reorder to j, i, k 217 sch.reorder(j, i, k_loop) 218 219 c_tvm = run_and_get_output(sch.mod, a_np, b_np) 220 np.testing.assert_allclose(c_tvm, c_ref, rtol=1e-5) 221 222 223 if __name__ == "__main__": 224 pytest.main([__file__, "-v"])