224 lines
7.9 KiB
Plaintext
224 lines
7.9 KiB
Plaintext
/*
|
|
* Copyright (c) 2025 by SGLang team.
|
|
* Copyright (c) 2024-2025 by FlashInfer team.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#ifndef SPECULATIVE_SAMPLING_CUH_
|
|
#define SPECULATIVE_SAMPLING_CUH_
|
|
|
|
#include <assert.h>
|
|
|
|
#include <flashinfer/sampling.cuh>
|
|
|
|
namespace flashinfer {
|
|
|
|
namespace sampling {
|
|
|
|
using namespace cub;
|
|
|
|
template <
|
|
uint32_t BLOCK_THREADS,
|
|
BlockScanAlgorithm SCAN_ALGORITHM,
|
|
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
|
uint32_t VEC_SIZE,
|
|
bool DETERMINISTIC,
|
|
typename DType,
|
|
typename IdType>
|
|
__global__ void TreeSpeculativeSamplingTargetOnly(
|
|
IdType* predicts, // mutable
|
|
IdType* accept_index, // mutable
|
|
IdType* accept_token_num, // mutable
|
|
IdType* candidates,
|
|
IdType* retrive_index,
|
|
IdType* retrive_next_token,
|
|
IdType* retrive_next_sibling,
|
|
DType* uniform_samples,
|
|
DType* target_probs,
|
|
DType* draft_probs,
|
|
uint32_t batch_size,
|
|
uint32_t num_speculative_tokens,
|
|
uint32_t num_draft_tokens,
|
|
uint32_t d,
|
|
DType threshold_single,
|
|
DType threshold_acc) {
|
|
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
|
|
|
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
|
uint8_t smem_sampling[];
|
|
auto& temp_storage =
|
|
reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
|
|
|
|
DType prob_acc = 0.0;
|
|
uint32_t cur_prob_offset = bx * num_draft_tokens * d;
|
|
DType coin = uniform_samples[bx * num_draft_tokens];
|
|
IdType last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
|
|
accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx;
|
|
uint32_t num_accepted_tokens = 0;
|
|
IdType cur_index = 0;
|
|
|
|
for (uint32_t j = 1; j < num_speculative_tokens; ++j) {
|
|
cur_index = retrive_next_token[bx * num_draft_tokens + cur_index];
|
|
while (cur_index != -1) {
|
|
IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index];
|
|
IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index];
|
|
DType target_prob_single = target_probs[cur_prob_offset + draft_token_id];
|
|
prob_acc += target_prob_single;
|
|
|
|
if (coin <= prob_acc / threshold_acc || target_prob_single >= threshold_single) {
|
|
// accept token
|
|
prob_acc = 0.;
|
|
cur_prob_offset = (bx * num_draft_tokens + cur_index) * d;
|
|
coin = uniform_samples[bx * num_draft_tokens + cur_index];
|
|
predicts[last_accepted_retrive_idx] = draft_token_id;
|
|
++num_accepted_tokens;
|
|
accept_index[bx * num_speculative_tokens + num_accepted_tokens] = draft_index;
|
|
last_accepted_retrive_idx = draft_index;
|
|
break;
|
|
} else {
|
|
// FIXME: leverage draft probs
|
|
draft_probs[cur_prob_offset + draft_token_id] = target_probs[cur_prob_offset + draft_token_id];
|
|
cur_index = retrive_next_sibling[bx * num_draft_tokens + cur_index];
|
|
}
|
|
}
|
|
if (cur_index == -1) break;
|
|
}
|
|
accept_token_num[bx] = num_accepted_tokens;
|
|
|
|
// sample from relu(target_probs - draft_probs)
|
|
DType sum_relu_q_minus_p(0);
|
|
vec_t<DType, VEC_SIZE> q_vec, p_vec;
|
|
DType relu_q_minus_p[VEC_SIZE];
|
|
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
|
q_vec.fill(DType(0));
|
|
p_vec.fill(DType(0));
|
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
|
q_vec.load(target_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
|
if (num_accepted_tokens != num_speculative_tokens - 1) {
|
|
// there is no draft_probs for the bonus token
|
|
p_vec.load(draft_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
|
}
|
|
}
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
|
relu_q_minus_p[j] = max(q_vec[j] - p_vec[j], DType(0));
|
|
}
|
|
sum_relu_q_minus_p += BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
|
.Sum<VEC_SIZE>(relu_q_minus_p);
|
|
__syncthreads();
|
|
}
|
|
if (tx == 0) {
|
|
temp_storage.block_aggregate.value = sum_relu_q_minus_p;
|
|
}
|
|
// init the first rejected token to (d - 1)
|
|
temp_storage.sampled_id = d - 1;
|
|
__syncthreads();
|
|
sum_relu_q_minus_p = temp_storage.block_aggregate.value;
|
|
DType u = coin * sum_relu_q_minus_p;
|
|
|
|
DType aggregate_relu_q_minus_p(0);
|
|
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
|
q_vec.fill(DType(0));
|
|
p_vec.fill(DType(0));
|
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
|
q_vec.load(target_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
|
if (num_accepted_tokens != num_speculative_tokens - 1) {
|
|
// there is no draft_probs for the bonus token
|
|
p_vec.load(draft_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
|
}
|
|
}
|
|
|
|
vec_t<DType, VEC_SIZE> relu_q_minus_p_vec;
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
|
relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0));
|
|
}
|
|
|
|
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC, DType>(
|
|
i, d, [&](DType x) { return x > 0; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p, &temp_storage);
|
|
if (aggregate_relu_q_minus_p > u) {
|
|
break;
|
|
}
|
|
}
|
|
__syncthreads();
|
|
// set the first rejected token
|
|
predicts[last_accepted_retrive_idx] = temp_storage.sampled_id;
|
|
// value at not used indices are undefined
|
|
}
|
|
|
|
template <typename DType, typename IdType>
|
|
cudaError_t TreeSpeculativeSamplingTargetOnly(
|
|
IdType* predicts, // mutable
|
|
IdType* output_token_ids, // mutable
|
|
IdType* output_accepted_token_num, // mutable
|
|
IdType* candidates,
|
|
IdType* retrive_index,
|
|
IdType* retrive_next_token,
|
|
IdType* retrive_next_sibling,
|
|
DType* uniform_samples,
|
|
DType* target_probs,
|
|
DType* draft_probs,
|
|
uint32_t batch_size,
|
|
uint32_t num_speculative_tokens,
|
|
uint32_t num_draft_tokens,
|
|
uint32_t d,
|
|
DType threshold_single = 1,
|
|
DType threshold_acc = 1,
|
|
bool deterministic = true,
|
|
cudaStream_t stream = 0) {
|
|
constexpr uint32_t BLOCK_THREADS = 1024;
|
|
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
|
|
|
|
const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
|
dim3 nblks(batch_size);
|
|
dim3 nthrs(BLOCK_THREADS);
|
|
float capped_threshold_acc = fmaxf(threshold_acc, 1e-9f);
|
|
void* args[] = {
|
|
&predicts,
|
|
&output_token_ids,
|
|
&output_accepted_token_num,
|
|
&candidates,
|
|
&retrive_index,
|
|
&retrive_next_token,
|
|
&retrive_next_sibling,
|
|
&uniform_samples,
|
|
&target_probs,
|
|
&draft_probs,
|
|
&batch_size,
|
|
&num_speculative_tokens,
|
|
&num_draft_tokens,
|
|
&d,
|
|
&threshold_single,
|
|
&capped_threshold_acc};
|
|
DISPATCH_ALIGNED_VEC_SIZE(
|
|
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
|
auto kernel = TreeSpeculativeSamplingTargetOnly<
|
|
BLOCK_THREADS,
|
|
SCAN_ALGO,
|
|
REDUCE_ALGO,
|
|
VEC_SIZE,
|
|
DETERMINISTIC,
|
|
DType,
|
|
IdType>;
|
|
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
|
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
|
})});
|
|
return cudaSuccess;
|
|
}
|
|
|
|
} // namespace sampling
|
|
|
|
} // namespace flashinfer
|
|
|
|
#endif // SPECULATIVE_SAMPLING_CUH_
|