sglang0.4.5.post1/python/sglang/srt/mem_cache/paged_allocator.py

284 lines
8.4 KiB
Python

"""
Copyright 2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
Page-aligned memory pool.
"""
import torch
import triton
import triton.language as tl
from sglang.srt.mem_cache.memory_pool import KVCache
from sglang.srt.utils import get_bool_env_var, next_power_of_2
@triton.jit
def alloc_extend_kernel(
pre_lens_ptr,
seq_lens_ptr,
last_loc_ptr,
free_page_ptr,
out_indices,
ret_values,
bs_upper: tl.constexpr,
page_size: tl.constexpr,
max_num_extend_tokens: tl.constexpr,
):
pid = tl.program_id(0)
load_offset = tl.arange(0, bs_upper)
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
pre_lens = tl.load(pre_lens_ptr + load_offset, mask=load_offset <= pid)
extend_lens = seq_lens - pre_lens
seq_len = tl.load(seq_lens_ptr + pid)
pre_len = tl.load(pre_lens_ptr + pid)
extend_len = seq_len - pre_len
sum_extend_lens = tl.sum(extend_lens)
output_start_loc = sum_extend_lens - extend_len
num_pages_after = (seq_lens + page_size - 1) // page_size
num_pages_before = (pre_lens + page_size - 1) // page_size
num_new_pages = num_pages_after - num_pages_before
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
pre_len + page_size - 1
) // page_size
sum_num_new_pages = tl.sum(num_new_pages)
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
# Return value
if pid == tl.num_programs(0) - 1:
merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
tl.int64
)
tl.store(ret_values, merged_value)
# Part 1: fill the old partial page
last_loc = tl.load(last_loc_ptr + pid)
num_part1 = (
min(seq_len, (pre_len + page_size - 1) // page_size * page_size) - pre_len
)
offset_one_page = tl.arange(0, page_size)
tl.store(
out_indices + output_start_loc + offset_one_page,
last_loc + 1 + offset_one_page,
mask=offset_one_page < num_part1,
)
if pre_len + num_part1 == seq_len:
return
# Part 2: fill the new full pages
num_part2 = (
seq_len // page_size * page_size
- (pre_len + page_size - 1) // page_size * page_size
)
offset_many_page = tl.arange(0, max_num_extend_tokens)
page_start = tl.load(
free_page_ptr + new_page_start_loc + offset_many_page // page_size,
mask=offset_many_page < num_part2,
)
tl.store(
out_indices + output_start_loc + num_part1 + offset_many_page,
page_start * page_size + offset_many_page % page_size,
mask=offset_many_page < num_part2,
)
if pre_len + num_part1 + num_part2 == seq_len:
return
# Part 3: fill the new partial page
num_part3 = seq_len - seq_len // page_size * page_size
start_loc = tl.load(
free_page_ptr + new_page_start_loc + num_page_start_loc_self - 1
)
tl.store(
out_indices + output_start_loc + num_part1 + num_part2 + offset_one_page,
start_loc * page_size + offset_one_page,
mask=offset_one_page < num_part3,
)
@triton.jit
def alloc_decode_kernel(
seq_lens_ptr,
last_loc_ptr,
free_page_ptr,
out_indices,
ret_values,
bs_upper: tl.constexpr,
page_size: tl.constexpr,
):
pid = tl.program_id(0)
load_offset = tl.arange(0, bs_upper)
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
pre_lens = tl.where(load_offset <= pid, seq_lens - 1, seq_lens)
seq_len = tl.load(seq_lens_ptr + pid)
pre_len = seq_len - 1
num_pages_after = (seq_lens + page_size - 1) // page_size
num_pages_before = (pre_lens + page_size - 1) // page_size
num_new_pages = num_pages_after - num_pages_before
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
pre_len + page_size - 1
) // page_size
sum_num_new_pages = tl.sum(num_new_pages)
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
# Return value
if pid == tl.num_programs(0) - 1:
tl.store(ret_values, sum_num_new_pages)
if num_page_start_loc_self == 0:
last_loc = tl.load(last_loc_ptr + pid)
tl.store(out_indices + pid, last_loc + 1)
else:
page = tl.load(free_page_ptr + new_page_start_loc)
tl.store(out_indices + pid, page * page_size)
class PagedTokenToKVPoolAllocator:
"""
An allocator managing the indices to kv cache data.
This class has the same interface as `TokenToKVPoolAllocator` but the output
of one request is always page-aligned.
TODO: fuse last_loc into the kernel.
"""
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
):
self.size = size
self.dtype = dtype
self.device = device
self.page_size = page_size
self.num_pages = size // page_size
self.free_pages = None
self.is_not_in_free_group = True
self.free_group = []
self.clear()
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self._kvcache = kvcache
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
def available_size(self):
return len(self.free_pages) * self.page_size
def alloc_extend(
self,
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
):
if self.debug_mode:
assert torch.all(
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
bs = len(prefix_lens)
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int64, device=self.device
)
alloc_extend_kernel[(bs,)](
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.ret_values,
next_power_of_2(bs),
self.page_size,
next_power_of_2(extend_num_tokens),
)
merged_value = self.ret_values.item()
num_new_pages = merged_value >> 32
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def alloc_decode(
self,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
):
if self.debug_mode:
assert torch.all(
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
alloc_decode_kernel[(bs,)](
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.ret_values,
next_power_of_2(bs),
self.page_size,
)
num_new_pages = self.ret_values.item()
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
free_page_indices = torch.unique(free_index // self.page_size)
self.free_pages = torch.cat((free_page_indices, self.free_pages))
else:
self.free_group.append(free_index)
def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.cat(self.free_group))
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange(
1, self.num_pages + 1, dtype=torch.int64, device=self.device
)
self.is_not_in_free_group = True
self.free_group = []