#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "cutlass_extensions/gemm/collective/collective_builder.hpp" #include "cutlass_extensions/gemm/dispatch_policy.hpp" #include "utils.h" using namespace cute; template void launch_sm90_fp8_blockwise_scaled_mm( torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, const torch::Tensor& scales_b) { using ElementAccumulator = float; using ElementCompute = float; using ElementBlockScale = float; using ElementA = cutlass::float_e4m3_t; using LayoutA = cutlass::layout::RowMajor; constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; using ElementB = cutlass::float_e4m3_t; using LayoutB = cutlass::layout::ColumnMajor; constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; using ElementC = void; using LayoutC = cutlass::layout::RowMajor; constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; using ElementD = OutType; using LayoutD = cutlass::layout::RowMajor; constexpr int AlignmentD = AlignmentC; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT; using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, EpilogueSchedule, StoreEpilogueCompute>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule>::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, // Indicates ProblemShape CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::PersistentScheduler>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; Gemm gemm_op; int m = a.size(0); int k = a.size(1); int n = b.size(1); auto a_ptr = static_cast(a.data_ptr()); auto b_ptr = static_cast(b.data_ptr()); auto o_ptr = static_cast(out.data_ptr()); auto a_s_ptr = static_cast(scales_a.data_ptr()); auto b_s_ptr = static_cast(scales_b.data_ptr()); using StrideA = typename Gemm::GemmKernel::StrideA; using StrideB = typename Gemm::GemmKernel::StrideB; using StrideC = typename Gemm::GemmKernel::StrideC; using StrideD = typename Gemm::GemmKernel::StrideD; StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); StrideC stride_c; StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1)); typename GemmKernel::MainloopArguments mainloop_args{a_ptr, stride_a, b_ptr, stride_b, 4, a_s_ptr, b_s_ptr}; typename GemmKernel::EpilogueArguments epilogue_args{{}, nullptr, stride_d, o_ptr, stride_d}; typename Gemm::Arguments args = { cutlass::gemm::GemmUniversalMode::kGemm, {m, n, k, 1}, mainloop_args, epilogue_args, }; size_t workspace_size = gemm_op.get_workspace_size(args); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); auto workspace = torch::empty(workspace_size, workspace_options); auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); auto can_implement = gemm_op.can_implement(args); TORCH_CHECK(can_implement == cutlass::Status::kSuccess, cutlassGetStatusString(can_implement)) auto status = gemm_op.run(args, workspace.data_ptr(), stream); TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status)) } template < typename OutType, typename MmaTileShape, typename PerSmTileShape, typename EpilogueTileShape, typename ScalesPerTile, int TileSizeM_ = 128, class ClusterShape = Shape<_1, _1, _1>> void launch_sm100_fp8_blockwise_scaled_mm( torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, const torch::Tensor& scales_b) { static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{}); static constexpr int ScaleGranularityM = size<0>(MmaTileShape{}) / ScaleMsPerTile; static constexpr int ScaleGranularityN = size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{}); static constexpr int ScaleGranularityK = size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{}); using ElementAB = cutlass::float_e4m3_t; using ElementA = ElementAB; using ElementB = ElementAB; using ElementC = void; using ElementD = OutType; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutD = cutlass::layout::RowMajor; using LayoutC = LayoutD; // This means both SFA and SFB are column-major. using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig< ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, cute::UMMA::Major::MN, cute::UMMA::Major::K>; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; static constexpr int AlignmentC = AlignmentD; using ElementAccumulator = float; using ElementBlockScale = float; using ElementCompute = float; using ArchTag = cutlass::arch::Sm100; using OperatorClass = cutlass::arch::OpClassTensorOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, cutlass::arch::OpClassTensorOp, PerSmTileShape, ClusterShape, EpilogueTileShape, ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, cutlass::epilogue::TmaWarpSpecialized1Sm>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, ElementA, cute::tuple, AlignmentA, ElementB, cute::tuple, AlignmentB, ElementAccumulator, MmaTileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::PersistentScheduler>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; Gemm gemm_op; int m = a.size(0); int k = a.size(1); int n = b.size(1); auto a_ptr = static_cast(a.data_ptr()); auto b_ptr = static_cast(b.data_ptr()); auto scales_a_ptr = static_cast(scales_a.data_ptr()); auto scales_b_ptr = static_cast(scales_b.data_ptr()); auto c_ptr = static_cast(out.data_ptr()); using StrideA = typename GemmKernel::StrideA; using StrideB = typename GemmKernel::StrideB; using StrideD = typename GemmKernel::StrideD; using StrideC = typename GemmKernel::StrideD; StrideA a_stride = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); StrideB b_stride = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); StrideC c_stride = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); LayoutSFA layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); LayoutSFB layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); typename GemmKernel::MainloopArguments mainloop_args{ a_ptr, a_stride, b_ptr, b_stride, scales_a_ptr, layout_SFA, scales_b_ptr, layout_SFB}; typename GemmKernel::EpilogueArguments epilogue_args{{}, c_ptr, c_stride, c_ptr, c_stride}; epilogue_args.thread.alpha = 1.0f; typename GemmKernel::Arguments args = { cutlass::gemm::GemmUniversalMode::kGemm, {m, n, k, 1}, mainloop_args, epilogue_args}; auto can_implement = gemm_op.can_implement(args); TORCH_CHECK(can_implement == cutlass::Status::kSuccess, cutlassGetStatusString(can_implement)) size_t workspace_size = gemm_op.get_workspace_size(args); cutlass::device_memory::allocation workspace(workspace_size); auto init_status = gemm_op.initialize(args, workspace.get()); TORCH_CHECK(init_status == cutlass::Status::kSuccess, cutlassGetStatusString(init_status)); auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); auto status = gemm_op.run(stream); TORCH_CHECK(status == cutlass::Status::kSuccess, cutlassGetStatusString(status)) } template void sm90_fp8_blockwise_dispatch_shape( torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, const torch::Tensor& scales_b) { using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_1, _1, _1>; launch_sm90_fp8_blockwise_scaled_mm(out, a, b, scales_a, scales_b); } template void sm100_fp8_blockwise_dispatch_shape( torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, const torch::Tensor& scales_b) { if (a.size(0) <= 128) { using MmaTileShape = Shape<_64, _128, _128>; using PerSmTileShape = Shape<_64, _128, _128>; using EpilogueTileShape = Shape<_64, _64>; using ScalesPerTile = Shape<_64, _1, _1>; launch_sm100_fp8_blockwise_scaled_mm( out, a, b, scales_a, scales_b); } else { using MmaTileShape = Shape<_128, _128, _128>; using PerSmTileShape = Shape<_128, _128, _128>; using EpilogueTileShape = Shape<_128, _64>; using ScalesPerTile = Shape<_128, _1, _1>; launch_sm100_fp8_blockwise_scaled_mm( out, a, b, scales_a, scales_b); } } torch::Tensor fp8_blockwise_scaled_mm( const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, const torch::Tensor& scales_b, const torch::Dtype& out_dtype) { TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor"); TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); TORCH_CHECK( (mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment"); TORCH_CHECK( (mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment"); TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn"); TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn"); TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16"); auto is_contiguous_vector = [](const torch::Tensor& t) { auto t_sizes = t.sizes(); return t.is_contiguous() && (t.dim() == 1 || (t.dim() == 2 && *std::min_element(t_sizes.begin(), t_sizes.end()) == 1)); }; TORCH_CHECK(mat_a.size(0) == scales_a.size(0), "size of scales_a is not matched"); TORCH_CHECK(mat_a.size(1) / 128 == scales_a.size(1), "size of scales_a is not matched"); TORCH_CHECK(scales_a.stride(0) == 1 || is_contiguous_vector(scales_a), "scales_a must be M major"); TORCH_CHECK(mat_b.size(0) / 128 == scales_b.size(0), "size of scales_b is not matched"); TORCH_CHECK(mat_b.size(1) / 128 == scales_b.size(1), "size of scales_b is not matched"); TORCH_CHECK(scales_b.stride(0) == 1 || is_contiguous_vector(scales_b), "scales_b must be K major"); TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32"); TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32"); torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype)); TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment"); auto sm_version = getSMVersion(); #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) #if defined CUDA_VERSION && CUDA_VERSION >= 12000 if (sm_version == 90) { if (out_dtype == torch::kBFloat16) { sm90_fp8_blockwise_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b); } else { sm90_fp8_blockwise_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b); } return out; } #endif #endif #if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) #if defined CUDA_VERSION && CUDA_VERSION >= 12080 if (sm_version == 100) { int64_t original_rows = mat_a.size(0); torch::Tensor mat_a_padded = pad_tensor(mat_a, /*alignment=*/4); torch::Tensor scales_a_padded = pad_tensor(scales_a, /*alignment=*/4, /*col_major=*/true); torch::Tensor out_padded = torch::empty({mat_a_padded.size(0), mat_b.size(1)}, out.options()); if (out_dtype == torch::kBFloat16) { sm100_fp8_blockwise_dispatch_shape( out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b); } else { sm100_fp8_blockwise_dispatch_shape(out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b); } return out_padded.slice(0, 0, original_rows); } #endif #endif TORCH_CHECK_NOT_IMPLEMENTED( false, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version); }