// Adapted from // https://github.com/mlc-ai/xgrammar/blob/v0.1.18/python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu /* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // clang-format off #include #include #include #include #include #if !defined(CUDA_VERSION) || CUDA_VERSION < 12040 void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional indices = at::nullopt) { TORCH_CHECK(false, "CUDA version must be >= 12.4 for ApplyTokenBitmaskInplace"); } #else #ifndef CUDART_INF_FP16 #define CUDART_INF_FP16 __ushort_as_half((unsigned short)0x7C00U) #endif #ifndef CUDART_INF_BF16 #define CUDART_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U) #endif constexpr int32_t BITS_PER_BLOCK = 32; constexpr int32_t THREADS_PER_THREAD_BLOCK = 256; template __device__ T NegativeInfinity() { return -INFINITY; } template <> __device__ __half NegativeInfinity<__half>() { return -CUDART_INF_FP16; } template <> __device__ __nv_bfloat16 NegativeInfinity<__nv_bfloat16>() { return -CUDART_INF_BF16; } template __device__ PackedT PackedNegativeInfinity() { constexpr int kAlignment = sizeof(PackedT) / sizeof(T); T packed[kAlignment]; #pragma unroll for (int i = 0; i < kAlignment; i++) { packed[i] = NegativeInfinity(); } return *reinterpret_cast(packed); } template __global__ void __launch_bounds__(THREADS_PER_THREAD_BLOCK) LogitsBitmaskKernel( T* __restrict__ logits, const int32_t* __restrict__ bitmask, const int32_t* __restrict__ indices, int32_t vocab_size, int32_t logits_stride, int32_t bitmask_stride) { constexpr int kAlignment = sizeof(PackedT) / sizeof(T); constexpr uint32_t kPackedMask = (1 << kAlignment) - 1; const int batch_idx = (indices == nullptr) ? blockIdx.y : indices[blockIdx.y]; const int block_offset = blockIdx.x * THREADS_PER_THREAD_BLOCK * kBitsPerThread; T* logits_gmem_ptr = logits + batch_idx * logits_stride + block_offset; const int32_t* bitmask_gmem_ptr = bitmask + batch_idx * bitmask_stride + block_offset / BITS_PER_BLOCK; const int bitmask_inner_idx = threadIdx.x % (BITS_PER_BLOCK / kAlignment); T logits_reg[kAlignment]; #pragma unroll for (int offset = threadIdx.x * kAlignment; offset < THREADS_PER_THREAD_BLOCK * kBitsPerThread; offset += THREADS_PER_THREAD_BLOCK * kAlignment) { if (block_offset + offset >= vocab_size) { break; } const uint32_t bitmask_val = (~bitmask_gmem_ptr[offset / BITS_PER_BLOCK] >> (bitmask_inner_idx * kAlignment)) & kPackedMask; if (bitmask_val == 0) { continue; } if (bitmask_val == kPackedMask) { *reinterpret_cast(logits_gmem_ptr + offset) = PackedNegativeInfinity(); continue; } *reinterpret_cast(logits_reg) = *reinterpret_cast(logits_gmem_ptr + offset); #pragma unroll for (int i = 0; i < kAlignment; i++) { if (((bitmask_val >> i) & 1)) { logits_reg[i] = NegativeInfinity(); } } *reinterpret_cast(logits_gmem_ptr + offset) = *reinterpret_cast(logits_reg); } } template ::value>> constexpr auto CeilDiv(T numerator, T denominator) { return (numerator + denominator - 1) / denominator; } template void ApplyTokenBitmaskInplaceDispatchToBitsPerThread( T* __restrict__ logits, const int32_t* __restrict__ bitmask, const int32_t* __restrict__ indices, int32_t vocab_size, int32_t logits_stride, int32_t bitmask_stride, int32_t num_rows) { constexpr int kAlignment = sizeof(PackedT) / sizeof(T); const int32_t num_blocks_per_row = CeilDiv(2048 / THREADS_PER_THREAD_BLOCK * 128, num_rows); const int32_t num_bits_per_thread = CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * num_blocks_per_row); const dim3 block(THREADS_PER_THREAD_BLOCK); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); if (num_bits_per_thread <= 4 && kAlignment <= 4) { const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 4), num_rows); LogitsBitmaskKernel <<>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); } else if (num_bits_per_thread <= 8 && kAlignment <= 8) { const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 8), num_rows); LogitsBitmaskKernel <<>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); } else if (num_bits_per_thread <= 16 && kAlignment <= 16) { const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 16), num_rows); LogitsBitmaskKernel <<>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); } else { const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 32), num_rows); LogitsBitmaskKernel <<>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); } } template void ApplyTokenBitmaskInplaceDispatchToPackedT( T* __restrict__ logits, const int32_t* __restrict__ bitmask, const int32_t* __restrict__ indices, int32_t vocab_size, int32_t logits_stride, int32_t bitmask_stride, int32_t num_rows) { if (logits_stride % (sizeof(float4) / sizeof(T)) == 0) { ApplyTokenBitmaskInplaceDispatchToBitsPerThread( logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows); } else { ApplyTokenBitmaskInplaceDispatchToBitsPerThread( logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows); } } void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional indices = at::nullopt) { TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor."); TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous."); TORCH_CHECK(logits.dim() == 1 || logits.dim() == 2, "logits must be a 1D or 2D tensor."); std::pair logits_shape = logits.dim() == 2 ? std::make_pair(static_cast(logits.size(0)), static_cast(logits.size(1))) : std::make_pair(1, static_cast(logits.size(0))); TORCH_CHECK(bitmask.is_cuda(), "bitmask must be a CUDA tensor."); TORCH_CHECK(bitmask.is_contiguous(), "bitmask must be contiguous."); TORCH_CHECK(bitmask.dim() == 1 || bitmask.dim() == 2, "bitmask must be a 1D or 2D tensor."); std::pair bitmask_shape = bitmask.dim() == 2 ? std::make_pair(static_cast(bitmask.size(0)), static_cast(bitmask.size(1))) : std::make_pair(1, static_cast(bitmask.size(0))); TORCH_CHECK(bitmask.dtype() == torch::kInt32, "bitmask must be of type int32."); TORCH_CHECK( (logits_shape.second + BITS_PER_BLOCK - 1) / BITS_PER_BLOCK >= bitmask_shape.second, "The provided logits's vocab size should be no less than the bitmask's vocab size " "(converted from bitmask size). But got vocab size ", logits_shape.second, " vs bitmask size ", bitmask_shape.second); int vocab_size = std::min(logits_shape.second, bitmask_shape.second * BITS_PER_BLOCK); int32_t num_rows = logits_shape.first; int32_t* indices_ptr = nullptr; if (indices) { TORCH_CHECK(indices->is_cuda(), "indices must be a CUDA tensor."); TORCH_CHECK(indices->is_contiguous(), "indices must be contiguous."); TORCH_CHECK(indices->dim() == 1, "indices must be a 1D tensor."); TORCH_CHECK(indices->dtype() == torch::kInt32, "indices must be of type int32."); num_rows = indices->size(0); indices_ptr = indices->data_ptr(); } else { TORCH_CHECK(logits_shape.first == bitmask_shape.first, "logits and bitmask must have the same batch size."); } switch (logits.scalar_type()) { case torch::kFloat32: { ApplyTokenBitmaskInplaceDispatchToPackedT( logits.data_ptr(), bitmask.data_ptr(), indices_ptr, vocab_size, logits_shape.second, bitmask_shape.second, num_rows); break; } case torch::kFloat16: { ApplyTokenBitmaskInplaceDispatchToPackedT( reinterpret_cast<__half*>(logits.data_ptr()), bitmask.data_ptr(), indices_ptr, vocab_size, logits_shape.second, bitmask_shape.second, num_rows); break; } case torch::kBFloat16: { ApplyTokenBitmaskInplaceDispatchToPackedT( reinterpret_cast<__nv_bfloat16*>(logits.data_ptr()), bitmask.data_ptr(), indices_ptr, vocab_size, logits_shape.second, bitmask_shape.second, num_rows); break; } default: TORCH_CHECK(false, "logits dtype must be float, half or bfloat16."); break; } } #endif // clang-format on