sglang.0.4.8.post1/sglang/sgl-kernel/tests/test_kvcacheio.py

240 lines
8.5 KiB
Python

import pytest
import torch
from sgl_kernel.kvcacheio import (
transfer_kv_all_layer,
transfer_kv_all_layer_mla,
transfer_kv_per_layer,
transfer_kv_per_layer_mla,
)
def ref_copy_with_indices(src_pool, dst_pool, src_indices, dst_indices):
dst_pool[dst_indices] = src_pool[src_indices].to(dst_pool.device)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("num_items_to_transfer", [1, 128, 1024])
@pytest.mark.parametrize("page_size", [1, 16, 64])
@pytest.mark.parametrize("item_size", [256])
@pytest.mark.parametrize("total_items_in_pool", [10240])
@pytest.mark.parametrize("is_mla", [False, True])
@pytest.mark.parametrize("all_layers", [False, True])
def test_transfer_kv(
dtype: torch.dtype,
num_items_to_transfer: int,
item_size: int,
page_size: int,
total_items_in_pool: int,
is_mla: bool,
all_layers: bool,
):
"""
Tests the per-layer transfer functions, treating tensors as memory pools.
"""
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
device = "cuda"
torch.cuda.manual_seed(42)
num_layers = 4 # A small number of layers for pool creation
total_pages_in_pool = total_items_in_pool // page_size
num_pages_to_transfer = num_items_to_transfer // page_size
if num_pages_to_transfer == 0:
torch.set_default_dtype(original_dtype)
return
page_indices = torch.randperm(total_pages_in_pool, dtype=torch.int64)
src_indices_host = torch.cat(
[
torch.arange(p * page_size, (p + 1) * page_size)
for p in page_indices[:num_pages_to_transfer]
]
)
src_indices_device = src_indices_host.to(device)
dst_indices_host = torch.cat(
[
torch.arange(p * page_size, (p + 1) * page_size)
for p in page_indices[num_pages_to_transfer : 2 * num_pages_to_transfer]
]
)
dst_indices_device = dst_indices_host.to(device)
# Prepare memory pools based on whether it's an MLA case.
if is_mla:
src_pool_host = torch.randn(
num_layers, total_items_in_pool, item_size
).pin_memory()
dst_pool_ref = torch.zeros_like(src_pool_host).to(device)
dst_pool_kernel = torch.zeros_like(dst_pool_ref)
dst_pool_direct = torch.zeros_like(dst_pool_ref)
else:
src_k_pool = torch.randn(
num_layers, total_items_in_pool, item_size
).pin_memory()
src_v_pool = torch.randn(
num_layers, total_items_in_pool, item_size
).pin_memory()
dst_k_pool_ref = torch.zeros_like(src_k_pool).to(device)
dst_v_pool_ref = torch.zeros_like(src_v_pool).to(device)
dst_k_pool_kernel = torch.zeros_like(dst_k_pool_ref)
dst_v_pool_kernel = torch.zeros_like(dst_v_pool_ref)
dst_k_pool_direct = torch.zeros_like(dst_k_pool_ref)
dst_v_pool_direct = torch.zeros_like(dst_v_pool_ref)
torch.cuda.synchronize()
# We will test the per-layer function on the first layer (index 0) of the pool.
layer_idx_to_test = 0
if is_mla:
if not all_layers:
ref_copy_with_indices(
src_pool_host[layer_idx_to_test],
dst_pool_ref[layer_idx_to_test],
src_indices_host,
dst_indices_device,
)
transfer_kv_per_layer_mla(
src_pool_host[layer_idx_to_test],
dst_pool_kernel[layer_idx_to_test],
src_indices_device,
dst_indices_device,
io_backend="kernel",
page_size=page_size,
item_size=item_size,
)
transfer_kv_per_layer_mla(
src_pool_host[layer_idx_to_test],
dst_pool_direct[layer_idx_to_test],
src_indices_host,
dst_indices_device,
io_backend="direct",
page_size=page_size,
item_size=item_size,
)
else:
for layer_id in range(num_layers):
ref_copy_with_indices(
src_pool_host[layer_id],
dst_pool_ref[layer_id],
src_indices_host,
dst_indices_device,
)
transfer_kv_all_layer_mla(
src_pool_host,
dst_pool_kernel,
src_indices_device,
dst_indices_device,
io_backend="kernel",
page_size=page_size,
item_size=item_size,
num_layers=num_layers,
src_layer_offset=total_items_in_pool * item_size,
dst_layer_offset=total_items_in_pool * item_size,
)
transfer_kv_all_layer_mla(
src_pool_host,
dst_pool_direct,
src_indices_host,
dst_indices_device,
io_backend="direct",
page_size=page_size,
item_size=item_size,
num_layers=num_layers,
src_layer_offset=total_items_in_pool * item_size,
dst_layer_offset=total_items_in_pool * item_size,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_pool_kernel, dst_pool_ref)
torch.testing.assert_close(dst_pool_direct, dst_pool_ref)
else:
if not all_layers:
ref_copy_with_indices(
src_k_pool[layer_idx_to_test],
dst_k_pool_ref[layer_idx_to_test],
src_indices_host,
dst_indices_device,
)
ref_copy_with_indices(
src_v_pool[layer_idx_to_test],
dst_v_pool_ref[layer_idx_to_test],
src_indices_host,
dst_indices_device,
)
transfer_kv_per_layer(
src_k_pool[layer_idx_to_test],
dst_k_pool_kernel[layer_idx_to_test],
src_v_pool[layer_idx_to_test],
dst_v_pool_kernel[layer_idx_to_test],
src_indices_device,
dst_indices_device,
io_backend="kernel",
page_size=page_size,
item_size=item_size,
)
transfer_kv_per_layer(
src_k_pool[layer_idx_to_test],
dst_k_pool_direct[layer_idx_to_test],
src_v_pool[layer_idx_to_test],
dst_v_pool_direct[layer_idx_to_test],
src_indices_host,
dst_indices_device,
io_backend="direct",
page_size=page_size,
item_size=item_size,
)
else:
for layer_id in range(num_layers):
ref_copy_with_indices(
src_k_pool[layer_id],
dst_k_pool_ref[layer_id],
src_indices_host,
dst_indices_device,
)
ref_copy_with_indices(
src_v_pool[layer_id],
dst_v_pool_ref[layer_id],
src_indices_host,
dst_indices_device,
)
transfer_kv_all_layer(
src_k_pool,
dst_k_pool_kernel,
src_v_pool,
dst_v_pool_kernel,
src_indices_device,
dst_indices_device,
io_backend="kernel",
page_size=page_size,
item_size=item_size,
num_layers=num_layers,
src_layer_offset=total_items_in_pool * item_size,
dst_layer_offset=total_items_in_pool * item_size,
)
transfer_kv_all_layer(
src_k_pool,
dst_k_pool_direct,
src_v_pool,
dst_v_pool_direct,
src_indices_host,
dst_indices_device,
io_backend="direct",
page_size=page_size,
item_size=item_size,
num_layers=num_layers,
src_layer_offset=total_items_in_pool * item_size,
dst_layer_offset=total_items_in_pool * item_size,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_k_pool_kernel, dst_k_pool_ref)
torch.testing.assert_close(dst_v_pool_kernel, dst_v_pool_ref)
torch.testing.assert_close(dst_k_pool_direct, dst_k_pool_ref)
torch.testing.assert_close(dst_v_pool_direct, dst_v_pool_ref)
torch.set_default_dtype(original_dtype)
if __name__ == "__main__":
pytest.main([__file__])