sglang_v0.5.2/flashinfer_0.3.1/tests/test_green_ctx.py

129 lines
3.7 KiB
Python

import pytest
import torch
import flashinfer.green_ctx as green_ctx
@pytest.mark.parametrize("device", ["cuda:0"])
@pytest.mark.parametrize("num_groups", [1, 2, 3])
@pytest.mark.parametrize("min_count", [16, 32])
def test_green_ctx_creation(
device: str,
num_groups: int,
min_count: int,
):
streams, resources = green_ctx.split_device_green_ctx(
torch.device(device), num_groups, min_count
)
assert len(resources) == num_groups + 1
for resource in resources[:-1]:
sm_count = resource.sm.smCount
assert sm_count >= min_count
@pytest.mark.parametrize("device", ["cuda:0"])
@pytest.mark.parametrize("num_groups", [1, 2, 3])
@pytest.mark.parametrize("min_count", [16, 32])
def test_green_ctx_kernel_execution(
device: str,
num_groups: int,
min_count: int,
):
streams, resources = green_ctx.split_device_green_ctx(
torch.device(device), num_groups, min_count
)
num_partitions = num_groups + 1
assert len(streams) == num_partitions
assert len(resources) == num_partitions
for stream in streams:
with torch.cuda.stream(stream):
x = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16)
y = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16)
z = x @ y
print(z.shape)
@pytest.mark.parametrize("device", ["cuda:0"])
@pytest.mark.parametrize(
"sm_counts",
[
[16, 16, 16],
[8, 16, 24],
[32],
[8, 8, 8, 8],
],
)
def test_split_device_green_ctx_by_sm_count_creation(
device: str,
sm_counts: list,
):
streams, resources = green_ctx.split_device_green_ctx_by_sm_count(
torch.device(device), sm_counts
)
num_partitions = len(sm_counts) + 1
assert len(resources) == num_partitions
assert len(streams) == num_partitions
# Check that each partition has the expected SM count
for i, expected_sm_count in enumerate(sm_counts):
actual_sm_count = resources[i].sm.smCount
assert actual_sm_count >= expected_sm_count
@pytest.mark.parametrize("device", ["cuda:0"])
@pytest.mark.parametrize(
"sm_counts",
[
[16, 16, 16],
[8, 16, 24],
[32],
],
)
def test_split_device_green_ctx_by_sm_count_kernel_execution(
device: str,
sm_counts: list,
):
streams, resources = green_ctx.split_device_green_ctx_by_sm_count(
torch.device(device), sm_counts
)
num_partitions = len(sm_counts) + 1
assert len(streams) == num_partitions
assert len(resources) == num_partitions
for i, stream in enumerate(streams):
with torch.cuda.stream(stream):
x = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16)
y = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16)
z = x @ y
print(f"Partition {i}: {z.shape}")
@pytest.mark.parametrize("device", ["cuda:0"])
@pytest.mark.parametrize(
"sm_counts",
[
[1, 2, 3, 4], # Should be aligned to minimum requirements
[7, 8, 9, 10], # Should be aligned to 8 for compute capability 9+
[15, 16, 17, 18], # Should be aligned to 8 for compute capability 9+
],
)
def test_split_device_green_ctx_by_sm_count_alignment(
device: str,
sm_counts: list,
):
_, resources = green_ctx.split_device_green_ctx_by_sm_count(
torch.device(device), sm_counts
)
for resource in resources[:-1]: # Exclude remaining SMs
sm_count = resource.sm.smCount
assert sm_count > 0
min_sm_count, sm_alignment = green_ctx.get_sm_count_constraint(
*green_ctx.get_compute_capability(torch.device(device))
)
assert sm_count >= min_sm_count
assert sm_count % sm_alignment == 0