/* * Copyright (c) 2023 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef FLASHINFER_MMA_CUH_ #define FLASHINFER_MMA_CUH_ #include #include #include #include #include namespace flashinfer { namespace mma { #if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120400) #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890)) #define FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED #endif #endif #if (__CUDACC_VER_MAJOR__ >= 11) #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900)) #define FLASHINFER_STMATRIX_M8N8X4_ENABLED #endif #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) #define FLASHINFER_MMA_F16F16F32_M16N8K16_ENABLED #define FLASHINFER_MMA_F16F16F16_M16N8K16_ENABLED #endif #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 750)) #define FLASHINFER_MMA_F16F16F32_M16N8K8_ENABLED #define FLASHINFER_MMA_F16F16F16_M16N8K8_ENABLED #define FLASHINFER_LDMATRIX_M8N8X4_ENABLED #endif #endif #if defined(__CUDA_ARCH__) #define FLASHINFER_RUNTIME_ASSERT(x) __brkpt() #else #define FLASHINFER_RUNTIME_ASSERT(x) assert(0 && x) #endif enum class MMAMode { kInit = 0U, kInplaceUpdate = 1U, }; /*! * \brief Wrapper of PTX ldmatrix m8n8.x4 instruction, loads data from shared memory * to fragment * \tparam T data type of the fragment * \param R pointer to the fragment * \param smem_ptr pointer to the shared memory */ template __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t* R, T* smem_ptr) { #ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) : "r"(smem_int_ptr)); #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); #endif } /*! * \brief Wrapper of PTX ldmatrix m8n8.x4 instruction, loads data from shared memory * to fragment * \tparam T data type of the fragment * \param R pointer to the fragment * \param smem_ptr pointer to the shared memory */ template __device__ __forceinline__ void ldmatrix_m8n8x4_left_half(uint32_t* R, T* smem_ptr) { #ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, _, %1, _}, [%2];\n" : "=r"(R[0]), "=r"(R[1]) : "r"(smem_int_ptr)); #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); #endif } /*! * \brief Wrapper of PTX ldmatrix m8n8.x4 instruction, loads data from shared memory * to fragment * \tparam T data type of the fragment * \param R pointer to the fragment * \param smem_ptr pointer to the shared memory */ template __device__ __forceinline__ void ldmatrix_m8n8x4_right_half(uint32_t* R, T* smem_ptr) { #ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {_, %0, _, %1}, [%2];\n" : "=r"(R[0]), "=r"(R[1]) : "r"(smem_int_ptr)); #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); #endif } /*! * \brief Wrapper of PTX ldmatrix m8n8.x4 transposed instruction, loads data from * shared memory to fragment and transposes the fragment * \tparam T data type of the fragment * \param R pointer to the fragment * \param smem_ptr pointer to the shared memory */ template __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t* R, T* smem_ptr) { #ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) : "r"(smem_int_ptr)); #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); #endif } /*! * \brief Wrapper of PTX ldmatrix m8n8.x4 transposed instruction, loads data from * shared memory to fragment and transposes the fragment * \tparam T data type of the fragment * \param R pointer to the fragment * \param smem_ptr pointer to the shared memory */ template __device__ __forceinline__ void ldmatrix_m8n8x4_trans_left_half(uint32_t* R, T* smem_ptr) { #ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, _, _}, [%2];\n" : "=r"(R[0]), "=r"(R[1]) : "r"(smem_int_ptr)); #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); #endif } /*! * \brief Wrapper of PTX ldmatrix m8n8.x4 transposed instruction, loads data from * shared memory to fragment and transposes the fragment * \tparam T data type of the fragment * \param R pointer to the fragment * \param smem_ptr pointer to the shared memory */ template __device__ __forceinline__ void ldmatrix_m8n8x4_trans_right_half(uint32_t* R, T* smem_ptr) { #ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {_, _, %0, %1}, [%2];\n" : "=r"(R[0]), "=r"(R[1]) : "r"(smem_int_ptr)); #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); #endif } /*! * \brief Wrapper of PTX stmatrix m8n8.x4 instruction, stores data from fragment * to shared memory * \tparam T data type of the fragment * \param R pointer to the fragment * \param smem_ptr pointer to the shared memory */ template __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t* R, T* smem_ptr) { #ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n" : : "r"(smem_int_ptr), "r"(R[0]), "r"(R[1]), "r"(R[2]), "r"(R[3])); #else // Fallback implementation, slower than PTX instruction const uint32_t tx = threadIdx.x; uint4 word; #pragma unroll for (uint32_t reg_id = 0; reg_id < 4; ++reg_id) { word.x = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4); word.y = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 1); word.z = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 2); word.w = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 3); if (tx / 8 == reg_id) { *(uint4*)smem_ptr = word; } } #endif } /*! * \brief Wrapper of two mma m16n8k32 instructions for row major and column major f8 matrix * multiplication, accumulated in f32. * \tparam T data type of the fragment * \tparam mma_mode whether we are initializing the accumulator or updating it * \param C pointer to the accumulator * \param A pointer to the fragment of matrix A * \param B pointer to the fragment of matrix B */ template __device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32(float* C, uint32_t* A, uint32_t* B) { static_assert(sizeof(T) == 1, "DType must be 8bit floating data type"); #if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED) if constexpr (mma_mode == MMAMode::kInit) { if constexpr (std::is_same_v) { asm volatile( "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); asm volatile( "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); } else { // e5m2 asm volatile( "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); asm volatile( "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); } } else { if constexpr (std::is_same_v) { asm volatile( "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); asm volatile( "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); } else { // e5m2 asm volatile( "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); asm volatile( "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); } } #else FLASHINFER_RUNTIME_ASSERT( "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"); #endif } /*! * \brief Wrapper of two mma m16n8k16 instructions for row major and column major f16 matrix * multiplication, accumulated in f32. * \tparam T data type of the fragment * \tparam mma_mode whether we are initializing the accumulator or updating it * \param C pointer to the accumulator * \param A pointer to the fragment of matrix A * \param B pointer to the fragment of matrix B */ template __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, uint32_t* A, uint32_t* B) { #if defined(FLASHINFER_MMA_F16F16F32_M16N8K16_ENABLED) if constexpr (mma_mode == MMAMode::kInit) { if constexpr (std::is_same_v) { asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); } else { asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); } } else { if constexpr (std::is_same_v) { asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); } else { asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0, %1, %2, %3}," "{%4, %5, %6, %7}," "{%8, %9}," "{%10, %11, %12, %13};\n" : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); } } #elif defined(FLASHINFER_MMA_F16F16F32_M16N8K8_ENABLED) if constexpr (std::is_same_v) { if constexpr (mma_mode == MMAMode::kInit) { asm volatile( "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}," "{%4, %5}," "{%6}," "{%7, %8, %9, %10};\n" : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); asm volatile( "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}," "{%4, %5}," "{%6}," "{%7, %8, %9, %10};\n" : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) : "r"(A[2]), "r"(A[3]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); asm volatile( "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}," "{%4, %5}," "{%6}," "{%7, %8, %9, %10};\n" : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) : "r"(A[0]), "r"(A[1]), "r"(B[2]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); asm volatile( "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}," "{%4, %5}," "{%6}," "{%7, %8, %9, %10};\n" : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) : "r"(A[2]), "r"(A[3]), "r"(B[3]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); } else { asm volatile( "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}," "{%4, %5}," "{%6}," "{%7, %8, %9, %10};\n" : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); asm volatile( "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}," "{%4, %5}," "{%6}," "{%7, %8, %9, %10};\n" : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) : "r"(A[2]), "r"(A[3]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); asm volatile( "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}," "{%4, %5}," "{%6}," "{%7, %8, %9, %10};\n" : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) : "r"(A[0]), "r"(A[1]), "r"(B[2]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); asm volatile( "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}," "{%4, %5}," "{%6}," "{%7, %8, %9, %10};\n" : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) : "r"(A[2]), "r"(A[3]), "r"(B[3]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); } } else { FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); } #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); #endif } /*! * \brief Use mma instructions to compute rowsum. */ template __device__ __forceinline__ void m16k32_rowsum_f8f8f32(float* d, DType* s) { static_assert(sizeof(DType) == 1, "DType must be 8bit floating data type"); uint32_t* s_u32 = (uint32_t*)(s); #if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED) if constexpr (std::is_same_v) { asm volatile( "{\n" "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " "{%0, _, %1, _}," "{%2, %3, %4, %5}," "{%6, %7}," "{%8, 0., %9, 0.};\n" "}\n" : "=f"(d[0]), "=f"(d[1]) : "r"(s_u32[0]), "r"(s_u32[1]), "r"(s_u32[2]), "r"(s_u32[3]), "r"(943208504), "r"(943208504), "f"(d[0]), "f"(d[1])); } else { // e5m2 asm volatile( "{\n" "mma.sync.aligned.m16n8k16.row.col.f32.e5m2.e5m2.f32 " "{%0, _, %1, _}," "{%2, %3, %4, %5}," "{%6, %7}," "{%8, 0., %9, 0.};\n" "}\n" : "=f"(d[0]), "=f"(d[1]) : "r"(s_u32[0]), "r"(s_u32[1]), "r"(s_u32[2]), "r"(s_u32[3]), "r"(1010580540), "r"(1010580540), "f"(d[0]), "f"(d[1])); } #else FLASHINFER_RUNTIME_ASSERT( "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"); #endif } /*! * \brief Use mma instructions to compute rowsum. */ template __device__ __forceinline__ void m16k16_rowsum_f16f16f32(float* d, DType* s) { static_assert(sizeof(DType) == 2, "DType must be 16bit floating data type"); uint32_t* s_u32 = (uint32_t*)(s); #if defined(FLASHINFER_MMA_F16F16F32_M16N8K16_ENABLED) if constexpr (std::is_same_v) { asm volatile( "{\n" "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " "{%0, _, %1, _}," "{%2, %3, %4, %5}," "{%6, %7}," "{%8, 0., %9, 0.};\n" "}\n" : "=f"(d[0]), "=f"(d[1]) : "r"(s_u32[0]), "r"(s_u32[1]), "r"(s_u32[2]), "r"(s_u32[3]), "r"(1006648320), "r"(1006648320), "f"(d[0]), "f"(d[1])); } else { asm volatile( "{\n" "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0, _, %1, _}," "{%2, %3, %4, %5}," "{%6, %7}," "{%8, 0., %9, 0.};\n" "}\n" : "=f"(d[0]), "=f"(d[1]) : "r"(s_u32[0]), "r"(s_u32[1]), "r"(s_u32[2]), "r"(s_u32[3]), "r"(1065369472), "r"(1065369472), "f"(d[0]), "f"(d[1])); } #elif defined(FLASHINFER_MMA_F16F16F32_M16N8K8_ENABLED) if constexpr (std::is_same_v) { asm volatile( "{\n" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{%0, _, %1, _}," "{%2, %3}," "{%4}," "{%5, 0., %6, 0.};\n" "}\n" : "=f"(d[0]), "=f"(d[1]) : "r"(s_u32[0]), "r"(s_u32[1]), "r"(1006648320), "f"(d[0]), "f"(d[1])); asm volatile( "{\n" "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " "{%0, _, %1, _}," "{%2, %3}," "{%4}," "{%5, 0., %6, 0.};\n" "}\n" : "=f"(d[0]), "=f"(d[1]) : "r"(s_u32[2]), "r"(s_u32[3]), "r"(1006648320), "f"(d[0]), "f"(d[1])); } else { FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); } #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); #endif } /*! * \brief Wrapper of two mma m16n8k16 instructions for row major and column major f16 matrix * multiplication, accumulated in f16. * \tparam mma_mode whether we are initializing the accumulator or updating it * \param C pointer to the accumulator * \param A pointer to the fragment of matrix A * \param B pointer to the fragment of matrix B */ template __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f16(uint32_t* C, uint32_t* A, uint32_t* B) { #if defined(FLASHINFER_MMA_F16F16F16_M16N8K16_ENABLED) if constexpr (mma_mode == MMAMode::kInit) { asm volatile( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1}," "{%2, %3, %4, %5}," "{%6, %7}," "{%8, %9};\n" : "=r"(C[0]), "=r"(C[1]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0)); asm volatile( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1}," "{%2, %3, %4, %5}," "{%6, %7}," "{%8, %9};\n" : "=r"(C[2]), "=r"(C[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(0), "r"(0)); } else { asm volatile( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1}," "{%2, %3, %4, %5}," "{%6, %7}," "{%8, %9};\n" : "=r"(C[0]), "=r"(C[1]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1])); asm volatile( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1}," "{%2, %3, %4, %5}," "{%6, %7}," "{%8, %9};\n" : "=r"(C[2]), "=r"(C[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(C[2]), "r"(C[3])); } #elif defined(FLASHINFER_MMA_F16F16F16_M16N8K8_ENABLED) if constexpr (mma_mode == MMAMode::kInit) { asm volatile( "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{%0, %1}," "{%2, %3}," "{%4}," "{%5, %6};\n" : "=r"(C[0]), "=r"(C[1]) : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(0), "r"(0)); asm volatile( "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{%0, %1}," "{%2, %3}," "{%4}," "{%5, %6};\n" : "=r"(C[0]), "=r"(C[1]) : "r"(A[2]), "r"(A[3]), "r"(B[1]), "r"(0), "r"(0)); asm volatile( "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{%0, %1}," "{%2, %3}," "{%4}," "{%5, %6};\n" : "=r"(C[2]), "=r"(C[3]) : "r"(A[0]), "r"(A[1]), "r"(B[2]), "r"(0), "r"(0)); asm volatile( "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{%0, %1}," "{%2, %3}," "{%4}," "{%5, %6};\n" : "=r"(C[2]), "=r"(C[3]) : "r"(A[2]), "r"(A[3]), "r"(B[3]), "r"(0), "r"(0)); } else { asm volatile( "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{%0, %1}," "{%2, %3}," "{%4}," "{%5, %6};\n" : "=r"(C[0]), "=r"(C[1]) : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1])); asm volatile( "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{%0, %1}," "{%2, %3}," "{%4}," "{%5, %6};\n" : "=r"(C[0]), "=r"(C[1]) : "r"(A[2]), "r"(A[3]), "r"(B[1]), "r"(C[0]), "r"(C[1])); asm volatile( "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{%0, %1}," "{%2, %3}," "{%4}," "{%5, %6};\n" : "=r"(C[2]), "=r"(C[3]) : "r"(A[0]), "r"(A[1]), "r"(B[2]), "r"(C[2]), "r"(C[3])); asm volatile( "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{%0, %1}," "{%2, %3}," "{%4}," "{%5, %6};\n" : "=r"(C[2]), "=r"(C[3]) : "r"(A[2]), "r"(A[3]), "r"(B[3]), "r"(C[2]), "r"(C[3])); } #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); #endif } } // namespace mma } // namespace flashinfer #endif // FLASHINFER_MMA_CUH_