/* * Copyright (c) 2023 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef FLASHINFER_ATTENTION_HOPPER_HEADER_CUH_ #define FLASHINFER_ATTENTION_HOPPER_HEADER_CUH_ #include #include #include #include #include #include #include #include #include #include "../permuted_smem.cuh" namespace flashinfer { using namespace cute::SM90::GMMA; // using WGMMA_NN_64x32x16_F32BF16BF16_SS = template struct WGMMA_ASYNC_SS {}; template struct WGMMA_ASYNC_RS {}; #define EXPAND_FRAG_ARGS_4(x) x[0], x[1], x[2], x[3] #define EXPAND_FRAG_ARGS_8(x) x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7] #define EXPAND_FRAG_ARGS_16(x) \ x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], x[10], x[11], x[12], x[13], x[14], \ x[15] #define EXPAND_FRAG_ARGS_32(x) \ x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], x[10], x[11], x[12], x[13], x[14], \ x[15], x[16], x[17], x[18], x[19], x[20], x[21], x[22], x[23], x[24], x[25], x[26], x[27], \ x[28], x[29], x[30], x[31] #define EXPAND_FRAG_ARGS_64(x) \ x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], x[10], x[11], x[12], x[13], x[14], \ x[15], x[16], x[17], x[18], x[19], x[20], x[21], x[22], x[23], x[24], x[25], x[26], x[27], \ x[28], x[29], x[30], x[31], x[32], x[33], x[34], x[35], x[36], x[37], x[38], x[39], x[40], \ x[41], x[42], x[43], x[44], x[45], x[46], x[47], x[48], x[49], x[50], x[51], x[52], x[53], \ x[54], x[55], x[56], x[57], x[58], x[59], x[60], x[61], x[62], x[63] #define EXPAND_FRAG_ARGS_128(x) \ x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], x[10], x[11], x[12], x[13], x[14], \ x[15], x[16], x[17], x[18], x[19], x[20], x[21], x[22], x[23], x[24], x[25], x[26], x[27], \ x[28], x[29], x[30], x[31], x[32], x[33], x[34], x[35], x[36], x[37], x[38], x[39], x[40], \ x[41], x[42], x[43], x[44], x[45], x[46], x[47], x[48], x[49], x[50], x[51], x[52], x[53], \ x[54], x[55], x[56], x[57], x[58], x[59], x[60], x[61], x[62], x[63], x[64], x[65], x[66], \ x[67], x[68], x[69], x[70], x[71], x[72], x[73], x[74], x[75], x[76], x[77], x[78], x[79], \ x[80], x[81], x[82], x[83], x[84], x[85], x[86], x[87], x[88], x[89], x[90], x[91], x[92], \ x[93], x[94], x[95], x[96], x[97], x[98], x[99], x[100], x[101], x[102], x[103], x[104], \ x[105], x[106], x[107], x[108], x[109], x[110], x[111], x[112], x[113], x[114], x[115], \ x[116], x[117], x[118], x[119], x[120], x[121], x[122], x[123], x[124], x[125], x[126], \ x[127] template struct WGMMA_ASYNC_RS<__half, float, 64, 16, 16, TransposeA, TransposeB, scaleA, scaleB> { template static __device__ __forceinline__ void op(uint32_t* a_frag, uint64_t desc_b, float* d_frag) { constexpr auto scale_d = init ? ScaleOut::Zero : ScaleOut::One; MMA_64x16x16_F32F16F16_RS::fma( EXPAND_FRAG_ARGS_4(a_frag), desc_b, EXPAND_FRAG_ARGS_8(d_frag), scale_d); } }; template struct WGMMA_ASYNC_SS<__half, float, 64, 16, 16, TransposeA, TransposeB, scaleA, scaleB> { template static __device__ __forceinline__ void op(uint64_t desc_a, uint64_t desc_b, float* d_frag) { constexpr auto scale_d = init ? ScaleOut::Zero : ScaleOut::One; MMA_64x16x16_F32F16F16_SS::fma( desc_a, desc_b, EXPAND_FRAG_ARGS_8(d_frag), scale_d); } }; template struct WGMMA_ASYNC_RS<__half, float, 64, 32, 16, TransposeA, TransposeB, scaleA, scaleB> { template static __device__ __forceinline__ void op(uint32_t* a_frag, uint64_t desc_b, float* d_frag) { constexpr auto scale_d = init ? ScaleOut::Zero : ScaleOut::One; MMA_64x32x16_F32F16F16_RS::fma( EXPAND_FRAG_ARGS_4(a_frag), desc_b, EXPAND_FRAG_ARGS_16(d_frag), scale_d); } }; template struct WGMMA_ASYNC_SS<__half, float, 64, 32, 16, TransposeA, TransposeB, scaleA, scaleB> { template static __device__ __forceinline__ void op(uint64_t desc_a, uint64_t desc_b, float* d_frag) { constexpr auto scale_d = init ? ScaleOut::Zero : ScaleOut::One; MMA_64x32x16_F32F16F16_SS::fma( desc_a, desc_b, EXPAND_FRAG_ARGS_16(d_frag), scale_d); } }; template struct WGMMA_ASYNC_RS<__half, float, 64, 64, 16, TransposeA, TransposeB, scaleA, scaleB> { template static __device__ __forceinline__ void op(uint32_t* a_frag, uint64_t desc_b, float* d_frag) { constexpr auto scale_d = init ? ScaleOut::Zero : ScaleOut::One; MMA_64x64x16_F32F16F16_RS::fma( EXPAND_FRAG_ARGS_4(a_frag), desc_b, EXPAND_FRAG_ARGS_32(d_frag), scale_d); } }; template struct WGMMA_ASYNC_SS<__half, float, 64, 64, 16, TransposeA, TransposeB, scaleA, scaleB> { template static __device__ __forceinline__ void op(uint64_t desc_a, uint64_t desc_b, float* d_frag) { constexpr auto scale_d = init ? ScaleOut::Zero : ScaleOut::One; MMA_64x64x16_F32F16F16_SS::fma( desc_a, desc_b, EXPAND_FRAG_ARGS_32(d_frag), scale_d); } }; template struct WGMMA_ASYNC_RS<__half, float, 64, 128, 16, TransposeA, TransposeB, scaleA, scaleB> { template static __device__ __forceinline__ void op(uint32_t* a_frag, uint64_t desc_b, float* d_frag) { constexpr auto scale_d = init ? ScaleOut::Zero : ScaleOut::One; MMA_64x128x16_F32F16F16_RS::fma( EXPAND_FRAG_ARGS_4(a_frag), desc_b, EXPAND_FRAG_ARGS_64(d_frag), scale_d); } }; template struct WGMMA_ASYNC_SS<__half, float, 64, 128, 16, TransposeA, TransposeB, scaleA, scaleB> { template static __device__ __forceinline__ void op(uint64_t desc_a, uint64_t desc_b, float* d_frag) { constexpr auto scale_d = init ? ScaleOut::Zero : ScaleOut::One; MMA_64x128x16_F32F16F16_SS::fma( desc_a, desc_b, EXPAND_FRAG_ARGS_64(d_frag), scale_d); } }; template struct WGMMA_ASYNC_RS<__half, float, 64, 256, 16, TransposeA, TransposeB, scaleA, scaleB> { template static __device__ __forceinline__ void op(uint32_t* a_frag, uint64_t desc_b, float* d_frag) { constexpr auto scale_d = init ? ScaleOut::Zero : ScaleOut::One; MMA_64x256x16_F32F16F16_RS::fma( EXPAND_FRAG_ARGS_4(a_frag), desc_b, EXPAND_FRAG_ARGS_128(d_frag), scale_d); } }; template struct WGMMA_ASYNC_SS<__half, float, 64, 256, 16, TransposeA, TransposeB, scaleA, scaleB> { template static __device__ __forceinline__ void op(uint64_t desc_a, uint64_t desc_b, float* d_frag) { constexpr auto scale_d = init ? ScaleOut::Zero : ScaleOut::One; MMA_64x256x16_F32F16F16_SS::fma( desc_a, desc_b, EXPAND_FRAG_ARGS_128(d_frag), scale_d); } }; template struct WGMMA_ASYNC_SS<__nv_bfloat16, float, 64, 16, 16, TransposeA, TransposeB, scaleA, scaleB> { template static __device__ __forceinline__ void op(uint64_t desc_a, uint64_t desc_b, float* d_frag) { constexpr auto scale_d = init ? ScaleOut::Zero : ScaleOut::One; MMA_64x16x16_F32BF16BF16_SS::fma( desc_a, desc_b, EXPAND_FRAG_ARGS_8(d_frag), scale_d); } }; template struct WGMMA_ASYNC_SS<__nv_bfloat16, float, 64, 32, 16, TransposeA, TransposeB, scaleA, scaleB> { template static __device__ __forceinline__ void op(uint64_t desc_a, uint64_t desc_b, float* d_frag) { constexpr auto scale_d = init ? ScaleOut::Zero : ScaleOut::One; MMA_64x32x16_F32BF16BF16_SS::fma( desc_a, desc_b, EXPAND_FRAG_ARGS_16(d_frag), scale_d); } }; template struct WGMMA_ASYNC_SS<__nv_bfloat16, float, 64, 64, 16, TransposeA, TransposeB, scaleA, scaleB> { template static __device__ __forceinline__ void op(uint64_t desc_a, uint64_t desc_b, float* d_frag) { constexpr auto scale_d = init ? ScaleOut::Zero : ScaleOut::One; MMA_64x64x16_F32BF16BF16_SS::fma( desc_a, desc_b, EXPAND_FRAG_ARGS_32(d_frag), scale_d); } }; template struct WGMMA_ASYNC_SS<__nv_bfloat16, float, 64, 128, 16, TransposeA, TransposeB, scaleA, scaleB> { template static __device__ __forceinline__ void op(uint64_t desc_a, uint64_t desc_b, float* d_frag) { constexpr auto scale_d = init ? ScaleOut::Zero : ScaleOut::One; MMA_64x128x16_F32BF16BF16_SS::fma( desc_a, desc_b, EXPAND_FRAG_ARGS_64(d_frag), scale_d); } }; template struct WGMMA_ASYNC_SS<__nv_bfloat16, float, 64, 256, 16, TransposeA, TransposeB, scaleA, scaleB> { template static __device__ __forceinline__ void op(uint64_t desc_a, uint64_t desc_b, float* d_frag) { constexpr auto scale_d = init ? ScaleOut::Zero : ScaleOut::One; MMA_64x256x16_F32BF16BF16_SS::fma( desc_a, desc_b, EXPAND_FRAG_ARGS_128(d_frag), scale_d); } }; using Swizzle128B = cute::Swizzle<3, 4, 3>; using Swizzle64B = cute::Swizzle<2, 4, 3>; using Swizzle32B = cute::Swizzle<1, 4, 3>; template __device__ __forceinline__ uint32_t get_swizzle_offset(uint32_t i, uint32_t j) { constexpr uint32_t M = 8; if constexpr (swizzle_mode == SwizzleMode::k128B) { constexpr uint32_t N = 8; return Swizzle128B{}(((i / M) * M * stride + ((j / N) * M + i % M) * N + (j % N)) << 4) >> 4; } else { constexpr uint32_t N = 4; return Swizzle64B{}(((i / M) * M * stride + ((j / N) * M + i % M) * N + (j % N)) << 4) >> 4; } } __device__ __forceinline__ uint64_t matrix_descriptor_encode(uint64_t x) { return (((x) & 0x3FFFF) >> 0x4); } template __device__ uint64_t make_smem_desc(T* ptr) { uint32_t addr = static_cast(__cvta_generic_to_shared(ptr)); uint64_t desc = 0x0000000000000000; desc |= matrix_descriptor_encode(addr); // leading byte offset desc |= matrix_descriptor_encode(leading_byte_offset) << 16; // stride byte offset desc |= matrix_descriptor_encode(stride_byte_offset) << 32; desc |= ((swizzle_mode == SwizzleMode::k128B) ? 1llu : (swizzle_mode == SwizzleMode::k64B) ? 2llu : 3llu) << 62; return desc; } __device__ __forceinline__ void warpgroup_arrive() { cute::warpgroup_arrive(); } template __device__ __forceinline__ void warpgroup_wait() { cute::warpgroup_wait(); } __device__ __forceinline__ void warpgroup_commit_batch() { cute::warpgroup_commit_batch(); } template __device__ __forceinline__ void warpgroup_fence_frag(float* frag) { #pragma unroll for (uint32_t i = 0; i < size; ++i) { cute::warpgroup_fence_operand(frag[i]); } } }; // namespace flashinfer #endif // FLASHINFER_ATTENTION_HOPPER_HEADER_CUH_