80 lines
3.4 KiB
Plaintext
80 lines
3.4 KiB
Plaintext
/*
|
|
* Copyright (c) 2024 by FlashInfer team.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#include <flashinfer/sampling.cuh>
|
|
|
|
#include "pytorch_extension_utils.h"
|
|
|
|
using namespace flashinfer;
|
|
|
|
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs,
|
|
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val) {
|
|
CHECK_INPUT(probs);
|
|
auto device = probs.device();
|
|
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
|
|
unsigned int batch_size = probs.size(0);
|
|
unsigned int vocab_size = probs.size(1);
|
|
bool has_top_p_arr = maybe_top_p_arr.has_value();
|
|
|
|
const c10::cuda::OptionalCUDAGuard device_guard(device);
|
|
auto stream = at::cuda::getCurrentCUDAStream();
|
|
cudaError_t status = sampling::TopPRenormProb<float>(
|
|
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()),
|
|
has_top_p_arr ? static_cast<float*>(maybe_top_p_arr->data_ptr()) : nullptr, batch_size,
|
|
top_p_val, vocab_size, stream);
|
|
TORCH_CHECK(status == cudaSuccess,
|
|
"TopPRenormProb failed with error code " + std::string(cudaGetErrorString(status)));
|
|
}
|
|
|
|
void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs,
|
|
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val) {
|
|
CHECK_INPUT(probs);
|
|
auto device = probs.device();
|
|
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
|
|
unsigned int batch_size = probs.size(0);
|
|
unsigned int vocab_size = probs.size(1);
|
|
bool has_top_k_arr = maybe_top_k_arr.has_value();
|
|
|
|
const c10::cuda::OptionalCUDAGuard device_guard(device);
|
|
auto stream = at::cuda::getCurrentCUDAStream();
|
|
cudaError_t status = sampling::TopKRenormProb<float>(
|
|
static_cast<float*>(probs.data_ptr()), static_cast<float*>(renorm_probs.data_ptr()),
|
|
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr->data_ptr()) : nullptr, batch_size,
|
|
top_k_val, vocab_size, stream);
|
|
|
|
TORCH_CHECK(status == cudaSuccess,
|
|
"TopKRenormProb failed with error code " + std::string(cudaGetErrorString(status)));
|
|
}
|
|
|
|
void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits,
|
|
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val) {
|
|
CHECK_INPUT(logits);
|
|
auto device = logits.device();
|
|
CHECK_DIM(2, logits); // logits: (batch_size, vocab_size)
|
|
unsigned int batch_size = logits.size(0);
|
|
unsigned int vocab_size = logits.size(1);
|
|
bool has_top_k_arr = maybe_top_k_arr.has_value();
|
|
|
|
const c10::cuda::OptionalCUDAGuard device_guard(device);
|
|
auto stream = at::cuda::getCurrentCUDAStream();
|
|
cudaError_t status = sampling::TopKMaskLogits<float>(
|
|
static_cast<float*>(logits.data_ptr()), static_cast<float*>(mask_logits.data_ptr()),
|
|
has_top_k_arr ? static_cast<int*>(maybe_top_k_arr->data_ptr()) : nullptr, batch_size,
|
|
top_k_val, vocab_size, stream);
|
|
|
|
TORCH_CHECK(status == cudaSuccess,
|
|
"TopKMaskLogits failed with error code " + std::string(cudaGetErrorString(status)));
|
|
}
|