""" Test forward_split_prefill functionality. Usage: python3 -m unittest test_forward_split_prefill.TestForwardSplitPrefill or python3 test_forward_split_prefill.py """ import time import unittest import numpy as np import torch from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase class TestForwardSplitPrefill(CustomTestCase): """Test cases for forward_split_prefill functionality.""" @classmethod def setUpClass(cls): """Set up the test environment once for all tests.""" cls.model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.tp_size = 1 cls.device = "cuda" # Initialize server args cls.server_args = ServerArgs( model_path=cls.model_path, tokenizer_path=cls.model_path, host="127.0.0.1", disable_cuda_graph=True, # Disable CUDA graph for testing split prefill disable_hybrid_swa_memory=True, port=30000, tp_size=cls.tp_size, mem_fraction_static=0.8, trust_remote_code=True, ) cls.port_args = PortArgs.init_new(cls.server_args) # Load model and tokenizer cls.model_config = ModelConfig.from_server_args(cls.server_args) cls.model_runner = ModelRunner( model_config=cls.model_config, mem_fraction_static=cls.server_args.mem_fraction_static, gpu_id=0, tp_rank=0, tp_size=cls.tp_size, pp_rank=0, pp_size=1, nccl_port=cls.port_args.nccl_port, server_args=cls.server_args, ) cls.tokenizer = get_tokenizer( cls.server_args.tokenizer_path, tokenizer_mode=cls.server_args.tokenizer_mode, trust_remote_code=cls.server_args.trust_remote_code, ) print( f"Test with model: {cls.model_path}, num_hidden_layers: {cls.model_config.num_hidden_layers}" ) def prepare_test_batch(self, batch_size=2, input_len=128, is_split_prefill=True): """Prepare a test batch for split prefill testing.""" # Create synthetic input input_ids = np.random.randint(10, 1000, (batch_size, input_len), dtype=np.int32) sampling_params = SamplingParams( temperature=0.0, max_new_tokens=8, ) reqs = [] for i in range(batch_size): req = Req( rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]), sampling_params=sampling_params, ) req.prefix_indices = [] req.fill_ids = req.origin_input_ids req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.logprob_start_len = len(req.origin_input_ids) - 1 reqs.append(req) batch = ScheduleBatch.init_new( reqs=reqs, req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, tree_cache=None, model_config=self.model_config, enable_overlap=False, spec_algorithm=SpeculativeAlgorithm.NONE, enable_custom_logit_processor=False, ) if is_split_prefill: batch.prepare_for_split_prefill() else: batch.prepare_for_extend() # Create forward batch model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) return forward_batch def test_split_prefill_functionality(self): """Test that split prefill can complete successfully.""" print("\n=== Testing split prefill functionality ===") forward_batch = self.prepare_test_batch(batch_size=2, input_len=64) # Reset split index forward_batch.split_index = 0 # Test split prefill in chunks num_layers = self.model_config.num_hidden_layers chunk_size = max(1, num_layers // 4) # Split into 4 chunks results = [] split_count = 0 while forward_batch.split_index < num_layers: print( f"Processing split {split_count}, split_index: {forward_batch.split_index}" ) result = self.model_runner.forward_split_prefill( forward_batch=forward_batch, reinit_attn_backend=(split_count == 0), forward_count=chunk_size, ) results.append(result) split_count += 1 # Verify split_index is updated correctly expected_next_index = min(split_count * chunk_size, num_layers) self.assertEqual(forward_batch.split_index, expected_next_index) # The last result should contain logits self.assertIsNotNone(results[-1], "Final split should return logits") print(f"Split prefill completed in {split_count} splits") def test_split_prefill_vs_normal_prefill(self): """Test that split prefill produces the same results as normal prefill.""" print("\n=== Testing split prefill vs normal prefill consistency ===") forward_batch_normal = self.prepare_test_batch( batch_size=2, input_len=128, is_split_prefill=False ) forward_batch_split = self.prepare_test_batch( batch_size=2, input_len=128, is_split_prefill=True ) # Ensure same input forward_batch_split.input_ids = forward_batch_normal.input_ids.clone() forward_batch_split.positions = forward_batch_normal.positions.clone() # Method 1: Normal extend (prefill) print("Running normal extend (prefill)...") normal_result = self.model_runner.forward_extend(forward_batch_normal) # Method 2: Split prefill print("Running split prefill...") num_layers = self.model_config.num_hidden_layers chunk_size = max(1, num_layers // 3) # Split into 3 chunks split_result = None while forward_batch_split.split_index < num_layers: result = self.model_runner.forward_split_prefill( forward_batch=forward_batch_split, forward_count=chunk_size, ) if result is not None: split_result = result # Compare results self.assertIsNotNone(normal_result, "Normal prefill should return result") self.assertIsNotNone(split_result, "Split prefill should return result") # Compare logits shapes self.assertEqual( normal_result.next_token_logits.shape, split_result.next_token_logits.shape, "Logits shapes should match", ) # Compare logits values (should be very close due to same computation) # Use a larger tolerance for numerical differences in split computation torch.testing.assert_close( normal_result.next_token_logits, split_result.next_token_logits, rtol=1e-3, atol=1e-3, msg="Split prefill and normal prefill should produce similar logits", ) print("✓ Split prefill and normal prefill produce consistent results") def test_split_prefill_different_chunk_sizes(self): """Test split prefill with different chunk sizes.""" print("\n=== Testing split prefill with different chunk sizes ===") num_layers = self.model_config.num_hidden_layers chunk_sizes = [1, 2, max(1, num_layers // 2), num_layers] # Prepare identical batches for each test base_batch = self.prepare_test_batch(batch_size=1, input_len=16) base_input_ids = base_batch.input_ids.clone() base_positions = base_batch.positions.clone() results = [] for chunk_size in chunk_sizes: if chunk_size > num_layers: continue print(f"Testing chunk size: {chunk_size}") # Prepare fresh batch forward_batch = self.prepare_test_batch(batch_size=1, input_len=16) forward_batch.input_ids = base_input_ids.clone() forward_batch.positions = base_positions.clone() forward_batch.split_index = 0 # Run split prefill split_result = None while forward_batch.split_index < num_layers: result = self.model_runner.forward_split_prefill( forward_batch=forward_batch, forward_count=chunk_size, ) if result is not None: split_result = result self.assertIsNotNone( split_result, f"Split prefill should succeed with chunk_size={chunk_size}", ) results.append(split_result) # Compare all results should be identical (same input, same computation) if len(results) > 1: for i, result in enumerate(results[1:], 1): torch.testing.assert_close( results[0].next_token_logits, result.next_token_logits, rtol=1e-3, atol=1e-3, msg=f"Results with different chunk sizes should be identical (chunk_size {chunk_sizes[i]})", ) print("✓ All chunk sizes produce consistent results") def test_split_prefill_edge_cases(self): """Test edge cases for split prefill.""" print("\n=== Testing split prefill edge cases ===") # Test with single layer chunks forward_batch = self.prepare_test_batch(batch_size=1, input_len=8) # Process one layer at a time num_layers = self.model_config.num_hidden_layers for layer_idx in range(num_layers): result = self.model_runner.forward_split_prefill( forward_batch=forward_batch, reinit_attn_backend=(layer_idx == 0), forward_count=1, # One layer at a time ) if layer_idx == num_layers - 1: # Last layer should return result self.assertIsNotNone(result, "Last layer should return logits") else: # Intermediate layers should return None self.assertIsNone(result, f"Layer {layer_idx} should return None") print("✓ Single layer processing works correctly") if __name__ == "__main__": unittest.main()