#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "cutlass_extensions/gemm/cutlass_gemm_caller.cuh" #include "cutlass_extensions/gemm/fp8_blockwise_gemm_sm90_dispatch.cuh" #include "utils.h" using namespace cute; template < typename OutType, typename MmaTileShape, typename PerSmTileShape, typename EpilogueTileShape, typename ScalesPerTile, int TileSizeM_ = 128, class ClusterShape = Shape<_1, _1, _1>> void launch_sm100_fp8_blockwise_scaled_mm( torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, const torch::Tensor& scales_b) { static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{}); static constexpr int ScaleGranularityM = size<0>(MmaTileShape{}) / ScaleMsPerTile; static constexpr int ScaleGranularityN = size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{}); static constexpr int ScaleGranularityK = size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{}); using ElementAB = cutlass::float_e4m3_t; using ElementA = ElementAB; using ElementB = ElementAB; using ElementC = void; using ElementD = OutType; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutD = cutlass::layout::RowMajor; using LayoutC = LayoutD; // This means both SFA and SFB are column-major. using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig< ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, cute::UMMA::Major::MN, cute::UMMA::Major::K>; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; static constexpr int AlignmentC = AlignmentD; using ElementAccumulator = float; using ElementBlockScale = float; using ElementCompute = float; using ArchTag = cutlass::arch::Sm100; using OperatorClass = cutlass::arch::OpClassTensorOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, cutlass::arch::OpClassTensorOp, PerSmTileShape, ClusterShape, EpilogueTileShape, ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, cutlass::epilogue::TmaWarpSpecialized1Sm>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, ElementA, cute::tuple, AlignmentA, ElementB, cute::tuple, AlignmentB, ElementAccumulator, MmaTileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::PersistentScheduler>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; Gemm gemm_op; int m = a.size(0); int k = a.size(1); int n = b.size(1); auto a_ptr = static_cast(a.data_ptr()); auto b_ptr = static_cast(b.data_ptr()); auto scales_a_ptr = static_cast(scales_a.data_ptr()); auto scales_b_ptr = static_cast(scales_b.data_ptr()); auto c_ptr = static_cast(out.data_ptr()); using StrideA = typename GemmKernel::StrideA; using StrideB = typename GemmKernel::StrideB; using StrideD = typename GemmKernel::StrideD; using StrideC = typename GemmKernel::StrideD; StrideA a_stride = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); StrideB b_stride = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); StrideC c_stride = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); LayoutSFA layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); LayoutSFB layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); typename GemmKernel::MainloopArguments mainloop_args{ a_ptr, a_stride, b_ptr, b_stride, scales_a_ptr, layout_SFA, scales_b_ptr, layout_SFB}; typename GemmKernel::EpilogueArguments epilogue_args{{}, c_ptr, c_stride, c_ptr, c_stride}; epilogue_args.thread.alpha = 1.0f; typename GemmKernel::Arguments args = { cutlass::gemm::GemmUniversalMode::kGemm, {m, n, k, 1}, mainloop_args, epilogue_args}; auto can_implement = gemm_op.can_implement(args); TORCH_CHECK(can_implement == cutlass::Status::kSuccess, cutlassGetStatusString(can_implement)) size_t workspace_size = gemm_op.get_workspace_size(args); cutlass::device_memory::allocation workspace(workspace_size); auto init_status = gemm_op.initialize(args, workspace.get()); TORCH_CHECK(init_status == cutlass::Status::kSuccess, cutlassGetStatusString(init_status)); auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); auto status = gemm_op.run(stream); TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status)) } template void sm100_fp8_blockwise_dispatch_shape( torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, const torch::Tensor& scales_b) { if (a.size(0) <= 128) { using MmaTileShape = Shape<_64, _128, _128>; using PerSmTileShape = Shape<_64, _128, _128>; using EpilogueTileShape = Shape<_64, _64>; using ScalesPerTile = Shape<_64, _1, _1>; launch_sm100_fp8_blockwise_scaled_mm( out, a, b, scales_a, scales_b); } else { using MmaTileShape = Shape<_128, _128, _128>; using PerSmTileShape = Shape<_128, _128, _128>; using EpilogueTileShape = Shape<_128, _64>; using ScalesPerTile = Shape<_128, _1, _1>; launch_sm100_fp8_blockwise_scaled_mm( out, a, b, scales_a, scales_b); } } template < typename OutType, typename MmaTileShape, typename PerSmTileShape, typename EpilogueTileShape, typename ScalesPerTile, int TileSizeM_ = 128, class ClusterShape = Shape<_1, _1, _1>> void launch_sm120_fp8_blockwise_scaled_mm( torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, const torch::Tensor& scales_b) { using ElementBlockScale = float; // A matrix configuration using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of // elements (up to 16 bytes) // B matrix configuration using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of // elements (up to 16 bytes) // C/D matrix configuration using ElementD = OutType; // Element type for D matrix operand using ElementC = void; // Element type for C matrix operand using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of // elements (up to 16 bytes) constexpr int AlignmentC = AlignmentD; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) // Kernel functional config using ElementAccumulator = float; // Element type for internal accumulation using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag - changed from OpClassBlockScaledTensorOp static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{}); static constexpr int ScaleGranularityM = size<0>(MmaTileShape{}) / ScaleMsPerTile; static constexpr int ScaleGranularityN = size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{}); static constexpr int ScaleGranularityK = size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{}); using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig< ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, cute::UMMA::Major::MN, cute::UMMA::Major::K>; // FP8 Block-wise scaling configuration using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, PerSmTileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD, LayoutDTag, AlignmentD, cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, ElementA, cute::tuple, AlignmentA, ElementB, cute::tuple, AlignmentB, ElementAccumulator, MmaTileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel // schedule >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, // Indicates ProblemShape CollectiveMainloop, CollectiveEpilogue, void>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; Gemm gemm_op; int m = a.size(0); int k = a.size(1); int n = b.size(1); auto a_ptr = static_cast(a.data_ptr()); auto b_ptr = static_cast(b.data_ptr()); auto c_ptr = static_cast(out.data_ptr()); auto scales_a_ptr = static_cast(scales_a.data_ptr()); auto scales_b_ptr = static_cast(scales_b.data_ptr()); using StrideA = typename Gemm::GemmKernel::StrideA; using StrideB = typename Gemm::GemmKernel::StrideB; using StrideD = typename Gemm::GemmKernel::StrideD; using StrideC = typename Gemm::GemmKernel::StrideD; StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); LayoutSFA layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); LayoutSFB layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); typename GemmKernel::MainloopArguments mainloop_args{ a_ptr, stride_a, b_ptr, stride_b, scales_a_ptr, layout_SFA, scales_b_ptr, layout_SFB}; typename GemmKernel::EpilogueArguments epilogue_args{{}, c_ptr, stride_c, c_ptr, stride_c}; epilogue_args.thread.alpha = 1.0f; typename Gemm::Arguments args = { cutlass::gemm::GemmUniversalMode::kGemm, {m, n, k, 1}, mainloop_args, epilogue_args, }; auto can_implement = gemm_op.can_implement(args); TORCH_CHECK(can_implement == cutlass::Status::kSuccess, cutlassGetStatusString(can_implement)) size_t workspace_size = gemm_op.get_workspace_size(args); cutlass::device_memory::allocation workspace(workspace_size); auto init_status = gemm_op.initialize(args, workspace.get()); TORCH_CHECK(init_status == cutlass::Status::kSuccess, cutlassGetStatusString(init_status)); auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); auto status = gemm_op.run(stream); TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status)) } template void sm120_fp8_blockwise_dispatch_shape( torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, const torch::Tensor& scales_b) { using MmaTileShape = Shape<_128, _128, _128>; using PerSmTileShape = Shape<_128, _128, _128>; using EpilogueTileShape = Shape<_128, _64>; using ScalesPerTile = Shape<_128, _1, _1>; launch_sm120_fp8_blockwise_scaled_mm( out, a, b, scales_a, scales_b); } torch::Tensor fp8_blockwise_scaled_mm( const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, const torch::Tensor& scales_b, const torch::Dtype& out_dtype) { TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor"); TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor"); TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); TORCH_CHECK( (mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment"); TORCH_CHECK( (mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment"); TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn"); TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn"); TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16"); auto is_contiguous_vector = [](const torch::Tensor& t) { auto t_sizes = t.sizes(); return t.is_contiguous() && (t.dim() == 1 || (t.dim() == 2 && *std::min_element(t_sizes.begin(), t_sizes.end()) == 1)); }; TORCH_CHECK(mat_a.size(0) == scales_a.size(0), "size of scales_a is not matched"); TORCH_CHECK(mat_a.size(1) / 128 == scales_a.size(1), "size of scales_a is not matched"); TORCH_CHECK(scales_a.stride(0) == 1 || is_contiguous_vector(scales_a), "scales_a must be M major"); TORCH_CHECK(mat_b.size(0) / 128 == scales_b.size(0), "size of scales_b is not matched"); TORCH_CHECK(mat_b.size(1) / 128 == scales_b.size(1), "size of scales_b is not matched"); TORCH_CHECK(scales_b.stride(0) == 1 || is_contiguous_vector(scales_b), "scales_b must be K major"); TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32"); TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32"); torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype)); TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment"); auto sm_version = getSMVersion(); int64_t original_rows = mat_a.size(0); torch::Tensor mat_a_padded = pad_tensor(mat_a, /*alignment=*/4); torch::Tensor scales_a_padded = pad_tensor(scales_a, /*alignment=*/4, /*col_major=*/true); torch::Tensor out_padded = torch::empty({mat_a_padded.size(0), mat_b.size(1)}, out.options()); #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) #if defined CUDA_VERSION && CUDA_VERSION >= 12000 if (sm_version == 90) { torch::Tensor scales_b_contiguous = scales_b.contiguous(); if (out_dtype == torch::kBFloat16) { cutlass_gemm_blockwise_sm90_fp8_dispatch( out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous); } else { cutlass_gemm_blockwise_sm90_fp8_dispatch( out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous); } return out_padded.slice(0, 0, original_rows); } #endif #endif #if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) #if defined CUDA_VERSION && CUDA_VERSION >= 12080 if (sm_version == 100 #if CUDA_VERSION >= 12090 || sm_version == 103 #endif ) { if (out_dtype == torch::kBFloat16) { sm100_fp8_blockwise_dispatch_shape( out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b); } else { sm100_fp8_blockwise_dispatch_shape(out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b); } return out_padded.slice(0, 0, original_rows); } #endif #endif #if defined(CUTLASS_ARCH_MMA_SM120A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) #if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 if (sm_version == 120) { if (out_dtype == torch::kBFloat16) { sm120_fp8_blockwise_dispatch_shape( out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b); } else { sm120_fp8_blockwise_dispatch_shape(out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b); } return out_padded.slice(0, 0, original_rows); } #endif #endif TORCH_CHECK_NOT_IMPLEMENTED( false, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version); }