sglang_v0.5.2/flashinfer_0.3.1/tvm_binding/sampling.cu

33 lines
1.5 KiB
Plaintext

#include <flashinfer/attention/hopper/attention_updater.cuh>
#include <flashinfer/attention/hopper/variant_helper.cuh>
#include <flashinfer/cutlass_utils.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/math.cuh>
#include <flashinfer/sampling.cuh>
#include "tvm_binding_utils.h"
using namespace flashinfer;
// TODO: change the philox seeds and offsets to DLTensor once the underlying API for sampling
// changes to support multiple seeds
void SamplingFromProbs(DLTensor* probs, DLTensor* output, DLTensor* maybe_indices,
bool deterministic, uint64_t philox_seed, uint64_t philox_offset,
int64_t cuda_stream) {
CHECK(probs->ndim == 2) << "Probs should have 2 dimensions";
unsigned int batch_size = output->shape[0];
unsigned int vocab_size = probs->shape[1];
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
float* probs_cast = static_cast<float*>(probs->data) + probs->byte_offset;
int* output_cast = static_cast<int*>(output->data) + output->byte_offset;
int* maybe_indices_cast =
maybe_indices ? static_cast<int*>(maybe_indices->data) + maybe_indices->byte_offset : nullptr;
cudaError_t status =
sampling::SamplingFromProb(probs_cast, output_cast, maybe_indices_cast, batch_size,
vocab_size, deterministic, philox_seed, philox_offset, stream);
CHECK(status == cudaSuccess) << "SamplingFromProbs failed with error "
<< cudaGetErrorString(status);
}