sglang0.4.5.post1/python/sglang/srt/layers/logits_processor.py

587 lines
24 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Logits processing."""
import dataclasses
import logging
from typing import List, Optional, Union
import torch
import triton
import triton.language as tl
from torch import nn
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
from sglang.srt.layers.dp_attention import (
dp_gather_replicate,
dp_scatter,
get_attention_dp_rank,
get_attention_dp_size,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.utils import dump_to_file
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class LogitsProcessorOutput:
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor
# Used by speculative decoding (EAGLE)
# The last hidden layers
hidden_states: Optional[torch.Tensor] = None
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
# The logprobs of the next tokens. shape: [#seq]
next_token_logprobs: Optional[torch.Tensor] = None
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
next_token_top_logprobs_val: Optional[List] = None
next_token_top_logprobs_idx: Optional[List] = None
# The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
next_token_token_ids_logprobs_val: Optional[List] = None
next_token_token_ids_logprobs_idx: Optional[List] = None
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
# The logprobs of input tokens. shape: [#token]
input_token_logprobs: Optional[torch.Tensor] = None
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
input_top_logprobs_val: List = None
input_top_logprobs_idx: List = None
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
input_token_ids_logprobs_val: Optional[List] = None
input_token_ids_logprobs_idx: Optional[List] = None
@dataclasses.dataclass
class LogitsMetadata:
forward_mode: ForwardMode
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
extend_return_logprob: bool = False
extend_return_top_logprob: bool = False
extend_token_ids_logprob: bool = False
extend_seq_lens: Optional[torch.Tensor] = None
extend_seq_lens_cpu: Optional[List[int]] = None
extend_logprob_start_lens_cpu: Optional[List[int]] = None
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
top_logprobs_nums: Optional[List[int]] = None
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
token_ids_logprobs: Optional[List[List[int]]] = None
# logits and logprobs post processing
temp_scaled_logprobs: bool = False
temperature: torch.Tensor = None
top_p_normalized_logprobs: bool = False
top_p: torch.Tensor = None
# DP attention metadata. Not needed when DP attention is not used.
# Number of tokens in the request.
global_num_tokens_gpu: Optional[torch.Tensor] = None
# The start position of local hidden states.
dp_local_start_pos: Optional[torch.Tensor] = None
dp_local_num_tokens: Optional[torch.Tensor] = None
gathered_buffer: Optional[torch.Tensor] = None
# Buffer to gather logits from all ranks.
forward_batch_gathered_buffer: Optional[torch.Tensor] = None
# Number of tokens to sample per DP rank
global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
# for padding
padded_static_len: int = -1
@classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch):
if (
forward_batch.forward_mode.is_extend()
and forward_batch.return_logprob
and not forward_batch.forward_mode.is_target_verify()
):
extend_return_top_logprob = any(
x > 0 for x in forward_batch.top_logprobs_nums
)
extend_token_ids_logprob = any(
x is not None for x in forward_batch.token_ids_logprobs
)
extend_return_logprob = False
extend_logprob_pruned_lens_cpu = []
for extend_len, start_len in zip(
forward_batch.extend_seq_lens_cpu,
forward_batch.extend_logprob_start_lens_cpu,
):
if extend_len - start_len > 0:
extend_return_logprob = True
extend_logprob_pruned_lens_cpu.append(extend_len - start_len)
else:
extend_return_logprob = extend_return_top_logprob = (
extend_token_ids_logprob
) = extend_logprob_pruned_lens_cpu = False
return cls(
forward_mode=forward_batch.forward_mode,
capture_hidden_mode=forward_batch.capture_hidden_mode,
extend_return_logprob=extend_return_logprob,
extend_return_top_logprob=extend_return_top_logprob,
extend_token_ids_logprob=extend_token_ids_logprob,
extend_seq_lens=forward_batch.extend_seq_lens,
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
top_logprobs_nums=forward_batch.top_logprobs_nums,
token_ids_logprobs=forward_batch.token_ids_logprobs,
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
padded_static_len=forward_batch.padded_static_len,
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
dp_local_start_pos=forward_batch.dp_local_start_pos,
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
gathered_buffer=forward_batch.gathered_buffer,
forward_batch_gathered_buffer=forward_batch.gathered_buffer,
global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
)
def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
if self.global_num_tokens_for_logprob_cpu is None:
# we are capturing cuda graph
return
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
dp_rank = get_attention_dp_rank()
if dp_rank == 0:
dp_local_start_pos = torch.zeros_like(
self.global_num_tokens_for_logprob_gpu[0]
)
else:
dp_local_start_pos = cumtokens[dp_rank - 1]
dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
gathered_buffer = torch.zeros(
(
sum(self.global_num_tokens_for_logprob_cpu),
hidden_states.shape[1],
),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
self.dp_local_start_pos = dp_local_start_pos
self.dp_local_num_tokens = dp_local_num_tokens
self.gathered_buffer = gathered_buffer
class LogitsProcessor(nn.Module):
def __init__(
self, config, skip_all_gather: bool = False, logit_scale: Optional[float] = None
):
super().__init__()
self.config = config
self.logit_scale = logit_scale
self.do_tensor_parallel_all_gather = (
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
)
self.do_tensor_parallel_all_gather_dp_attn = (
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
)
self.final_logit_softcapping = getattr(
self.config, "final_logit_softcapping", None
)
if (
self.final_logit_softcapping is not None
and self.final_logit_softcapping < 0
):
self.final_logit_softcapping = None
self.debug_tensor_dump_output_folder = global_server_args_dict.get(
"debug_tensor_dump_output_folder", None
)
def forward(
self,
input_ids,
hidden_states,
lm_head: VocabParallelEmbedding,
logits_metadata: Union[LogitsMetadata, ForwardBatch],
aux_hidden_states: Optional[torch.Tensor] = None,
) -> LogitsProcessorOutput:
if isinstance(logits_metadata, ForwardBatch):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
# Get the last hidden states and last logits for the next token prediction
if (
logits_metadata.forward_mode.is_decode_or_idle()
or logits_metadata.forward_mode.is_target_verify()
):
pruned_states = hidden_states
if aux_hidden_states is not None:
aux_pruned_states = [hidden for hidden in aux_hidden_states]
sample_indices = None
input_logprob_indices = None
elif (
logits_metadata.forward_mode.is_extend()
and not logits_metadata.extend_return_logprob
):
# Prefill without input logprobs.
if logits_metadata.padded_static_len < 0:
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
else:
# If padding_static length is 5 and extended_seq_lens is [2, 3],
# then our batch looks like [t00, t01, p, p, p, t10, t11, t12, p, p]
# and this retrieves t01 and t12, which are the valid last tokens
idx = torch.arange(
len(logits_metadata.extend_seq_lens),
device=logits_metadata.extend_seq_lens.device,
)
last_index = (
idx * logits_metadata.padded_static_len
+ logits_metadata.extend_seq_lens
- 1
)
pruned_states = hidden_states[last_index]
if aux_hidden_states is not None:
aux_pruned_states = [hidden[last_index] for hidden in aux_hidden_states]
sample_indices = None
input_logprob_indices = None
else:
# Input logprobs are required.
# Find 3 different indices.
# 1. pruned_states: hidden states that we want logprobs from.
# 2. sample_indices: Indices that have sampled tokens.
# 3. input_logprob_indices: Indices that have input logprob tokens.
sample_index_pt = -1
sample_indices = []
input_logprob_indices_pt = 0
input_logprob_indices = []
pt, pruned_states = 0, []
for extend_logprob_start_len, extend_len in zip(
logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu,
):
# It can happen in chunked prefill. We still need to sample 1 token,
# But we don't want to include it in input logprob.
if extend_len == extend_logprob_start_len:
start_len = extend_logprob_start_len - 1
else:
start_len = extend_logprob_start_len
# We always need at least 1 token to sample because that's required
# by a caller.
assert extend_len > start_len
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
pt += extend_len
sample_index_pt += extend_len - start_len
sample_indices.append(sample_index_pt)
input_logprob_indices.extend(
[
input_logprob_indices_pt + i
for i in range(extend_len - extend_logprob_start_len)
]
)
input_logprob_indices_pt += extend_len - start_len
pruned_states = torch.cat(pruned_states)
sample_indices = torch.tensor(
sample_indices, device=pruned_states.device, dtype=torch.int64
)
input_logprob_indices = torch.tensor(
input_logprob_indices, device=pruned_states.device, dtype=torch.int64
)
# Compute logits for both input and sampled tokens.
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
sampled_logits = (
logits[sample_indices] if sample_indices is not None else logits
)
if self.debug_tensor_dump_output_folder:
assert (
not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
), "dp attention + sharded lm_head doesn't support full logits"
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
hidden_states_to_store: Optional[torch.Tensor] = None
if logits_metadata.capture_hidden_mode.need_capture():
if logits_metadata.capture_hidden_mode.is_full():
if aux_hidden_states is not None:
aux_hidden_states = torch.cat(aux_hidden_states, dim=-1)
hidden_states_to_store = aux_hidden_states
else:
hidden_states_to_store = hidden_states
elif logits_metadata.capture_hidden_mode.is_last():
# Get the last token hidden states. If sample_indices is None,
# pruned states only contain the last tokens already.
if aux_hidden_states is not None:
aux_pruned_states = torch.cat(aux_pruned_states, dim=-1)
hidden_states_to_store = (
aux_pruned_states[sample_indices]
if sample_indices
else aux_pruned_states
)
else:
hidden_states_to_store = (
pruned_states[sample_indices]
if sample_indices
else pruned_states
)
else:
assert False, "Should never reach"
if not logits_metadata.extend_return_logprob:
# Decode mode or extend mode without return_logprob.
return LogitsProcessorOutput(
next_token_logits=sampled_logits,
hidden_states=hidden_states_to_store,
)
else:
input_logprobs = logits[input_logprob_indices]
del hidden_states, logits
# Normalize the logprob w/o temperature, top-p
pruned_lens = torch.tensor(
logits_metadata.extend_logprob_pruned_lens_cpu,
device=input_logprobs.device,
)
if logits_metadata.temp_scaled_logprobs:
logits_metadata.temperature = torch.repeat_interleave(
logits_metadata.temperature.view(-1),
pruned_lens,
).view(-1, 1)
if logits_metadata.top_p_normalized_logprobs:
logits_metadata.top_p = torch.repeat_interleave(
logits_metadata.top_p,
pruned_lens,
)
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
input_logprobs, logits_metadata
)
# Get the logprob of top-k tokens
if logits_metadata.extend_return_top_logprob:
(
input_top_logprobs_val,
input_top_logprobs_idx,
) = self.get_top_logprobs(input_logprobs, logits_metadata)
else:
input_top_logprobs_val = input_top_logprobs_idx = None
# Get the logprob of given token id
if logits_metadata.extend_token_ids_logprob:
(
input_token_ids_logprobs_val,
input_token_ids_logprobs_idx,
) = self.get_token_ids_logprobs(input_logprobs, logits_metadata)
else:
input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None
input_token_logprobs = input_logprobs[
torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
logits_metadata.extend_input_logprob_token_ids_gpu,
]
return LogitsProcessorOutput(
next_token_logits=sampled_logits,
input_token_logprobs=input_token_logprobs,
input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx,
hidden_states=hidden_states_to_store,
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
)
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
logits_metadata: LogitsMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Get logits from hidden_states.
If sampled_logits_only is True, it means hidden_states only contain the
last position (e.g., extend without input logprobs). The caller should
guarantee the given hidden_states follow this constraint.
"""
if self.do_tensor_parallel_all_gather_dp_attn:
logits_metadata.compute_dp_attention_metadata(hidden_states)
hidden_states, local_hidden_states = (
logits_metadata.gathered_buffer,
hidden_states.clone(),
)
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
if hasattr(lm_head, "weight"):
logits = torch.matmul(
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
)
else:
# GGUF models
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
if self.logit_scale is not None:
logits.mul_(self.logit_scale)
if self.do_tensor_parallel_all_gather:
logits = tensor_model_parallel_all_gather(logits)
if self.do_tensor_parallel_all_gather_dp_attn:
logits, global_logits = (
torch.empty(
(local_hidden_states.shape[0], logits.shape[1]),
device=logits.device,
dtype=logits.dtype,
),
logits,
)
dp_scatter(logits, global_logits, logits_metadata)
logits = logits[:, : self.config.vocab_size].float()
if self.final_logit_softcapping:
fused_softcap(logits, self.final_logit_softcapping)
return logits
@staticmethod
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
input_top_logprobs_val, input_top_logprobs_idx = [], []
pt = 0
for k, pruned_len in zip(
logits_metadata.top_logprobs_nums,
logits_metadata.extend_logprob_pruned_lens_cpu,
):
if pruned_len <= 0:
input_top_logprobs_val.append([])
input_top_logprobs_idx.append([])
continue
input_top_logprobs_val.append(
[values[pt + j][:k] for j in range(pruned_len)]
)
input_top_logprobs_idx.append(
[indices[pt + j][:k] for j in range(pruned_len)]
)
pt += pruned_len
return input_top_logprobs_val, input_top_logprobs_idx
@staticmethod
def get_token_ids_logprobs(
all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
):
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
pt = 0
for token_ids, pruned_len in zip(
logits_metadata.token_ids_logprobs,
logits_metadata.extend_logprob_pruned_lens_cpu,
):
if pruned_len <= 0:
input_token_ids_logprobs_val.append([])
input_token_ids_logprobs_idx.append([])
continue
input_token_ids_logprobs_val.append(
[all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
)
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
pt += pruned_len
return input_token_ids_logprobs_val, input_token_ids_logprobs_idx
@staticmethod
def compute_temp_top_p_normalized_logprobs(
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
) -> torch.Tensor:
"""
compute logprobs for the output token from the given logits.
Returns:
torch.Tensor: logprobs from logits
"""
# Scale logits if temperature scaling is enabled
if logits_metadata.temp_scaled_logprobs:
last_logits = last_logits / logits_metadata.temperature
# Normalize logprobs if top_p normalization is enabled
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
if (
logits_metadata.top_p_normalized_logprobs
and (logits_metadata.top_p != 1.0).any()
):
from sglang.srt.layers.sampler import top_p_normalize_probs_torch
probs = torch.softmax(last_logits, dim=-1)
del last_logits
probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p)
return torch.log(probs)
else:
return torch.nn.functional.log_softmax(last_logits, dim=-1)
@triton.jit
def fused_softcap_kernel(
full_logits_ptr,
softcapping_value,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0).to(tl.int64)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# Load values
x = tl.load(full_logits_ptr + offsets, mask=mask)
# Perform operations in-place
x = x / softcapping_value
# Manual tanh implementation using exp
exp2x = tl.exp(2 * x)
x = (exp2x - 1) / (exp2x + 1)
x = x * softcapping_value
# Store result
tl.store(full_logits_ptr + offsets, x, mask=mask)
def fused_softcap(full_logits, final_logit_softcapping):
n_elements = full_logits.numel()
BLOCK_SIZE = 1024
grid = ((n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE, 1, 1)
fused_softcap_kernel[grid](
full_logits_ptr=full_logits,
softcapping_value=final_logit_softcapping,
n_elements=n_elements,
BLOCK_SIZE=BLOCK_SIZE,
)
return full_logits