from typing import List import torch def is_hip() -> bool: return torch.version.hip is not None _is_hip = is_hip() def transfer_kv_per_layer( src_k: torch.Tensor, dst_k: torch.Tensor, src_v: torch.Tensor, dst_v: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, item_size: int, block_quota: int = 2, num_warps_per_block: int = 16 if _is_hip else 32, ): torch.ops.sgl_kernel.transfer_kv_per_layer( src_k, dst_k, src_v, dst_v, src_indices, dst_indices, item_size, block_quota, num_warps_per_block, ) def transfer_kv_per_layer_pf_lf( src_k: torch.Tensor, dst_k: torch.Tensor, src_v: torch.Tensor, dst_v: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, layer_id: int, item_size: int, src_layout_dim: int, block_quota: int = 2, num_warps_per_block: int = 16 if _is_hip else 32, ): torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf( src_k, dst_k, src_v, dst_v, src_indices, dst_indices, layer_id, item_size, src_layout_dim, block_quota, num_warps_per_block, ) def transfer_kv_all_layer( src_k_layers: torch.Tensor, dst_k_layers: torch.Tensor, src_v_layers: torch.Tensor, dst_v_layers: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, item_size: int, num_layers: int, block_quota: int = 2, num_warps_per_block: int = 16 if _is_hip else 32, ): torch.ops.sgl_kernel.transfer_kv_all_layer( src_k_layers, dst_k_layers, src_v_layers, dst_v_layers, src_indices, dst_indices, item_size, num_layers, block_quota, num_warps_per_block, ) def transfer_kv_all_layer_lf_pf( src_k_layers: torch.Tensor, dst_k: torch.Tensor, src_v_layers: torch.Tensor, dst_v: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, item_size: int, dst_layout_dim: int, num_layers: int, block_quota: int = 2, num_warps_per_block: int = 16 if _is_hip else 32, ): torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf( src_k_layers, dst_k, src_v_layers, dst_v, src_indices, dst_indices, item_size, dst_layout_dim, num_layers, block_quota, num_warps_per_block, ) def transfer_kv_direct( src_layers: List[torch.Tensor], dst_layers: List[torch.Tensor], src_indices: torch.Tensor, dst_indices: torch.Tensor, page_size: int, ): torch.ops.sgl_kernel.transfer_kv_direct( src_layers, dst_layers, src_indices, dst_indices, page_size ) def transfer_kv_per_layer_direct_pf_lf( src_ptrs: List[torch.Tensor], dst_ptrs: List[torch.Tensor], src_indices: torch.Tensor, dst_indices: torch.Tensor, layer_id: int, page_size: int, ): torch.ops.sgl_kernel.transfer_kv_per_layer_direct_pf_lf( src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size ) def transfer_kv_all_layer_direct_lf_pf( src_ptrs: List[torch.Tensor], dst_ptrs: List[torch.Tensor], src_indices: torch.Tensor, dst_indices: torch.Tensor, page_size: int, ): torch.ops.sgl_kernel.transfer_kv_all_layer_direct_lf_pf( src_ptrs, dst_ptrs, src_indices, dst_indices, page_size ) def transfer_kv_per_layer_mla( src: torch.Tensor, dst: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, item_size: int, block_quota: int = 2, num_warps_per_block: int = 16 if _is_hip else 32, ): torch.ops.sgl_kernel.transfer_kv_per_layer_mla( src, dst, src_indices, dst_indices, item_size, block_quota, num_warps_per_block, ) def transfer_kv_per_layer_mla_pf_lf( src: torch.Tensor, dst: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, layer_id: int, item_size: int, src_layout_dim: int, block_quota: int = 2, num_warps_per_block: int = 16 if _is_hip else 32, ): torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf( src, dst, src_indices, dst_indices, layer_id, item_size, src_layout_dim, block_quota, num_warps_per_block, ) def transfer_kv_all_layer_mla( src_layers: torch.Tensor, dst_layers: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, item_size: int, num_layers: int, block_quota: int = 2, num_warps_per_block: int = 16 if _is_hip else 32, ): torch.ops.sgl_kernel.transfer_kv_all_layer_mla( src_layers, dst_layers, src_indices, dst_indices, item_size, num_layers, block_quota, num_warps_per_block, ) def transfer_kv_all_layer_mla_lf_pf( src_layers: torch.Tensor, dst: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, item_size: int, dst_layout_dim: int, num_layers: int, block_quota: int = 2, num_warps_per_block: int = 16 if _is_hip else 32, ): torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf( src_layers, dst, src_indices, dst_indices, item_size, dst_layout_dim, num_layers, block_quota, num_warps_per_block, )