sglang_v0.5.2/sglang/sgl-kernel/python/sgl_kernel/kvcacheio.py

244 lines
5.4 KiB
Python

from typing import List
import torch
def is_hip() -> bool:
return torch.version.hip is not None
_is_hip = is_hip()
def transfer_kv_per_layer(
src_k: torch.Tensor,
dst_k: torch.Tensor,
src_v: torch.Tensor,
dst_v: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer(
src_k,
dst_k,
src_v,
dst_v,
src_indices,
dst_indices,
item_size,
block_quota,
num_warps_per_block,
)
def transfer_kv_per_layer_pf_lf(
src_k: torch.Tensor,
dst_k: torch.Tensor,
src_v: torch.Tensor,
dst_v: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
layer_id: int,
item_size: int,
src_layout_dim: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf(
src_k,
dst_k,
src_v,
dst_v,
src_indices,
dst_indices,
layer_id,
item_size,
src_layout_dim,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer(
src_k_layers: torch.Tensor,
dst_k_layers: torch.Tensor,
src_v_layers: torch.Tensor,
dst_v_layers: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer(
src_k_layers,
dst_k_layers,
src_v_layers,
dst_v_layers,
src_indices,
dst_indices,
item_size,
num_layers,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer_lf_pf(
src_k_layers: torch.Tensor,
dst_k: torch.Tensor,
src_v_layers: torch.Tensor,
dst_v: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
dst_layout_dim: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf(
src_k_layers,
dst_k,
src_v_layers,
dst_v,
src_indices,
dst_indices,
item_size,
dst_layout_dim,
num_layers,
block_quota,
num_warps_per_block,
)
def transfer_kv_direct(
src_layers: List[torch.Tensor],
dst_layers: List[torch.Tensor],
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
page_size: int,
):
torch.ops.sgl_kernel.transfer_kv_direct(
src_layers, dst_layers, src_indices, dst_indices, page_size
)
def transfer_kv_per_layer_direct_pf_lf(
src_ptrs: List[torch.Tensor],
dst_ptrs: List[torch.Tensor],
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
layer_id: int,
page_size: int,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_direct_pf_lf(
src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size
)
def transfer_kv_all_layer_direct_lf_pf(
src_ptrs: List[torch.Tensor],
dst_ptrs: List[torch.Tensor],
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
page_size: int,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_direct_lf_pf(
src_ptrs, dst_ptrs, src_indices, dst_indices, page_size
)
def transfer_kv_per_layer_mla(
src: torch.Tensor,
dst: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
src,
dst,
src_indices,
dst_indices,
item_size,
block_quota,
num_warps_per_block,
)
def transfer_kv_per_layer_mla_pf_lf(
src: torch.Tensor,
dst: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
layer_id: int,
item_size: int,
src_layout_dim: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf(
src,
dst,
src_indices,
dst_indices,
layer_id,
item_size,
src_layout_dim,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer_mla(
src_layers: torch.Tensor,
dst_layers: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
src_layers,
dst_layers,
src_indices,
dst_indices,
item_size,
num_layers,
block_quota,
num_warps_per_block,
)
def transfer_kv_all_layer_mla_lf_pf(
src_layers: torch.Tensor,
dst: torch.Tensor,
src_indices: torch.Tensor,
dst_indices: torch.Tensor,
item_size: int,
dst_layout_dim: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf(
src_layers,
dst,
src_indices,
dst_indices,
item_size,
dst_layout_dim,
num_layers,
block_quota,
num_warps_per_block,
)