226 lines
9.0 KiB
Python
226 lines
9.0 KiB
Python
import numpy as np
|
|
import torch
|
|
|
|
import flashinfer
|
|
from flashinfer.testing.utils import bench_gpu_time
|
|
|
|
|
|
def normal_distribution(std):
|
|
def normal_noise(shape, device):
|
|
return torch.randn(shape, device=device) * std
|
|
|
|
normal_noise.__name__ = f"normal_distribution(std={std})"
|
|
return normal_noise
|
|
|
|
|
|
def gumbel_distribution(beta):
|
|
def gumbel_noise(shape, device):
|
|
U = torch.rand(shape, device=device)
|
|
eps = 1e-20
|
|
return torch.log(-torch.log(U + eps) + eps) / beta
|
|
|
|
gumbel_noise.__name__ = f"gumbel_distribution(beta={beta})"
|
|
return gumbel_noise
|
|
|
|
|
|
def init_seed_sampling(*args, **kwargs):
|
|
torch.manual_seed(42)
|
|
return flashinfer.sampling.sampling_from_probs(*args, **kwargs)
|
|
|
|
|
|
def init_seed_sampling_from_logits(*args, **kwargs):
|
|
torch.manual_seed(42)
|
|
return flashinfer.sampling.sampling_from_logits(*args, **kwargs)
|
|
|
|
|
|
def init_seed_sampling_from_softmax_logits(logits, *args, **kwargs):
|
|
torch.manual_seed(42)
|
|
return flashinfer.sampling.sampling_from_probs(
|
|
torch.softmax(logits, dim=-1), *args, **kwargs
|
|
)
|
|
|
|
|
|
def init_seed_top_k_sampling(*args, **kwargs):
|
|
torch.manual_seed(42)
|
|
return flashinfer.sampling.top_k_sampling_from_probs(*args, **kwargs)
|
|
|
|
|
|
def init_seed_top_p_sampling(*args, **kwargs):
|
|
torch.manual_seed(42)
|
|
return flashinfer.sampling.top_p_sampling_from_probs(*args, **kwargs)
|
|
|
|
|
|
@torch.inference_mode()
|
|
def main():
|
|
print("---")
|
|
print("naive sampling")
|
|
for vocab_size in [128512]:
|
|
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
|
|
for distrib in [
|
|
normal_distribution(1),
|
|
normal_distribution(5),
|
|
gumbel_distribution(0.1),
|
|
gumbel_distribution(1),
|
|
]:
|
|
for deterministic in [True, False]:
|
|
logits = distrib((batch_size, vocab_size), device="cuda")
|
|
probs = torch.softmax(logits, dim=-1)
|
|
samples = torch.zeros(
|
|
batch_size, dtype=torch.int32, device=probs.device
|
|
)
|
|
measurements = bench_gpu_time(
|
|
lambda: init_seed_sampling(probs, deterministic=deterministic),
|
|
dry_run_time_ms=100,
|
|
repeat_time_ms=1000,
|
|
)
|
|
ms = np.median(measurements)
|
|
|
|
io = (
|
|
probs.numel() * probs.element_size()
|
|
+ samples.numel() * samples.element_size()
|
|
)
|
|
bandwidth = io * 1e-6 / ms
|
|
print(
|
|
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
|
|
)
|
|
|
|
print("---")
|
|
print("top-k sampling")
|
|
for vocab_size in [128512]:
|
|
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
|
|
for distrib in [
|
|
normal_distribution(1),
|
|
normal_distribution(5),
|
|
gumbel_distribution(0.1),
|
|
gumbel_distribution(1),
|
|
]:
|
|
for deterministic in [True, False]:
|
|
for k in [10, 100, 1000, 5000]:
|
|
logits = distrib((batch_size, vocab_size), device="cuda")
|
|
probs = torch.softmax(logits, dim=-1)
|
|
samples = torch.zeros(
|
|
batch_size, dtype=torch.int32, device=probs.device
|
|
)
|
|
measurements = bench_gpu_time(
|
|
lambda: init_seed_top_k_sampling(
|
|
probs, k, deterministic=deterministic
|
|
),
|
|
dry_run_time_ms=100,
|
|
repeat_time_ms=1000,
|
|
)
|
|
ms = np.median(measurements)
|
|
|
|
io = (
|
|
probs.numel() * probs.element_size()
|
|
+ samples.numel() * samples.element_size()
|
|
)
|
|
bandwidth = io * 1e-6 / ms
|
|
print(
|
|
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
|
|
)
|
|
|
|
print("---")
|
|
print("top-p sampling")
|
|
|
|
for vocab_size in [128512]:
|
|
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
|
|
for distrib in [
|
|
normal_distribution(1),
|
|
normal_distribution(5),
|
|
gumbel_distribution(0.1),
|
|
gumbel_distribution(1),
|
|
]:
|
|
for deterministic in [True, False]:
|
|
for p in [0.1, 0.5, 0.9]:
|
|
logits = distrib((batch_size, vocab_size), device="cuda")
|
|
probs = torch.softmax(logits, dim=-1)
|
|
samples = torch.zeros(
|
|
batch_size, dtype=torch.int32, device=probs.device
|
|
)
|
|
measurements = bench_gpu_time(
|
|
lambda: init_seed_top_p_sampling(
|
|
probs, p, deterministic=deterministic
|
|
),
|
|
dry_run_time_ms=100,
|
|
repeat_time_ms=1000,
|
|
)
|
|
ms = np.median(measurements)
|
|
|
|
io = (
|
|
probs.numel() * probs.element_size()
|
|
+ samples.numel() * samples.element_size()
|
|
)
|
|
bandwidth = io * 1e-6 / ms
|
|
print(
|
|
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
|
|
)
|
|
|
|
print("---")
|
|
print("sampling from softmax(logits)")
|
|
for vocab_size in [128512]:
|
|
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
|
|
for distrib in [
|
|
normal_distribution(1),
|
|
normal_distribution(5),
|
|
gumbel_distribution(0.1),
|
|
gumbel_distribution(1),
|
|
]:
|
|
for deterministic in [True, False]:
|
|
logits = distrib((batch_size, vocab_size), device="cuda")
|
|
samples = torch.zeros(
|
|
batch_size, dtype=torch.int32, device=logits.device
|
|
)
|
|
measurements = bench_gpu_time(
|
|
lambda: init_seed_sampling_from_softmax_logits(
|
|
logits, samples, deterministic=deterministic
|
|
),
|
|
dry_run_time_ms=100,
|
|
repeat_time_ms=1000,
|
|
)
|
|
ms = np.median(measurements)
|
|
io = (
|
|
logits.numel() * logits.element_size()
|
|
+ samples.numel() * samples.element_size()
|
|
)
|
|
bandwidth = io * 1e-6 / ms
|
|
print(
|
|
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
|
|
)
|
|
|
|
print("---")
|
|
print("sampling from logits")
|
|
for vocab_size in [128512]:
|
|
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
|
|
for distrib in [
|
|
normal_distribution(1),
|
|
normal_distribution(5),
|
|
gumbel_distribution(0.1),
|
|
gumbel_distribution(1),
|
|
]:
|
|
for deterministic in [True, False]:
|
|
logits = distrib((batch_size, vocab_size), device="cuda")
|
|
samples = torch.zeros(
|
|
batch_size, dtype=torch.int32, device=logits.device
|
|
)
|
|
measurements = bench_gpu_time(
|
|
lambda: init_seed_sampling_from_logits(
|
|
logits, samples, deterministic=deterministic
|
|
),
|
|
dry_run_time_ms=100,
|
|
repeat_time_ms=1000,
|
|
)
|
|
ms = np.median(measurements)
|
|
|
|
io = (
|
|
logits.numel() * logits.element_size()
|
|
+ samples.numel() * samples.element_size()
|
|
)
|
|
bandwidth = io * 1e-6 / ms
|
|
print(
|
|
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, deterministic: {deterministic}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|