587 lines
24 KiB
Python
587 lines
24 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Logits processing."""
|
|
|
|
import dataclasses
|
|
import logging
|
|
from typing import List, Optional, Union
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
from torch import nn
|
|
|
|
from sglang.srt.distributed import (
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
tensor_model_parallel_all_gather,
|
|
)
|
|
from sglang.srt.layers.dp_attention import (
|
|
dp_gather_replicate,
|
|
dp_scatter,
|
|
get_attention_dp_rank,
|
|
get_attention_dp_size,
|
|
)
|
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
from sglang.srt.model_executor.forward_batch_info import (
|
|
CaptureHiddenMode,
|
|
ForwardBatch,
|
|
ForwardMode,
|
|
)
|
|
from sglang.srt.utils import dump_to_file
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class LogitsProcessorOutput:
|
|
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
|
next_token_logits: torch.Tensor
|
|
# Used by speculative decoding (EAGLE)
|
|
# The last hidden layers
|
|
hidden_states: Optional[torch.Tensor] = None
|
|
|
|
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
|
|
# The logprobs of the next tokens. shape: [#seq]
|
|
next_token_logprobs: Optional[torch.Tensor] = None
|
|
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
|
|
next_token_top_logprobs_val: Optional[List] = None
|
|
next_token_top_logprobs_idx: Optional[List] = None
|
|
# The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids)
|
|
next_token_token_ids_logprobs_val: Optional[List] = None
|
|
next_token_token_ids_logprobs_idx: Optional[List] = None
|
|
|
|
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
|
# The logprobs of input tokens. shape: [#token]
|
|
input_token_logprobs: Optional[torch.Tensor] = None
|
|
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
|
|
input_top_logprobs_val: List = None
|
|
input_top_logprobs_idx: List = None
|
|
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
|
|
input_token_ids_logprobs_val: Optional[List] = None
|
|
input_token_ids_logprobs_idx: Optional[List] = None
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class LogitsMetadata:
|
|
forward_mode: ForwardMode
|
|
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
|
|
|
|
extend_return_logprob: bool = False
|
|
extend_return_top_logprob: bool = False
|
|
extend_token_ids_logprob: bool = False
|
|
extend_seq_lens: Optional[torch.Tensor] = None
|
|
extend_seq_lens_cpu: Optional[List[int]] = None
|
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
|
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
|
top_logprobs_nums: Optional[List[int]] = None
|
|
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
|
token_ids_logprobs: Optional[List[List[int]]] = None
|
|
|
|
# logits and logprobs post processing
|
|
temp_scaled_logprobs: bool = False
|
|
temperature: torch.Tensor = None
|
|
top_p_normalized_logprobs: bool = False
|
|
top_p: torch.Tensor = None
|
|
|
|
# DP attention metadata. Not needed when DP attention is not used.
|
|
# Number of tokens in the request.
|
|
global_num_tokens_gpu: Optional[torch.Tensor] = None
|
|
# The start position of local hidden states.
|
|
dp_local_start_pos: Optional[torch.Tensor] = None
|
|
dp_local_num_tokens: Optional[torch.Tensor] = None
|
|
gathered_buffer: Optional[torch.Tensor] = None
|
|
# Buffer to gather logits from all ranks.
|
|
forward_batch_gathered_buffer: Optional[torch.Tensor] = None
|
|
# Number of tokens to sample per DP rank
|
|
global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
|
|
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
|
|
|
# for padding
|
|
padded_static_len: int = -1
|
|
|
|
@classmethod
|
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
|
if (
|
|
forward_batch.forward_mode.is_extend()
|
|
and forward_batch.return_logprob
|
|
and not forward_batch.forward_mode.is_target_verify()
|
|
):
|
|
extend_return_top_logprob = any(
|
|
x > 0 for x in forward_batch.top_logprobs_nums
|
|
)
|
|
extend_token_ids_logprob = any(
|
|
x is not None for x in forward_batch.token_ids_logprobs
|
|
)
|
|
extend_return_logprob = False
|
|
extend_logprob_pruned_lens_cpu = []
|
|
for extend_len, start_len in zip(
|
|
forward_batch.extend_seq_lens_cpu,
|
|
forward_batch.extend_logprob_start_lens_cpu,
|
|
):
|
|
if extend_len - start_len > 0:
|
|
extend_return_logprob = True
|
|
extend_logprob_pruned_lens_cpu.append(extend_len - start_len)
|
|
else:
|
|
extend_return_logprob = extend_return_top_logprob = (
|
|
extend_token_ids_logprob
|
|
) = extend_logprob_pruned_lens_cpu = False
|
|
|
|
return cls(
|
|
forward_mode=forward_batch.forward_mode,
|
|
capture_hidden_mode=forward_batch.capture_hidden_mode,
|
|
extend_return_logprob=extend_return_logprob,
|
|
extend_return_top_logprob=extend_return_top_logprob,
|
|
extend_token_ids_logprob=extend_token_ids_logprob,
|
|
extend_seq_lens=forward_batch.extend_seq_lens,
|
|
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
|
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
|
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
|
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
|
token_ids_logprobs=forward_batch.token_ids_logprobs,
|
|
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
|
|
padded_static_len=forward_batch.padded_static_len,
|
|
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
|
|
dp_local_start_pos=forward_batch.dp_local_start_pos,
|
|
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
|
|
gathered_buffer=forward_batch.gathered_buffer,
|
|
forward_batch_gathered_buffer=forward_batch.gathered_buffer,
|
|
global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
|
|
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
|
|
)
|
|
|
|
def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
|
|
if self.global_num_tokens_for_logprob_cpu is None:
|
|
# we are capturing cuda graph
|
|
return
|
|
|
|
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
|
dp_rank = get_attention_dp_rank()
|
|
if dp_rank == 0:
|
|
dp_local_start_pos = torch.zeros_like(
|
|
self.global_num_tokens_for_logprob_gpu[0]
|
|
)
|
|
else:
|
|
dp_local_start_pos = cumtokens[dp_rank - 1]
|
|
dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
|
|
gathered_buffer = torch.zeros(
|
|
(
|
|
sum(self.global_num_tokens_for_logprob_cpu),
|
|
hidden_states.shape[1],
|
|
),
|
|
dtype=hidden_states.dtype,
|
|
device=hidden_states.device,
|
|
)
|
|
|
|
self.dp_local_start_pos = dp_local_start_pos
|
|
self.dp_local_num_tokens = dp_local_num_tokens
|
|
self.gathered_buffer = gathered_buffer
|
|
|
|
|
|
class LogitsProcessor(nn.Module):
|
|
def __init__(
|
|
self, config, skip_all_gather: bool = False, logit_scale: Optional[float] = None
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.logit_scale = logit_scale
|
|
self.do_tensor_parallel_all_gather = (
|
|
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
|
)
|
|
self.do_tensor_parallel_all_gather_dp_attn = (
|
|
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
|
|
)
|
|
self.final_logit_softcapping = getattr(
|
|
self.config, "final_logit_softcapping", None
|
|
)
|
|
if (
|
|
self.final_logit_softcapping is not None
|
|
and self.final_logit_softcapping < 0
|
|
):
|
|
self.final_logit_softcapping = None
|
|
|
|
self.debug_tensor_dump_output_folder = global_server_args_dict.get(
|
|
"debug_tensor_dump_output_folder", None
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids,
|
|
hidden_states,
|
|
lm_head: VocabParallelEmbedding,
|
|
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
|
aux_hidden_states: Optional[torch.Tensor] = None,
|
|
) -> LogitsProcessorOutput:
|
|
if isinstance(logits_metadata, ForwardBatch):
|
|
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
|
# Get the last hidden states and last logits for the next token prediction
|
|
if (
|
|
logits_metadata.forward_mode.is_decode_or_idle()
|
|
or logits_metadata.forward_mode.is_target_verify()
|
|
):
|
|
pruned_states = hidden_states
|
|
if aux_hidden_states is not None:
|
|
aux_pruned_states = [hidden for hidden in aux_hidden_states]
|
|
sample_indices = None
|
|
input_logprob_indices = None
|
|
elif (
|
|
logits_metadata.forward_mode.is_extend()
|
|
and not logits_metadata.extend_return_logprob
|
|
):
|
|
# Prefill without input logprobs.
|
|
if logits_metadata.padded_static_len < 0:
|
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
|
else:
|
|
# If padding_static length is 5 and extended_seq_lens is [2, 3],
|
|
# then our batch looks like [t00, t01, p, p, p, t10, t11, t12, p, p]
|
|
# and this retrieves t01 and t12, which are the valid last tokens
|
|
idx = torch.arange(
|
|
len(logits_metadata.extend_seq_lens),
|
|
device=logits_metadata.extend_seq_lens.device,
|
|
)
|
|
last_index = (
|
|
idx * logits_metadata.padded_static_len
|
|
+ logits_metadata.extend_seq_lens
|
|
- 1
|
|
)
|
|
pruned_states = hidden_states[last_index]
|
|
if aux_hidden_states is not None:
|
|
aux_pruned_states = [hidden[last_index] for hidden in aux_hidden_states]
|
|
sample_indices = None
|
|
input_logprob_indices = None
|
|
else:
|
|
# Input logprobs are required.
|
|
# Find 3 different indices.
|
|
# 1. pruned_states: hidden states that we want logprobs from.
|
|
# 2. sample_indices: Indices that have sampled tokens.
|
|
# 3. input_logprob_indices: Indices that have input logprob tokens.
|
|
sample_index_pt = -1
|
|
sample_indices = []
|
|
input_logprob_indices_pt = 0
|
|
input_logprob_indices = []
|
|
pt, pruned_states = 0, []
|
|
for extend_logprob_start_len, extend_len in zip(
|
|
logits_metadata.extend_logprob_start_lens_cpu,
|
|
logits_metadata.extend_seq_lens_cpu,
|
|
):
|
|
# It can happen in chunked prefill. We still need to sample 1 token,
|
|
# But we don't want to include it in input logprob.
|
|
if extend_len == extend_logprob_start_len:
|
|
start_len = extend_logprob_start_len - 1
|
|
else:
|
|
start_len = extend_logprob_start_len
|
|
|
|
# We always need at least 1 token to sample because that's required
|
|
# by a caller.
|
|
assert extend_len > start_len
|
|
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
|
pt += extend_len
|
|
sample_index_pt += extend_len - start_len
|
|
sample_indices.append(sample_index_pt)
|
|
input_logprob_indices.extend(
|
|
[
|
|
input_logprob_indices_pt + i
|
|
for i in range(extend_len - extend_logprob_start_len)
|
|
]
|
|
)
|
|
input_logprob_indices_pt += extend_len - start_len
|
|
|
|
pruned_states = torch.cat(pruned_states)
|
|
sample_indices = torch.tensor(
|
|
sample_indices, device=pruned_states.device, dtype=torch.int64
|
|
)
|
|
input_logprob_indices = torch.tensor(
|
|
input_logprob_indices, device=pruned_states.device, dtype=torch.int64
|
|
)
|
|
|
|
# Compute logits for both input and sampled tokens.
|
|
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
|
|
sampled_logits = (
|
|
logits[sample_indices] if sample_indices is not None else logits
|
|
)
|
|
|
|
if self.debug_tensor_dump_output_folder:
|
|
assert (
|
|
not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
|
|
), "dp attention + sharded lm_head doesn't support full logits"
|
|
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
|
|
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
|
|
|
|
hidden_states_to_store: Optional[torch.Tensor] = None
|
|
if logits_metadata.capture_hidden_mode.need_capture():
|
|
if logits_metadata.capture_hidden_mode.is_full():
|
|
if aux_hidden_states is not None:
|
|
aux_hidden_states = torch.cat(aux_hidden_states, dim=-1)
|
|
hidden_states_to_store = aux_hidden_states
|
|
else:
|
|
hidden_states_to_store = hidden_states
|
|
elif logits_metadata.capture_hidden_mode.is_last():
|
|
# Get the last token hidden states. If sample_indices is None,
|
|
# pruned states only contain the last tokens already.
|
|
if aux_hidden_states is not None:
|
|
aux_pruned_states = torch.cat(aux_pruned_states, dim=-1)
|
|
hidden_states_to_store = (
|
|
aux_pruned_states[sample_indices]
|
|
if sample_indices
|
|
else aux_pruned_states
|
|
)
|
|
else:
|
|
hidden_states_to_store = (
|
|
pruned_states[sample_indices]
|
|
if sample_indices
|
|
else pruned_states
|
|
)
|
|
else:
|
|
assert False, "Should never reach"
|
|
|
|
if not logits_metadata.extend_return_logprob:
|
|
# Decode mode or extend mode without return_logprob.
|
|
return LogitsProcessorOutput(
|
|
next_token_logits=sampled_logits,
|
|
hidden_states=hidden_states_to_store,
|
|
)
|
|
else:
|
|
input_logprobs = logits[input_logprob_indices]
|
|
del hidden_states, logits
|
|
|
|
# Normalize the logprob w/o temperature, top-p
|
|
pruned_lens = torch.tensor(
|
|
logits_metadata.extend_logprob_pruned_lens_cpu,
|
|
device=input_logprobs.device,
|
|
)
|
|
if logits_metadata.temp_scaled_logprobs:
|
|
logits_metadata.temperature = torch.repeat_interleave(
|
|
logits_metadata.temperature.view(-1),
|
|
pruned_lens,
|
|
).view(-1, 1)
|
|
if logits_metadata.top_p_normalized_logprobs:
|
|
logits_metadata.top_p = torch.repeat_interleave(
|
|
logits_metadata.top_p,
|
|
pruned_lens,
|
|
)
|
|
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
|
input_logprobs, logits_metadata
|
|
)
|
|
|
|
# Get the logprob of top-k tokens
|
|
if logits_metadata.extend_return_top_logprob:
|
|
(
|
|
input_top_logprobs_val,
|
|
input_top_logprobs_idx,
|
|
) = self.get_top_logprobs(input_logprobs, logits_metadata)
|
|
else:
|
|
input_top_logprobs_val = input_top_logprobs_idx = None
|
|
|
|
# Get the logprob of given token id
|
|
if logits_metadata.extend_token_ids_logprob:
|
|
(
|
|
input_token_ids_logprobs_val,
|
|
input_token_ids_logprobs_idx,
|
|
) = self.get_token_ids_logprobs(input_logprobs, logits_metadata)
|
|
else:
|
|
input_token_ids_logprobs_val = input_token_ids_logprobs_idx = None
|
|
|
|
input_token_logprobs = input_logprobs[
|
|
torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
|
|
logits_metadata.extend_input_logprob_token_ids_gpu,
|
|
]
|
|
|
|
return LogitsProcessorOutput(
|
|
next_token_logits=sampled_logits,
|
|
input_token_logprobs=input_token_logprobs,
|
|
input_top_logprobs_val=input_top_logprobs_val,
|
|
input_top_logprobs_idx=input_top_logprobs_idx,
|
|
hidden_states=hidden_states_to_store,
|
|
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
|
|
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
|
|
)
|
|
|
|
def _get_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
lm_head: VocabParallelEmbedding,
|
|
logits_metadata: LogitsMetadata,
|
|
embedding_bias: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""Get logits from hidden_states.
|
|
|
|
If sampled_logits_only is True, it means hidden_states only contain the
|
|
last position (e.g., extend without input logprobs). The caller should
|
|
guarantee the given hidden_states follow this constraint.
|
|
"""
|
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
|
logits_metadata.compute_dp_attention_metadata(hidden_states)
|
|
hidden_states, local_hidden_states = (
|
|
logits_metadata.gathered_buffer,
|
|
hidden_states.clone(),
|
|
)
|
|
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
|
|
|
if hasattr(lm_head, "weight"):
|
|
logits = torch.matmul(
|
|
hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
|
|
)
|
|
else:
|
|
# GGUF models
|
|
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
|
|
|
|
if self.logit_scale is not None:
|
|
logits.mul_(self.logit_scale)
|
|
|
|
if self.do_tensor_parallel_all_gather:
|
|
logits = tensor_model_parallel_all_gather(logits)
|
|
|
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
|
logits, global_logits = (
|
|
torch.empty(
|
|
(local_hidden_states.shape[0], logits.shape[1]),
|
|
device=logits.device,
|
|
dtype=logits.dtype,
|
|
),
|
|
logits,
|
|
)
|
|
dp_scatter(logits, global_logits, logits_metadata)
|
|
|
|
logits = logits[:, : self.config.vocab_size].float()
|
|
|
|
if self.final_logit_softcapping:
|
|
fused_softcap(logits, self.final_logit_softcapping)
|
|
|
|
return logits
|
|
|
|
@staticmethod
|
|
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
|
max_k = max(logits_metadata.top_logprobs_nums)
|
|
ret = all_logprobs.topk(max_k, dim=1)
|
|
values = ret.values.tolist()
|
|
indices = ret.indices.tolist()
|
|
|
|
input_top_logprobs_val, input_top_logprobs_idx = [], []
|
|
|
|
pt = 0
|
|
for k, pruned_len in zip(
|
|
logits_metadata.top_logprobs_nums,
|
|
logits_metadata.extend_logprob_pruned_lens_cpu,
|
|
):
|
|
if pruned_len <= 0:
|
|
input_top_logprobs_val.append([])
|
|
input_top_logprobs_idx.append([])
|
|
continue
|
|
|
|
input_top_logprobs_val.append(
|
|
[values[pt + j][:k] for j in range(pruned_len)]
|
|
)
|
|
input_top_logprobs_idx.append(
|
|
[indices[pt + j][:k] for j in range(pruned_len)]
|
|
)
|
|
pt += pruned_len
|
|
|
|
return input_top_logprobs_val, input_top_logprobs_idx
|
|
|
|
@staticmethod
|
|
def get_token_ids_logprobs(
|
|
all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
|
|
):
|
|
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
|
|
pt = 0
|
|
for token_ids, pruned_len in zip(
|
|
logits_metadata.token_ids_logprobs,
|
|
logits_metadata.extend_logprob_pruned_lens_cpu,
|
|
):
|
|
if pruned_len <= 0:
|
|
input_token_ids_logprobs_val.append([])
|
|
input_token_ids_logprobs_idx.append([])
|
|
continue
|
|
|
|
input_token_ids_logprobs_val.append(
|
|
[all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
|
|
)
|
|
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
|
|
pt += pruned_len
|
|
|
|
return input_token_ids_logprobs_val, input_token_ids_logprobs_idx
|
|
|
|
@staticmethod
|
|
def compute_temp_top_p_normalized_logprobs(
|
|
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
|
|
) -> torch.Tensor:
|
|
"""
|
|
compute logprobs for the output token from the given logits.
|
|
|
|
Returns:
|
|
torch.Tensor: logprobs from logits
|
|
"""
|
|
# Scale logits if temperature scaling is enabled
|
|
if logits_metadata.temp_scaled_logprobs:
|
|
last_logits = last_logits / logits_metadata.temperature
|
|
|
|
# Normalize logprobs if top_p normalization is enabled
|
|
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
|
|
if (
|
|
logits_metadata.top_p_normalized_logprobs
|
|
and (logits_metadata.top_p != 1.0).any()
|
|
):
|
|
from sglang.srt.layers.sampler import top_p_normalize_probs_torch
|
|
|
|
probs = torch.softmax(last_logits, dim=-1)
|
|
del last_logits
|
|
probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p)
|
|
return torch.log(probs)
|
|
else:
|
|
return torch.nn.functional.log_softmax(last_logits, dim=-1)
|
|
|
|
|
|
@triton.jit
|
|
def fused_softcap_kernel(
|
|
full_logits_ptr,
|
|
softcapping_value,
|
|
n_elements,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(0).to(tl.int64)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
|
|
# Load values
|
|
x = tl.load(full_logits_ptr + offsets, mask=mask)
|
|
|
|
# Perform operations in-place
|
|
x = x / softcapping_value
|
|
|
|
# Manual tanh implementation using exp
|
|
exp2x = tl.exp(2 * x)
|
|
x = (exp2x - 1) / (exp2x + 1)
|
|
|
|
x = x * softcapping_value
|
|
|
|
# Store result
|
|
tl.store(full_logits_ptr + offsets, x, mask=mask)
|
|
|
|
|
|
def fused_softcap(full_logits, final_logit_softcapping):
|
|
n_elements = full_logits.numel()
|
|
BLOCK_SIZE = 1024
|
|
grid = ((n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE, 1, 1)
|
|
|
|
fused_softcap_kernel[grid](
|
|
full_logits_ptr=full_logits,
|
|
softcapping_value=final_logit_softcapping,
|
|
n_elements=n_elements,
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
)
|
|
return full_logits
|