import pytest import torch import torch.nn.functional as F from sgl_kernel import tree_speculative_sampling_target_only test_cases = [ ( 1, 1, [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18], [[0, 3, 4, 5], [6, 10, 11, -1]], [3, 2], ), ( 0, # threshold_single 0, # threshold_acc [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18], [[0, 1, 2, -1], [6, 10, 11, -1]], [2, 2], ), ] @pytest.mark.parametrize( "threshold_single, threshold_acc, expected_predicts, expected_accept_index, expected_accept_token_num", test_cases, ) def test_tree_speculative_sampling_target_only( threshold_single, threshold_acc, expected_predicts, expected_accept_index, expected_accept_token_num, ): """ Tests the tree_speculative_sampling_target_only function using Pytest parameterization. """ device = "cuda" candidates = torch.tensor( [ [0, 1, 2, 3, 4, 5], [7, 8, 9, 10, 11, 12], ], dtype=torch.int64, device=device, ) retrive_index = torch.tensor( [ [0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11], ], dtype=torch.int64, device=device, ) retrive_next_token = torch.tensor( [ [1, 2, -1, 4, 5, -1], [4, 2, 3, -1, 5, -1], ], dtype=torch.int64, device=device, ) retrive_next_sibling = torch.tensor( [ [-1, 3, -1, -1, -1, -1], [-1, -1, -1, -1, 1, -1], ], dtype=torch.int64, device=device, ) target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device=device) target_logits[0, 0, 3] = 10 target_logits[0, 3, 4] = 10 target_logits[0, 4, 5] = 10 target_logits[1, 0, 11] = 10 target_logits[1, 4, 12] = 10 for i in range(target_logits.shape[0]): for j in range(target_logits.shape[1]): if torch.max(target_logits[i, j]) < 10: target_logits[i, j, 18] = 10 temperatures = torch.tensor([0.01, 0.01], dtype=torch.float32, device=device) bs, num_draft_tokens = candidates.shape num_spec_step = len(expected_accept_index[0]) predict_shape = (len(expected_predicts),) predicts = torch.full(predict_shape, -1, dtype=torch.int32, device=device) accept_index = torch.full((bs, num_spec_step), -1, dtype=torch.int32, device=device) accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device=device) expanded_temperature = temperatures.unsqueeze(1).unsqueeze(1) target_probs = F.softmax(target_logits / expanded_temperature, dim=-1) draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device=device) coins = torch.rand(bs, num_draft_tokens, device=device, dtype=torch.float32) coins_for_final_sampling = torch.rand(bs, device=device).to(torch.float32) tree_speculative_sampling_target_only( predicts=predicts, accept_index=accept_index, accept_token_num=accept_token_num, candidates=candidates, retrive_index=retrive_index, retrive_next_token=retrive_next_token, retrive_next_sibling=retrive_next_sibling, uniform_samples=coins, uniform_samples_for_final_sampling=coins_for_final_sampling, target_probs=target_probs, draft_probs=draft_probs, threshold_single=threshold_single, threshold_acc=threshold_acc, deterministic=True, ) assert ( predicts.tolist() == expected_predicts ), f"Predicts mismatch for thresholds ({threshold_single}, {threshold_acc})" assert ( accept_index.tolist() == expected_accept_index ), f"Accept index mismatch for thresholds ({threshold_single}, {threshold_acc})" assert ( accept_token_num.tolist() == expected_accept_token_num ), f"Accept token num mismatch for thresholds ({threshold_single}, {threshold_acc})" if __name__ == "__main__": pytest.main([__file__])