from __future__ import annotations """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ """ The radix tree data structure for managing the KV cache. """ import heapq import time from collections import defaultdict from functools import partial from typing import TYPE_CHECKING, List, Optional, Tuple import torch from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req class TreeNode: counter = 0 def __init__(self, id: Optional[int] = None): self.children = defaultdict(TreeNode) self.parent = None self.key = None self.value = None self.lock_ref = 0 self.last_access_time = time.time() self.hit_count = 0 # indicating the node is loading KV cache from host self.loading = False # store the host indices of KV cache self.host_value = None self.id = TreeNode.counter if id is None else id TreeNode.counter += 1 @property def evicted(self): return self.value is None @property def backuped(self): return self.host_value is not None def __lt__(self, other: "TreeNode"): return self.last_access_time < other.last_access_time def _key_match_page_size1(key0: List, key1: List): i = 0 for k0, k1 in zip(key0, key1): if k0 != k1: break i += 1 return i def _key_match_paged(key0: List, key1: List, page_size: int): min_len = min(len(key0), len(key1)) i = 0 while i < min_len: if key0[i : i + page_size] != key1[i : i + page_size]: break i += page_size return i class RadixCache(BasePrefixCache): def __init__( self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool_allocator: TokenToKVPoolAllocator, page_size: int, disable: bool = False, ): self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.page_size = page_size self.disable = disable if self.token_to_kv_pool_allocator: self.device = self.token_to_kv_pool_allocator.device else: self.device = torch.device("cpu") if self.page_size == 1: self.key_match_fn = _key_match_page_size1 self.get_child_key_fn = lambda key: key[0] else: self.key_match_fn = partial(_key_match_paged, page_size=page_size) self.get_child_key_fn = lambda key: tuple(key[:page_size]) self.reset() ##### Public API ##### def reset(self): self.root_node = TreeNode() self.root_node.key = [] self.root_node.value = [] self.root_node.lock_ref = 1 self.evictable_size_ = 0 self.protected_size_ = 0 def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]: """Find the matching prefix from the radix tree. Args: key: A list of token IDs to find a matching prefix. Returns: A tuple of a tensor of matching prefix token IDs and the last node that contains the prefix values. Note that this API can modify the internal state of the Radix tree. The last node create a new child if the prefix is shorter than the last node's value. """ if self.disable or len(key) == 0: return ( torch.empty( (0,), dtype=torch.int64, device=self.device, ), self.root_node, ) if self.page_size != 1: page_aligned_len = len(key) // self.page_size * self.page_size key = key[:page_aligned_len] value, last_node = self._match_prefix_helper(self.root_node, key) if value: value = torch.cat(value) else: value = torch.empty((0,), dtype=torch.int64, device=self.device) return value, last_node def insert(self, key: List, value=None): if self.disable: return 0 if value is None: value = [x for x in key] return self._insert_helper(self.root_node, key, value) def cache_finished_req(self, req: Req): """Cache request when it finishes.""" if self.disable: kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1 ] self.token_to_kv_pool_allocator.free(kv_indices) self.req_to_token_pool.free(req.req_pool_idx) return token_ids = (req.origin_input_ids + req.output_ids)[:-1] kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : len(token_ids) ] if self.page_size != 1: page_aligned_len = len(kv_indices) // self.page_size * self.page_size page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) else: page_aligned_len = len(kv_indices) page_aligned_kv_indices = kv_indices.clone() # Radix Cache takes one ref in memory pool new_prefix_len = self.insert( token_ids[:page_aligned_len], page_aligned_kv_indices ) self.token_to_kv_pool_allocator.free( kv_indices[len(req.prefix_indices) : new_prefix_len] ) # Remove req slot release the cache lock self.req_to_token_pool.free(req.req_pool_idx) self.dec_lock_ref(req.last_node) def cache_unfinished_req(self, req: Req): """Cache request when it is unfinished.""" if self.disable: return token_ids = req.fill_ids kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : len(token_ids) ] if self.page_size != 1: page_aligned_len = len(kv_indices) // self.page_size * self.page_size page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() else: page_aligned_len = len(kv_indices) page_aligned_kv_indices = kv_indices.clone() page_aligned_token_ids = token_ids[:page_aligned_len] # Radix Cache takes one ref in memory pool new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices) self.token_to_kv_pool_allocator.free( kv_indices[len(req.prefix_indices) : new_prefix_len] ) # The prefix indices could be updated, reuse it new_indices, new_last_node = self.match_prefix(page_aligned_token_ids) self.req_to_token_pool.write( (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), new_indices[len(req.prefix_indices) :], ) self.dec_lock_ref(req.last_node) self.inc_lock_ref(new_last_node) # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later if self.page_size != 1: req.prefix_indices = torch.cat( [new_indices, kv_indices[len(new_indices) :]] ) else: req.prefix_indices = new_indices req.last_node = new_last_node def pretty_print(self): self._print_helper(self.root_node, 0) print(f"#tokens: {self.total_size()}") def total_size(self): return self._total_size_helper() def evict(self, num_tokens: int): if self.disable: return leaves = self._collect_leaves() heapq.heapify(leaves) num_evicted = 0 while num_evicted < num_tokens and len(leaves): x = heapq.heappop(leaves) if x == self.root_node: break if x.lock_ref > 0: continue self.token_to_kv_pool_allocator.free(x.value) num_evicted += len(x.value) self._delete_leaf(x) if len(x.parent.children) == 0: heapq.heappush(leaves, x.parent) def inc_lock_ref(self, node: TreeNode): if self.disable: return 0 delta = 0 while node != self.root_node: if node.lock_ref == 0: self.evictable_size_ -= len(node.value) self.protected_size_ += len(node.value) delta -= len(node.value) node.lock_ref += 1 node = node.parent return delta def dec_lock_ref(self, node: TreeNode): if self.disable: return 0 delta = 0 while node != self.root_node: if node.lock_ref == 1: self.evictable_size_ += len(node.value) self.protected_size_ -= len(node.value) delta += len(node.value) node.lock_ref -= 1 node = node.parent return delta def evictable_size(self): return self.evictable_size_ def protected_size(self): # protected size refers to the size of the cache that is locked return self.protected_size_ def all_values_flatten(self): values = [] def _dfs_helper(node: TreeNode): for _, child in node.children.items(): values.append(child.value) _dfs_helper(child) _dfs_helper(self.root_node) return torch.cat(values) ##### Internal Helper Functions ##### def _match_prefix_helper(self, node: TreeNode, key: List): node.last_access_time = time.time() child_key = self.get_child_key_fn(key) value = [] while len(key) > 0 and child_key in node.children.keys(): child = node.children[child_key] child.last_access_time = time.time() prefix_len = self.key_match_fn(child.key, key) if prefix_len < len(child.key): new_node = self._split_node(child.key, child, prefix_len) value.append(new_node.value) node = new_node break else: value.append(child.value) node = child key = key[prefix_len:] if len(key): child_key = self.get_child_key_fn(key) return value, node def _split_node(self, key, child: TreeNode, split_len: int): # new_node -> child new_node = TreeNode() new_node.children = {self.get_child_key_fn(key[split_len:]): child} new_node.parent = child.parent new_node.lock_ref = child.lock_ref new_node.key = child.key[:split_len] new_node.value = child.value[:split_len] child.parent = new_node child.key = child.key[split_len:] child.value = child.value[split_len:] new_node.parent.children[self.get_child_key_fn(key)] = new_node return new_node def _insert_helper(self, node: TreeNode, key: List, value): node.last_access_time = time.time() if len(key) == 0: return 0 child_key = self.get_child_key_fn(key) total_prefix_length = 0 while len(key) > 0 and child_key in node.children.keys(): node = node.children[child_key] node.last_access_time = time.time() prefix_len = self.key_match_fn(node.key, key) total_prefix_length += prefix_len key = key[prefix_len:] value = value[prefix_len:] if prefix_len < len(node.key): new_node = self._split_node(node.key, node, prefix_len) node = new_node if len(key): child_key = self.get_child_key_fn(key) if len(key): new_node = TreeNode() new_node.parent = node new_node.key = key new_node.value = value node.children[child_key] = new_node self.evictable_size_ += len(value) return total_prefix_length def _print_helper(self, node: TreeNode, indent: int): """Prints the radix tree in a human-readable format.""" stack = [(node, indent)] while stack: current_node, current_indent = stack.pop() print( " " * current_indent, len(current_node.key), current_node.key[:10], f"r={current_node.lock_ref}", ) for key, child in current_node.children.items(): stack.append((child, current_indent + 2)) assert key == self.get_child_key_fn( child.key ), f"{key=}, {self.get_child_key_fn(child.key)=}" def _delete_leaf(self, node): for k, v in node.parent.children.items(): if v == node: break del node.parent.children[k] self.evictable_size_ -= len(node.key) def _total_size_helper(self): total_size = 0 stack = [self.root_node] while stack: current_node = stack.pop() total_size += len(current_node.value) for child in current_node.children.values(): if child.evicted: continue stack.append(child) return total_size def _collect_leaves(self): ret_list = [] stack = [self.root_node] while stack: cur_node = stack.pop() if len(cur_node.children) == 0: ret_list.append(cur_node) else: stack.extend(cur_node.children.values()) return ret_list if __name__ == "__main__": tree = RadixCache(None, None, page_size=1, disable=False) tree.insert("Hello") tree.insert("Hello") tree.insert("Hello_L.A.!") # tree.insert("Hello_world! Happy") # tree.insert("I love you!") tree.pretty_print() # print(tree.match_prefix("I love you! aha")) # def evict_callback(x): # print("evict", x) # return len(x) # tree.evict(5, evict_callback) # tree.evict(10, evict_callback) # tree.pretty_print()