124 lines
4.5 KiB
Python
124 lines
4.5 KiB
Python
"""
|
|
Copyright (c) 2024 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import flashinfer
|
|
from flashinfer.utils import determine_gemm_backend, is_sm90a_supported
|
|
|
|
DTYPES = [torch.float16]
|
|
CUDA_DEVICES = ["cuda:0"]
|
|
|
|
|
|
@pytest.fixture(autouse=True, scope="module")
|
|
def warmup_jit():
|
|
jit_specs = [flashinfer.gemm.gen_gemm_module()]
|
|
if is_sm90a_supported(torch.device("cuda:0")):
|
|
jit_specs.append(flashinfer.gemm.gen_gemm_sm90_module())
|
|
flashinfer.jit.build_jit_specs(jit_specs, verbose=False)
|
|
yield
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 77, 199])
|
|
@pytest.mark.parametrize("num_rows_per_batch", [3, 10, 99])
|
|
@pytest.mark.parametrize("d_in", [128, 1024, 4096])
|
|
@pytest.mark.parametrize("d_out", [128, 1024, 4096])
|
|
@pytest.mark.parametrize("use_weight_indices", [False, True])
|
|
@pytest.mark.parametrize("column_major", [False, True])
|
|
@pytest.mark.parametrize("dtype", DTYPES)
|
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
|
@pytest.mark.parametrize("backend", ["sm90", "sm80"])
|
|
def test_segment_gemm(
|
|
batch_size,
|
|
num_rows_per_batch,
|
|
d_in,
|
|
d_out,
|
|
use_weight_indices,
|
|
column_major,
|
|
dtype,
|
|
device,
|
|
backend,
|
|
):
|
|
torch.manual_seed(42)
|
|
if batch_size * num_rows_per_batch > 8192:
|
|
pytest.skip("batch_size * num_rows_per_batch too large for test.")
|
|
latest_supported_backend = determine_gemm_backend(torch.device(device))
|
|
if backend == "sm90" and latest_supported_backend == "sm80":
|
|
pytest.skip("sm90 backend not supported on this device.")
|
|
torch.manual_seed(42)
|
|
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device=device)
|
|
segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer, backend=backend)
|
|
x = torch.randn(batch_size * num_rows_per_batch, d_in, dtype=dtype, device=device)
|
|
if use_weight_indices:
|
|
num_weights = 1024
|
|
if column_major:
|
|
weight = torch.randn(num_weights, d_out, d_in, dtype=dtype, device=device)
|
|
else:
|
|
weight = torch.randn(num_weights, d_in, d_out, dtype=dtype, device=device)
|
|
else:
|
|
if column_major:
|
|
weight = torch.randn(batch_size, d_out, d_in, dtype=dtype, device=device)
|
|
else:
|
|
weight = torch.randn(batch_size, d_in, d_out, dtype=dtype, device=device)
|
|
y = segment_gemm.run(
|
|
x,
|
|
weight,
|
|
batch_size,
|
|
weight_column_major=column_major,
|
|
seg_lens=torch.full((batch_size,), num_rows_per_batch, dtype=torch.int64),
|
|
weight_indices=(
|
|
(torch.arange(0, batch_size, device=device) % num_weights)
|
|
if use_weight_indices
|
|
else None
|
|
),
|
|
)
|
|
|
|
if use_weight_indices:
|
|
for i in range(batch_size):
|
|
torch.testing.assert_close(
|
|
y[i * num_rows_per_batch : (i + 1) * num_rows_per_batch],
|
|
torch.matmul(
|
|
x[i * num_rows_per_batch : (i + 1) * num_rows_per_batch].float(),
|
|
(
|
|
weight[i % num_weights].float().T
|
|
if column_major
|
|
else weight[i % num_weights].float()
|
|
),
|
|
).to(dtype),
|
|
rtol=1e-3,
|
|
atol=1e-3,
|
|
)
|
|
else:
|
|
torch.testing.assert_close(
|
|
y,
|
|
torch.matmul(
|
|
x.view(batch_size, num_rows_per_batch, d_in).float(),
|
|
weight.float().transpose(-1, -2) if column_major else weight.float(),
|
|
)
|
|
.view(batch_size * num_rows_per_batch, d_out)
|
|
.to(dtype),
|
|
rtol=1e-3,
|
|
atol=2e-3,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_segment_gemm(199, 17, 128, 1024, False, False, torch.float16, "cuda:0", "auto")
|
|
test_segment_gemm(199, 17, 128, 1024, False, True, torch.float16, "cuda:0", "auto")
|
|
test_segment_gemm(199, 17, 128, 1024, True, False, torch.float16, "cuda:0", "auto")
|
|
test_segment_gemm(199, 17, 128, 1024, True, True, torch.float16, "cuda:0", "auto")
|