sglang_v0.5.2/sglang/test/srt/test_forward_split_prefill.py

300 lines
11 KiB
Python

"""
Test forward_split_prefill functionality.
Usage:
python3 -m unittest test_forward_split_prefill.TestForwardSplitPrefill
or
python3 test_forward_split_prefill.py
"""
import time
import unittest
import numpy as np
import torch
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
class TestForwardSplitPrefill(CustomTestCase):
"""Test cases for forward_split_prefill functionality."""
@classmethod
def setUpClass(cls):
"""Set up the test environment once for all tests."""
cls.model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.tp_size = 1
cls.device = "cuda"
# Initialize server args
cls.server_args = ServerArgs(
model_path=cls.model_path,
tokenizer_path=cls.model_path,
host="127.0.0.1",
disable_cuda_graph=True, # Disable CUDA graph for testing split prefill
disable_hybrid_swa_memory=True,
port=30000,
tp_size=cls.tp_size,
mem_fraction_static=0.8,
trust_remote_code=True,
)
cls.port_args = PortArgs.init_new(cls.server_args)
# Load model and tokenizer
cls.model_config = ModelConfig.from_server_args(cls.server_args)
cls.model_runner = ModelRunner(
model_config=cls.model_config,
mem_fraction_static=cls.server_args.mem_fraction_static,
gpu_id=0,
tp_rank=0,
tp_size=cls.tp_size,
pp_rank=0,
pp_size=1,
nccl_port=cls.port_args.nccl_port,
server_args=cls.server_args,
)
cls.tokenizer = get_tokenizer(
cls.server_args.tokenizer_path,
tokenizer_mode=cls.server_args.tokenizer_mode,
trust_remote_code=cls.server_args.trust_remote_code,
)
print(
f"Test with model: {cls.model_path}, num_hidden_layers: {cls.model_config.num_hidden_layers}"
)
def prepare_test_batch(self, batch_size=2, input_len=128, is_split_prefill=True):
"""Prepare a test batch for split prefill testing."""
# Create synthetic input
input_ids = np.random.randint(10, 1000, (batch_size, input_len), dtype=np.int32)
sampling_params = SamplingParams(
temperature=0.0,
max_new_tokens=8,
)
reqs = []
for i in range(batch_size):
req = Req(
rid=i,
origin_input_text="",
origin_input_ids=list(input_ids[i]),
sampling_params=sampling_params,
)
req.prefix_indices = []
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
reqs.append(req)
batch = ScheduleBatch.init_new(
reqs=reqs,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
tree_cache=None,
model_config=self.model_config,
enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE,
enable_custom_logit_processor=False,
)
if is_split_prefill:
batch.prepare_for_split_prefill()
else:
batch.prepare_for_extend()
# Create forward batch
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
return forward_batch
def test_split_prefill_functionality(self):
"""Test that split prefill can complete successfully."""
print("\n=== Testing split prefill functionality ===")
forward_batch = self.prepare_test_batch(batch_size=2, input_len=64)
# Reset split index
forward_batch.split_index = 0
# Test split prefill in chunks
num_layers = self.model_config.num_hidden_layers
chunk_size = max(1, num_layers // 4) # Split into 4 chunks
results = []
split_count = 0
while forward_batch.split_index < num_layers:
print(
f"Processing split {split_count}, split_index: {forward_batch.split_index}"
)
result = self.model_runner.forward_split_prefill(
forward_batch=forward_batch,
reinit_attn_backend=(split_count == 0),
forward_count=chunk_size,
)
results.append(result)
split_count += 1
# Verify split_index is updated correctly
expected_next_index = min(split_count * chunk_size, num_layers)
self.assertEqual(forward_batch.split_index, expected_next_index)
# The last result should contain logits
self.assertIsNotNone(results[-1], "Final split should return logits")
print(f"Split prefill completed in {split_count} splits")
def test_split_prefill_vs_normal_prefill(self):
"""Test that split prefill produces the same results as normal prefill."""
print("\n=== Testing split prefill vs normal prefill consistency ===")
forward_batch_normal = self.prepare_test_batch(
batch_size=2, input_len=128, is_split_prefill=False
)
forward_batch_split = self.prepare_test_batch(
batch_size=2, input_len=128, is_split_prefill=True
)
# Ensure same input
forward_batch_split.input_ids = forward_batch_normal.input_ids.clone()
forward_batch_split.positions = forward_batch_normal.positions.clone()
# Method 1: Normal extend (prefill)
print("Running normal extend (prefill)...")
normal_result = self.model_runner.forward_extend(forward_batch_normal)
# Method 2: Split prefill
print("Running split prefill...")
num_layers = self.model_config.num_hidden_layers
chunk_size = max(1, num_layers // 3) # Split into 3 chunks
split_result = None
while forward_batch_split.split_index < num_layers:
result = self.model_runner.forward_split_prefill(
forward_batch=forward_batch_split,
forward_count=chunk_size,
)
if result is not None:
split_result = result
# Compare results
self.assertIsNotNone(normal_result, "Normal prefill should return result")
self.assertIsNotNone(split_result, "Split prefill should return result")
# Compare logits shapes
self.assertEqual(
normal_result.next_token_logits.shape,
split_result.next_token_logits.shape,
"Logits shapes should match",
)
# Compare logits values (should be very close due to same computation)
# Use a larger tolerance for numerical differences in split computation
torch.testing.assert_close(
normal_result.next_token_logits,
split_result.next_token_logits,
rtol=1e-3,
atol=1e-3,
msg="Split prefill and normal prefill should produce similar logits",
)
print("✓ Split prefill and normal prefill produce consistent results")
def test_split_prefill_different_chunk_sizes(self):
"""Test split prefill with different chunk sizes."""
print("\n=== Testing split prefill with different chunk sizes ===")
num_layers = self.model_config.num_hidden_layers
chunk_sizes = [1, 2, max(1, num_layers // 2), num_layers]
# Prepare identical batches for each test
base_batch = self.prepare_test_batch(batch_size=1, input_len=16)
base_input_ids = base_batch.input_ids.clone()
base_positions = base_batch.positions.clone()
results = []
for chunk_size in chunk_sizes:
if chunk_size > num_layers:
continue
print(f"Testing chunk size: {chunk_size}")
# Prepare fresh batch
forward_batch = self.prepare_test_batch(batch_size=1, input_len=16)
forward_batch.input_ids = base_input_ids.clone()
forward_batch.positions = base_positions.clone()
forward_batch.split_index = 0
# Run split prefill
split_result = None
while forward_batch.split_index < num_layers:
result = self.model_runner.forward_split_prefill(
forward_batch=forward_batch,
forward_count=chunk_size,
)
if result is not None:
split_result = result
self.assertIsNotNone(
split_result,
f"Split prefill should succeed with chunk_size={chunk_size}",
)
results.append(split_result)
# Compare all results should be identical (same input, same computation)
if len(results) > 1:
for i, result in enumerate(results[1:], 1):
torch.testing.assert_close(
results[0].next_token_logits,
result.next_token_logits,
rtol=1e-3,
atol=1e-3,
msg=f"Results with different chunk sizes should be identical (chunk_size {chunk_sizes[i]})",
)
print("✓ All chunk sizes produce consistent results")
def test_split_prefill_edge_cases(self):
"""Test edge cases for split prefill."""
print("\n=== Testing split prefill edge cases ===")
# Test with single layer chunks
forward_batch = self.prepare_test_batch(batch_size=1, input_len=8)
# Process one layer at a time
num_layers = self.model_config.num_hidden_layers
for layer_idx in range(num_layers):
result = self.model_runner.forward_split_prefill(
forward_batch=forward_batch,
reinit_attn_backend=(layer_idx == 0),
forward_count=1, # One layer at a time
)
if layer_idx == num_layers - 1:
# Last layer should return result
self.assertIsNotNone(result, "Last layer should return logits")
else:
# Intermediate layers should return None
self.assertIsNone(result, f"Layer {layer_idx} should return None")
print("✓ Single layer processing works correctly")
if __name__ == "__main__":
unittest.main()