49 lines
2.5 KiB
C++
49 lines
2.5 KiB
C++
/******************************************************************************
|
|
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
|
******************************************************************************/
|
|
|
|
#pragma once
|
|
|
|
#include <vector>
|
|
|
|
inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) {
|
|
// If varlen, we don't actually know seqlen_q but only max_seqlen_q.
|
|
if (varlen_q) return true;
|
|
// Heuristic: PackGQA is a bit slower but can help if seqlen_q is small or not near a multiple of kBlockM
|
|
auto round_up = [](int a, int b) { return (a + b - 1) / b * b; };
|
|
float nopack_gqa_efficiency = float(seqlen_q) / float(round_up(seqlen_q, blockM));
|
|
float pack_gqa_efficiency = float(seqlen_q * qhead_per_khead) / float(round_up(seqlen_q * qhead_per_khead, blockM));
|
|
return nopack_gqa_efficiency < 0.9 * pack_gqa_efficiency;
|
|
};
|
|
|
|
// Find the number of splits that maximizes the occupancy. For example, if we have
|
|
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
|
|
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
|
|
// splits as that would incur more HBM reads/writes.
|
|
// So we find the best efficiency, then find the smallest number of splits that gets 85%
|
|
// of the best efficiency.
|
|
inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
|
|
// If we have enough to almost fill the SMs, then just use 1 split
|
|
if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
|
|
// If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512.
|
|
if (num_n_blocks <= 4) { return 1; }
|
|
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
|
|
float max_efficiency = 0.f;
|
|
std::vector<float> efficiency;
|
|
efficiency.reserve(max_splits);
|
|
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
|
float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
|
|
float eff = n_waves / ceil(n_waves);
|
|
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
|
if (eff > max_efficiency) { max_efficiency = eff; }
|
|
efficiency.push_back(eff);
|
|
}
|
|
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
|
if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
|
|
// printf("num_splits chosen = %d\n", num_splits);
|
|
return num_splits;
|
|
}
|
|
}
|
|
return 1;
|
|
}
|