sglang_v0.5.2/flashinfer_0.3.1/benchmarks/bench_cutlass_fused_moe.py

230 lines
7.7 KiB
Python

"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import argparse
import pprint
import numpy as np
import torch
from torch.nn import functional as F
import flashinfer.fused_moe as fused_moe
from flashinfer import fp4_quantize
from flashinfer.autotuner import AutoTuner, autotune, get_config_path
from flashinfer.testing.utils import bench_gpu_time
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
test_configs = [
{
"hidden_size": 7168,
"num_experts": 256,
"top_k": 8,
"intermediate_size": 256,
},
{
"hidden_size": 7168,
"num_experts": 32,
"top_k": 8,
"intermediate_size": 2048,
},
]
def compute_routing(
router_logits: torch.Tensor, top_k: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute routing weights and selected experts from router logits.
Args:
router_logits (torch.Tensor): Router logits of shape [batch_size, num_experts]
top_k (int): Number of experts to route to per token
Returns:
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- routing_weights: Expert weights of shape [batch_size, top_k]
- selected_experts: Expert indices of shape [batch_size, top_k]
"""
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.float()
return routing_weights, selected_experts
def bench_cutlass_fused_moe(
batch_size,
hidden_size,
num_experts,
top_k,
intermediate_size,
skip_autotune,
):
torch.manual_seed(42)
quant_blocksize = 16
round_up = lambda x, y: (x + y - 1) // y * y
e = num_experts
m = batch_size
n = intermediate_size
k = hidden_size
otype = torch.bfloat16
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10
w1_cutlass = torch.cat((w1[:, n:, :], w1[:, :n, :]), dim=1).contiguous()
sf_w1_2n = round_up(2 * n, 128)
sf_w1_k = round_up(k // quant_blocksize, 4)
w1_blockscale = torch.empty(
(e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn
)
w1_blockscale_cutlass = torch.empty(
(e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn
)
w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10
sf_w2_k = round_up(k, 128)
sf_w2_n = round_up(n // quant_blocksize, 4)
w2_blockscale = torch.empty(
(e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn
)
w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8)
w1_q_cutlass = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8)
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
for expert in range(e):
w1_amax = torch.abs(w1).max().to(torch.float32)
w2_amax = torch.abs(w2).max().to(torch.float32)
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
w1_q[expert], w1_blockscale[expert] = fp4_quantize(w1[expert], w1_gs[expert])
w1_q_cutlass[expert], w1_blockscale_cutlass[expert] = fp4_quantize(
w1_cutlass[expert], w1_gs[expert]
)
w2_q[expert], w2_blockscale[expert] = fp4_quantize(w2[expert], w2_gs[expert])
x = torch.randn(m, k, dtype=otype).cuda()
a1_gs = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(x).max().to(
torch.float32
).cuda()
a1_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32)
a2_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32)
router_logits = torch.randn(m, e, dtype=otype).cuda()
routing_weights, selected_experts = compute_routing(router_logits, top_k)
flash_output = torch.zeros_like(x)
quant_scales = [
a1_gs,
w1_blockscale.view(torch.int32),
1.0 / (a1_gs * w1_gs),
a2_gs,
w2_blockscale.view(torch.int32),
1.0 / (a2_gs * w2_gs),
]
hidden_states = x
hidden_states, input_sf = fp4_quantize(x, a1_gs)
# Warmup
for _ in range(3):
_ = fused_moe.cutlass_fused_moe(
hidden_states,
selected_experts.to(torch.int),
routing_weights,
w1_q.contiguous().view(torch.long),
w2_q.contiguous().view(torch.long),
otype,
quant_scales=quant_scales,
input_sf=input_sf,
output=flash_output,
tune_max_num_tokens=16384,
)
if not skip_autotune:
with torch.inference_mode(), autotune(True):
_ = fused_moe.cutlass_fused_moe(
hidden_states,
selected_experts.to(torch.int),
routing_weights,
w1_q.contiguous().view(torch.long),
w2_q.contiguous().view(torch.long),
otype,
quant_scales=quant_scales,
input_sf=input_sf,
output=flash_output,
tune_max_num_tokens=16384,
)
ms_list = bench_gpu_time(
lambda: fused_moe.cutlass_fused_moe(
hidden_states,
selected_experts.to(torch.int),
routing_weights,
w1_q.contiguous().view(torch.long),
w2_q.contiguous().view(torch.long),
otype,
quant_scales=quant_scales,
input_sf=input_sf,
output=flash_output,
),
)
median_ms = np.median(ms_list)
print(f"{'input':<15} {'weight1':<20} {'weight2':<20} {'time(ms)'}")
print(
f"{str(tuple(hidden_states.shape)):<15} {str(tuple(w1.shape)):<20} {str(tuple(w2.shape)):<20} {median_ms:.3f}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--update-config",
action="store_true",
help="Update the config file with the new profiling results",
)
parser.add_argument(
"--num-tokens", type=int, default=32, help="Number of tokens to profile"
)
parser.add_argument("--skip-autotune", action="store_true", help="Skip autotuning")
args = parser.parse_args()
AutoTuner.get().clear_cache()
for config in test_configs:
bench_cutlass_fused_moe(
args.num_tokens,
config["hidden_size"],
config["num_experts"],
config["top_k"],
config["intermediate_size"],
args.skip_autotune,
)
configs = AutoTuner.get().profiling_cache
if args.update_config and configs:
# The original key contains a runner's hash in k[2] which might be different across machines.
# So, we remove it for now. v[0] and v[1] are the runner id and the tactic.
converted = {str((k[0], k[1], k[3])): (v[0], v[1]) for k, v in configs.items()}
config_path = get_config_path(is_module=False)
with open(config_path, "w") as f:
f.write("best_configs = ")
pprint.pprint(converted, stream=f)
print(f"Saved the cache to {config_path}")