109 lines
4.1 KiB
Python
109 lines
4.1 KiB
Python
import numpy as np
|
|
import torch
|
|
|
|
import flashinfer
|
|
from flashinfer.testing.utils import bench_gpu_time
|
|
|
|
|
|
def normal_distribution(std):
|
|
def normal_noise(shape, device):
|
|
return torch.randn(shape, device=device) * std
|
|
|
|
normal_noise.__name__ = f"normal_distribution(std={std})"
|
|
return normal_noise
|
|
|
|
|
|
def gumbel_distribution(beta):
|
|
def gumbel_noise(shape, device):
|
|
U = torch.rand(shape, device=device)
|
|
eps = 1e-20
|
|
return torch.log(-torch.log(U + eps) + eps) / beta
|
|
|
|
gumbel_noise.__name__ = f"gumbel_distribution(beta={beta})"
|
|
return gumbel_noise
|
|
|
|
|
|
@torch.inference_mode()
|
|
def main():
|
|
torch.manual_seed(42)
|
|
print("---")
|
|
print("top-p renorm")
|
|
for vocab_size in [128512]:
|
|
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
|
|
for distrib in [
|
|
normal_distribution(1),
|
|
normal_distribution(5),
|
|
gumbel_distribution(0.1),
|
|
gumbel_distribution(1),
|
|
]:
|
|
for p in [0.1, 0.5, 0.9, 1.0]:
|
|
logits = distrib((batch_size, vocab_size), device="cuda")
|
|
probs = torch.softmax(logits, dim=-1)
|
|
measurements = bench_gpu_time(
|
|
lambda: flashinfer.sampling.top_p_renorm_probs(probs, p),
|
|
dry_run_time_ms=100,
|
|
repeat_time_ms=1000,
|
|
)
|
|
ms = np.median(measurements)
|
|
|
|
io = (probs.numel() * probs.element_size()) * 2
|
|
bandwidth = io * 1e-6 / ms
|
|
print(
|
|
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
|
|
)
|
|
|
|
print("---")
|
|
print("top-k renorm")
|
|
for vocab_size in [128512]:
|
|
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
|
|
for distrib in [
|
|
normal_distribution(1),
|
|
normal_distribution(5),
|
|
gumbel_distribution(0.1),
|
|
gumbel_distribution(1),
|
|
]:
|
|
for k in [10, 100, 1000, 5000]:
|
|
logits = distrib((batch_size, vocab_size), device="cuda")
|
|
probs = torch.softmax(logits, dim=-1)
|
|
measurements = bench_gpu_time(
|
|
lambda: flashinfer.sampling.top_k_renorm_probs(probs, k),
|
|
dry_run_time_ms=100,
|
|
repeat_time_ms=1000,
|
|
)
|
|
ms = np.median(measurements)
|
|
|
|
io = (probs.numel() * probs.element_size()) * 2
|
|
bandwidth = io * 1e-6 / ms
|
|
print(
|
|
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
|
|
)
|
|
|
|
print("---")
|
|
print("top-k mask logits")
|
|
for vocab_size in [128512]:
|
|
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
|
|
for distrib in [
|
|
normal_distribution(1),
|
|
normal_distribution(5),
|
|
gumbel_distribution(0.1),
|
|
gumbel_distribution(1),
|
|
]:
|
|
for k in [10, 100, 1000, 5000]:
|
|
logits = distrib((batch_size, vocab_size), device="cuda")
|
|
measurements = bench_gpu_time(
|
|
lambda: flashinfer.sampling.top_k_mask_logits(logits, k),
|
|
dry_run_time_ms=100,
|
|
repeat_time_ms=1000,
|
|
)
|
|
ms = np.median(measurements)
|
|
|
|
io = (logits.numel() * logits.element_size()) * 2
|
|
bandwidth = io * 1e-6 / ms
|
|
print(
|
|
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|