sglang_v0.5.2/flashinfer_0.3.1/benchmarks/bench_sampling.py

226 lines
9.0 KiB
Python

import numpy as np
import torch
import flashinfer
from flashinfer.testing.utils import bench_gpu_time
def normal_distribution(std):
def normal_noise(shape, device):
return torch.randn(shape, device=device) * std
normal_noise.__name__ = f"normal_distribution(std={std})"
return normal_noise
def gumbel_distribution(beta):
def gumbel_noise(shape, device):
U = torch.rand(shape, device=device)
eps = 1e-20
return torch.log(-torch.log(U + eps) + eps) / beta
gumbel_noise.__name__ = f"gumbel_distribution(beta={beta})"
return gumbel_noise
def init_seed_sampling(*args, **kwargs):
torch.manual_seed(42)
return flashinfer.sampling.sampling_from_probs(*args, **kwargs)
def init_seed_sampling_from_logits(*args, **kwargs):
torch.manual_seed(42)
return flashinfer.sampling.sampling_from_logits(*args, **kwargs)
def init_seed_sampling_from_softmax_logits(logits, *args, **kwargs):
torch.manual_seed(42)
return flashinfer.sampling.sampling_from_probs(
torch.softmax(logits, dim=-1), *args, **kwargs
)
def init_seed_top_k_sampling(*args, **kwargs):
torch.manual_seed(42)
return flashinfer.sampling.top_k_sampling_from_probs(*args, **kwargs)
def init_seed_top_p_sampling(*args, **kwargs):
torch.manual_seed(42)
return flashinfer.sampling.top_p_sampling_from_probs(*args, **kwargs)
@torch.inference_mode()
def main():
print("---")
print("naive sampling")
for vocab_size in [128512]:
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
for distrib in [
normal_distribution(1),
normal_distribution(5),
gumbel_distribution(0.1),
gumbel_distribution(1),
]:
for deterministic in [True, False]:
logits = distrib((batch_size, vocab_size), device="cuda")
probs = torch.softmax(logits, dim=-1)
samples = torch.zeros(
batch_size, dtype=torch.int32, device=probs.device
)
measurements = bench_gpu_time(
lambda: init_seed_sampling(probs, deterministic=deterministic),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)
io = (
probs.numel() * probs.element_size()
+ samples.numel() * samples.element_size()
)
bandwidth = io * 1e-6 / ms
print(
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
)
print("---")
print("top-k sampling")
for vocab_size in [128512]:
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
for distrib in [
normal_distribution(1),
normal_distribution(5),
gumbel_distribution(0.1),
gumbel_distribution(1),
]:
for deterministic in [True, False]:
for k in [10, 100, 1000, 5000]:
logits = distrib((batch_size, vocab_size), device="cuda")
probs = torch.softmax(logits, dim=-1)
samples = torch.zeros(
batch_size, dtype=torch.int32, device=probs.device
)
measurements = bench_gpu_time(
lambda: init_seed_top_k_sampling(
probs, k, deterministic=deterministic
),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)
io = (
probs.numel() * probs.element_size()
+ samples.numel() * samples.element_size()
)
bandwidth = io * 1e-6 / ms
print(
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
)
print("---")
print("top-p sampling")
for vocab_size in [128512]:
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
for distrib in [
normal_distribution(1),
normal_distribution(5),
gumbel_distribution(0.1),
gumbel_distribution(1),
]:
for deterministic in [True, False]:
for p in [0.1, 0.5, 0.9]:
logits = distrib((batch_size, vocab_size), device="cuda")
probs = torch.softmax(logits, dim=-1)
samples = torch.zeros(
batch_size, dtype=torch.int32, device=probs.device
)
measurements = bench_gpu_time(
lambda: init_seed_top_p_sampling(
probs, p, deterministic=deterministic
),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)
io = (
probs.numel() * probs.element_size()
+ samples.numel() * samples.element_size()
)
bandwidth = io * 1e-6 / ms
print(
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
)
print("---")
print("sampling from softmax(logits)")
for vocab_size in [128512]:
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
for distrib in [
normal_distribution(1),
normal_distribution(5),
gumbel_distribution(0.1),
gumbel_distribution(1),
]:
for deterministic in [True, False]:
logits = distrib((batch_size, vocab_size), device="cuda")
samples = torch.zeros(
batch_size, dtype=torch.int32, device=logits.device
)
measurements = bench_gpu_time(
lambda: init_seed_sampling_from_softmax_logits(
logits, samples, deterministic=deterministic
),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)
io = (
logits.numel() * logits.element_size()
+ samples.numel() * samples.element_size()
)
bandwidth = io * 1e-6 / ms
print(
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
)
print("---")
print("sampling from logits")
for vocab_size in [128512]:
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
for distrib in [
normal_distribution(1),
normal_distribution(5),
gumbel_distribution(0.1),
gumbel_distribution(1),
]:
for deterministic in [True, False]:
logits = distrib((batch_size, vocab_size), device="cuda")
samples = torch.zeros(
batch_size, dtype=torch.int32, device=logits.device
)
measurements = bench_gpu_time(
lambda: init_seed_sampling_from_logits(
logits, samples, deterministic=deterministic
),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)
io = (
logits.numel() * logits.element_size()
+ samples.numel() * samples.element_size()
)
bandwidth = io * 1e-6 / ms
print(
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
)
if __name__ == "__main__":
main()