/* * Copyright (c) 2024 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include "pytorch_extension_utils.h" using namespace flashinfer; void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_p_arr, double top_p_val) { CHECK_INPUT(probs); auto device = probs.device(); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) unsigned int batch_size = probs.size(0); unsigned int vocab_size = probs.size(1); bool has_top_p_arr = maybe_top_p_arr.has_value(); const c10::cuda::OptionalCUDAGuard device_guard(device); auto stream = at::cuda::getCurrentCUDAStream(); cudaError_t status = sampling::TopPRenormProb( static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), has_top_p_arr ? static_cast(maybe_top_p_arr->data_ptr()) : nullptr, batch_size, top_p_val, vocab_size, stream); TORCH_CHECK(status == cudaSuccess, "TopPRenormProb failed with error code " + std::string(cudaGetErrorString(status))); } void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_k_arr, int64_t top_k_val) { CHECK_INPUT(probs); auto device = probs.device(); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) unsigned int batch_size = probs.size(0); unsigned int vocab_size = probs.size(1); bool has_top_k_arr = maybe_top_k_arr.has_value(); const c10::cuda::OptionalCUDAGuard device_guard(device); auto stream = at::cuda::getCurrentCUDAStream(); cudaError_t status = sampling::TopKRenormProb( static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), has_top_k_arr ? static_cast(maybe_top_k_arr->data_ptr()) : nullptr, batch_size, top_k_val, vocab_size, stream); TORCH_CHECK(status == cudaSuccess, "TopKRenormProb failed with error code " + std::string(cudaGetErrorString(status))); } void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits, std::optional maybe_top_k_arr, int64_t top_k_val) { CHECK_INPUT(logits); auto device = logits.device(); CHECK_DIM(2, logits); // logits: (batch_size, vocab_size) unsigned int batch_size = logits.size(0); unsigned int vocab_size = logits.size(1); bool has_top_k_arr = maybe_top_k_arr.has_value(); const c10::cuda::OptionalCUDAGuard device_guard(device); auto stream = at::cuda::getCurrentCUDAStream(); cudaError_t status = sampling::TopKMaskLogits( static_cast(logits.data_ptr()), static_cast(mask_logits.data_ptr()), has_top_k_arr ? static_cast(maybe_top_k_arr->data_ptr()) : nullptr, batch_size, top_k_val, vocab_size, stream); TORCH_CHECK(status == cudaSuccess, "TopKMaskLogits failed with error code " + std::string(cudaGetErrorString(status))); }