""" Copyright 2025 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ """ Page-aligned memory pool. """ import torch import triton import triton.language as tl from sglang.srt.mem_cache.memory_pool import KVCache from sglang.srt.utils import get_bool_env_var, next_power_of_2 @triton.jit def alloc_extend_kernel( pre_lens_ptr, seq_lens_ptr, last_loc_ptr, free_page_ptr, out_indices, ret_values, bs_upper: tl.constexpr, page_size: tl.constexpr, max_num_extend_tokens: tl.constexpr, ): pid = tl.program_id(0) load_offset = tl.arange(0, bs_upper) seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid) pre_lens = tl.load(pre_lens_ptr + load_offset, mask=load_offset <= pid) extend_lens = seq_lens - pre_lens seq_len = tl.load(seq_lens_ptr + pid) pre_len = tl.load(pre_lens_ptr + pid) extend_len = seq_len - pre_len sum_extend_lens = tl.sum(extend_lens) output_start_loc = sum_extend_lens - extend_len num_pages_after = (seq_lens + page_size - 1) // page_size num_pages_before = (pre_lens + page_size - 1) // page_size num_new_pages = num_pages_after - num_pages_before num_page_start_loc_self = (seq_len + page_size - 1) // page_size - ( pre_len + page_size - 1 ) // page_size sum_num_new_pages = tl.sum(num_new_pages) new_page_start_loc = sum_num_new_pages - num_page_start_loc_self # Return value if pid == tl.num_programs(0) - 1: merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to( tl.int64 ) tl.store(ret_values, merged_value) # Part 1: fill the old partial page last_loc = tl.load(last_loc_ptr + pid) num_part1 = ( min(seq_len, (pre_len + page_size - 1) // page_size * page_size) - pre_len ) offset_one_page = tl.arange(0, page_size) tl.store( out_indices + output_start_loc + offset_one_page, last_loc + 1 + offset_one_page, mask=offset_one_page < num_part1, ) if pre_len + num_part1 == seq_len: return # Part 2: fill the new full pages num_part2 = ( seq_len // page_size * page_size - (pre_len + page_size - 1) // page_size * page_size ) offset_many_page = tl.arange(0, max_num_extend_tokens) page_start = tl.load( free_page_ptr + new_page_start_loc + offset_many_page // page_size, mask=offset_many_page < num_part2, ) tl.store( out_indices + output_start_loc + num_part1 + offset_many_page, page_start * page_size + offset_many_page % page_size, mask=offset_many_page < num_part2, ) if pre_len + num_part1 + num_part2 == seq_len: return # Part 3: fill the new partial page num_part3 = seq_len - seq_len // page_size * page_size start_loc = tl.load( free_page_ptr + new_page_start_loc + num_page_start_loc_self - 1 ) tl.store( out_indices + output_start_loc + num_part1 + num_part2 + offset_one_page, start_loc * page_size + offset_one_page, mask=offset_one_page < num_part3, ) @triton.jit def alloc_decode_kernel( seq_lens_ptr, last_loc_ptr, free_page_ptr, out_indices, ret_values, bs_upper: tl.constexpr, page_size: tl.constexpr, ): pid = tl.program_id(0) load_offset = tl.arange(0, bs_upper) seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid) pre_lens = tl.where(load_offset <= pid, seq_lens - 1, seq_lens) seq_len = tl.load(seq_lens_ptr + pid) pre_len = seq_len - 1 num_pages_after = (seq_lens + page_size - 1) // page_size num_pages_before = (pre_lens + page_size - 1) // page_size num_new_pages = num_pages_after - num_pages_before num_page_start_loc_self = (seq_len + page_size - 1) // page_size - ( pre_len + page_size - 1 ) // page_size sum_num_new_pages = tl.sum(num_new_pages) new_page_start_loc = sum_num_new_pages - num_page_start_loc_self # Return value if pid == tl.num_programs(0) - 1: tl.store(ret_values, sum_num_new_pages) if num_page_start_loc_self == 0: last_loc = tl.load(last_loc_ptr + pid) tl.store(out_indices + pid, last_loc + 1) else: page = tl.load(free_page_ptr + new_page_start_loc) tl.store(out_indices + pid, page * page_size) class PagedTokenToKVPoolAllocator: """ An allocator managing the indices to kv cache data. This class has the same interface as `TokenToKVPoolAllocator` but the output of one request is always page-aligned. TODO: fuse last_loc into the kernel. """ def __init__( self, size: int, page_size: int, dtype: torch.dtype, device: str, kvcache: KVCache, ): self.size = size self.dtype = dtype self.device = device self.page_size = page_size self.num_pages = size // page_size self.free_pages = None self.is_not_in_free_group = True self.free_group = [] self.clear() self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL") self._kvcache = kvcache self.ret_values = torch.empty((), dtype=torch.int64, device=self.device) def available_size(self): return len(self.free_pages) * self.page_size def alloc_extend( self, prefix_lens: torch.Tensor, seq_lens: torch.Tensor, last_loc: torch.Tensor, extend_num_tokens: int, ): if self.debug_mode: assert torch.all( (last_loc + 1) % self.page_size == prefix_lens % self.page_size ) bs = len(prefix_lens) out_indices = torch.empty( (extend_num_tokens,), dtype=torch.int64, device=self.device ) alloc_extend_kernel[(bs,)]( prefix_lens, seq_lens, last_loc, self.free_pages, out_indices, self.ret_values, next_power_of_2(bs), self.page_size, next_power_of_2(extend_num_tokens), ) merged_value = self.ret_values.item() num_new_pages = merged_value >> 32 if num_new_pages > len(self.free_pages): return None self.free_pages = self.free_pages[num_new_pages:] return out_indices def alloc_decode( self, seq_lens: torch.Tensor, last_loc: torch.Tensor, ): if self.debug_mode: assert torch.all( (last_loc + 2) % self.page_size == seq_lens % self.page_size ) bs = len(seq_lens) out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device) alloc_decode_kernel[(bs,)]( seq_lens, last_loc, self.free_pages, out_indices, self.ret_values, next_power_of_2(bs), self.page_size, ) num_new_pages = self.ret_values.item() if num_new_pages > len(self.free_pages): return None self.free_pages = self.free_pages[num_new_pages:] return out_indices def free(self, free_index: torch.Tensor): if free_index.numel() == 0: return if self.is_not_in_free_group: free_page_indices = torch.unique(free_index // self.page_size) self.free_pages = torch.cat((free_page_indices, self.free_pages)) else: self.free_group.append(free_index) def free_group_begin(self): self.is_not_in_free_group = False self.free_group = [] def free_group_end(self): self.is_not_in_free_group = True if self.free_group: self.free(torch.cat(self.free_group)) def clear(self): # The padded slot 0 is used for writing dummy outputs from padded tokens. self.free_pages = torch.arange( 1, self.num_pages + 1, dtype=torch.int64, device=self.device ) self.is_not_in_free_group = True self.free_group = []