287 lines
11 KiB
Plaintext
287 lines
11 KiB
Plaintext
/*
|
|
* Copyright (c) 2025 by SGLang team.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
|
|
#ifndef USE_ROCM
|
|
#include "pytorch_extension_utils.h"
|
|
#else
|
|
#include "pytorch_extension_utils_rocm.h"
|
|
#endif
|
|
|
|
// parent_list [bs, topk * (depth - 1) + 1)]
|
|
// selected_index [bs, draft_token_num - 1]
|
|
// verified_seq_len [bs]
|
|
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
|
|
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
|
|
// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token]
|
|
__global__ void build_tree_efficient(
|
|
int64_t* parent_list,
|
|
int64_t* selected_index,
|
|
int32_t* verified_seq_len,
|
|
bool* tree_mask,
|
|
int64_t* positions,
|
|
int64_t* retrive_index,
|
|
int64_t* retrive_next_token,
|
|
int64_t* retrive_next_sibling,
|
|
int topk,
|
|
int depth,
|
|
int draft_token_num) {
|
|
int bid = blockIdx.x;
|
|
int tid = threadIdx.x;
|
|
|
|
if (tid >= draft_token_num) {
|
|
return;
|
|
}
|
|
int seq_tree_idx = draft_token_num * draft_token_num * bid;
|
|
for (int i = 0; i < bid; i++) {
|
|
seq_tree_idx += verified_seq_len[i] * draft_token_num;
|
|
}
|
|
int seq_len = verified_seq_len[bid];
|
|
int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
|
|
for (int i = 0; i < draft_token_num - 1; i++) {
|
|
tree_mask[token_tree_idx + i] = false;
|
|
}
|
|
|
|
int position = 0;
|
|
if (tid == 0) {
|
|
positions[bid * draft_token_num] = seq_len;
|
|
|
|
int retrive_index_offset = bid * draft_token_num;
|
|
for (int i = draft_token_num - 1; i > 0; --i) {
|
|
int current_token_idx = retrive_index_offset + i;
|
|
retrive_index[bid * draft_token_num + i] = current_token_idx;
|
|
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk;
|
|
int parent_position = 0;
|
|
if (parent_tb_idx > 0) {
|
|
int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
|
|
for (; parent_position < draft_token_num; ++parent_position) {
|
|
if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) {
|
|
++parent_position;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
if (parent_position == draft_token_num) {
|
|
printf(
|
|
"WARNING: invalid eagle tree!!! Detected a token with no parent token selected. "
|
|
"Please check if the logprob has nan. The token will be ignored to keep proceeding.\n");
|
|
continue;
|
|
}
|
|
|
|
if (retrive_next_token[bid * draft_token_num + parent_position] == -1) {
|
|
retrive_next_token[bid * draft_token_num + parent_position] = i;
|
|
} else {
|
|
int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position];
|
|
retrive_next_token[bid * draft_token_num + parent_position] = i;
|
|
retrive_next_sibling[bid * draft_token_num + i] = origin_next_token;
|
|
}
|
|
}
|
|
retrive_index[bid * draft_token_num] = bid * draft_token_num;
|
|
} else {
|
|
int cur_position = tid - 1;
|
|
while (true) {
|
|
position += 1;
|
|
tree_mask[token_tree_idx + cur_position] = true;
|
|
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
|
|
if (parent_tb_idx == 0) {
|
|
break;
|
|
}
|
|
|
|
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
|
|
for (cur_position = 0; cur_position < draft_token_num; ++cur_position) {
|
|
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
positions[bid * draft_token_num + tid] = position + seq_len;
|
|
}
|
|
}
|
|
|
|
void build_tree_kernel_efficient(
|
|
at::Tensor parent_list,
|
|
at::Tensor selected_index,
|
|
at::Tensor verified_seq_len,
|
|
at::Tensor tree_mask,
|
|
at::Tensor positions,
|
|
at::Tensor retrive_index,
|
|
at::Tensor retrive_next_token,
|
|
at::Tensor retrive_next_sibling,
|
|
int64_t topk,
|
|
int64_t depth,
|
|
int64_t draft_token_num) {
|
|
// TODO (ying) check shape
|
|
// TODO (ying) check type
|
|
int bs = parent_list.size(0);
|
|
dim3 grid(bs);
|
|
dim3 block(draft_token_num);
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
build_tree_efficient<<<grid, block, 0, stream>>>(
|
|
static_cast<int64_t*>(parent_list.data_ptr()),
|
|
static_cast<int64_t*>(selected_index.data_ptr()),
|
|
static_cast<int32_t*>(verified_seq_len.data_ptr()),
|
|
static_cast<bool*>(tree_mask.data_ptr()),
|
|
static_cast<int64_t*>(positions.data_ptr()),
|
|
static_cast<int64_t*>(retrive_index.data_ptr()),
|
|
static_cast<int64_t*>(retrive_next_token.data_ptr()),
|
|
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
|
|
int32_t(topk),
|
|
int32_t(depth),
|
|
int32_t(draft_token_num));
|
|
}
|
|
|
|
template <typename IdType>
|
|
__global__ void VerifyTreeGreedy(
|
|
IdType* predicts,
|
|
IdType* accept_index,
|
|
IdType* accept_token_num, // mutable
|
|
IdType* candidates,
|
|
IdType* retrive_index,
|
|
IdType* retrive_next_token,
|
|
IdType* retrive_next_sibling,
|
|
IdType* target_predict,
|
|
uint32_t batch_size,
|
|
uint32_t num_speculative_tokens,
|
|
uint32_t num_draft_tokens) {
|
|
uint32_t bx = blockIdx.x;
|
|
|
|
IdType last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
|
|
accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx;
|
|
uint32_t num_accepted_tokens = 0;
|
|
IdType cur_index = 0;
|
|
|
|
for (uint32_t j = 1; j < num_speculative_tokens; ++j) {
|
|
cur_index = retrive_next_token[bx * num_draft_tokens + cur_index];
|
|
while (cur_index != -1) {
|
|
IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index];
|
|
IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index];
|
|
IdType target_token_id = target_predict[last_accepted_retrive_idx];
|
|
|
|
if (draft_token_id == target_token_id) {
|
|
// accept token
|
|
predicts[last_accepted_retrive_idx] = target_token_id;
|
|
++num_accepted_tokens;
|
|
accept_index[bx * num_speculative_tokens + num_accepted_tokens] = draft_index;
|
|
last_accepted_retrive_idx = draft_index;
|
|
break;
|
|
} else {
|
|
cur_index = retrive_next_sibling[bx * num_draft_tokens + cur_index];
|
|
}
|
|
}
|
|
if (cur_index == -1) break;
|
|
}
|
|
accept_token_num[bx] = num_accepted_tokens;
|
|
predicts[last_accepted_retrive_idx] = target_predict[last_accepted_retrive_idx];
|
|
}
|
|
|
|
// predicts: [tot_num_draft_tokens]
|
|
// accept_index: [bs, num_spec_step]
|
|
// accept_token_num: [bs]
|
|
// candidates: [bs, num_draft_tokens]
|
|
// retrive_index: [bs, num_draft_tokens]
|
|
// retrive_next_token: [bs, num_draft_tokens]
|
|
// retrive_next_sibling: [bs, num_draft_tokens]
|
|
// target_predict: [bs, num_draft_tokens]
|
|
void verify_tree_greedy(
|
|
at::Tensor predicts,
|
|
at::Tensor accept_index,
|
|
at::Tensor accept_token_num, // mutable
|
|
at::Tensor candidates,
|
|
at::Tensor retrive_index,
|
|
at::Tensor retrive_next_token,
|
|
at::Tensor retrive_next_sibling,
|
|
at::Tensor target_predict,
|
|
int64_t cuda_stream = 0) {
|
|
CHECK_INPUT(candidates);
|
|
CHECK_INPUT(retrive_index);
|
|
CHECK_INPUT(retrive_next_token);
|
|
CHECK_INPUT(retrive_next_sibling);
|
|
CHECK_INPUT(target_predict);
|
|
auto device = target_predict.device();
|
|
CHECK_EQ(candidates.device(), device);
|
|
CHECK_EQ(retrive_index.device(), device);
|
|
CHECK_EQ(retrive_next_token.device(), device);
|
|
CHECK_EQ(retrive_next_sibling.device(), device);
|
|
CHECK_EQ(target_predict.device(), device);
|
|
CHECK_DIM(1, predicts);
|
|
CHECK_DIM(2, accept_index);
|
|
CHECK_DIM(1, accept_token_num);
|
|
CHECK_DIM(2, candidates);
|
|
CHECK_DIM(2, retrive_index);
|
|
CHECK_DIM(2, retrive_next_token);
|
|
CHECK_DIM(2, retrive_next_sibling);
|
|
CHECK_DIM(2, target_predict);
|
|
unsigned int batch_size = candidates.size(0);
|
|
unsigned int num_spec_step = accept_index.size(1);
|
|
unsigned int num_draft_tokens = candidates.size(1);
|
|
CHECK_EQ(batch_size, accept_index.size(0));
|
|
CHECK_EQ(batch_size, accept_token_num.size(0));
|
|
CHECK_EQ(batch_size, retrive_index.size(0));
|
|
CHECK_EQ(batch_size, retrive_next_token.size(0));
|
|
CHECK_EQ(batch_size, retrive_next_sibling.size(0));
|
|
CHECK_EQ(batch_size, target_predict.size(0));
|
|
CHECK_EQ(num_draft_tokens, retrive_index.size(1));
|
|
CHECK_EQ(num_draft_tokens, retrive_next_token.size(1));
|
|
CHECK_EQ(num_draft_tokens, retrive_next_sibling.size(1));
|
|
CHECK_EQ(num_draft_tokens, target_predict.size(1));
|
|
CHECK_EQ(batch_size, accept_index.size(0));
|
|
CHECK_EQ(batch_size, accept_token_num.size(0));
|
|
if (predicts.scalar_type() != at::kInt) {
|
|
throw std::runtime_error("Expected 'predicts' to be of type int (torch.int32).");
|
|
}
|
|
if (accept_index.scalar_type() != at::kInt) {
|
|
throw std::runtime_error("Expected 'accept_index' to be of type int (torch.int32).");
|
|
}
|
|
if (accept_token_num.scalar_type() != at::kInt) {
|
|
throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32).");
|
|
}
|
|
if (candidates.scalar_type() != at::kInt) {
|
|
throw std::runtime_error("Expected 'candidates' to be of type int (torch.int32).");
|
|
}
|
|
if (retrive_index.scalar_type() != at::kInt) {
|
|
throw std::runtime_error("Expected 'retrive_index' to be of type int (torch.int32).");
|
|
}
|
|
if (retrive_next_token.scalar_type() != at::kInt) {
|
|
throw std::runtime_error("Expected 'retrive_next_token' to be of type int (torch.int32).");
|
|
}
|
|
if (retrive_next_sibling.scalar_type() != at::kInt) {
|
|
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type int (torch.int32).");
|
|
}
|
|
if (target_predict.scalar_type() != at::kInt) {
|
|
throw std::runtime_error("Expected 'target_predict' to be of type int (torch.int32).");
|
|
}
|
|
|
|
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
|
dim3 grid(batch_size);
|
|
dim3 block(1);
|
|
|
|
VerifyTreeGreedy<int><<<grid, block, 0, stream>>>(
|
|
static_cast<int*>(predicts.data_ptr()),
|
|
static_cast<int*>(accept_index.data_ptr()),
|
|
static_cast<int*>(accept_token_num.data_ptr()),
|
|
static_cast<int*>(candidates.data_ptr()),
|
|
static_cast<int*>(retrive_index.data_ptr()),
|
|
static_cast<int*>(retrive_next_token.data_ptr()),
|
|
static_cast<int*>(retrive_next_sibling.data_ptr()),
|
|
static_cast<int*>(target_predict.data_ptr()),
|
|
batch_size,
|
|
num_spec_step,
|
|
num_draft_tokens);
|
|
}
|