#include "pytorch_extension_utils.h" template struct ConvertToFP8 { static __device__ __nv_fp8_storage_t convert_to_fp8(T value) { return 0; } }; template <> struct ConvertToFP8<__nv_bfloat16> { static __device__ __nv_fp8_storage_t convert_to_fp8(__nv_bfloat16 value) { return __nv_cvt_bfloat16raw_to_fp8(value, __NV_SATFINITE, __NV_E4M3); } }; template <> struct ConvertToFP8 { static __device__ __nv_fp8_storage_t convert_to_fp8(half value) { return __nv_cvt_halfraw_to_fp8(value, __NV_SATFINITE, __NV_E4M3); } }; template struct ConvertFromFloat { static __device__ T convert_from_float(float value) { return 0; } }; template <> struct ConvertFromFloat<__nv_bfloat16> { static __device__ __nv_bfloat16 convert_from_float(float value) { return __float2bfloat16(value); } }; template <> struct ConvertFromFloat { static __device__ half convert_from_float(float value) { return __float2half(value); } }; template __global__ void fused_downcast_kernel( const T* cache_k, const T* cache_v, const float* k_scale, const float* v_scale, __nv_fp8_storage_t* output_k, __nv_fp8_storage_t* output_v, const int input_sl, const int head, const int dim, const T max_fp8, const T min_fp8, const int64_t mult, const int64_t offset, const int64_t* loc) { // TODO: change name int token_idx = blockIdx.x; int thread_idx = threadIdx.x; int total_threads = blockDim.x; T k_scale_val = ConvertFromFloat::convert_from_float(k_scale[0]); T v_scale_val = ConvertFromFloat::convert_from_float(v_scale[0]); T k_scale_inv = static_cast(1.f) / k_scale_val; T v_scale_inv = static_cast(1.f) / v_scale_val; auto clamp = [&](T val) { return val > max_fp8 ? max_fp8 : (min_fp8 > val ? min_fp8 : val); }; if (token_idx < input_sl) { int out_seq_idx = loc[token_idx]; #pragma unroll for (int i = thread_idx; i < head * dim; i += total_threads) { int in_idx = token_idx * head * dim + i; int out_idx = (out_seq_idx * mult + offset) * head * dim + i; T k_val = cache_k[in_idx] * k_scale_inv; k_val = clamp(k_val); output_k[out_idx] = ConvertToFP8::convert_to_fp8(k_val); T v_val = cache_v[in_idx] * v_scale_inv; v_val = clamp(v_val); output_v[out_idx] = ConvertToFP8::convert_to_fp8(v_val); } } } template void downcast_fp8_impl( at::Tensor& k, at::Tensor& v, at::Tensor& k_out, at::Tensor& v_out, at::Tensor& k_scale, at::Tensor& v_scale, at::Tensor& loc, int64_t mult, int64_t offset, cudaStream_t stream) { CHECK_INPUT(k); CHECK_INPUT(v); CHECK_INPUT(k_out); CHECK_INPUT(v_out); CHECK_INPUT(k_scale); CHECK_INPUT(v_scale); CHECK_INPUT(loc); int64_t input_sl = k.size(0); int64_t head = k.size(1); int64_t dim = k.size(2); dim3 grid(input_sl * head); int vec_size = 8; dim3 block(std::min(int(dim) / vec_size, 1024)); const T max_fp8 = static_cast(448.0f); const T min_fp8 = static_cast(-448.0f); fused_downcast_kernel<<>>( static_cast(k.data_ptr()), static_cast(v.data_ptr()), static_cast(k_scale.data_ptr()), static_cast(v_scale.data_ptr()), static_cast<__nv_fp8_storage_t*>(k_out.data_ptr()), static_cast<__nv_fp8_storage_t*>(v_out.data_ptr()), input_sl, head, dim, max_fp8, min_fp8, mult, offset, static_cast(loc.data_ptr())); cudaError_t status = cudaGetLastError(); TORCH_CHECK(status == cudaSuccess, "Kernel launch failed: " + std::string(cudaGetErrorString(status))); } void downcast_fp8( at::Tensor& k, at::Tensor& v, at::Tensor& k_out, at::Tensor& v_out, at::Tensor& k_scale, at::Tensor& v_scale, at::Tensor& loc, int64_t mult, int64_t offset, int64_t cuda_stream) { CHECK_INPUT(k); CHECK_INPUT(v); CHECK_INPUT(k_out); CHECK_INPUT(v_out); cudaStream_t stream = reinterpret_cast(cuda_stream); switch (k.scalar_type()) { case at::ScalarType::BFloat16: downcast_fp8_impl<__nv_bfloat16>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream); break; case at::ScalarType::Half: downcast_fp8_impl<__half>(k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, stream); break; default: TORCH_CHECK(false, "Unsupported input type for downcast_fp8. Expected bfloat16 or float16."); } }