# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/93e1a2634e22355b0856246b032b285ad1d1da6b/tests/test_sampling.py import pytest import sgl_kernel import torch @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5]) def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): torch.manual_seed(42) if p == 0.1: k = int(vocab_size * 0.5) elif p == 0.5: k = int(vocab_size * 0.1) else: raise ValueError("p not recognized") max_top_k_trails = 32 eps = 1e-4 pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) # top-p mask sorted_prob, indices = torch.sort(normalized_prob, descending=False) cdf = torch.cumsum(sorted_prob, dim=-1) mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int()) # top-k mask sorted_prob, _ = torch.sort(normalized_prob, descending=True) pivot = sorted_prob[:, k - 1] mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int() # overall mask mask = torch.minimum(mask_top_p, mask_top_k) uniform_samples = torch.empty(max_top_k_trails, batch_size, dtype=torch.float32).to( 0 ) top_p_tensor = torch.full((batch_size,), p).to(0) top_k_tensor = torch.full((batch_size,), k).to(0) num_trails = 1000 for _ in range(num_trails): uniform_samples.uniform_() samples, success = sgl_kernel.top_k_top_p_sampling_from_probs( normalized_prob, uniform_samples, top_k_tensor, top_p_tensor, filter_apply_order="joint", ) assert torch.all(success) assert torch.all(samples < vocab_size) and torch.all(samples >= 0) assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[ torch.arange(batch_size), samples ] @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) @pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) def test_top_p_renorm_probs(batch_size, vocab_size, p): pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) sorted_prob, indices = torch.sort(normalized_prob, descending=False) cdf = torch.cumsum(sorted_prob, dim=-1) mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) mask.scatter_add_(1, indices, (cdf >= (1 - p)).int()) renorm_prob_ground_truth = normalized_prob renorm_prob_ground_truth[mask == 0] = 0 renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( dim=-1, keepdim=True ) renorm_prob = sgl_kernel.top_p_renorm_prob(normalized_prob, p) torch.testing.assert_close( renorm_prob_ground_truth, renorm_prob, rtol=1e-3, atol=1e-3, ) @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) @pytest.mark.parametrize("k", [10, 100, 500]) def test_top_k_renorm_probs(batch_size, vocab_size, k): if k > vocab_size: pytest.skip("k should be less than vocab_size") torch.manual_seed(42) pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) sorted_prob, _ = torch.sort(normalized_prob, descending=True) pivot = sorted_prob[:, k - 1] mask = (normalized_prob >= pivot.unsqueeze(-1)).int() renorm_prob_ground_truth = normalized_prob renorm_prob_ground_truth[mask == 0] = 0 renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( dim=-1, keepdim=True ) renorm_prob = sgl_kernel.top_k_renorm_prob(normalized_prob, k) torch.testing.assert_close( renorm_prob_ground_truth, renorm_prob, rtol=1e-3, atol=1e-3, ) @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) @pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1]) def test_min_p_sampling(batch_size, vocab_size, p): torch.manual_seed(42) pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) sorted_prob, indices = torch.sort(normalized_prob, descending=False) # scale min-p top_probs = sorted_prob[:, -1].unsqueeze(-1) scaled_p = p * top_probs # min-p mask mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int()) uniform_samples = torch.empty(batch_size, dtype=torch.float32).to(0) min_p_tensor = torch.full((batch_size,), p).to(0) num_trails = 1000 for _ in range(num_trails): uniform_samples.uniform_() samples = sgl_kernel.min_p_sampling_from_probs( normalized_prob, uniform_samples, min_p_tensor, ) assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[ torch.nonzero(mask[torch.arange(batch_size), samples] == 0) ] if __name__ == "__main__": pytest.main([__file__])