230 lines
7.7 KiB
Python
230 lines
7.7 KiB
Python
"""
|
|
Copyright (c) 2024 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import argparse
|
|
import pprint
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
import flashinfer.fused_moe as fused_moe
|
|
from flashinfer import fp4_quantize
|
|
from flashinfer.autotuner import AutoTuner, autotune, get_config_path
|
|
from flashinfer.testing.utils import bench_gpu_time
|
|
|
|
FLOAT4_E2M1_MAX = 6.0
|
|
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
|
|
|
|
|
test_configs = [
|
|
{
|
|
"hidden_size": 7168,
|
|
"num_experts": 256,
|
|
"top_k": 8,
|
|
"intermediate_size": 256,
|
|
},
|
|
{
|
|
"hidden_size": 7168,
|
|
"num_experts": 32,
|
|
"top_k": 8,
|
|
"intermediate_size": 2048,
|
|
},
|
|
]
|
|
|
|
|
|
def compute_routing(
|
|
router_logits: torch.Tensor, top_k: int
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Compute routing weights and selected experts from router logits.
|
|
|
|
Args:
|
|
router_logits (torch.Tensor): Router logits of shape [batch_size, num_experts]
|
|
top_k (int): Number of experts to route to per token
|
|
|
|
Returns:
|
|
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
|
- routing_weights: Expert weights of shape [batch_size, top_k]
|
|
- selected_experts: Expert indices of shape [batch_size, top_k]
|
|
"""
|
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
|
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
|
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
routing_weights = routing_weights.float()
|
|
return routing_weights, selected_experts
|
|
|
|
|
|
def bench_cutlass_fused_moe(
|
|
batch_size,
|
|
hidden_size,
|
|
num_experts,
|
|
top_k,
|
|
intermediate_size,
|
|
skip_autotune,
|
|
):
|
|
torch.manual_seed(42)
|
|
quant_blocksize = 16
|
|
round_up = lambda x, y: (x + y - 1) // y * y
|
|
e = num_experts
|
|
m = batch_size
|
|
n = intermediate_size
|
|
k = hidden_size
|
|
otype = torch.bfloat16
|
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10
|
|
w1_cutlass = torch.cat((w1[:, n:, :], w1[:, :n, :]), dim=1).contiguous()
|
|
|
|
sf_w1_2n = round_up(2 * n, 128)
|
|
sf_w1_k = round_up(k // quant_blocksize, 4)
|
|
w1_blockscale = torch.empty(
|
|
(e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn
|
|
)
|
|
w1_blockscale_cutlass = torch.empty(
|
|
(e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn
|
|
)
|
|
|
|
w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10
|
|
sf_w2_k = round_up(k, 128)
|
|
sf_w2_n = round_up(n // quant_blocksize, 4)
|
|
w2_blockscale = torch.empty(
|
|
(e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn
|
|
)
|
|
w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8)
|
|
w1_q_cutlass = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8)
|
|
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
|
|
w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
|
|
w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
|
|
|
|
for expert in range(e):
|
|
w1_amax = torch.abs(w1).max().to(torch.float32)
|
|
w2_amax = torch.abs(w2).max().to(torch.float32)
|
|
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
|
|
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
|
|
|
|
w1_q[expert], w1_blockscale[expert] = fp4_quantize(w1[expert], w1_gs[expert])
|
|
|
|
w1_q_cutlass[expert], w1_blockscale_cutlass[expert] = fp4_quantize(
|
|
w1_cutlass[expert], w1_gs[expert]
|
|
)
|
|
|
|
w2_q[expert], w2_blockscale[expert] = fp4_quantize(w2[expert], w2_gs[expert])
|
|
|
|
x = torch.randn(m, k, dtype=otype).cuda()
|
|
a1_gs = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(x).max().to(
|
|
torch.float32
|
|
).cuda()
|
|
a1_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
|
a2_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
|
router_logits = torch.randn(m, e, dtype=otype).cuda()
|
|
routing_weights, selected_experts = compute_routing(router_logits, top_k)
|
|
|
|
flash_output = torch.zeros_like(x)
|
|
|
|
quant_scales = [
|
|
a1_gs,
|
|
w1_blockscale.view(torch.int32),
|
|
1.0 / (a1_gs * w1_gs),
|
|
a2_gs,
|
|
w2_blockscale.view(torch.int32),
|
|
1.0 / (a2_gs * w2_gs),
|
|
]
|
|
hidden_states = x
|
|
hidden_states, input_sf = fp4_quantize(x, a1_gs)
|
|
|
|
# Warmup
|
|
for _ in range(3):
|
|
_ = fused_moe.cutlass_fused_moe(
|
|
hidden_states,
|
|
selected_experts.to(torch.int),
|
|
routing_weights,
|
|
w1_q.contiguous().view(torch.long),
|
|
w2_q.contiguous().view(torch.long),
|
|
otype,
|
|
quant_scales=quant_scales,
|
|
input_sf=input_sf,
|
|
output=flash_output,
|
|
tune_max_num_tokens=16384,
|
|
)
|
|
|
|
if not skip_autotune:
|
|
with torch.inference_mode(), autotune(True):
|
|
_ = fused_moe.cutlass_fused_moe(
|
|
hidden_states,
|
|
selected_experts.to(torch.int),
|
|
routing_weights,
|
|
w1_q.contiguous().view(torch.long),
|
|
w2_q.contiguous().view(torch.long),
|
|
otype,
|
|
quant_scales=quant_scales,
|
|
input_sf=input_sf,
|
|
output=flash_output,
|
|
tune_max_num_tokens=16384,
|
|
)
|
|
ms_list = bench_gpu_time(
|
|
lambda: fused_moe.cutlass_fused_moe(
|
|
hidden_states,
|
|
selected_experts.to(torch.int),
|
|
routing_weights,
|
|
w1_q.contiguous().view(torch.long),
|
|
w2_q.contiguous().view(torch.long),
|
|
otype,
|
|
quant_scales=quant_scales,
|
|
input_sf=input_sf,
|
|
output=flash_output,
|
|
),
|
|
)
|
|
median_ms = np.median(ms_list)
|
|
print(f"{'input':<15} {'weight1':<20} {'weight2':<20} {'time(ms)'}")
|
|
print(
|
|
f"{str(tuple(hidden_states.shape)):<15} {str(tuple(w1.shape)):<20} {str(tuple(w2.shape)):<20} {median_ms:.3f}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--update-config",
|
|
action="store_true",
|
|
help="Update the config file with the new profiling results",
|
|
)
|
|
parser.add_argument(
|
|
"--num-tokens", type=int, default=32, help="Number of tokens to profile"
|
|
)
|
|
parser.add_argument("--skip-autotune", action="store_true", help="Skip autotuning")
|
|
args = parser.parse_args()
|
|
AutoTuner.get().clear_cache()
|
|
|
|
for config in test_configs:
|
|
bench_cutlass_fused_moe(
|
|
args.num_tokens,
|
|
config["hidden_size"],
|
|
config["num_experts"],
|
|
config["top_k"],
|
|
config["intermediate_size"],
|
|
args.skip_autotune,
|
|
)
|
|
|
|
configs = AutoTuner.get().profiling_cache
|
|
if args.update_config and configs:
|
|
# The original key contains a runner's hash in k[2] which might be different across machines.
|
|
# So, we remove it for now. v[0] and v[1] are the runner id and the tactic.
|
|
converted = {str((k[0], k[1], k[3])): (v[0], v[1]) for k, v in configs.items()}
|
|
config_path = get_config_path(is_module=False)
|
|
with open(config_path, "w") as f:
|
|
f.write("best_configs = ")
|
|
pprint.pprint(converted, stream=f)
|
|
print(f"Saved the cache to {config_path}")
|