// References: // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmgroupedbatchedex // https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLAS/Extensions/GemmGroupedBatchedEx/cublas_GemmGroupedBatchedEx_example.cu // https://github.com/zhihu/ZhiLight/blob/main/src/nn/linear/gemm_grouped.cpp #include #include #include #include #include #include #include #include #include #include #include #include #include #include "utils.h" static void check_group_count( const std::vector& inputs, const std::vector& weights, const std::vector& outputs) { TORCH_CHECK( ((inputs.size() == weights.size()) && (inputs.size() == outputs.size())), "The group count of inputs, weights and outputs should be the same."); } static void check_device_dtype(const torch::Dtype& dtype, const std::vector& tensors) { for (const auto& t : tensors) { TORCH_CHECK(dtype == t.dtype(), "dtype of all the tensors should be the same"); TORCH_CHECK(t.is_cuda(), "All tensors should be in Cuda memory"); } } static std::vector get_dims(const std::vector& tensors, int dim) { std::vector results; for (const auto& t : tensors) { TORCH_CHECK(t.dim() == 2, "Should pass in 2D matrices"); results.push_back(t.size(dim)); } return std::move(results); } static std::vector get_strides(const std::vector& tensors, int dim) { std::vector results; for (const auto& t : tensors) { results.push_back(t.stride(dim)); } return std::move(results); } static void check_equal(const std::vector& a, const std::vector& b, const std::string& err_msg) { for (int i = 0; i < a.size(); ++i) { TORCH_CHECK(a[i] == b[i], err_msg); } } static std::vector get_tensor_ptrs(const std::vector& tensors) { std::vector ptrs; for (auto& t : tensors) { ptrs.push_back(t.data_ptr()); } return std::move(ptrs); } static torch::Tensor create_ptr_pointer(const std::vector& ptrs, cudaStream_t stream) { auto options = torch::TensorOptions().dtype(torch::kDouble).device(torch::kCUDA); torch::Tensor gpu_ptrs = torch::empty({static_cast(ptrs.size())}, options); TORCH_CHECK( cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, stream) == CUBLAS_STATUS_SUCCESS); return gpu_ptrs; } // We want compute input @ weight^T in row major // This is equivalent to computing weight @ input^T in col major // Cublas only accepts matrix in column major, so this arrangement is needed void cublas_grouped_gemm( const std::vector& inputs, // b: (m, k) row major = (k, m) col major const std::vector& weights, // a: (n, k) row major = (n, k)^T col major const std::vector& outputs, // c: (m, n) row major = (n, m) col major const torch::Dtype& out_dtype, int64_t cublas_handle, int64_t cuda_stream) { TORCH_CHECK( out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "cublas grouped_gemm can" "only be applied to float16 and bfloat16 dtype"); int group_count = inputs.size(); check_group_count(inputs, weights, outputs); std::vector group_size(group_count, 1); // Make sure all tensors are on cuda and use the same dtype check_device_dtype(out_dtype, inputs); check_device_dtype(out_dtype, weights); check_device_dtype(out_dtype, outputs); // Weights should be transposed to (n, k) of column major std::vector transa_array(group_count, CUBLAS_OP_T); std::vector transb_array(group_count, CUBLAS_OP_N); // Get dim arrays std::vector m_array = get_dims(weights, 0); std::vector n_array = get_dims(inputs, 0); std::vector k_array = get_dims(inputs, 1); // Make sure the dimensions in each group match std::vector m_array1 = get_dims(outputs, 1); std::vector n_array1 = get_dims(outputs, 0); std::vector k_array1 = get_dims(weights, 1); check_equal(m_array, m_array1, "sizes don't match on m dimension"); check_equal(n_array, n_array1, "sizes don't match on n dimension"); check_equal(k_array, k_array1, "sizes don't match on k dimension"); // Get leading dimensions std::vector lda_array = get_strides(weights, 0); std::vector ldb_array = get_strides(inputs, 0); std::vector ldc_array = get_strides(outputs, 0); // Use default scaling factors std::vector alpha_array(group_count, 1); std::vector beta_array(group_count, 0); std::vector a_array = get_tensor_ptrs(weights); std::vector b_array = get_tensor_ptrs(inputs); std::vector c_array = get_tensor_ptrs(outputs); auto stream = reinterpret_cast(cuda_stream); // Should allocate tensors for storage of pointers torch::Tensor d_a = create_ptr_pointer(a_array, stream); torch::Tensor d_b = create_ptr_pointer(b_array, stream); torch::Tensor d_c = create_ptr_pointer(c_array, stream); #if defined CUDA_VERSION && CUDA_VERSION >= 12050 auto handle = reinterpret_cast(cublas_handle); cudaDataType_t cuda_data_type = (out_dtype == torch::kHalf ? CUDA_R_16F : CUDA_R_16BF); auto status = cublasGemmGroupedBatchedEx( handle, transa_array.data(), transb_array.data(), m_array.data(), n_array.data(), k_array.data(), alpha_array.data(), (void**)d_a.data_ptr(), cuda_data_type, lda_array.data(), (void**)d_b.data_ptr(), cuda_data_type, ldb_array.data(), beta_array.data(), (void**)d_c.data_ptr(), cuda_data_type, ldc_array.data(), group_count, group_size.data(), CUBLAS_COMPUTE_32F); TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "cublas grouped gemm failed: ", cublasGetStatusString(status)); TORCH_CHECK(cudaStreamSynchronize(stream) == cudaSuccess, "Failed when stream synchronization"); return; #endif TORCH_CHECK_NOT_IMPLEMENTED( false, "Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion()); }