138 lines
3.5 KiB
Python
138 lines
3.5 KiB
Python
import torch
|
|
|
|
|
|
def transfer_kv_per_layer(
|
|
src_k: torch.Tensor,
|
|
dst_k: torch.Tensor,
|
|
src_v: torch.Tensor,
|
|
dst_v: torch.Tensor,
|
|
src_indices: torch.Tensor,
|
|
dst_indices: torch.Tensor,
|
|
io_backend: str,
|
|
page_size: int,
|
|
item_size: int,
|
|
block_quota: int = 2,
|
|
num_warps_per_block: int = 32,
|
|
):
|
|
if io_backend == "kernel":
|
|
torch.ops.sgl_kernel.transfer_kv_per_layer(
|
|
src_k,
|
|
dst_k,
|
|
src_v,
|
|
dst_v,
|
|
src_indices,
|
|
dst_indices,
|
|
item_size,
|
|
block_quota,
|
|
num_warps_per_block,
|
|
)
|
|
elif io_backend == "direct":
|
|
torch.ops.sgl_kernel.transfer_kv_per_layer_direct(
|
|
src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported io backend")
|
|
|
|
|
|
def transfer_kv_all_layer(
|
|
src_k: torch.Tensor,
|
|
dst_k: torch.Tensor,
|
|
src_v: torch.Tensor,
|
|
dst_v: torch.Tensor,
|
|
src_indices: torch.Tensor,
|
|
dst_indices: torch.Tensor,
|
|
io_backend: str,
|
|
page_size: int,
|
|
item_size: int,
|
|
num_layers: int,
|
|
src_layer_offset: int,
|
|
dst_layer_offset: int,
|
|
block_quota: int = 2,
|
|
num_warps_per_block: int = 32,
|
|
):
|
|
if io_backend == "kernel":
|
|
torch.ops.sgl_kernel.transfer_kv_all_layer(
|
|
src_k,
|
|
dst_k,
|
|
src_v,
|
|
dst_v,
|
|
src_indices,
|
|
dst_indices,
|
|
item_size,
|
|
num_layers,
|
|
src_layer_offset,
|
|
dst_layer_offset,
|
|
block_quota,
|
|
num_warps_per_block,
|
|
)
|
|
elif io_backend == "direct":
|
|
torch.ops.sgl_kernel.transfer_kv_all_layer_direct(
|
|
src_k, dst_k, src_v, dst_v, src_indices, dst_indices, page_size, num_layers
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported io backend")
|
|
|
|
|
|
def transfer_kv_per_layer_mla(
|
|
src: torch.Tensor,
|
|
dst: torch.Tensor,
|
|
src_indices: torch.Tensor,
|
|
dst_indices: torch.Tensor,
|
|
io_backend: str,
|
|
page_size: int,
|
|
item_size: int,
|
|
block_quota: int = 2,
|
|
num_warps_per_block: int = 32,
|
|
):
|
|
if io_backend == "kernel":
|
|
torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
|
|
src,
|
|
dst,
|
|
src_indices,
|
|
dst_indices,
|
|
item_size,
|
|
block_quota,
|
|
num_warps_per_block,
|
|
)
|
|
elif io_backend == "direct":
|
|
torch.ops.sgl_kernel.transfer_kv_per_layer_mla_direct(
|
|
src, dst, src_indices, dst_indices, page_size
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported io backend")
|
|
|
|
|
|
def transfer_kv_all_layer_mla(
|
|
src: torch.Tensor,
|
|
dst: torch.Tensor,
|
|
src_indices: torch.Tensor,
|
|
dst_indices: torch.Tensor,
|
|
io_backend: str,
|
|
page_size: int,
|
|
item_size: int,
|
|
num_layers: int,
|
|
src_layer_offset: int,
|
|
dst_layer_offset: int,
|
|
block_quota: int = 2,
|
|
num_warps_per_block: int = 32,
|
|
):
|
|
if io_backend == "kernel":
|
|
torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
|
|
src,
|
|
dst,
|
|
src_indices,
|
|
dst_indices,
|
|
item_size,
|
|
num_layers,
|
|
src_layer_offset,
|
|
dst_layer_offset,
|
|
block_quota,
|
|
num_warps_per_block,
|
|
)
|
|
elif io_backend == "direct":
|
|
torch.ops.sgl_kernel.transfer_kv_all_layer_mla_direct(
|
|
src, dst, src_indices, dst_indices, page_size, num_layers
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported io backend")
|