sglang0.4.5.post1/python/sglang/srt/mem_cache/radix_cache.py

465 lines
15 KiB
Python

from __future__ import annotations
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
The radix tree data structure for managing the KV cache.
"""
import heapq
import time
from collections import defaultdict
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
class TreeNode:
counter = 0
def __init__(self, id: Optional[int] = None):
self.children = defaultdict(TreeNode)
self.parent = None
self.key = None
self.value = None
self.lock_ref = 0
self.last_access_time = time.time()
self.hit_count = 0
# indicating the node is loading KV cache from host
self.loading = False
# store the host indices of KV cache
self.host_value = None
self.id = TreeNode.counter if id is None else id
TreeNode.counter += 1
@property
def evicted(self):
return self.value is None
@property
def backuped(self):
return self.host_value is not None
def __lt__(self, other: "TreeNode"):
return self.last_access_time < other.last_access_time
def _key_match_page_size1(key0: List, key1: List):
i = 0
for k0, k1 in zip(key0, key1):
if k0 != k1:
break
i += 1
return i
def _key_match_paged(key0: List, key1: List, page_size: int):
min_len = min(len(key0), len(key1))
i = 0
while i < min_len:
if key0[i : i + page_size] != key1[i : i + page_size]:
break
i += page_size
return i
class RadixCache(BasePrefixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
page_size: int,
disable: bool = False,
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
self.disable = disable
if self.token_to_kv_pool_allocator:
self.device = self.token_to_kv_pool_allocator.device
else:
self.device = torch.device("cpu")
if self.page_size == 1:
self.key_match_fn = _key_match_page_size1
self.get_child_key_fn = lambda key: key[0]
else:
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
self.get_child_key_fn = lambda key: tuple(key[:page_size])
self.reset()
##### Public API #####
def reset(self):
self.root_node = TreeNode()
self.root_node.key = []
self.root_node.value = []
self.root_node.lock_ref = 1
self.evictable_size_ = 0
self.protected_size_ = 0
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
"""Find the matching prefix from the radix tree.
Args:
key: A list of token IDs to find a matching prefix.
Returns:
A tuple of a tensor of matching prefix token IDs and
the last node that contains the prefix values. Note that
this API can modify the internal state of the Radix tree.
The last node create a new child if the prefix is shorter
than the last node's value.
"""
if self.disable or len(key) == 0:
return (
torch.empty(
(0,),
dtype=torch.int64,
device=self.device,
),
self.root_node,
)
if self.page_size != 1:
page_aligned_len = len(key) // self.page_size * self.page_size
key = key[:page_aligned_len]
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.cat(value)
else:
value = torch.empty((0,), dtype=torch.int64, device=self.device)
return value, last_node
def insert(self, key: List, value=None):
if self.disable:
return 0
if value is None:
value = [x for x in key]
return self._insert_helper(self.root_node, key, value)
def cache_finished_req(self, req: Req):
"""Cache request when it finishes."""
if self.disable:
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
]
self.token_to_kv_pool_allocator.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx)
return
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else:
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone()
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(
token_ids[:page_aligned_len], page_aligned_kv_indices
)
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
# Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx)
self.dec_lock_ref(req.last_node)
def cache_unfinished_req(self, req: Req):
"""Cache request when it is unfinished."""
if self.disable:
return
token_ids = req.fill_ids
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
else:
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone()
page_aligned_token_ids = token_ids[:page_aligned_len]
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices)
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
# The prefix indices could be updated, reuse it
new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :],
)
self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node)
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
if self.page_size != 1:
req.prefix_indices = torch.cat(
[new_indices, kv_indices[len(new_indices) :]]
)
else:
req.prefix_indices = new_indices
req.last_node = new_last_node
def pretty_print(self):
self._print_helper(self.root_node, 0)
print(f"#tokens: {self.total_size()}")
def total_size(self):
return self._total_size_helper()
def evict(self, num_tokens: int):
if self.disable:
return
leaves = self._collect_leaves()
heapq.heapify(leaves)
num_evicted = 0
while num_evicted < num_tokens and len(leaves):
x = heapq.heappop(leaves)
if x == self.root_node:
break
if x.lock_ref > 0:
continue
self.token_to_kv_pool_allocator.free(x.value)
num_evicted += len(x.value)
self._delete_leaf(x)
if len(x.parent.children) == 0:
heapq.heappush(leaves, x.parent)
def inc_lock_ref(self, node: TreeNode):
if self.disable:
return 0
delta = 0
while node != self.root_node:
if node.lock_ref == 0:
self.evictable_size_ -= len(node.value)
self.protected_size_ += len(node.value)
delta -= len(node.value)
node.lock_ref += 1
node = node.parent
return delta
def dec_lock_ref(self, node: TreeNode):
if self.disable:
return 0
delta = 0
while node != self.root_node:
if node.lock_ref == 1:
self.evictable_size_ += len(node.value)
self.protected_size_ -= len(node.value)
delta += len(node.value)
node.lock_ref -= 1
node = node.parent
return delta
def evictable_size(self):
return self.evictable_size_
def protected_size(self):
# protected size refers to the size of the cache that is locked
return self.protected_size_
def all_values_flatten(self):
values = []
def _dfs_helper(node: TreeNode):
for _, child in node.children.items():
values.append(child.value)
_dfs_helper(child)
_dfs_helper(self.root_node)
return torch.cat(values)
##### Internal Helper Functions #####
def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time()
child_key = self.get_child_key_fn(key)
value = []
while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
child.last_access_time = time.time()
prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
node = new_node
break
else:
value.append(child.value)
node = child
key = key[prefix_len:]
if len(key):
child_key = self.get_child_key_fn(key)
return value, node
def _split_node(self, key, child: TreeNode, split_len: int):
# new_node -> child
new_node = TreeNode()
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
new_node.parent = child.parent
new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len]
new_node.value = child.value[:split_len]
child.parent = new_node
child.key = child.key[split_len:]
child.value = child.value[split_len:]
new_node.parent.children[self.get_child_key_fn(key)] = new_node
return new_node
def _insert_helper(self, node: TreeNode, key: List, value):
node.last_access_time = time.time()
if len(key) == 0:
return 0
child_key = self.get_child_key_fn(key)
total_prefix_length = 0
while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
node.last_access_time = time.time()
prefix_len = self.key_match_fn(node.key, key)
total_prefix_length += prefix_len
key = key[prefix_len:]
value = value[prefix_len:]
if prefix_len < len(node.key):
new_node = self._split_node(node.key, node, prefix_len)
node = new_node
if len(key):
child_key = self.get_child_key_fn(key)
if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.key = key
new_node.value = value
node.children[child_key] = new_node
self.evictable_size_ += len(value)
return total_prefix_length
def _print_helper(self, node: TreeNode, indent: int):
"""Prints the radix tree in a human-readable format."""
stack = [(node, indent)]
while stack:
current_node, current_indent = stack.pop()
print(
" " * current_indent,
len(current_node.key),
current_node.key[:10],
f"r={current_node.lock_ref}",
)
for key, child in current_node.children.items():
stack.append((child, current_indent + 2))
assert key == self.get_child_key_fn(
child.key
), f"{key=}, {self.get_child_key_fn(child.key)=}"
def _delete_leaf(self, node):
for k, v in node.parent.children.items():
if v == node:
break
del node.parent.children[k]
self.evictable_size_ -= len(node.key)
def _total_size_helper(self):
total_size = 0
stack = [self.root_node]
while stack:
current_node = stack.pop()
total_size += len(current_node.value)
for child in current_node.children.values():
if child.evicted:
continue
stack.append(child)
return total_size
def _collect_leaves(self):
ret_list = []
stack = [self.root_node]
while stack:
cur_node = stack.pop()
if len(cur_node.children) == 0:
ret_list.append(cur_node)
else:
stack.extend(cur_node.children.values())
return ret_list
if __name__ == "__main__":
tree = RadixCache(None, None, page_size=1, disable=False)
tree.insert("Hello")
tree.insert("Hello")
tree.insert("Hello_L.A.!")
# tree.insert("Hello_world! Happy")
# tree.insert("I love you!")
tree.pretty_print()
# print(tree.match_prefix("I love you! aha"))
# def evict_callback(x):
# print("evict", x)
# return len(x)
# tree.evict(5, evict_callback)
# tree.evict(10, evict_callback)
# tree.pretty_print()