sglang.0.4.8.post1/sglang/3rdparty/amd/tuning/benchmark_moe_rocm.py

381 lines
12 KiB
Python

import argparse
import json
import os
import sys
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from tqdm import tqdm
from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe,
get_config_file_name,
)
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
def main(model, tp_size, dtype: str, batches):
method = fused_moe
for bs in batches:
run_grid(int(bs), model=model, method=method, tp_size=tp_size, dtype=dtype)
def prune_configs(M, N, K, configs):
pruned_configs = []
elemBytes_a = 1 # [DV Note] Hard-coded for float16 (2 bytes)
elemBytes_b = 1 # [DV Note] Hard-coded for float16 (2 bytes)
mfma = 16 if M < 32 or N < 32 else 32
# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm = False
if M >= 2048 and N >= 2048:
large_gemm = True
for config in configs:
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
num_warps = config.get("num_warps")
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
# kpack = config.get("kpack")
if matrix_instr_nonkdim > mfma:
continue
if mfma == 4 and BLOCK_SIZE_K < 64:
continue
# some layouts could not work properly in case
# number elements per thread is less 1
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
continue
SPLIT_K = 1 # config.get("SPLIT_K")
GROUP_M = config.get("GROUP_SIZE_M")
if matrix_instr_nonkdim > BLOCK_SIZE_M or matrix_instr_nonkdim > BLOCK_SIZE_N:
continue
if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
continue
if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
continue
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
continue
# skip large split_k when not necessary
if SPLIT_K != 1 and not need_split_k(M, N, K):
continue
# skip split_k that leads to EVEN_K = false
leap = SPLIT_K * BLOCK_SIZE_K
modv = K % leap
if modv != 0:
continue
# skip large GROUP_M
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS = (
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
)
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if large_gemm:
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
continue
if BLOCK_SIZE_K < 64:
continue
if num_warps < 4:
continue
pruned_configs.append(config)
return pruned_configs
def union_of_list_of_dicts(l1, l2):
result = []
temp_list = l1.copy()
temp_list.extend(l2)
for myDict in temp_list:
if myDict not in result:
result.append(myDict)
return result
def run_grid(bs, model, method, tp_size, dtype: str):
config = AutoConfig.from_pretrained(model)
top_k = config.num_experts_per_tok
d_model = config.hidden_size
model_intermediate_size = config.intermediate_size
num_layers = config.num_hidden_layers
hidden_states_dtype = config.torch_dtype
if config.num_experts_per_tok:
if config.architectures[0] == "Grok1ModelForCausalLM":
num_total_experts = config.num_experts
else:
num_total_experts = config.num_local_experts
else:
raise ValueError(f"Unsupported Mixtral model {model}")
# tp_size = 2
num_warmup_calls = 10
num_calls = 30
num_warmup_trials = 1
num_trials = 1
full_configs = []
block_m_range = [16, 32, 64, 128, 256]
block_n_range = [16, 32, 64, 128, 256]
block_k_range = [32, 64, 128, 256] # MUST >= 32
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8, 16, 32]
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
num_stage_range = [2]
waves_per_eu_range = [0, 1, 2, 4, 8]
# Remove 32 because of triton compiling error
matrix_instr_nonkdim_range = [16]
kpack_range = [1, 2]
for block_size_m in block_m_range:
for block_size_n in block_n_range:
for block_size_k in block_k_range:
for group_size_m in group_m_range:
for num_warps in num_warps_range:
for num_stages in num_stage_range:
for waves_per_eu in waves_per_eu_range:
for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
for kpack in kpack_range:
full_configs.append(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu,
"matrix_instr_nonkdim": matrix_instr_nonkdim,
"kpack": kpack,
}
)
M1 = bs * 2
N1 = model_intermediate_size * 2 // tp_size
K1 = d_model
prune_configs_1 = prune_configs(M1, N1, K1, full_configs)
M2 = bs * 2
N2 = d_model
K2 = model_intermediate_size // tp_size
prune_configs_2 = prune_configs(M2, N2, K2, full_configs)
configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2)
print(
f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \
{len(prune_configs_2)=} | {len(configs)=}"
)
best_config = None
best_time_us = 1e20
print(f"{tp_size=} {bs=}")
for config in tqdm(configs):
# warmup
try:
print(config)
for _ in range(num_warmup_trials):
run_timing(
num_calls=num_warmup_calls,
bs=bs,
d_model=d_model,
num_total_experts=num_total_experts,
top_k=top_k,
tp_size=tp_size,
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
dtype=dtype,
hidden_states_dtype=hidden_states_dtype,
)
except triton.runtime.autotuner.OutOfResources:
continue
# trial
for _ in range(num_trials):
kernel_dur_ms = run_timing(
num_calls=num_calls,
bs=bs,
d_model=d_model,
num_total_experts=num_total_experts,
top_k=top_k,
tp_size=tp_size,
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
dtype=dtype,
hidden_states_dtype=hidden_states_dtype,
)
kernel_dur_us = 1000 * kernel_dur_ms
model_dur_ms = kernel_dur_ms * num_layers
if kernel_dur_us < best_time_us:
best_config = config
best_time_us = kernel_dur_us
tqdm.write(
f"{kernel_dur_us=:.1f} {model_dur_ms=:.1f}"
f" {bs=} {tp_size=} {top_k=} {num_total_experts=} "
f"{d_model=} {model_intermediate_size=} {num_layers=}"
)
print("best_time_us", best_time_us)
print("best_config", best_config)
# holds Dict[str, Dict[str, int]]
filename = get_config_file_name(
num_total_experts,
model_intermediate_size // tp_size,
"float8" if dtype == "float8" else None,
)
print(f"writing config to file {filename}")
existing_content = {}
if os.path.exists(filename):
with open(filename, "r") as f:
existing_content = json.load(f)
existing_content[str(bs)] = best_config
with open(filename, "w") as f:
json.dump(existing_content, f, indent=4)
f.write("\n")
def run_timing(
num_calls: int,
bs: int,
d_model: int,
num_total_experts: int,
top_k: int,
tp_size: int,
model_intermediate_size: int,
method,
config,
dtype: str,
hidden_states_dtype,
) -> float:
shard_intermediate_size = model_intermediate_size // tp_size
hidden_states = torch.rand(
(bs, d_model),
device="cuda:0",
dtype=hidden_states_dtype,
)
w1 = torch.rand(
(num_total_experts, 2 * shard_intermediate_size, d_model + padding_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
w2 = torch.rand(
(num_total_experts, d_model, shard_intermediate_size + padding_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if dtype == "float8":
w1 = w1.to(torch.float8_e4m3fnuz)
w2 = w2.to(torch.float8_e4m3fnuz)
w1_scale = torch.ones(
num_total_experts, device=hidden_states.device, dtype=torch.float32
)
w2_scale = torch.ones(
num_total_experts, device=hidden_states.device, dtype=torch.float32
)
a1_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
a2_scale = torch.ones(1, device=hidden_states.device, dtype=torch.float32)
gating_output = F.softmax(
torch.rand(
(num_calls, bs, num_total_experts),
device=hidden_states.device,
dtype=torch.float32,
),
dim=-1,
)
##################################
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(num_calls):
hidden_states = method(
hidden_states=hidden_states,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
gating_output=gating_output[0],
topk=top_k,
renormalize=True,
inplace=True,
override_config=config,
use_fp8=dtype == "float8",
)
end_event.record()
end_event.synchronize()
dur_ms = start_event.elapsed_time(end_event) / num_calls
return dur_ms
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="benchmark_mixtral_moe",
description="Benchmark and tune the fused_moe kernel",
)
parser.add_argument(
"--dtype",
type=str,
default="auto",
choices=["float8", "float16", "bfloat16"],
help="Data type used for fused_moe kernel computations",
)
parser.add_argument("--model", type=str, default="hpcai-tech/grok-1")
parser.add_argument("--tp-size", type=int, default=2, help="Tensor paralleli size")
parser.add_argument("-b", "--batches", type=str)
args = parser.parse_args()
batches = args.batches.split(",")
sys.exit(main(args.model, args.tp_size, args.dtype, batches))