test_tile_gemm.py
1 """ 2 Property-based tests for cuTile GEMM conceptual implementation. 3 4 Feature: ai-system-optimization-series, Property 3: cuTile GEMM Correctness 5 Validates: Requirements 5.1, 5.2 6 7 For any input matrices of compatible dimensions, the cuTile GEMM 8 implementation SHALL produce outputs matching numpy matrix multiplication 9 within floating-point tolerance. 10 """ 11 12 import numpy as np 13 import pytest 14 from cutile_cuda.tile_gemm import ( 15 TileArray, 16 TileShape, 17 tiled_gemm, 18 tiled_gemm_with_epilogue, 19 ) 20 from hypothesis import given, settings 21 from hypothesis import strategies as st 22 23 # --------------------------------------------------------------------------- 24 # Strategies 25 # --------------------------------------------------------------------------- 26 27 dim_strategy = st.integers(min_value=1, max_value=256) 28 29 30 class TestTiledGemmCorrectness: 31 """ 32 Property 3: cuTile GEMM output matches numpy reference. 33 """ 34 35 @given( 36 m=dim_strategy, 37 n=dim_strategy, 38 k=dim_strategy, 39 ) 40 @settings(max_examples=100, deadline=10000) 41 def test_tiled_gemm_matches_numpy(self, m, n, k): 42 """Output of tiled_gemm must equal numpy matmul within tolerance.""" 43 np.random.seed(42) 44 a_np = np.random.randn(m, k).astype(np.float32) 45 b_np = np.random.randn(k, n).astype(np.float32) 46 47 A = TileArray(a_np) 48 B = TileArray(b_np) 49 C = tiled_gemm(A, B) 50 51 c_ref = a_np @ b_np 52 np.testing.assert_allclose( 53 C.data, 54 c_ref, 55 rtol=1e-5, 56 atol=1e-5, 57 err_msg=f"tiled_gemm mismatch for M={m}, N={n}, K={k}", 58 ) 59 60 @given( 61 m=dim_strategy, 62 n=dim_strategy, 63 k=dim_strategy, 64 alpha=st.floats(min_value=-10.0, max_value=10.0, allow_nan=False, allow_infinity=False), 65 beta=st.floats(min_value=-10.0, max_value=10.0, allow_nan=False, allow_infinity=False), 66 ) 67 @settings(max_examples=100, deadline=10000) 68 def test_tiled_gemm_with_epilogue(self, m, n, k, alpha, beta): 69 """D = alpha * A @ B + beta * C must match numpy reference.""" 70 np.random.seed(42) 71 a_np = np.random.randn(m, k).astype(np.float32) 72 b_np = np.random.randn(k, n).astype(np.float32) 73 c_np = np.random.randn(m, n).astype(np.float32) 74 75 A = TileArray(a_np) 76 B = TileArray(b_np) 77 C = TileArray(c_np) 78 79 D = tiled_gemm_with_epilogue(A, B, C, alpha=alpha, beta=beta) 80 81 d_ref = alpha * (a_np @ b_np) + beta * c_np 82 np.testing.assert_allclose( 83 D.data, 84 d_ref, 85 rtol=1e-4, 86 atol=1e-4, 87 err_msg=f"epilogue mismatch for M={m}, N={n}, K={k}, alpha={alpha}, beta={beta}", 88 ) 89 90 91 class TestTileArray: 92 """Unit tests for TileArray operations.""" 93 94 def test_tile_array_creation(self): 95 data = np.ones((4, 4), dtype=np.float32) 96 ta = TileArray(data, TileShape(2, 2)) 97 assert ta.shape == (4, 4) 98 assert ta.tile_shape.m == 2 99 100 def test_tile_array_add(self): 101 a = TileArray(np.array([[1.0, 2.0]], dtype=np.float32)) 102 b = TileArray(np.array([[3.0, 4.0]], dtype=np.float32)) 103 c = a + b 104 np.testing.assert_array_equal(c.data, [[4.0, 6.0]]) 105 106 def test_tile_array_mul(self): 107 a = TileArray(np.array([[2.0, 3.0]], dtype=np.float32)) 108 c = a * 2 109 np.testing.assert_array_equal(c.data, [[4.0, 6.0]]) 110 111 def test_tile_array_matmul(self): 112 a = TileArray(np.eye(3, dtype=np.float32)) 113 b = TileArray(np.ones((3, 2), dtype=np.float32)) 114 c = a @ b 115 np.testing.assert_array_equal(c.data, np.ones((3, 2))) 116 117 def test_tile_kernel_decorator_preserves_metadata(self): 118 assert tiled_gemm.__name__ == "tiled_gemm" 119 assert tiled_gemm.__doc__ is not None 120 assert hasattr(tiled_gemm, "_is_tile_kernel") 121 assert tiled_gemm._is_tile_kernel is True 122 123 124 if __name__ == "__main__": 125 pytest.main([__file__, "-v"])