/* * Copyright (c) 2025 by SGLang team. * Copyright (c) 2024-2025 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef SPECULATIVE_SAMPLING_CUH_ #define SPECULATIVE_SAMPLING_CUH_ #include #include namespace flashinfer { namespace sampling { using namespace cub; template < uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC, typename DType, typename IdType> __global__ void TreeSpeculativeSamplingTargetOnly( IdType* predicts, // mutable IdType* accept_index, // mutable IdType* accept_token_num, // mutable IdType* candidates, IdType* retrive_index, IdType* retrive_next_token, IdType* retrive_next_sibling, DType* uniform_samples, DType* target_probs, DType* draft_probs, uint32_t batch_size, uint32_t num_speculative_tokens, uint32_t num_draft_tokens, uint32_t d, DType threshold_single, DType threshold_acc) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; extern __shared__ __align__(alignof(SamplingTempStorage)) uint8_t smem_sampling[]; auto& temp_storage = reinterpret_cast&>(smem_sampling); DType prob_acc = 0.0; uint32_t cur_prob_offset = bx * num_draft_tokens * d; DType coin = uniform_samples[bx * num_draft_tokens]; IdType last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens]; accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx; uint32_t num_accepted_tokens = 0; IdType cur_index = 0; for (uint32_t j = 1; j < num_speculative_tokens; ++j) { cur_index = retrive_next_token[bx * num_draft_tokens + cur_index]; while (cur_index != -1) { IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index]; IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index]; DType target_prob_single = target_probs[cur_prob_offset + draft_token_id]; prob_acc += target_prob_single; if (coin <= prob_acc / threshold_acc || target_prob_single >= threshold_single) { // accept token prob_acc = 0.; cur_prob_offset = (bx * num_draft_tokens + cur_index) * d; coin = uniform_samples[bx * num_draft_tokens + cur_index]; predicts[last_accepted_retrive_idx] = draft_token_id; ++num_accepted_tokens; accept_index[bx * num_speculative_tokens + num_accepted_tokens] = draft_index; last_accepted_retrive_idx = draft_index; break; } else { // FIXME: leverage draft probs draft_probs[cur_prob_offset + draft_token_id] = target_probs[cur_prob_offset + draft_token_id]; cur_index = retrive_next_sibling[bx * num_draft_tokens + cur_index]; } } if (cur_index == -1) break; } accept_token_num[bx] = num_accepted_tokens; // sample from relu(target_probs - draft_probs) DType sum_relu_q_minus_p(0); vec_t q_vec, p_vec; DType relu_q_minus_p[VEC_SIZE]; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { q_vec.fill(DType(0)); p_vec.fill(DType(0)); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { q_vec.load(target_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); if (num_accepted_tokens != num_speculative_tokens - 1) { // there is no draft_probs for the bonus token p_vec.load(draft_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); } } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { relu_q_minus_p[j] = max(q_vec[j] - p_vec[j], DType(0)); } sum_relu_q_minus_p += BlockReduce(temp_storage.block_prim.reduce) .Sum(relu_q_minus_p); __syncthreads(); } if (tx == 0) { temp_storage.block_aggregate.value = sum_relu_q_minus_p; } // init the first rejected token to (d - 1) temp_storage.sampled_id = d - 1; __syncthreads(); sum_relu_q_minus_p = temp_storage.block_aggregate.value; DType u = coin * sum_relu_q_minus_p; DType aggregate_relu_q_minus_p(0); for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { q_vec.fill(DType(0)); p_vec.fill(DType(0)); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { q_vec.load(target_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); if (num_accepted_tokens != num_speculative_tokens - 1) { // there is no draft_probs for the bonus token p_vec.load(draft_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); } } vec_t relu_q_minus_p_vec; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0)); } DeviceSamplingFromProb( i, d, [&](DType x) { return x > 0; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p, &temp_storage); if (aggregate_relu_q_minus_p > u) { break; } } __syncthreads(); // set the first rejected token predicts[last_accepted_retrive_idx] = temp_storage.sampled_id; // value at not used indices are undefined } template cudaError_t TreeSpeculativeSamplingTargetOnly( IdType* predicts, // mutable IdType* output_token_ids, // mutable IdType* output_accepted_token_num, // mutable IdType* candidates, IdType* retrive_index, IdType* retrive_next_token, IdType* retrive_next_sibling, DType* uniform_samples, DType* target_probs, DType* draft_probs, uint32_t batch_size, uint32_t num_speculative_tokens, uint32_t num_draft_tokens, uint32_t d, DType threshold_single = 1, DType threshold_acc = 1, bool deterministic = true, cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); float capped_threshold_acc = fmaxf(threshold_acc, 1e-9f); void* args[] = { &predicts, &output_token_ids, &output_accepted_token_num, &candidates, &retrive_index, &retrive_next_token, &retrive_next_sibling, &uniform_samples, &target_probs, &draft_probs, &batch_size, &num_speculative_tokens, &num_draft_tokens, &d, &threshold_single, &capped_threshold_acc}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { auto kernel = TreeSpeculativeSamplingTargetOnly< BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO, VEC_SIZE, DETERMINISTIC, DType, IdType>; FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); })}); return cudaSuccess; } } // namespace sampling } // namespace flashinfer #endif // SPECULATIVE_SAMPLING_CUH_