inference/sglang/sgl-kernel/tests/test_cublas_grouped_gemm.py

50 lines
1.7 KiB
Python

import unittest
import torch
from sgl_kernel import cublas_grouped_gemm
def torch_grouped_gemm(a_array, b_array, out_dtype):
c_array = []
for a, b in zip(a_array, b_array):
c_array.append(torch.matmul(a, b.t()).to(out_dtype))
return c_array
class TestGroupedGemm(unittest.TestCase):
def _test_accuracy(self, Ms, Ns, Ks, out_dtype):
group_count = len(Ms)
a_array = []
b_array = []
c_array_cublas = []
for i in range(group_count):
M, N, K = Ms[i], Ns[i], Ks[i]
a_array.append(torch.randn((M, K), device="cuda", dtype=out_dtype) * 5)
b_array.append(torch.randn((N, K), device="cuda", dtype=out_dtype) * 5)
c_array_cublas.append(torch.empty((M, N), device="cuda", dtype=out_dtype))
c_array_torch = torch_grouped_gemm(a_array, b_array, out_dtype)
cublas_grouped_gemm(a_array, b_array, c_array_cublas, out_dtype)
for i in range(group_count):
M, N, K = Ms[i], Ns[i], Ks[i]
torch.testing.assert_close(c_array_torch[i], c_array_cublas[i])
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
def test_accuracy(self):
Ms = [1, 16, 32, 256, 1024]
Ns = [2, 16, 128, 256, 4096]
Ks = [3, 16, 32, 512, 8192]
out_dtypes = [torch.float16, torch.bfloat16]
for out_dtype in out_dtypes:
self._test_accuracy(Ms, Ns, Ks, out_dtype)
if __name__ == "__main__":
if torch.cuda.is_available():
cuda_version = tuple(map(int, torch.version.cuda.split(".")))
if cuda_version >= (12, 5):
unittest.main()
else:
print(f"Cuda version {cuda_version} lower than 12.5, not executing tests.")