import heapq import logging import threading import time from typing import List, Optional import torch from sglang.srt.managers.cache_controller import HiCacheController from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPool, MHATokenToKVPoolHost, MLATokenToKVPool, MLATokenToKVPoolHost, ReqToTokenPool, TokenToKVPoolAllocator, ) from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match logger = logging.getLogger(__name__) class HiRadixCache(RadixCache): def __init__( self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool_allocator: TokenToKVPoolAllocator, tp_cache_group: torch.distributed.ProcessGroup, page_size: int, hicache_ratio: float, ): if page_size != 1: raise ValueError( "Page size larger than 1 is not yet supported in HiRadixCache." ) self.kv_cache = token_to_kv_pool_allocator.get_kvcache() if isinstance(self.kv_cache, MHATokenToKVPool): self.token_to_kv_pool_host = MHATokenToKVPoolHost( self.kv_cache, hicache_ratio ) elif isinstance(self.kv_cache, MLATokenToKVPool): self.token_to_kv_pool_host = MLATokenToKVPoolHost( self.kv_cache, hicache_ratio ) else: raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.") self.tp_group = tp_cache_group self.page_size = page_size self.load_cache_event = threading.Event() self.cache_controller = HiCacheController( token_to_kv_pool_allocator, self.token_to_kv_pool_host, load_cache_event=self.load_cache_event, ) # record the nodes with ongoing write through self.ongoing_write_through = {} # record the node segments with ongoing load back self.ongoing_load_back = {} # todo: dynamically adjust the threshold self.write_through_threshold = 1 self.load_back_threshold = 10 super().__init__( req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False ) def reset(self): TreeNode.counter = 0 self.cache_controller.reset() self.token_to_kv_pool_host.clear() super().reset() def get_height(self, node: TreeNode): height = 0 while node != self.root_node: node = node.parent height += 1 return height def write_backup(self, node: TreeNode): host_indices = self.cache_controller.write( device_indices=node.value, node_id=node.id, ) if host_indices is None: self.evict_host(len(node.value)) host_indices = self.cache_controller.write( device_indices=node.value, node_id=node.id, ) if host_indices is not None: node.host_value = host_indices self.ongoing_write_through[node.id] = node self.inc_lock_ref(node) else: return None return len(host_indices) def inc_hit_count(self, node: TreeNode): if self.cache_controller.write_policy != "write_through_selective": return node.hit_count += 1 if node.host_value is None and node.hit_count > self.write_through_threshold: self.write_backup(node) node.hit_count = 0 def writing_check(self): queue_size = torch.tensor( self.cache_controller.ack_write_queue.qsize(), dtype=torch.int ) if torch.distributed.get_world_size(group=self.tp_group) > 1: # synchrnoize TP workers to make the same update to radix cache torch.distributed.all_reduce( queue_size, op=torch.distributed.ReduceOp.MIN, group=self.tp_group, ) for _ in range(queue_size.item()): ack_id = self.cache_controller.ack_write_queue.get() self.dec_lock_ref(self.ongoing_write_through[ack_id]) del self.ongoing_write_through[ack_id] def loading_check(self): while not self.cache_controller.ack_load_queue.empty(): try: ack_id = self.cache_controller.ack_load_queue.get_nowait() start_node, end_node = self.ongoing_load_back[ack_id] self.dec_lock_ref(end_node) while end_node != start_node: assert end_node.loading end_node.loading = False end_node = end_node.parent # clear the reference del self.ongoing_load_back[ack_id] except Exception: break def evictable_size(self): return self.evictable_size_ def evict(self, num_tokens: int): leaves = self._collect_leaves_device() heapq.heapify(leaves) num_evicted = 0 pending_nodes = [] while num_evicted < num_tokens and len(leaves): x = heapq.heappop(leaves) if x.lock_ref > 0: continue if x.host_value is None: if self.cache_controller.write_policy == "write_back": num_evicted += self.write_backup(x) elif self.cache_controller.write_policy == "write_through_selective": num_evicted += self._evict_write_through_selective(x) else: assert ( self.cache_controller.write_policy != "write_through" ), "write_through should be inclusive" raise NotImplementedError else: num_evicted += self._evict_write_through(x) for child in x.parent.children.values(): if child in pending_nodes: continue if not child.evicted: break else: # all children are evicted or no children heapq.heappush(leaves, x.parent) if self.cache_controller.write_policy == "write_back": # blocking till all write back complete while len(self.ongoing_write_through) > 0: self.writing_check() time.sleep(0.1) def _evict_write_through(self, node: TreeNode): # evict a node already written to host num_evicted = self.cache_controller.evict_device(node.value, node.host_value) assert num_evicted > 0 self.evictable_size_ -= num_evicted node.value = None return num_evicted def _evict_write_through_selective(self, node: TreeNode): # evict a node not initiated write to host self.cache_controller.mem_pool_device_allocator.free(node.value) num_evicted = len(node.value) self._delete_leaf(node) return num_evicted def evict_host(self, num_tokens: int): leaves = self._collect_leaves() heapq.heapify(leaves) num_evicted = 0 while num_evicted < num_tokens and len(leaves): x = heapq.heappop(leaves) if x == self.root_node: break # only evict the host value of evicted nodes if not x.evicted: continue assert x.lock_ref == 0 and x.host_value is not None assert self.cache_controller.evict_host(x.host_value) > 0 for k, v in x.parent.children.items(): if v == x: break del x.parent.children[k] if len(x.parent.children) == 0 and x.parent.evicted: heapq.heappush(leaves, x.parent) def load_back( self, node: TreeNode, mem_quota: Optional[int] = None ) -> Optional[torch.Tensor]: # todo: more loading policies last_hit_node = node nodes_to_load = [] while node.evicted: assert ( node.backuped ), "No backup available on evicted nodes, should not happen" nodes_to_load.insert(0, node) node = node.parent else: ancester_node = node # protect the ancestor nodes from eviction delta = self.inc_lock_ref(ancester_node) # load it all or not at all host_indices = torch.cat([n.host_value for n in nodes_to_load]) if len(host_indices) < self.load_back_threshold or ( len(host_indices) > mem_quota + delta if mem_quota is not None else False ): # skip loading back if the total size is too small or exceeding the memory quota self.dec_lock_ref(ancester_node) return None device_indices = self.cache_controller.load( host_indices=host_indices, node_id=last_hit_node.id ) if device_indices is None: self.evict(len(host_indices)) device_indices = self.cache_controller.load( host_indices=host_indices, node_id=last_hit_node.id ) self.dec_lock_ref(ancester_node) if device_indices is None: # no sufficient GPU memory to load back KV caches return None self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node) offset = 0 for node in nodes_to_load: node.value = device_indices[offset : offset + len(node.host_value)] offset += len(node.host_value) node.loading = True self.evictable_size_ += len(device_indices) self.inc_lock_ref(last_hit_node) return device_indices def init_load_back( self, last_node: TreeNode, prefix_indices: torch.Tensor, mem_quota: Optional[int] = None, ): assert ( len(prefix_indices) == 0 or prefix_indices.is_cuda ), "indices of device kV caches should be on GPU" if last_node.evicted: loading_values = self.load_back(last_node, mem_quota) if loading_values is not None: prefix_indices = ( loading_values if len(prefix_indices) == 0 else torch.cat([prefix_indices, loading_values]) ) logger.debug( f"loading back {len(loading_values)} tokens for node {last_node.id}" ) while last_node.evicted: last_node = last_node.parent return last_node, prefix_indices def read_to_load_cache(self): self.load_cache_event.set() def match_prefix(self, key: List[int], include_evicted=False, **kwargs): if self.disable: return [], self.root_node value, last_node = self._match_prefix_helper(self.root_node, key) if value: value = torch.cat(value) else: value = torch.tensor([], dtype=torch.int64) last_node_global = last_node while last_node.evicted: last_node = last_node.parent if include_evicted: return value, last_node, last_node_global else: return value, last_node def _match_prefix_helper(self, node: TreeNode, key: List): node.last_access_time = time.time() value = [] while len(key) > 0 and key[0] in node.children.keys(): child = node.children[key[0]] child.last_access_time = time.time() prefix_len = _key_match(child.key, key) if prefix_len < len(child.key): new_node = self._split_node(child.key, child, prefix_len) if not new_node.evicted: value.append(new_node.value) node = new_node break else: if not child.evicted: value.append(child.value) node = child key = key[prefix_len:] return value, node def _split_node(self, key, child: TreeNode, split_len: int): # child node split into new_node -> child new_node = TreeNode() new_node.children = {key[split_len]: child} new_node.parent = child.parent new_node.lock_ref = child.lock_ref new_node.key = child.key[:split_len] new_node.loading = child.loading # split value and host value if exists if child.evicted: new_node.value = None else: new_node.value = child.value[:split_len] child.value = child.value[split_len:] if child.host_value is not None: new_node.host_value = child.host_value[:split_len] child.host_value = child.host_value[split_len:] child.parent = new_node child.key = child.key[split_len:] new_node.parent.children[key[0]] = new_node return new_node def _insert_helper(self, node: TreeNode, key: List, value): node.last_access_time = time.time() if len(key) == 0: return 0 if key[0] in node.children.keys(): child = node.children[key[0]] prefix_len = _key_match(child.key, key) if prefix_len == len(child.key): if child.evicted: # change the reference if the node is evicted # this often happens in the case of KV cache recomputation child.value = value[:prefix_len] self.token_to_kv_pool_host.update_synced(child.host_value) self.evictable_size_ += len(value[:prefix_len]) return self._insert_helper( child, key[prefix_len:], value[prefix_len:] ) else: self.inc_hit_count(child) return prefix_len + self._insert_helper( child, key[prefix_len:], value[prefix_len:] ) # partial match, split the node new_node = self._split_node(child.key, child, prefix_len) if new_node.evicted: new_node.value = value[:prefix_len] self.token_to_kv_pool_host.update_synced(new_node.host_value) self.evictable_size_ += len(new_node.value) return self._insert_helper( new_node, key[prefix_len:], value[prefix_len:] ) else: self.inc_hit_count(new_node) return prefix_len + self._insert_helper( new_node, key[prefix_len:], value[prefix_len:] ) if len(key): new_node = TreeNode() new_node.parent = node new_node.key = key new_node.value = value node.children[key[0]] = new_node self.evictable_size_ += len(value) if self.cache_controller.write_policy == "write_through": self.write_backup(new_node) return 0 def _collect_leaves_device(self): def is_leaf(node): if node.evicted: return False if node == self.root_node: return False if len(node.children) == 0: return True for child in node.children.values(): if not child.evicted: return False return True ret_list = [] stack = [self.root_node] while stack: cur_node = stack.pop() if is_leaf(cur_node): ret_list.append(cur_node) else: for cur_child in cur_node.children.values(): if not cur_child.evicted: stack.append(cur_child) return ret_list