93 lines
2.7 KiB
Python
93 lines
2.7 KiB
Python
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
@triton.jit
|
|
def create_flashinfer_kv_indices_triton(
|
|
req_to_token_ptr, # [max_batch, max_context_len]
|
|
req_pool_indices_ptr,
|
|
page_kernel_lens_ptr,
|
|
kv_indptr,
|
|
kv_start_idx,
|
|
kv_indices_ptr,
|
|
req_to_token_ptr_stride: tl.constexpr,
|
|
):
|
|
BLOCK_SIZE: tl.constexpr = 512
|
|
pid = tl.program_id(axis=0)
|
|
|
|
# find the req pool idx, this is for batch to token
|
|
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
|
kv_indices_offset = tl.load(kv_indptr + pid)
|
|
|
|
kv_start = 0
|
|
kv_end = 0
|
|
if kv_start_idx:
|
|
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
|
kv_end = kv_start
|
|
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
|
|
|
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
|
for i in range(num_loop):
|
|
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
|
mask = offset < kv_end - kv_start
|
|
data = tl.load(
|
|
req_to_token_ptr
|
|
+ req_pool_index * req_to_token_ptr_stride
|
|
+ kv_start
|
|
+ offset,
|
|
mask=mask,
|
|
)
|
|
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
|
|
|
|
|
@triton.jit
|
|
def create_flashmla_kv_indices_triton(
|
|
req_to_token_ptr, # [max_batch, max_context_len]
|
|
req_pool_indices_ptr,
|
|
page_kernel_lens_ptr,
|
|
kv_start_idx,
|
|
kv_indices_ptr,
|
|
req_to_token_ptr_stride: tl.constexpr,
|
|
kv_indices_ptr_stride: tl.constexpr,
|
|
):
|
|
PAGED_SIZE: tl.constexpr = 64
|
|
BLOCK_SIZE: tl.constexpr = 4096
|
|
NUM_PAGE_PER_BLOCK: tl.constexpr = 64
|
|
pid = tl.program_id(axis=0)
|
|
|
|
# find the req pool idx, this is for batch to token
|
|
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
|
|
|
kv_start = 0
|
|
kv_end = 0
|
|
if kv_start_idx:
|
|
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
|
kv_end = kv_start
|
|
|
|
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
|
|
|
num_paged = tl.cdiv(kv_end - kv_start, PAGED_SIZE)
|
|
num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
|
|
|
for i in range(num_pages_loop):
|
|
paged_offset = (
|
|
tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
|
|
) * PAGED_SIZE
|
|
paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
|
|
|
|
mask = paged_offset <= num_paged * PAGED_SIZE
|
|
mask_out = paged_offset_out <= num_paged
|
|
|
|
data = tl.load(
|
|
req_to_token_ptr
|
|
+ req_pool_index * req_to_token_ptr_stride
|
|
+ kv_start
|
|
+ paged_offset,
|
|
mask=mask,
|
|
)
|
|
tl.store(
|
|
kv_indices_ptr + pid * kv_indices_ptr_stride + paged_offset_out,
|
|
data // PAGED_SIZE,
|
|
mask=mask_out,
|
|
)
|