# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Request scheduler policy""" import os import random from collections import defaultdict from contextlib import contextmanager from enum import Enum, auto from typing import Dict, List, Optional, Set, Union import torch from sglang.srt.managers.schedule_batch import ( Req, ScheduleBatch, global_server_args_dict, ) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large. # This can prevent the server from being too conservative. # Note that this only clips the estimation in the scheduler but does not change the stop # condition. The request can still generate tokens until it hits the unclipped max_new_tokens. CLIP_MAX_NEW_TOKENS_ESTIMATION = int( os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096") ) # Threshold for in-batch prefix cache. # If a request has a matched prefix length (against existing cache) less than this value, # the scheduler runs the in-batch prefix caching check for this request. # If we set it to -1, it means we disable in-batch prefix caching. IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD = int( os.environ.get("IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD", "32") ) # Threshold for in-batch prefix cache. # If a request has a matched prefix length (within the waiting queue) larger than this value, # the scheduler deprioritizes this request IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int( os.environ.get("IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD", "32") ) class CacheAwarePolicy(Enum): """Scheduling policies that are aware of the tree cache.""" LPM = "lpm" # longest prefix match DFS_WEIGHT = "dfs-weight" # depth-first search weighting class CacheAgnosticPolicy(Enum): """Scheduling policies that are not aware of the tree cache.""" FCFS = "fcfs" # first come first serve LOF = "lof" # longest output first RANDOM = "random" class SchedulePolicy: Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy] def __init__( self, policy: str, tree_cache: BasePrefixCache, enable_hierarchical_cache: bool, ): self.policy = self._validate_and_adjust_policy(policy, tree_cache) self.tree_cache = tree_cache self.enable_hierarchical_cache = enable_hierarchical_cache # It is used to find the matching prefix for in-batch prefix caching. self.waiting_queue_radix_tree = RadixCache( req_to_token_pool=None, token_to_kv_pool_allocator=None, page_size=1, disable=False, ) def calc_priority(self, waiting_queue: List[Req]) -> bool: if self.policy == CacheAgnosticPolicy.FCFS: # A shortcut for FCFS return policy = self._determine_active_policy(waiting_queue) prefix_computed = False if isinstance(policy, CacheAwarePolicy): prefix_computed = True temporary_deprioritized = self._compute_prefix_matches( waiting_queue, policy ) if policy == CacheAwarePolicy.LPM: SchedulePolicy._sort_by_longest_prefix( waiting_queue, temporary_deprioritized ) elif policy == CacheAwarePolicy.DFS_WEIGHT: SchedulePolicy._sort_by_dfs_weight(waiting_queue, self.tree_cache) else: raise ValueError(f"Unknown CacheAware Policy: {policy=}") else: if policy == CacheAgnosticPolicy.FCFS: pass elif policy == CacheAgnosticPolicy.LOF: SchedulePolicy._sort_by_longest_output(waiting_queue) elif policy == CacheAgnosticPolicy.RANDOM: SchedulePolicy._sort_randomly(waiting_queue) else: raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}") return prefix_computed def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy: if self.policy == CacheAwarePolicy.LPM and len(waiting_queue) > 128: # Turn off the expensive prefix matching and sorting when the #queue is large. return CacheAgnosticPolicy.FCFS return self.policy def _validate_and_adjust_policy( self, policy: str, tree_cache: BasePrefixCache ) -> Policy: """ Validates the policy and adjusts it if necessary based on tree cache settings. """ try: policy_enum = CacheAwarePolicy(policy) if tree_cache.disable: # If tree_cache is disabled, using CacheAgnosticPolicy policy return CacheAgnosticPolicy.FCFS return policy_enum except ValueError: try: return CacheAgnosticPolicy(policy) except ValueError: raise ValueError(f"Unknown schedule_policy: {policy=}") def _compute_prefix_matches( self, waiting_queue: List[Req], policy: CacheAwarePolicy ) -> Set[int]: """ Computes and caches the matching prefixes for requests in the waiting queue, and handles in-batch prefix caching logic. """ temporary_deprioritized: Set[int] = set() self.waiting_queue_radix_tree.reset() for r in waiting_queue: prefix_ids = r.adjust_max_prefix_ids() # NOTE: the prefix_indices must always be aligned with last_node if self.enable_hierarchical_cache: r.prefix_indices, r.last_node, r.last_node_global = ( self.tree_cache.match_prefix(key=prefix_ids, include_evicted=True) ) else: r.prefix_indices, r.last_node = self.tree_cache.match_prefix( rid=r.rid, key=prefix_ids ) # NOTE(sang): This logic is for in-batch prefix caching; # If there are more than 1 request that have small matching prefix from # existing cache, but all those requests share the same prefix, we prefer # to schedule only one of them so that we can increase the cache hit rate. # We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small # threshold means we cannot use in-batch prefix caching for short prefixes. # It is kind of common when the engine is long running (e.g., imagine the prefix "the"). if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD: in_batch_matching_prefixes, _ = ( self.waiting_queue_radix_tree.match_prefix( rid=r.rid, key=prefix_ids ) ) if ( len(in_batch_matching_prefixes) >= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD ): temporary_deprioritized.add(r.rid) else: # Insert with a dummy key self.waiting_queue_radix_tree.insert( prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool) ) return temporary_deprioritized @staticmethod def _sort_by_longest_prefix( waiting_queue: List[Req], temporary_deprioritized: Set[int] ) -> None: """Sorts the waiting queue based on the longest prefix match.""" waiting_queue.sort( key=lambda r: ( -len(r.prefix_indices) if r.rid not in temporary_deprioritized else float("inf") ) ) @staticmethod def _sort_by_dfs_weight( waiting_queue: List[Req], tree_cache: BasePrefixCache ) -> None: """Sorts the waiting queue based on a depth-first search weighting.""" last_node_to_reqs = defaultdict(list) for req in waiting_queue: last_node_to_reqs[req.last_node].append(req) node_to_weight = defaultdict(int) for node in last_node_to_reqs: node_to_weight[node] = len(last_node_to_reqs[node]) SchedulePolicy._calc_weight(tree_cache.root_node, node_to_weight) waiting_queue.clear() SchedulePolicy._get_dfs_priority( tree_cache.root_node, node_to_weight, last_node_to_reqs, waiting_queue, ) @staticmethod def _sort_by_longest_output(waiting_queue: List[Req]) -> None: """Sorts the waiting queue based on the longest output (max_new_tokens).""" waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) @staticmethod def _sort_randomly(waiting_queue: List[Req]) -> None: """Shuffles the waiting queue randomly.""" random.shuffle(waiting_queue) @staticmethod def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None: for child in cur_node.children.values(): SchedulePolicy._calc_weight(child, node_to_weight) node_to_weight[cur_node] += node_to_weight[child] @staticmethod def _get_dfs_priority( cur_node: TreeNode, node_to_priority: Dict[TreeNode, int], last_node_to_reqs: Dict[TreeNode, List[Req]], q: List, ) -> None: childs = [child for child in cur_node.children.values()] childs.sort(key=lambda x: -node_to_priority[x]) for child in childs: SchedulePolicy._get_dfs_priority( child, node_to_priority, last_node_to_reqs, q ) q.extend(last_node_to_reqs[cur_node]) class AddReqResult(Enum): CONTINUE = auto() # Continue to add requests NO_TOKEN = auto() # No token left OTHER = auto() # Other reasons to stop adding requests class PrefillAdder: def __init__( self, tree_cache: BasePrefixCache, token_to_kv_pool_allocator: TokenToKVPoolAllocator, running_batch: ScheduleBatch, new_token_ratio: float, rem_input_tokens: int, rem_chunk_tokens: Optional[int], mixed_with_decode_tokens: int = 0, ): self.tree_cache = tree_cache self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.running_batch = running_batch self.new_token_ratio = new_token_ratio self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_chunk_tokens = rem_chunk_tokens if self.rem_chunk_tokens is not None: self.rem_chunk_tokens -= mixed_with_decode_tokens self.rem_total_token_offset = mixed_with_decode_tokens self.cur_rem_token_offset = mixed_with_decode_tokens self.req_states = None self.can_run_list = [] self.new_chunked_req = None self.log_hit_tokens = 0 self.log_input_tokens = 0 if running_batch is not None: self.rem_total_token_offset += sum( [ min( (r.sampling_params.max_new_tokens - len(r.output_ids)), CLIP_MAX_NEW_TOKENS_ESTIMATION, ) * self.new_token_ratio for r in running_batch.reqs ] ) @property def rem_total_tokens(self): return ( self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size() - self.rem_total_token_offset ) @property def cur_rem_tokens(self): return ( self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size() - self.cur_rem_token_offset ) def budget_state(self): if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0: return AddReqResult.NO_TOKEN if self.rem_input_tokens <= 0 or ( self.rem_chunk_tokens is not None and self.rem_chunk_tokens <= 0 ): return AddReqResult.OTHER return AddReqResult.CONTINUE def _prefill_one_req( self, prefix_len: int, extend_input_len: int, max_new_tokens: int ): self.rem_total_token_offset += extend_input_len + max_new_tokens self.cur_rem_token_offset += extend_input_len self.rem_input_tokens -= extend_input_len if self.rem_chunk_tokens is not None: self.rem_chunk_tokens -= extend_input_len self.log_hit_tokens += prefix_len self.log_input_tokens += extend_input_len def add_chunked_req(self, req: Req): truncated = req.extend_input_len > self.rem_chunk_tokens req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] self.can_run_list.append(req) self._prefill_one_req( 0, req.extend_input_len, ( min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION) if not truncated else 0 ), ) # Return if chunked prefill not finished return req if truncated else None @contextmanager def _lock_node(self, last_node: TreeNode): try: self.tree_cache.inc_lock_ref(last_node) yield None finally: self.tree_cache.dec_lock_ref(last_node) def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool): def add_req_state(r, insert_sort=False): new_token_ratio = ( 1.0 if r.sampling_params.ignore_eos else self.new_token_ratio ) tokens_left = r.sampling_params.max_new_tokens * new_token_ratio - len( r.output_ids ) tokens_occupied = len(r.origin_input_ids) + len(r.output_ids) if tokens_left > 0: if not insert_sort: self.req_states.append((tokens_left, tokens_occupied)) else: i = 0 for i in range(len(self.req_states)): if tokens_left <= self.req_states[i][0]: break self.req_states.insert(i, (tokens_left, tokens_occupied)) if self.req_states is None: self.req_states = [] add_req_state(req) if self.running_batch is not None: for r in self.running_batch.reqs: add_req_state(r) for r in self.can_run_list: add_req_state(r) self.req_states.sort(key=lambda x: x[0]) else: add_req_state(req, insert_sort=True) cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids) tokens_freed = 0 for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): decode_steps = ( self.req_states[i + 1][0] if i + 1 < len(self.req_states) else tokens_left ) bs = len(self.req_states) - i if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0: return AddReqResult.NO_TOKEN tokens_freed += tokens_occupied if ( self.rem_chunk_tokens is None # chunked prefill is disabled or req.extend_input_len <= self.rem_chunk_tokens # it is the last chunk ): # Non-chunked prefill self.can_run_list.append(req) self._prefill_one_req( 0, req.extend_input_len, min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION), ) else: if self.rem_chunk_tokens == 0: return AddReqResult.OTHER # Chunked prefill trunc_len = self.rem_chunk_tokens req.extend_input_len = trunc_len req.fill_ids = req.fill_ids[:trunc_len] self.can_run_list.append(req) self.new_chunked_req = req self._prefill_one_req(0, trunc_len, 0) return self.budget_state() def add_one_req( self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False ): if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True): return self.add_one_req_ignore_eos(req, has_chunked_req) total_tokens = req.extend_input_len + min( req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION ) input_tokens = req.extend_input_len prefix_len = len(req.prefix_indices) if total_tokens >= self.rem_total_tokens: return AddReqResult.NO_TOKEN if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0: return AddReqResult.OTHER with self._lock_node(req.last_node): if total_tokens > self.rem_total_tokens: return AddReqResult.NO_TOKEN if ( enable_hierarchical_cache and req.last_node_global is not None and req.last_node_global.evicted ): req.last_node, req.prefix_indices = self.tree_cache.init_load_back( req.last_node_global, req.prefix_indices ) req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) input_tokens = req.extend_input_len prefix_len = len(req.prefix_indices) if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens: # Non-chunked prefill self.can_run_list.append(req) self.tree_cache.inc_lock_ref(req.last_node) self._prefill_one_req( prefix_len, input_tokens, min( req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION, ), ) else: if self.rem_chunk_tokens == 0: return AddReqResult.OTHER # Chunked prefill trunc_len = self.rem_chunk_tokens req.extend_input_len = trunc_len req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] self.can_run_list.append(req) self.new_chunked_req = req self.tree_cache.inc_lock_ref(req.last_node) self._prefill_one_req(prefix_len, trunc_len, 0) return self.budget_state()