#include #include #include #include #include "pytorch_extension_utils.h" // Helper functions to convert between different data types // (float, half, bfloat16) for the merge attention states kernel. inline __device__ float to_float(float u) { return u; } inline __device__ float to_float(half u) { return __half2float(u); } inline __device__ float to_float(__nv_bfloat16 u) { return __bfloat162float(u); } inline __device__ void from_float(float& d, float s) { d = s; } inline __device__ void from_float(half& d, float s) { d = __float2half(s); } inline __device__ void from_float(__nv_bfloat16& d, float s) { d = __float2bfloat16(s); } // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 template __global__ void merge_attn_states_kernel( scalar_t* output, float* output_lse, const scalar_t* prefix_output, const float* prefix_lse, const scalar_t* suffix_output, const float* suffix_lse, const uint num_tokens, const uint num_heads, const uint head_size) { using pack_128b_t = uint4; const uint pack_size = 16 / sizeof(scalar_t); const uint threads_per_head = head_size / pack_size; const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x; const uint token_head_threads = num_tokens * num_heads * threads_per_head; if (global_idx >= token_head_threads) return; // global_idx -> token_idx + head_idx + pack_idx const uint token_head_idx = global_idx / threads_per_head; const uint pack_idx = global_idx % threads_per_head; const uint token_idx = token_head_idx / num_heads; const uint head_idx = token_head_idx % num_heads; const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc. const uint head_offset = token_idx * num_heads * head_size + head_idx * head_size; const scalar_t* prefix_head_ptr = prefix_output + head_offset; const scalar_t* suffix_head_ptr = suffix_output + head_offset; scalar_t* output_head_ptr = output + head_offset; // float p_lse = prefix_lse[head_idx * num_tokens + token_idx]; // float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; float p_lse = prefix_lse[token_idx * num_heads + head_idx]; float s_lse = suffix_lse[token_idx * num_heads + head_idx]; p_lse = std::isinf(p_lse) ? -std::numeric_limits::infinity() : p_lse; s_lse = std::isinf(s_lse) ? -std::numeric_limits::infinity() : s_lse; const float max_lse = fmaxf(p_lse, s_lse); p_lse = p_lse - max_lse; s_lse = s_lse - max_lse; const float p_se = expf(p_lse); const float s_se = expf(s_lse); const float out_se = p_se + s_se; const float p_scale = p_se / out_se; const float s_scale = s_se / out_se; if (pack_offset < head_size) { // Pack 128b load pack_128b_t p_out_pack = reinterpret_cast(prefix_head_ptr)[pack_offset / pack_size]; pack_128b_t s_out_pack = reinterpret_cast(suffix_head_ptr)[pack_offset / pack_size]; pack_128b_t o_out_pack; #pragma unroll for (uint i = 0; i < pack_size; ++i) { // Always use float for FMA to keep high precision. // half(uint16_t), bfloat16, float -> float. const float p_out_f = to_float(reinterpret_cast(&p_out_pack)[i]); const float s_out_f = to_float(reinterpret_cast(&s_out_pack)[i]); // fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale) const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale); // float -> half(uint16_t), bfloat16, float. from_float(reinterpret_cast(&o_out_pack)[i], o_out_f); } // Pack 128b storage reinterpret_cast(output_head_ptr)[pack_offset / pack_size] = o_out_pack; } // We only need to write to output_lse once per head. if (output_lse != nullptr && pack_idx == 0) { float out_lse = logf(out_se) + max_lse; output_lse[token_idx * num_heads + head_idx] = out_lse; } } // The following macro is used to dispatch the conversion function based on // the output data type. The FN is a macro that calls a function with // template. #define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \ { \ if (scalar_dtype == at::ScalarType::Float) { \ fn(float); \ } else if (scalar_dtype == at::ScalarType::Half) { \ fn(half); \ } else if (scalar_dtype == at::ScalarType::BFloat16) { \ fn(__nv_bfloat16); \ } else { \ TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \ } \ } #define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \ { \ merge_attn_states_kernel<<>>( \ reinterpret_cast(output.data_ptr()), \ reinterpret_cast(output_lse.data_ptr()), \ reinterpret_cast(prefix_output.data_ptr()), \ reinterpret_cast(prefix_lse.data_ptr()), \ reinterpret_cast(suffix_output.data_ptr()), \ reinterpret_cast(suffix_lse.data_ptr()), \ num_tokens, \ num_heads, \ head_size); \ } /*@brief Merges the attention states from prefix and suffix * into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d * * @param output [n,h,d] The output tensor to store the merged attention states. * @param output_lse [h,d] Optional tensor to store the log-sum-exp values. * @param prefix_output [n,h,d] The prefix attention states. * @param prefix_lse [n,h] The log-sum-exp values for the prefix attention * states. * @param suffix_output [n,h,d] The suffix attention states. * @param suffix_lse [n,h] The log-sum-exp values for the suffix attention * states. */ template void merge_attn_states_launcher( const at::Tensor& prefix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] const at::Tensor& prefix_lse, // [NUM_TOKENS, NUM_HEADS] const at::Tensor& suffix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] const at::Tensor& suffix_lse, // [NUM_TOKENS, NUM_HEADS] at::Tensor& output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] at::Tensor& output_lse // [NUM_TOKENS, NUM_HEADS] ) { constexpr uint NUM_THREADS = 128; const uint num_tokens = output.size(0); const uint num_heads = output.size(1); const uint head_size = output.size(2); const uint pack_size = 16 / sizeof(scalar_t); TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size); // Process one pack elements per thread. for float, the // pack_size is 4 for half/bf16, the pack_size is 8. const uint threads_per_head = head_size / pack_size; const uint total_threads = num_tokens * num_heads * threads_per_head; dim3 block(NUM_THREADS); dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS); const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device()); auto stream = at::cuda::getCurrentCUDAStream(); LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS); } #define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \ { merge_attn_states_launcher(v_a, s_a, v_b, s_b, v_merged, s_merged); } void merge_state_v2( at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged) { // Input tensors must be contiguous CHECK_INPUT(v_a); // v_a prefix_output (seq_len, num_heads, head_dim) CHECK_INPUT(s_a); // s_a prefix_lse (seq_len, num_heads) CHECK_INPUT(v_b); // v_b suffix_output (seq_len, num_heads, head_dim) CHECK_INPUT(s_b); // s_b suffix_lse (seq_len, num_heads) // v_merged output (seq_len, num_heads, head_dim) // s_merged output_lse (seq_len, num_heads) auto device = v_a.device(); CHECK_EQ(s_a.device(), device); CHECK_EQ(v_b.device(), device); CHECK_EQ(s_b.device(), device); CHECK_DIM(3, v_a); CHECK_DIM(2, s_a); CHECK_DIM(3, v_b); CHECK_DIM(2, s_b); CHECK_SHAPE(v_a, v_b); CHECK_SHAPE(s_a, s_b); CHECK_EQ(v_a.size(0), s_a.size(0)); CHECK_EQ(v_a.size(1), s_b.size(1)); DISPATCH_BY_SCALAR_DTYPE(v_merged.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER); }