import pytest import torch from sgl_kernel.kvcacheio import ( transfer_kv_all_layer, transfer_kv_all_layer_mla, transfer_kv_per_layer, transfer_kv_per_layer_mla, ) def ref_copy_with_indices(src_pool, dst_pool, src_indices, dst_indices): dst_pool[dst_indices] = src_pool[src_indices].to(dst_pool.device) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("num_items_to_transfer", [1, 128, 1024]) @pytest.mark.parametrize("page_size", [1, 16, 64]) @pytest.mark.parametrize("item_size", [256]) @pytest.mark.parametrize("total_items_in_pool", [10240]) @pytest.mark.parametrize("is_mla", [False, True]) @pytest.mark.parametrize("all_layers", [False, True]) def test_transfer_kv( dtype: torch.dtype, num_items_to_transfer: int, item_size: int, page_size: int, total_items_in_pool: int, is_mla: bool, all_layers: bool, ): """ Tests the per-layer transfer functions, treating tensors as memory pools. """ original_dtype = torch.get_default_dtype() torch.set_default_dtype(dtype) device = "cuda" torch.cuda.manual_seed(42) num_layers = 4 # A small number of layers for pool creation total_pages_in_pool = total_items_in_pool // page_size num_pages_to_transfer = num_items_to_transfer // page_size if num_pages_to_transfer == 0: torch.set_default_dtype(original_dtype) return page_indices = torch.randperm(total_pages_in_pool, dtype=torch.int64) src_indices_host = torch.cat( [ torch.arange(p * page_size, (p + 1) * page_size) for p in page_indices[:num_pages_to_transfer] ] ) src_indices_device = src_indices_host.to(device) dst_indices_host = torch.cat( [ torch.arange(p * page_size, (p + 1) * page_size) for p in page_indices[num_pages_to_transfer : 2 * num_pages_to_transfer] ] ) dst_indices_device = dst_indices_host.to(device) # Prepare memory pools based on whether it's an MLA case. if is_mla: src_pool_host = torch.randn( num_layers, total_items_in_pool, item_size ).pin_memory() dst_pool_ref = torch.zeros_like(src_pool_host).to(device) dst_pool_kernel = torch.zeros_like(dst_pool_ref) dst_pool_direct = torch.zeros_like(dst_pool_ref) else: src_k_pool = torch.randn( num_layers, total_items_in_pool, item_size ).pin_memory() src_v_pool = torch.randn( num_layers, total_items_in_pool, item_size ).pin_memory() dst_k_pool_ref = torch.zeros_like(src_k_pool).to(device) dst_v_pool_ref = torch.zeros_like(src_v_pool).to(device) dst_k_pool_kernel = torch.zeros_like(dst_k_pool_ref) dst_v_pool_kernel = torch.zeros_like(dst_v_pool_ref) dst_k_pool_direct = torch.zeros_like(dst_k_pool_ref) dst_v_pool_direct = torch.zeros_like(dst_v_pool_ref) torch.cuda.synchronize() # We will test the per-layer function on the first layer (index 0) of the pool. layer_idx_to_test = 0 if is_mla: if not all_layers: ref_copy_with_indices( src_pool_host[layer_idx_to_test], dst_pool_ref[layer_idx_to_test], src_indices_host, dst_indices_device, ) transfer_kv_per_layer_mla( src_pool_host[layer_idx_to_test], dst_pool_kernel[layer_idx_to_test], src_indices_device, dst_indices_device, io_backend="kernel", page_size=page_size, item_size=item_size, ) transfer_kv_per_layer_mla( src_pool_host[layer_idx_to_test], dst_pool_direct[layer_idx_to_test], src_indices_host, dst_indices_device, io_backend="direct", page_size=page_size, item_size=item_size, ) else: for layer_id in range(num_layers): ref_copy_with_indices( src_pool_host[layer_id], dst_pool_ref[layer_id], src_indices_host, dst_indices_device, ) transfer_kv_all_layer_mla( src_pool_host, dst_pool_kernel, src_indices_device, dst_indices_device, io_backend="kernel", page_size=page_size, item_size=item_size, num_layers=num_layers, src_layer_offset=total_items_in_pool * item_size, dst_layer_offset=total_items_in_pool * item_size, ) transfer_kv_all_layer_mla( src_pool_host, dst_pool_direct, src_indices_host, dst_indices_device, io_backend="direct", page_size=page_size, item_size=item_size, num_layers=num_layers, src_layer_offset=total_items_in_pool * item_size, dst_layer_offset=total_items_in_pool * item_size, ) torch.cuda.synchronize() torch.testing.assert_close(dst_pool_kernel, dst_pool_ref) torch.testing.assert_close(dst_pool_direct, dst_pool_ref) else: if not all_layers: ref_copy_with_indices( src_k_pool[layer_idx_to_test], dst_k_pool_ref[layer_idx_to_test], src_indices_host, dst_indices_device, ) ref_copy_with_indices( src_v_pool[layer_idx_to_test], dst_v_pool_ref[layer_idx_to_test], src_indices_host, dst_indices_device, ) transfer_kv_per_layer( src_k_pool[layer_idx_to_test], dst_k_pool_kernel[layer_idx_to_test], src_v_pool[layer_idx_to_test], dst_v_pool_kernel[layer_idx_to_test], src_indices_device, dst_indices_device, io_backend="kernel", page_size=page_size, item_size=item_size, ) transfer_kv_per_layer( src_k_pool[layer_idx_to_test], dst_k_pool_direct[layer_idx_to_test], src_v_pool[layer_idx_to_test], dst_v_pool_direct[layer_idx_to_test], src_indices_host, dst_indices_device, io_backend="direct", page_size=page_size, item_size=item_size, ) else: for layer_id in range(num_layers): ref_copy_with_indices( src_k_pool[layer_id], dst_k_pool_ref[layer_id], src_indices_host, dst_indices_device, ) ref_copy_with_indices( src_v_pool[layer_id], dst_v_pool_ref[layer_id], src_indices_host, dst_indices_device, ) transfer_kv_all_layer( src_k_pool, dst_k_pool_kernel, src_v_pool, dst_v_pool_kernel, src_indices_device, dst_indices_device, io_backend="kernel", page_size=page_size, item_size=item_size, num_layers=num_layers, src_layer_offset=total_items_in_pool * item_size, dst_layer_offset=total_items_in_pool * item_size, ) transfer_kv_all_layer( src_k_pool, dst_k_pool_direct, src_v_pool, dst_v_pool_direct, src_indices_host, dst_indices_device, io_backend="direct", page_size=page_size, item_size=item_size, num_layers=num_layers, src_layer_offset=total_items_in_pool * item_size, dst_layer_offset=total_items_in_pool * item_size, ) torch.cuda.synchronize() torch.testing.assert_close(dst_k_pool_kernel, dst_k_pool_ref) torch.testing.assert_close(dst_v_pool_kernel, dst_v_pool_ref) torch.testing.assert_close(dst_k_pool_direct, dst_k_pool_ref) torch.testing.assert_close(dst_v_pool_direct, dst_v_pool_ref) torch.set_default_dtype(original_dtype) if __name__ == "__main__": pytest.main([__file__])