574 lines
22 KiB
Python
574 lines
22 KiB
Python
"""
|
|
Copyright (c) 2024 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import flashinfer
|
|
|
|
|
|
def normal_distribution(std):
|
|
def normal_noise(shape, device):
|
|
return torch.randn(shape, device=device) * std
|
|
|
|
normal_noise.__name__ = f"normal_distribution(std={std})"
|
|
return normal_noise
|
|
|
|
|
|
def gumbel_distribution(beta):
|
|
def gumbel_noise(shape, device):
|
|
U = torch.rand(shape, device=device)
|
|
eps = 1e-20
|
|
return torch.log(-torch.log(U + eps) + eps) / beta
|
|
|
|
gumbel_noise.__name__ = f"gumbel_distribution(beta={beta})"
|
|
return gumbel_noise
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
@pytest.mark.parametrize(
|
|
"distribution",
|
|
[
|
|
normal_distribution(1),
|
|
normal_distribution(5),
|
|
gumbel_distribution(0.1),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("temperature", [1.0, 0.5, 0.1])
|
|
@pytest.mark.parametrize("temperature_arr", [True, False])
|
|
@pytest.mark.parametrize("neg_inf_input", [True, False])
|
|
def test_softmax(
|
|
batch_size, vocab_size, distribution, temperature, temperature_arr, neg_inf_input
|
|
):
|
|
torch.manual_seed(42)
|
|
logits = distribution((batch_size, vocab_size), "cuda:0")
|
|
if neg_inf_input:
|
|
# assign random logits to -inf
|
|
num_inf = torch.randint(0, logits.numel() - 1, (), device=logits.device).item()
|
|
inf_idx = torch.randperm(logits.numel(), device=logits.device)[:num_inf]
|
|
logits.view(-1).index_fill_(0, inf_idx, float("-inf"))
|
|
|
|
if temperature_arr:
|
|
temperature_arr = torch.full((batch_size,), temperature, device="cuda:0")
|
|
probs = flashinfer.sampling.softmax(logits, temperature=temperature_arr)
|
|
logits_scaled = logits / temperature_arr.unsqueeze(-1)
|
|
else:
|
|
probs = flashinfer.sampling.softmax(logits, temperature=temperature)
|
|
logits_scaled = logits / temperature
|
|
|
|
probs_ref = torch.softmax(logits_scaled, dim=-1)
|
|
|
|
assert torch.allclose(probs, probs_ref, atol=1e-5)
|
|
|
|
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
@pytest.mark.parametrize(
|
|
"distribution",
|
|
[
|
|
normal_distribution(1),
|
|
normal_distribution(5),
|
|
gumbel_distribution(0.1),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("zero_ratio", [0.0, 0.5, 0.9])
|
|
def test_sampling_freq(vocab_size, distribution, zero_ratio):
|
|
torch.manual_seed(42)
|
|
num_trials = 5000000
|
|
logits = distribution((1, vocab_size), "cuda:0")
|
|
zero_indices = torch.randperm(vocab_size)[: int(vocab_size * zero_ratio)]
|
|
logits[:, zero_indices] = -float("inf")
|
|
probs = torch.softmax(logits, dim=-1)
|
|
counter = torch.zeros(vocab_size, dtype=torch.int32, device=logits.device)
|
|
|
|
samples = flashinfer.sampling.sampling_from_probs(
|
|
probs, indices=torch.zeros(num_trials, dtype=torch.int32, device=logits.device)
|
|
)
|
|
counter.scatter_add_(0, samples.long(), torch.ones_like(samples))
|
|
freq = counter.float() / num_trials
|
|
|
|
assert torch.all(counter[zero_indices] == 0)
|
|
similarity = torch.cosine_similarity(freq, probs)
|
|
assert similarity > 0.99, f"similarity: {similarity}"
|
|
|
|
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
@pytest.mark.parametrize(
|
|
"distribution",
|
|
[
|
|
normal_distribution(1),
|
|
normal_distribution(5),
|
|
gumbel_distribution(0.1),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("p", [0.1, 0.5, 0.9])
|
|
def test_top_p_sampling_freq(vocab_size, distribution, p):
|
|
# use torch profiler to check the performance of the code
|
|
torch.manual_seed(42)
|
|
logits = distribution((1, vocab_size), "cuda:0")
|
|
probs = torch.softmax(logits, dim=-1)
|
|
sorted_prob, indices = torch.sort(probs, descending=False)
|
|
cdf = torch.cumsum(sorted_prob, dim=-1)
|
|
mask = torch.zeros(1, vocab_size, dtype=torch.int32, device=logits.device)
|
|
mask.scatter_add_(1, indices, (cdf > (1 - p)).int())
|
|
|
|
renorm_probs = flashinfer.sampling.top_p_renorm_probs(probs, p)
|
|
counter = torch.zeros(vocab_size, dtype=torch.int32, device=logits.device)
|
|
num_trials = 5000000
|
|
samples = flashinfer.sampling.top_p_sampling_from_probs(
|
|
probs,
|
|
p,
|
|
indices=torch.zeros(num_trials, dtype=torch.int32, device=logits.device),
|
|
)
|
|
counter.scatter_add_(0, samples.long(), torch.ones_like(samples))
|
|
freq = counter.float() / num_trials
|
|
assert torch.all(mask[torch.arange(1), samples] == 1)
|
|
similarity = torch.cosine_similarity(freq, renorm_probs)
|
|
assert similarity > 0.99, f"similarity: {similarity}"
|
|
|
|
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
@pytest.mark.parametrize(
|
|
"distribution",
|
|
[
|
|
normal_distribution(1),
|
|
normal_distribution(5),
|
|
gumbel_distribution(0.1),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("k", [10, 100, 500])
|
|
def test_top_k_sampling_freq(vocab_size, distribution, k):
|
|
if k > vocab_size:
|
|
pytest.skip("k should be less than vocab_size")
|
|
torch.manual_seed(42)
|
|
logits = distribution((1, vocab_size), "cuda:0")
|
|
probs = torch.softmax(logits, dim=-1)
|
|
sorted_prob, _ = torch.sort(probs, descending=True)
|
|
pivot = sorted_prob[:, k - 1]
|
|
mask = (probs >= pivot.unsqueeze(-1)).int()
|
|
|
|
renorm_probs = flashinfer.sampling.top_k_renorm_probs(probs, k)
|
|
counter = torch.zeros(vocab_size, dtype=torch.int32, device=logits.device)
|
|
num_trials = 5000000
|
|
samples = flashinfer.sampling.top_k_sampling_from_probs(
|
|
probs,
|
|
k,
|
|
indices=torch.zeros(num_trials, dtype=torch.int32, device=logits.device),
|
|
)
|
|
counter.scatter_add_(0, samples.long(), torch.ones_like(samples))
|
|
freq = counter.float() / num_trials
|
|
assert torch.all(mask[torch.arange(1), samples] == 1)
|
|
similarity = torch.cosine_similarity(freq, renorm_probs)
|
|
assert similarity > 0.99, f"similarity: {similarity}"
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
def test_sampling(batch_size, vocab_size):
|
|
torch.manual_seed(42)
|
|
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
|
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
|
|
|
num_trails = 5000
|
|
for _ in range(num_trails):
|
|
samples = flashinfer.sampling.sampling_from_probs(normalized_prob)
|
|
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
def test_sampling_from_logits(batch_size, vocab_size):
|
|
torch.manual_seed(42)
|
|
logits = torch.randn(batch_size, vocab_size, device="cuda:0")
|
|
num_trails = 5000
|
|
for _ in range(num_trails):
|
|
samples = flashinfer.sampling.sampling_from_logits(logits)
|
|
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
|
|
|
|
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
@pytest.mark.parametrize(
|
|
"distribution",
|
|
[
|
|
normal_distribution(1),
|
|
normal_distribution(5),
|
|
gumbel_distribution(0.1),
|
|
],
|
|
)
|
|
def test_sampling_from_logits_freq(vocab_size, distribution):
|
|
torch.manual_seed(42)
|
|
num_trials = 5000000
|
|
logits = distribution((1, vocab_size), "cuda:0")
|
|
probs = torch.softmax(logits, dim=-1)
|
|
counter = torch.zeros(vocab_size, dtype=torch.int32, device=logits.device)
|
|
samples = flashinfer.sampling.sampling_from_logits(
|
|
logits, indices=torch.zeros(num_trials, dtype=torch.int32, device=logits.device)
|
|
)
|
|
counter.scatter_add_(0, samples.long(), torch.ones_like(samples))
|
|
freq = counter.float() / num_trials
|
|
similarity = torch.cosine_similarity(freq, probs)
|
|
assert similarity > 0.99, f"similarity: {similarity}"
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
@pytest.mark.parametrize("p", [0.1, 0.5, 0.9])
|
|
def test_top_p_sampling(batch_size, vocab_size, p):
|
|
torch.manual_seed(42)
|
|
eps = 1e-4
|
|
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
|
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
|
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
|
cdf = torch.cumsum(sorted_prob, dim=-1)
|
|
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
|
|
mask.scatter_add_(1, indices, (cdf > (1 - p) - eps).int())
|
|
|
|
num_trails = 1000
|
|
for _ in range(num_trails):
|
|
samples = flashinfer.sampling.top_p_sampling_from_probs(normalized_prob, p)
|
|
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
|
|
assert torch.all(mask[torch.arange(batch_size), samples] == 1)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
@pytest.mark.parametrize("k", [10, 100, 500])
|
|
def test_top_k_sampling(batch_size, vocab_size, k):
|
|
if k > vocab_size:
|
|
pytest.skip("k should be less than vocab_size")
|
|
torch.manual_seed(42)
|
|
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
|
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
|
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
|
|
pivot = sorted_prob[:, k - 1]
|
|
mask = (normalized_prob >= pivot.unsqueeze(-1)).int()
|
|
|
|
num_trails = 1000
|
|
for _ in range(num_trails):
|
|
samples = flashinfer.sampling.top_k_sampling_from_probs(normalized_prob, k)
|
|
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
|
|
assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[
|
|
torch.arange(batch_size), samples
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
@pytest.mark.parametrize("k", [10, 100, 500])
|
|
def test_top_k_sampling_with_variable_k(batch_size, vocab_size, k):
|
|
if k > vocab_size:
|
|
pytest.skip("k should be less than vocab_size")
|
|
torch.manual_seed(42)
|
|
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
|
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
|
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
|
|
k = torch.randint(1, k + 1, (batch_size,), device="cuda:0")
|
|
pivot = sorted_prob[torch.arange(batch_size), k - 1]
|
|
mask = (normalized_prob >= pivot.unsqueeze(-1)).int()
|
|
|
|
num_trails = 1000
|
|
for _ in range(num_trails):
|
|
samples = flashinfer.sampling.top_k_sampling_from_probs(normalized_prob, k)
|
|
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
|
|
assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[
|
|
torch.arange(batch_size), samples
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
@pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1])
|
|
def test_min_p_sampling(batch_size, vocab_size, p):
|
|
torch.manual_seed(42)
|
|
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
|
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
|
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
|
# scale min-p
|
|
top_probs = sorted_prob[:, -1].unsqueeze(-1)
|
|
scaled_p = p * top_probs
|
|
# min-p mask
|
|
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
|
|
mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int())
|
|
min_p_tensor = torch.full((batch_size,), p, device="cuda:0")
|
|
|
|
num_trails = 1000
|
|
for _ in range(num_trails):
|
|
samples = flashinfer.sampling.min_p_sampling_from_probs(
|
|
normalized_prob,
|
|
min_p_tensor,
|
|
)
|
|
|
|
assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[
|
|
torch.nonzero(mask[torch.arange(batch_size), samples] == 0)
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
@pytest.mark.parametrize("p", [0.1, 0.5])
|
|
def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
|
|
torch.manual_seed(42)
|
|
if p == 0.1:
|
|
k = int(vocab_size * 0.5)
|
|
elif p == 0.5:
|
|
k = int(vocab_size * 0.1)
|
|
else:
|
|
raise ValueError("p not recognized")
|
|
eps = 1e-4
|
|
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
|
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
|
# top-p mask
|
|
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
|
cdf = torch.cumsum(sorted_prob, dim=-1)
|
|
mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
|
|
mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int())
|
|
# top-k mask
|
|
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
|
|
pivot = sorted_prob[:, k - 1]
|
|
mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int()
|
|
# overall mask
|
|
mask = torch.minimum(mask_top_p, mask_top_k)
|
|
top_p_tensor = torch.full((batch_size,), p, device="cuda:0")
|
|
top_k_tensor = torch.full((batch_size,), k, device="cuda:0")
|
|
|
|
num_trails = 1000
|
|
for _ in range(num_trails):
|
|
samples = flashinfer.sampling.top_k_top_p_sampling_from_probs(
|
|
normalized_prob,
|
|
top_k_tensor,
|
|
top_p_tensor,
|
|
filter_apply_order="joint",
|
|
)
|
|
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
|
|
assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[
|
|
torch.arange(batch_size), samples
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
@pytest.mark.parametrize("k", [100])
|
|
@pytest.mark.parametrize("p", [0.1, 0.5])
|
|
def test_top_k_top_p_sampling_from_probs_logits_alignment(batch_size, vocab_size, k, p):
|
|
torch.manual_seed(42)
|
|
logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5
|
|
generator_logits = torch.Generator("cuda:0")
|
|
generator_probs = generator_logits.clone_state()
|
|
samples = flashinfer.sampling.top_k_top_p_sampling_from_logits(
|
|
logits, k, p, filter_apply_order="top_k_first", generator=generator_logits
|
|
)
|
|
samples_ref = flashinfer.sampling.top_k_top_p_sampling_from_probs(
|
|
torch.softmax(logits, dim=-1),
|
|
k,
|
|
p,
|
|
filter_apply_order="top_k_first",
|
|
generator=generator_probs,
|
|
)
|
|
assert torch.all(samples == samples_ref)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
@pytest.mark.parametrize("p", [0.1, 0.5])
|
|
def test_top_k_top_p_joint_sampling_from_logits(batch_size, vocab_size, p):
|
|
torch.manual_seed(42)
|
|
logits = torch.rand(batch_size, vocab_size, device="cuda:0") * 5
|
|
generator_logits = torch.Generator("cuda:0")
|
|
generator_probs = generator_logits.clone_state()
|
|
if p == 0.1:
|
|
k = int(vocab_size * 0.5)
|
|
elif p == 0.5:
|
|
k = int(vocab_size * 0.1)
|
|
else:
|
|
raise ValueError("p not recognized")
|
|
|
|
samples = flashinfer.sampling.top_k_top_p_sampling_from_logits(
|
|
logits, k, p, filter_apply_order="joint", generator=generator_logits
|
|
)
|
|
|
|
samples_ref = flashinfer.sampling.top_k_top_p_sampling_from_probs(
|
|
torch.softmax(logits, dim=-1),
|
|
k,
|
|
p,
|
|
filter_apply_order="joint",
|
|
generator=generator_probs,
|
|
)
|
|
assert torch.all(samples == samples_ref)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
@pytest.mark.parametrize("p", [0.1, 0.5, 0.9, 1.0])
|
|
def test_top_p_renorm_probs(batch_size, vocab_size, p):
|
|
torch.manual_seed(42)
|
|
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
|
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
|
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
|
cdf = torch.cumsum(sorted_prob, dim=-1)
|
|
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
|
|
mask.scatter_add_(1, indices, (cdf >= (1 - p)).int())
|
|
renorm_prob_ground_truth = normalized_prob.clone()
|
|
renorm_prob_ground_truth[mask == 0] = 0
|
|
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
|
|
dim=-1, keepdim=True
|
|
)
|
|
|
|
renorm_prob = flashinfer.sampling.top_p_renorm_probs(normalized_prob, p)
|
|
torch.testing.assert_close(
|
|
renorm_prob_ground_truth,
|
|
renorm_prob,
|
|
rtol=1e-3,
|
|
atol=1e-3,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
@pytest.mark.parametrize("k", [10, 100, 500])
|
|
def test_top_k_renorm_probs(batch_size, vocab_size, k):
|
|
if k > vocab_size:
|
|
pytest.skip("k should be less than vocab_size")
|
|
torch.manual_seed(42)
|
|
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
|
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
|
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
|
|
pivot = sorted_prob[:, k - 1]
|
|
mask = (normalized_prob >= pivot.unsqueeze(-1)).int()
|
|
renorm_prob_ground_truth = normalized_prob.clone()
|
|
renorm_prob_ground_truth[mask == 0] = 0
|
|
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
|
|
dim=-1, keepdim=True
|
|
)
|
|
|
|
renorm_prob = flashinfer.sampling.top_k_renorm_probs(normalized_prob, k)
|
|
for i in range(batch_size):
|
|
torch.testing.assert_close(
|
|
renorm_prob_ground_truth[i],
|
|
renorm_prob[i],
|
|
rtol=1e-3,
|
|
atol=1e-3,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
@pytest.mark.parametrize("k", [10, 100, 500])
|
|
@pytest.mark.parametrize("neginf_input", [False, True])
|
|
def test_top_k_mask_logits(batch_size, vocab_size, k, neginf_input):
|
|
if k > vocab_size:
|
|
pytest.skip("k should be less than vocab_size")
|
|
torch.manual_seed(42)
|
|
logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5
|
|
if neginf_input:
|
|
num_neginf = torch.randint(1, vocab_size * batch_size, (1,)).item()
|
|
idxs = torch.randperm(batch_size * vocab_size, device="cuda:0")[:num_neginf]
|
|
logits[idxs // vocab_size, idxs % vocab_size] = -float("inf")
|
|
probs = torch.softmax(logits, dim=-1)
|
|
masked_logits = flashinfer.sampling.top_k_mask_logits(logits, k)
|
|
renormed_probs = torch.softmax(masked_logits, dim=-1)
|
|
renormed_probs_ref = flashinfer.sampling.top_k_renorm_prob(probs, k)
|
|
|
|
torch.testing.assert_close(
|
|
renormed_probs,
|
|
renormed_probs_ref,
|
|
rtol=1e-3,
|
|
atol=1e-3,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
|
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
|
@pytest.mark.parametrize("num_speculate_tokens", [1, 3, 5, 7])
|
|
@pytest.mark.parametrize("onehot_target", [False, True])
|
|
def test_chain_speculative_sampling(
|
|
batch_size,
|
|
vocab_size,
|
|
num_speculate_tokens,
|
|
onehot_target,
|
|
):
|
|
pre_norm_draft_prob = torch.rand(
|
|
batch_size, num_speculate_tokens, vocab_size, device="cuda:0"
|
|
)
|
|
normalized_draft_prob = pre_norm_draft_prob / pre_norm_draft_prob.sum(
|
|
dim=-1, keepdim=True
|
|
)
|
|
draft_token_ids = torch.randint(
|
|
vocab_size, (batch_size, num_speculate_tokens), device="cuda:0"
|
|
)
|
|
if not onehot_target:
|
|
pre_norm_target_prob = torch.rand(
|
|
batch_size, num_speculate_tokens + 1, vocab_size, device="cuda:0"
|
|
)
|
|
target_onehot_prob = pre_norm_target_prob / pre_norm_target_prob.sum(
|
|
dim=-1, keepdim=True
|
|
)
|
|
else:
|
|
target_token_ids = torch.randint(
|
|
vocab_size, (batch_size, num_speculate_tokens + 1), device="cuda:0"
|
|
)
|
|
target_token_ids[..., :num_speculate_tokens] = draft_token_ids
|
|
target_onehot_prob = torch.zeros(
|
|
(batch_size, num_speculate_tokens + 1, vocab_size), device="cuda:0"
|
|
)
|
|
target_onehot_prob.scatter_(2, target_token_ids.unsqueeze(-1), 1)
|
|
|
|
# NOTE(Zihao): this is a very simple test that only checks whether output is valid or not.
|
|
for trials in range(10): # noqa: B007
|
|
accepted_num = torch.zeros(batch_size, dtype=torch.int32, device="cuda:0")
|
|
emitted_num = torch.zeros(batch_size, dtype=torch.int32, device="cuda:0")
|
|
(
|
|
output_token_ids,
|
|
accepted_num,
|
|
emitted_num,
|
|
) = flashinfer.sampling.chain_speculative_sampling(
|
|
normalized_draft_prob,
|
|
draft_token_ids,
|
|
target_onehot_prob,
|
|
accepted_num,
|
|
emitted_num,
|
|
)
|
|
if onehot_target:
|
|
assert torch.all(output_token_ids == target_token_ids)
|
|
else:
|
|
assert torch.all(output_token_ids[output_token_ids >= 0] < vocab_size)
|
|
assert output_token_ids.shape == (batch_size, num_speculate_tokens + 1)
|
|
matches = output_token_ids[..., :-1] != draft_token_ids
|
|
for row in range(batch_size):
|
|
mismatch_idx = torch.nonzero(matches[row], as_tuple=True)[0]
|
|
if len(mismatch_idx) > 0:
|
|
# mismatch_idx should be contiguous
|
|
assert torch.all(mismatch_idx[1:] == mismatch_idx[:-1] + 1)
|
|
# from the second mismatched token on, the output tokens should be -1
|
|
assert torch.all(output_token_ids[row, mismatch_idx[0] + 1 :] == -1)
|
|
|
|
assert torch.all(emitted_num + 1 == (output_token_ids != -1).sum(dim=1))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# test_sampling_freq(128256, gumbel_distribution(0.1), 0.5)
|
|
test_sampling_from_logits_freq(128256, gumbel_distribution(0.1))
|
|
# test_top_p_sampling_freq(128256, gumbel_distribution(0.1), 0.5)
|
|
# test_top_k_sampling_freq(1, 128256, 10)
|
|
# test_sampling(19, 500)
|
|
# test_sampling(1, 111)
|
|
# test_top_p_sampling(3, 111, 0.9)
|
|
# test_top_k_sampling(3, 111, 10)
|
|
# test_top_p_renorm_probs(3, 111, 0.9)
|
|
# test_top_k_renorm_probs(3, 111, 10)
|
|
# test_top_k_mask_logits(99, 989, 10)
|
|
# test_chain_speculative_sampling(3, 111, 3, False)
|
|
# test_chain_speculative_sampling(3, 111, 3, True)
|