sglang_v0.5.2/flashinfer_0.3.1/csrc/renorm.cu

80 lines
3.4 KiB
Plaintext

/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/sampling.cuh>
#include "pytorch_extension_utils.h"
using namespace flashinfer;
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val) {
CHECK_INPUT(probs);
auto device = probs.device();
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
unsigned int batch_size = probs.size(0);
unsigned int vocab_size = probs.size(1);
bool has_top_p_arr = maybe_top_p_arr.has_value();
const c10::cuda::OptionalCUDAGuard device_guard(device);
auto stream = at::cuda::getCurrentCUDAStream();
cudaError_t status = sampling::TopPRenormProb<float>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()),
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr->data_ptr()) : nullptr, batch_size,
top_p_val, vocab_size, stream);
TORCH_CHECK(status == cudaSuccess,
"TopPRenormProb failed with error code " + std::string(cudaGetErrorString(status)));
}
void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val) {
CHECK_INPUT(probs);
auto device = probs.device();
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
unsigned int batch_size = probs.size(0);
unsigned int vocab_size = probs.size(1);
bool has_top_k_arr = maybe_top_k_arr.has_value();
const c10::cuda::OptionalCUDAGuard device_guard(device);
auto stream = at::cuda::getCurrentCUDAStream();
cudaError_t status = sampling::TopKRenormProb<float>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr->data_ptr()) : nullptr, batch_size,
top_k_val, vocab_size, stream);
TORCH_CHECK(status == cudaSuccess,
"TopKRenormProb failed with error code " + std::string(cudaGetErrorString(status)));
}
void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits,
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val) {
CHECK_INPUT(logits);
auto device = logits.device();
CHECK_DIM(2, logits); // logits: (batch_size, vocab_size)
unsigned int batch_size = logits.size(0);
unsigned int vocab_size = logits.size(1);
bool has_top_k_arr = maybe_top_k_arr.has_value();
const c10::cuda::OptionalCUDAGuard device_guard(device);
auto stream = at::cuda::getCurrentCUDAStream();
cudaError_t status = sampling::TopKMaskLogits<float>(
static_cast<float*>(logits.data_ptr()), static_cast<float*>(mask_logits.data_ptr()),
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr->data_ptr()) : nullptr, batch_size,
top_k_val, vocab_size, stream);
TORCH_CHECK(status == cudaSuccess,
"TopKMaskLogits failed with error code " + std::string(cudaGetErrorString(status)));
}