129 lines
3.7 KiB
Python
129 lines
3.7 KiB
Python
import pytest
|
|
import torch
|
|
|
|
import flashinfer.green_ctx as green_ctx
|
|
|
|
|
|
@pytest.mark.parametrize("device", ["cuda:0"])
|
|
@pytest.mark.parametrize("num_groups", [1, 2, 3])
|
|
@pytest.mark.parametrize("min_count", [16, 32])
|
|
def test_green_ctx_creation(
|
|
device: str,
|
|
num_groups: int,
|
|
min_count: int,
|
|
):
|
|
streams, resources = green_ctx.split_device_green_ctx(
|
|
torch.device(device), num_groups, min_count
|
|
)
|
|
|
|
assert len(resources) == num_groups + 1
|
|
for resource in resources[:-1]:
|
|
sm_count = resource.sm.smCount
|
|
assert sm_count >= min_count
|
|
|
|
|
|
@pytest.mark.parametrize("device", ["cuda:0"])
|
|
@pytest.mark.parametrize("num_groups", [1, 2, 3])
|
|
@pytest.mark.parametrize("min_count", [16, 32])
|
|
def test_green_ctx_kernel_execution(
|
|
device: str,
|
|
num_groups: int,
|
|
min_count: int,
|
|
):
|
|
streams, resources = green_ctx.split_device_green_ctx(
|
|
torch.device(device), num_groups, min_count
|
|
)
|
|
num_partitions = num_groups + 1
|
|
assert len(streams) == num_partitions
|
|
assert len(resources) == num_partitions
|
|
|
|
for stream in streams:
|
|
with torch.cuda.stream(stream):
|
|
x = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16)
|
|
y = torch.randn(8192, 8192, device=device, dtype=torch.bfloat16)
|
|
z = x @ y
|
|
print(z.shape)
|
|
|
|
|
|
@pytest.mark.parametrize("device", ["cuda:0"])
|
|
@pytest.mark.parametrize(
|
|
"sm_counts",
|
|
[
|
|
[16, 16, 16],
|
|
[8, 16, 24],
|
|
[32],
|
|
[8, 8, 8, 8],
|
|
],
|
|
)
|
|
def test_split_device_green_ctx_by_sm_count_creation(
|
|
device: str,
|
|
sm_counts: list,
|
|
):
|
|
streams, resources = green_ctx.split_device_green_ctx_by_sm_count(
|
|
torch.device(device), sm_counts
|
|
)
|
|
num_partitions = len(sm_counts) + 1
|
|
assert len(resources) == num_partitions
|
|
assert len(streams) == num_partitions
|
|
|
|
# Check that each partition has the expected SM count
|
|
for i, expected_sm_count in enumerate(sm_counts):
|
|
actual_sm_count = resources[i].sm.smCount
|
|
assert actual_sm_count >= expected_sm_count
|
|
|
|
|
|
@pytest.mark.parametrize("device", ["cuda:0"])
|
|
@pytest.mark.parametrize(
|
|
"sm_counts",
|
|
[
|
|
[16, 16, 16],
|
|
[8, 16, 24],
|
|
[32],
|
|
],
|
|
)
|
|
def test_split_device_green_ctx_by_sm_count_kernel_execution(
|
|
device: str,
|
|
sm_counts: list,
|
|
):
|
|
streams, resources = green_ctx.split_device_green_ctx_by_sm_count(
|
|
torch.device(device), sm_counts
|
|
)
|
|
num_partitions = len(sm_counts) + 1
|
|
assert len(streams) == num_partitions
|
|
assert len(resources) == num_partitions
|
|
|
|
for i, stream in enumerate(streams):
|
|
with torch.cuda.stream(stream):
|
|
x = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16)
|
|
y = torch.randn(4096, 4096, device=device, dtype=torch.bfloat16)
|
|
z = x @ y
|
|
print(f"Partition {i}: {z.shape}")
|
|
|
|
|
|
@pytest.mark.parametrize("device", ["cuda:0"])
|
|
@pytest.mark.parametrize(
|
|
"sm_counts",
|
|
[
|
|
[1, 2, 3, 4], # Should be aligned to minimum requirements
|
|
[7, 8, 9, 10], # Should be aligned to 8 for compute capability 9+
|
|
[15, 16, 17, 18], # Should be aligned to 8 for compute capability 9+
|
|
],
|
|
)
|
|
def test_split_device_green_ctx_by_sm_count_alignment(
|
|
device: str,
|
|
sm_counts: list,
|
|
):
|
|
_, resources = green_ctx.split_device_green_ctx_by_sm_count(
|
|
torch.device(device), sm_counts
|
|
)
|
|
|
|
for resource in resources[:-1]: # Exclude remaining SMs
|
|
sm_count = resource.sm.smCount
|
|
assert sm_count > 0
|
|
|
|
min_sm_count, sm_alignment = green_ctx.get_sm_count_constraint(
|
|
*green_ctx.get_compute_capability(torch.device(device))
|
|
)
|
|
assert sm_count >= min_sm_count
|
|
assert sm_count % sm_alignment == 0
|