from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional import torch import torch.nn.functional as F import triton import triton.language as tl from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient from sglang.srt.utils import is_cuda_available, is_hip if is_cuda_available(): from sgl_kernel import ( top_k_renorm_prob, top_p_renorm_prob, tree_speculative_sampling_target_only, verify_tree_greedy, ) elif is_hip(): from sgl_kernel import verify_tree_greedy if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch import logging logger = logging.getLogger(__name__) @dataclass class EagleDraftInput: # The inputs for decode # shape: (b, topk) topk_p: torch.Tensor = None topk_index: torch.Tensor = None # shape: (b, hidden_size) hidden_states: torch.Tensor = None capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL # Inputs for extend # shape: (b,) verified_id: torch.Tensor = None accept_length: torch.Tensor = None accept_length_cpu: List[int] = None # Inputs for the attention backends # shape: (b + 1,) kv_indptr: torch.Tensor = None kv_indices: torch.Tensor = None all_padding_lens: Optional[torch.Tensor] = None def prepare_for_extend(self, batch: ScheduleBatch): # Prefill only generate 1 token. assert len(self.verified_id) == len(batch.seq_lens) pt = 0 for i, extend_len in enumerate(batch.extend_lens): input_ids = batch.input_ids[pt : pt + extend_len] batch.input_ids[pt : pt + extend_len] = torch.cat( (input_ids[1:], self.verified_id[i].reshape(1)) ) pt += extend_len def prepare_extend_after_decode( self, batch: ScheduleBatch, speculative_num_steps: int, ): assert len(self.verified_id) == len(batch.out_cache_loc) accept_length_cpu = batch.spec_info.accept_length_cpu batch.extend_lens = [x + 1 for x in accept_length_cpu] batch.extend_num_tokens = sum(batch.extend_lens) batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend seq_lens_cpu = batch.seq_lens.tolist() self.positions = torch.empty_like(self.verified_id, dtype=torch.long) new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32) self.accept_length.add_(1) create_extend_spec_info[(self.accept_length.numel(),)]( self.verified_id, batch.seq_lens, self.accept_length, torch.cumsum(self.accept_length, axis=0, dtype=torch.int), self.positions, new_verified_id, triton.next_power_of_2(speculative_num_steps + 1), ) batch.seq_lens_sum = sum(seq_lens_cpu) batch.input_ids = self.verified_id self.verified_id = new_verified_id def generate_attn_arg_prefill( self, req_pool_indices: torch.Tensor, paged_kernel_lens: torch.Tensor, paged_kernel_lens_sum: int, req_to_token: torch.Tensor, ): bs = self.accept_length.numel() qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) # TODO: replace cum_kv_seq_len[-1] with paged_kernel_lens_sum to avoid the device sync. kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") create_flashinfer_kv_indices_triton[(bs,)]( req_to_token, req_pool_indices, paged_kernel_lens, cum_kv_seq_len, None, kv_indices, req_to_token.size(1), ) return kv_indices, cum_kv_seq_len, qo_indptr, None def filter_batch(self, new_indices: torch.Tensor): self.topk_p = self.topk_p[: len(new_indices)] self.topk_index = self.topk_index[: len(new_indices)] self.hidden_states = self.hidden_states[: len(new_indices)] self.verified_id = self.verified_id[: len(new_indices)] def merge_batch(self, spec_info: EagleDraftInput): if self.hidden_states is None: self.hidden_states = spec_info.hidden_states self.verified_id = spec_info.verified_id self.topk_p = spec_info.topk_p self.topk_index = spec_info.topk_index return if spec_info.hidden_states is None: return self.hidden_states = torch.cat( [self.hidden_states, spec_info.hidden_states], axis=0 ) self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) @dataclass class EagleVerifyOutput: # Draft input batch draft_input: EagleDraftInput # Logit outputs from target worker logits_output: LogitsProcessorOutput # Accepeted token ids including the bonus token verified_id: torch.Tensor # Accepeted token length per sequence in a batch in CPU. accept_length_per_req_cpu: List[int] # Accepeted indices from logits_output.next_token_logits accepeted_indices: torch.Tensor @dataclass class EagleVerifyInput: draft_token: torch.Tensor custom_mask: torch.Tensor positions: torch.Tensor retrive_index: torch.Tensor retrive_next_token: torch.Tensor retrive_next_sibling: torch.Tensor retrive_cum_len: torch.Tensor draft_token_num: int spec_steps: int capture_hidden_mode: CaptureHiddenMode @classmethod def create( cls, verified_id: torch.Tensor, score_list: List[torch.Tensor], token_list: List[torch.Tensor], parents_list: List[torch.Tensor], seq_lens: torch.Tensor, seq_lens_sum: int, topk: int, spec_steps: int, num_verify_tokens: int, ): ( tree_mask, position, retrive_index, retrive_next_token, retrive_next_sibling, draft_tokens, ) = build_tree_kernel_efficient( verified_id, score_list, token_list, parents_list, seq_lens, seq_lens_sum, topk, spec_steps, num_verify_tokens, ) return cls( draft_tokens, tree_mask, position, retrive_index, retrive_next_token, retrive_next_sibling, None, num_verify_tokens, spec_steps, CaptureHiddenMode.FULL, ) def prepare_for_verify(self, batch: ScheduleBatch): batch.input_ids = self.draft_token batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) bs = batch.batch_size() assign_req_to_token_pool[(bs,)]( batch.req_pool_indices, batch.req_to_token_pool.req_to_token, batch.seq_lens, batch.seq_lens + self.draft_token_num, batch.out_cache_loc, batch.req_to_token_pool.req_to_token.shape[1], triton.next_power_of_2(bs), ) def generate_attn_arg_prefill( self, req_pool_indices: torch.Tensor, paged_kernel_lens: torch.Tensor, paged_kernel_lens_sum: int, req_to_token: torch.Tensor, ): batch_size = len(req_pool_indices) qo_indptr = torch.arange( 0, (1 + batch_size) * self.draft_token_num, step=self.draft_token_num, dtype=torch.int32, device="cuda", ) cum_kv_seq_len = torch.zeros( (batch_size + 1,), dtype=torch.int32, device="cuda" ) paged_kernel_lens = paged_kernel_lens + self.draft_token_num cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) kv_indices = torch.empty( paged_kernel_lens_sum + self.draft_token_num * batch_size, dtype=torch.int32, device="cuda", ) create_flashinfer_kv_indices_triton[(batch_size,)]( req_to_token, req_pool_indices, paged_kernel_lens, cum_kv_seq_len, None, kv_indices, req_to_token.size(1), ) return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask def verify( self, batch: ScheduleBatch, logits_output: torch.Tensor, token_to_kv_pool_allocator: TokenToKVPoolAllocator, ) -> torch.Tensor: """ Verify and find accepted tokens based on logits output and batch (which contains spec decoding information). WARNING: This API in-place modifies the states of logits_output This API updates values inside logits_output based on the accepted tokens. I.e., logits_output.next_token_logits only contains accepeted token logits. """ bs = self.retrive_index.shape[0] candidates = self.draft_token.reshape(bs, self.draft_token_num) sampling_info = batch.sampling_info predict_shape = list(logits_output.next_token_logits.shape)[:-1] predict_shape[-1] += 1 predict = torch.empty(predict_shape, dtype=torch.int32, device="cuda") accept_index = torch.full( (bs, self.spec_steps + 1), -1, dtype=torch.int32, device="cuda" ) accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda") if sampling_info.penalizer_orchestrator.is_required: # This is a relaxed version of penalties for speculative decoding. linear_penalty = torch.zeros( (bs, logits_output.next_token_logits.shape[1]), dtype=torch.float32, device="cuda", ) sampling_info.apply_logits_bias(linear_penalty) logits_output.next_token_logits.add_( torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) ) if batch.sampling_info.is_all_greedy: target_predict = torch.argmax(logits_output.next_token_logits, dim=-1) target_predict = target_predict.reshape(bs, self.draft_token_num) verify_tree_greedy( predicts=predict, # mutable accept_index=accept_index, # mutable accept_token_num=accept_length, # mutable candidates=candidates.to(torch.int32), retrive_index=self.retrive_index.to(torch.int32), retrive_next_token=self.retrive_next_token.to(torch.int32), retrive_next_sibling=self.retrive_next_sibling.to(torch.int32), target_predict=target_predict.to(torch.int32), ) else: # apply temperature and get target probs expanded_temperature = torch.repeat_interleave( sampling_info.temperatures, self.draft_token_num, dim=0 ) # (bs * draft_token_num, 1) target_probs = F.softmax( logits_output.next_token_logits / expanded_temperature, dim=-1 ) # (bs * draft_token_num, vocab_size) target_probs = top_k_renorm_prob( target_probs, torch.repeat_interleave( sampling_info.top_ks, self.draft_token_num, dim=0 ), ) # (bs * draft_token_num, vocab_size) target_probs = top_p_renorm_prob( target_probs, torch.repeat_interleave( sampling_info.top_ps, self.draft_token_num, dim=0 ), ) target_probs = target_probs.reshape(bs, self.draft_token_num, -1) draft_probs = torch.zeros( target_probs.shape, dtype=torch.float32, device="cuda" ) coins = torch.rand_like(candidates, dtype=torch.float32, device="cuda") tree_speculative_sampling_target_only( predicts=predict, # mutable accept_index=accept_index, # mutable accept_token_num=accept_length, # mutable candidates=candidates.to(torch.int32), retrive_index=self.retrive_index.to(torch.int32), retrive_next_token=self.retrive_next_token.to(torch.int32), retrive_next_sibling=self.retrive_next_sibling.to(torch.int32), uniform_samples=coins, target_probs=target_probs, draft_probs=draft_probs, threshold_single=global_server_args_dict[ "speculative_accept_threshold_single" ], threshold_acc=global_server_args_dict[ "speculative_accept_threshold_acc" ], deterministic=True, ) new_accept_index = [] unfinished_index = [] accept_index_cpu = accept_index.tolist() predict_cpu = predict.tolist() has_finished = False # iterate every accepted token and check if req has finished after append the token # should be checked BEFORE free kv cache slots for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)): new_accept_index_ = [] for j, idx in enumerate(accept_index_row): if idx == -1: break id = predict_cpu[idx] # if not found_finished: req.output_ids.append(id) req.check_finished() if req.finished(): has_finished = True # set all tokens after finished token to -1 and break accept_index[i, j + 1 :] = -1 break else: new_accept_index_.append(idx) if not req.finished(): new_accept_index.extend(new_accept_index_) unfinished_index.append(i) req.spec_verify_ct += 1 if not has_finished: accept_index = accept_index[accept_index != -1] verified_id = predict[accept_index] evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask[accept_index] = False mem_need_free_idx = batch.out_cache_loc[evict_mask] token_to_kv_pool_allocator.free(mem_need_free_idx) batch.out_cache_loc = batch.out_cache_loc[accept_index] assign_req_to_token_pool[(bs,)]( batch.req_pool_indices, batch.req_to_token_pool.req_to_token, batch.seq_lens, batch.seq_lens + accept_length + 1, batch.out_cache_loc, batch.req_to_token_pool.req_to_token.shape[1], triton.next_power_of_2(bs), ) batch.seq_lens.add_(accept_length + 1) accept_length_cpu = accept_length.tolist() draft_input = EagleDraftInput() draft_input.hidden_states = batch.spec_info.hidden_states[accept_index] draft_input.verified_id = verified_id draft_input.accept_length = accept_length draft_input.accept_length_cpu = accept_length_cpu draft_input.seq_lens_for_draft_extend = batch.seq_lens draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices return EagleVerifyOutput( draft_input=draft_input, logits_output=logits_output, verified_id=verified_id, accept_length_per_req_cpu=accept_length_cpu, accepeted_indices=accept_index, ) else: accept_length = (accept_index != -1).sum(dim=1) - 1 accept_index = accept_index[accept_index != -1] verified_id = predict[accept_index] evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask[accept_index] = False mem_need_free_idx = batch.out_cache_loc[evict_mask] token_to_kv_pool_allocator.free(mem_need_free_idx) assign_req_to_token_pool[(bs,)]( batch.req_pool_indices, batch.req_to_token_pool.req_to_token, batch.seq_lens, batch.seq_lens + accept_length + 1, batch.out_cache_loc[accept_index], batch.req_to_token_pool.req_to_token.shape[1], triton.next_power_of_2(bs), ) batch.seq_lens.add_(accept_length + 1) accept_length_cpu = accept_length.tolist() draft_input = EagleDraftInput() if len(new_accept_index) > 0: new_accept_index = torch.tensor(new_accept_index, device="cuda") draft_input.hidden_states = batch.spec_info.hidden_states[ new_accept_index ] draft_input.verified_id = predict[new_accept_index] draft_input.accept_length = accept_length[unfinished_index] draft_input.accept_length_cpu = [ accept_length_cpu[i] for i in unfinished_index ] if has_finished: draft_input.seq_lens_for_draft_extend = batch.seq_lens[ unfinished_index ] draft_input.req_pool_indices_for_draft_extend = ( batch.req_pool_indices[unfinished_index] ) else: draft_input.seq_lens_for_draft_extend = batch.seq_lens draft_input.req_pool_indices_for_draft_extend = ( batch.req_pool_indices ) batch.out_cache_loc = batch.out_cache_loc[new_accept_index] return EagleVerifyOutput( draft_input=draft_input, logits_output=logits_output, verified_id=verified_id, accept_length_per_req_cpu=accept_length_cpu, accepeted_indices=accept_index, ) @triton.jit def create_extend_spec_info( verified_id, seq_len, accept_len, accept_len_cum, positions, new_verified_id, accept_len_upper: tl.constexpr, ): pid = tl.program_id(axis=0) offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1) seq_length = tl.load(seq_len + pid) accept_length = tl.load(accept_len + pid) positions_ptr = positions + offset data = tl.arange(0, accept_len_upper) mask = data < accept_length tl.store(positions_ptr + data, seq_length - accept_length + data, mask) offset = tl.load(accept_len_cum + pid) - 1 verified_id_data = tl.load(verified_id + offset) tl.store(new_verified_id + pid, verified_id_data) @triton.jit def assign_req_to_token_pool( req_pool_indices, req_to_token, start_offset, end_offset, out_cache_loc, pool_len: tl.constexpr, bs_upper: tl.constexpr, ): BLOCK_SIZE: tl.constexpr = 32 pid = tl.program_id(axis=0) kv_start = tl.load(start_offset + pid) kv_end = tl.load(end_offset + pid) token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len length_offset = tl.arange(0, bs_upper) start = tl.load(start_offset + length_offset, mask=length_offset < pid) end = tl.load(end_offset + length_offset, mask=length_offset < pid) out_offset = tl.sum(end - start, axis=0) out_cache_ptr = out_cache_loc + out_offset save_offset = tl.arange(0, BLOCK_SIZE) + kv_start load_offset = tl.arange(0, BLOCK_SIZE) num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) for _ in range(num_loop): mask = save_offset < kv_end data = tl.load(out_cache_ptr + load_offset, mask=mask) tl.store(token_pool + save_offset, data, mask=mask) save_offset += BLOCK_SIZE load_offset += BLOCK_SIZE @triton.jit def assign_draft_cache_locs( req_pool_indices, req_to_token, seq_lens, out_cache_loc, pool_len: tl.constexpr, topk: tl.constexpr, speculative_num_steps: tl.constexpr, ): BLOCK_SIZE: tl.constexpr = 32 pid = tl.program_id(axis=0) kv_start = tl.load(seq_lens + pid) kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE) for i in range(num_loop): save_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + kv_start load_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE mask = save_offset < kv_end data = tl.load(out_cache_ptr + load_offset, mask=mask) tl.store(token_pool + save_offset, data, mask=mask) @triton.jit def generate_draft_decode_kv_indices( req_pool_indices, req_to_token, paged_kernel_lens, kv_indices, kv_indptr, positions, num_seqs: tl.constexpr, topk: tl.constexpr, pool_len: tl.constexpr, kv_indices_stride: tl.constexpr, kv_indptr_stride: tl.constexpr, bs_upper: tl.constexpr, iter_upper: tl.constexpr, num_tokens_upper: tl.constexpr, ): BLOCK_SIZE: tl.constexpr = 128 iters = tl.program_id(axis=0) bid = tl.program_id(axis=1) topk_id = tl.program_id(axis=2) kv_indices += kv_indices_stride * iters kv_indptr += kv_indptr_stride * iters iters += 1 load_offset = tl.arange(0, bs_upper) seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid) seq_len = tl.load(paged_kernel_lens + bid) cum_seq_len = tl.sum(seq_lens) kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters) kv_ptr = kv_indices + kv_offset token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len kv_offset = tl.arange(0, BLOCK_SIZE) num_loop = tl.cdiv(seq_len, BLOCK_SIZE) for _ in range(num_loop): mask = kv_offset < seq_len data = tl.load(token_pool_ptr + kv_offset, mask=mask) tl.store(kv_ptr + kv_offset, data, mask=mask) kv_offset += BLOCK_SIZE extend_offset = tl.arange(0, iter_upper) extend_data = tl.load( token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id, mask=extend_offset < iters, ) tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters) # Update kv_indptr bs_offset = tl.arange(0, num_tokens_upper) zid = bid * topk + topk_id if zid == 0: zid = num_seqs * topk positions = tl.load(positions + bs_offset, mask=bs_offset < zid) base = tl.sum(positions) tl.store(kv_indptr + zid, base + zid * iters) @torch.compile(dynamic=True) def select_top_k_tokens( i: int, topk_p: torch.Tensor, topk_index: torch.Tensor, hidden_states: torch.Tensor, scores: torch.Tensor, topk: int, ): if i == 0: # The first step after extend input_ids = topk_index.flatten() hidden_states = hidden_states.repeat_interleave(topk, dim=0) scores = topk_p # shape: (b, topk) tree_info = ( topk_p.unsqueeze(1), # shape: (b, 1, topk) topk_index, # shape: (b, topk) torch.arange(-1, topk, dtype=torch.long, device="cuda") .unsqueeze(0) .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1) ) else: # The later decode steps expand_scores = torch.mul( scores.unsqueeze(2), topk_p.reshape(-1, topk, topk) ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk) topk_cs_p, topk_cs_index = fast_topk( expand_scores.flatten(start_dim=1), topk, dim=-1 ) # (b, topk) scores = topk_cs_p # shape: (b, topk) topk_index = topk_index.reshape(-1, topk**2) input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten() selected_input_index = topk_cs_index.flatten() // topk + torch.arange( 0, hidden_states.shape[0], step=topk, device="cuda" ).repeat_interleave(topk) hidden_states = hidden_states[selected_input_index, :] tree_info = ( expand_scores, # shape: (b, topk, topk) topk_index, # shape: (b, topk * topk) topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk) ) return input_ids, hidden_states, scores, tree_info def fast_topk(values, topk, dim): if topk == 1: # Use max along the specified dimension to get both value and index max_value, max_index = torch.max(values, dim=dim) return max_value.unsqueeze(1), max_index.unsqueeze(1) else: # Use topk for efficiency with larger k values return torch.topk(values, topk, dim=dim)