/ 04_CuTile_NextGen_CUDA / tests / test_tile_gemm.py
test_tile_gemm.py
  1  """
  2  Property-based tests for cuTile GEMM conceptual implementation.
  3  
  4  Feature: ai-system-optimization-series, Property 3: cuTile GEMM Correctness
  5  Validates: Requirements 5.1, 5.2
  6  
  7  For any input matrices of compatible dimensions, the cuTile GEMM
  8  implementation SHALL produce outputs matching numpy matrix multiplication
  9  within floating-point tolerance.
 10  """
 11  
 12  import numpy as np
 13  import pytest
 14  from cutile_cuda.tile_gemm import (
 15      TileArray,
 16      TileShape,
 17      tiled_gemm,
 18      tiled_gemm_with_epilogue,
 19  )
 20  from hypothesis import given, settings
 21  from hypothesis import strategies as st
 22  
 23  # ---------------------------------------------------------------------------
 24  # Strategies
 25  # ---------------------------------------------------------------------------
 26  
 27  dim_strategy = st.integers(min_value=1, max_value=256)
 28  
 29  
 30  class TestTiledGemmCorrectness:
 31      """
 32      Property 3: cuTile GEMM output matches numpy reference.
 33      """
 34  
 35      @given(
 36          m=dim_strategy,
 37          n=dim_strategy,
 38          k=dim_strategy,
 39      )
 40      @settings(max_examples=100, deadline=10000)
 41      def test_tiled_gemm_matches_numpy(self, m, n, k):
 42          """Output of tiled_gemm must equal numpy matmul within tolerance."""
 43          np.random.seed(42)
 44          a_np = np.random.randn(m, k).astype(np.float32)
 45          b_np = np.random.randn(k, n).astype(np.float32)
 46  
 47          A = TileArray(a_np)
 48          B = TileArray(b_np)
 49          C = tiled_gemm(A, B)
 50  
 51          c_ref = a_np @ b_np
 52          np.testing.assert_allclose(
 53              C.data,
 54              c_ref,
 55              rtol=1e-5,
 56              atol=1e-5,
 57              err_msg=f"tiled_gemm mismatch for M={m}, N={n}, K={k}",
 58          )
 59  
 60      @given(
 61          m=dim_strategy,
 62          n=dim_strategy,
 63          k=dim_strategy,
 64          alpha=st.floats(min_value=-10.0, max_value=10.0, allow_nan=False, allow_infinity=False),
 65          beta=st.floats(min_value=-10.0, max_value=10.0, allow_nan=False, allow_infinity=False),
 66      )
 67      @settings(max_examples=100, deadline=10000)
 68      def test_tiled_gemm_with_epilogue(self, m, n, k, alpha, beta):
 69          """D = alpha * A @ B + beta * C must match numpy reference."""
 70          np.random.seed(42)
 71          a_np = np.random.randn(m, k).astype(np.float32)
 72          b_np = np.random.randn(k, n).astype(np.float32)
 73          c_np = np.random.randn(m, n).astype(np.float32)
 74  
 75          A = TileArray(a_np)
 76          B = TileArray(b_np)
 77          C = TileArray(c_np)
 78  
 79          D = tiled_gemm_with_epilogue(A, B, C, alpha=alpha, beta=beta)
 80  
 81          d_ref = alpha * (a_np @ b_np) + beta * c_np
 82          np.testing.assert_allclose(
 83              D.data,
 84              d_ref,
 85              rtol=1e-4,
 86              atol=1e-4,
 87              err_msg=f"epilogue mismatch for M={m}, N={n}, K={k}, alpha={alpha}, beta={beta}",
 88          )
 89  
 90  
 91  class TestTileArray:
 92      """Unit tests for TileArray operations."""
 93  
 94      def test_tile_array_creation(self):
 95          data = np.ones((4, 4), dtype=np.float32)
 96          ta = TileArray(data, TileShape(2, 2))
 97          assert ta.shape == (4, 4)
 98          assert ta.tile_shape.m == 2
 99  
100      def test_tile_array_add(self):
101          a = TileArray(np.array([[1.0, 2.0]], dtype=np.float32))
102          b = TileArray(np.array([[3.0, 4.0]], dtype=np.float32))
103          c = a + b
104          np.testing.assert_array_equal(c.data, [[4.0, 6.0]])
105  
106      def test_tile_array_mul(self):
107          a = TileArray(np.array([[2.0, 3.0]], dtype=np.float32))
108          c = a * 2
109          np.testing.assert_array_equal(c.data, [[4.0, 6.0]])
110  
111      def test_tile_array_matmul(self):
112          a = TileArray(np.eye(3, dtype=np.float32))
113          b = TileArray(np.ones((3, 2), dtype=np.float32))
114          c = a @ b
115          np.testing.assert_array_equal(c.data, np.ones((3, 2)))
116  
117      def test_tile_kernel_decorator_preserves_metadata(self):
118          assert tiled_gemm.__name__ == "tiled_gemm"
119          assert tiled_gemm.__doc__ is not None
120          assert hasattr(tiled_gemm, "_is_tile_kernel")
121          assert tiled_gemm._is_tile_kernel is True
122  
123  
124  if __name__ == "__main__":
125      pytest.main([__file__, "-v"])