chatai/sglang/sgl-kernel/tests/test_sampling.py

142 lines
5.4 KiB
Python

# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/93e1a2634e22355b0856246b032b285ad1d1da6b/tests/test_sampling.py
import pytest
import sgl_kernel
import torch
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
@pytest.mark.parametrize("p", [0.1, 0.5])
def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
torch.manual_seed(42)
if p == 0.1:
k = int(vocab_size * 0.5)
elif p == 0.5:
k = int(vocab_size * 0.1)
else:
raise ValueError("p not recognized")
max_top_k_trails = 32
eps = 1e-4
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
# top-p mask
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
cdf = torch.cumsum(sorted_prob, dim=-1)
mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int())
# top-k mask
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
pivot = sorted_prob[:, k - 1]
mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int()
# overall mask
mask = torch.minimum(mask_top_p, mask_top_k)
uniform_samples = torch.empty(max_top_k_trails, batch_size, dtype=torch.float32).to(
0
)
top_p_tensor = torch.full((batch_size,), p).to(0)
top_k_tensor = torch.full((batch_size,), k).to(0)
num_trails = 1000
for _ in range(num_trails):
uniform_samples.uniform_()
samples, success = sgl_kernel.top_k_top_p_sampling_from_probs(
normalized_prob,
uniform_samples,
top_k_tensor,
top_p_tensor,
filter_apply_order="joint",
)
assert torch.all(success)
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[
torch.arange(batch_size), samples
]
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
@pytest.mark.parametrize("p", [0.1, 0.5, 0.9])
def test_top_p_renorm_probs(batch_size, vocab_size, p):
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
cdf = torch.cumsum(sorted_prob, dim=-1)
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
mask.scatter_add_(1, indices, (cdf >= (1 - p)).int())
renorm_prob_ground_truth = normalized_prob
renorm_prob_ground_truth[mask == 0] = 0
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
dim=-1, keepdim=True
)
renorm_prob = sgl_kernel.top_p_renorm_prob(normalized_prob, p)
torch.testing.assert_close(
renorm_prob_ground_truth,
renorm_prob,
rtol=1e-3,
atol=1e-3,
)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
@pytest.mark.parametrize("k", [10, 100, 500])
def test_top_k_renorm_probs(batch_size, vocab_size, k):
if k > vocab_size:
pytest.skip("k should be less than vocab_size")
torch.manual_seed(42)
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
pivot = sorted_prob[:, k - 1]
mask = (normalized_prob >= pivot.unsqueeze(-1)).int()
renorm_prob_ground_truth = normalized_prob
renorm_prob_ground_truth[mask == 0] = 0
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
dim=-1, keepdim=True
)
renorm_prob = sgl_kernel.top_k_renorm_prob(normalized_prob, k)
torch.testing.assert_close(
renorm_prob_ground_truth,
renorm_prob,
rtol=1e-3,
atol=1e-3,
)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
@pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1])
def test_min_p_sampling(batch_size, vocab_size, p):
torch.manual_seed(42)
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
# scale min-p
top_probs = sorted_prob[:, -1].unsqueeze(-1)
scaled_p = p * top_probs
# min-p mask
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int())
uniform_samples = torch.empty(batch_size, dtype=torch.float32).to(0)
min_p_tensor = torch.full((batch_size,), p).to(0)
num_trails = 1000
for _ in range(num_trails):
uniform_samples.uniform_()
samples = sgl_kernel.min_p_sampling_from_probs(
normalized_prob,
uniform_samples,
min_p_tensor,
)
assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[
torch.nonzero(mask[torch.arange(batch_size), samples] == 0)
]
if __name__ == "__main__":
pytest.main([__file__])