sglang.0.4.8.post1/sglang/test/srt/test_io_struct.py

512 lines
19 KiB
Python

import copy
import unittest
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
)
class TestGenerateReqInputNormalization(CustomTestCase):
"""Test the normalization of GenerateReqInput for batch processing and different input formats."""
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
def setUp(self):
# Common setup for all tests
self.base_req = GenerateReqInput(
text=["Hello", "World"],
sampling_params=[{}, {}],
rid=["id1", "id2"],
)
def test_single_image_to_list_of_lists(self):
"""Test that a single image is converted to a list of single-image lists."""
req = copy.deepcopy(self.base_req)
req.image_data = "single_image.jpg" # A single image (non-list)
req.normalize_batch_and_arguments()
# Should be converted to [[image], [image]]
self.assertEqual(len(req.image_data), 2)
self.assertEqual(len(req.image_data[0]), 1)
self.assertEqual(len(req.image_data[1]), 1)
self.assertEqual(req.image_data[0][0], "single_image.jpg")
self.assertEqual(req.image_data[1][0], "single_image.jpg")
# Check modalities
self.assertEqual(req.modalities, ["image", "image"])
def test_list_of_images_to_list_of_lists(self):
"""Test that a list of images is converted to a list of single-image lists."""
req = copy.deepcopy(self.base_req)
req.image_data = ["image1.jpg", "image2.jpg"] # List of images
req.normalize_batch_and_arguments()
# Should be converted to [[image1], [image2]]
self.assertEqual(len(req.image_data), 2)
self.assertEqual(len(req.image_data[0]), 1)
self.assertEqual(len(req.image_data[1]), 1)
self.assertEqual(req.image_data[0][0], "image1.jpg")
self.assertEqual(req.image_data[1][0], "image2.jpg")
# Check modalities
self.assertEqual(req.modalities, ["image", "image"])
def test_list_of_lists_with_different_modalities(self):
"""Test handling of list of lists of images with different modalities."""
req = copy.deepcopy(self.base_req)
req.image_data = [
["image1.jpg"], # Single image (image modality)
["image2.jpg", "image3.jpg"], # Multiple images (multi-images modality)
]
req.normalize_batch_and_arguments()
# Structure should remain the same
self.assertEqual(len(req.image_data), 2)
self.assertEqual(len(req.image_data[0]), 1)
self.assertEqual(len(req.image_data[1]), 2)
# Check modalities
self.assertEqual(req.modalities, ["image", "multi-images"])
def test_list_of_lists_with_none_values(self):
"""Test handling of list of lists with None values."""
req = copy.deepcopy(self.base_req)
req.image_data = [
[None], # None value
["image.jpg"], # Single image
]
req.normalize_batch_and_arguments()
# Structure should remain the same
self.assertEqual(len(req.image_data), 2)
self.assertEqual(len(req.image_data[0]), 1)
self.assertEqual(len(req.image_data[1]), 1)
# Check modalities
self.assertEqual(req.modalities, [None, "image"])
def test_expanding_parallel_sample_correlation(self):
"""Test that when expanding with parallel samples, prompts, images and modalities are properly correlated."""
req = copy.deepcopy(self.base_req)
req.text = ["Prompt 1", "Prompt 2"]
req.image_data = [
["image1.jpg"],
["image2.jpg", "image3.jpg"],
]
req.sampling_params = {"n": 3} # All prompts get 3 samples
# Define expected values before normalization
expected_text = req.text * 3
expected_images = req.image_data * 3
expected_modalities = ["image", "multi-images"] * 3
req.normalize_batch_and_arguments()
# Should be expanded to 6 items (2 original * 3 parallel)
self.assertEqual(len(req.image_data), 6)
# Check that images are properly expanded
self.assertEqual(req.image_data, expected_images)
# Check modalities
self.assertEqual(req.modalities, expected_modalities)
# Ensure that text items are properly duplicated too
self.assertEqual(req.text, expected_text)
def test_specific_parallel_n_per_sample(self):
"""Test parallel expansion when different samples have different n values."""
req = copy.deepcopy(self.base_req)
req.text = ["Prompt 1", "Prompt 2"]
req.image_data = [
["image1.jpg"],
["image2.jpg", "image3.jpg"],
]
req.sampling_params = [
{"n": 2},
{"n": 2},
] # First prompt gets 2 samples, second prompt gets 2 samples
expected_images = req.image_data * 2
expected_modalities = ["image", "multi-images"] * 2
expected_text = req.text * 2
req.normalize_batch_and_arguments()
# Should be expanded to 4 items (2 original * 2 parallel)
self.assertEqual(len(req.image_data), 4)
# Check that the first 2 are copies for the first prompt
self.assertEqual(req.image_data, expected_images)
# Check modalities
self.assertEqual(req.modalities, expected_modalities)
# Check text expansion
self.assertEqual(req.text, expected_text)
def test_mixed_none_and_images_with_parallel_samples(self):
"""Test that when some batch items have images and others None, parallel expansion works correctly."""
req = copy.deepcopy(self.base_req)
req.text = ["Prompt 1", "Prompt 2", "Prompt 3"]
req.image_data = [
["image1.jpg"],
None,
["image3_1.jpg", "image3_2.jpg"],
]
req.sampling_params = {"n": 2} # All prompts get 2 samples
expected_images = req.image_data * 2
expected_modalities = ["image", None, "multi-images"] * 2
expected_text = req.text * 2
req.normalize_batch_and_arguments()
# Should be expanded to 6 items (3 original * 2 parallel)
self.assertEqual(len(req.image_data), 6)
# Check image data
self.assertEqual(req.image_data, expected_images)
# Check modalities
self.assertEqual(req.modalities, expected_modalities)
# Check text expansion
self.assertEqual(req.text, expected_text)
def test_correlation_with_sampling_params(self):
"""Test that sampling parameters are correctly correlated with prompts during expansion."""
req = copy.deepcopy(self.base_req)
req.text = ["Prompt 1", "Prompt 2"]
req.image_data = [
["image1.jpg"],
["image2.jpg"],
]
req.sampling_params = [
{"temperature": 0.7, "n": 2},
{"temperature": 0.9, "n": 2},
]
req.normalize_batch_and_arguments()
# Check sampling params expansion
self.assertEqual(len(req.sampling_params), 4)
self.assertEqual(req.sampling_params[0]["temperature"], 0.7)
self.assertEqual(req.sampling_params[1]["temperature"], 0.9)
self.assertEqual(req.sampling_params[2]["temperature"], 0.7)
self.assertEqual(req.sampling_params[3]["temperature"], 0.9)
# Should be expanded to 4 items (2 original * 2 parallel)
self.assertEqual(len(req.image_data), 4)
# Check correlation with images
self.assertEqual(req.image_data[0], ["image1.jpg"])
self.assertEqual(req.image_data[1], ["image2.jpg"])
self.assertEqual(req.image_data[2], ["image1.jpg"])
self.assertEqual(req.image_data[3], ["image2.jpg"])
def test_single_example_with_image(self):
"""Test handling of single example with image."""
req = GenerateReqInput(
text="Hello",
image_data="single_image.jpg",
)
req.normalize_batch_and_arguments()
# For single examples, image_data doesn't get processed into lists
self.assertEqual(req.image_data, "single_image.jpg")
self.assertIsNone(req.modalities) # Modalities isn't set for single examples
def test_single_to_batch_with_parallel_sampling(self):
"""Test single example converted to batch with parallel sampling."""
req = GenerateReqInput(
text="Hello",
image_data="single_image.jpg",
sampling_params={"n": 3}, # parallel_sample_num = 3
)
# Define expected values before normalization
expected_text = ["Hello"] * 3
req.normalize_batch_and_arguments()
# Should be converted to batch with text=["Hello"]
self.assertEqual(req.text, expected_text)
# Image should be automatically wrapped to list of lists with length 1*3=3
self.assertEqual(len(req.image_data), 3)
self.assertEqual(req.image_data[0][0], "single_image.jpg")
self.assertEqual(req.image_data[1][0], "single_image.jpg")
self.assertEqual(req.image_data[2][0], "single_image.jpg")
# Modalities should be set for all 3 examples
self.assertEqual(req.modalities, ["image", "image", "image"])
def test_audio_data_handling(self):
"""Test handling of audio_data."""
req = copy.deepcopy(self.base_req)
req.audio_data = "audio.mp3" # Single audio
req.normalize_batch_and_arguments()
# Should be converted to ["audio.mp3", "audio.mp3"]
self.assertEqual(len(req.audio_data), 2)
self.assertEqual(req.audio_data[0], "audio.mp3")
self.assertEqual(req.audio_data[1], "audio.mp3")
# Test with list
req = copy.deepcopy(self.base_req)
req.audio_data = ["audio1.mp3", "audio2.mp3"]
req.normalize_batch_and_arguments()
# Should remain the same
self.assertEqual(len(req.audio_data), 2)
self.assertEqual(req.audio_data[0], "audio1.mp3")
self.assertEqual(req.audio_data[1], "audio2.mp3")
def test_input_ids_normalization(self):
"""Test normalization of input_ids instead of text."""
# Test single input_ids
req = GenerateReqInput(input_ids=[1, 2, 3])
req.normalize_batch_and_arguments()
self.assertTrue(req.is_single)
self.assertEqual(req.batch_size, 1)
# Test batch input_ids
req = GenerateReqInput(input_ids=[[1, 2, 3], [4, 5, 6]])
req.normalize_batch_and_arguments()
self.assertFalse(req.is_single)
self.assertEqual(req.batch_size, 2)
# Test with parallel sampling
req = GenerateReqInput(
input_ids=[[1, 2, 3], [4, 5, 6]], sampling_params={"n": 2}
)
req.normalize_batch_and_arguments()
self.assertEqual(len(req.input_ids), 4) # 2 original * 2 parallel
def test_input_embeds_normalization(self):
"""Test normalization of input_embeds."""
# Test single input_embeds
req = GenerateReqInput(input_embeds=[[0.1, 0.2], [0.3, 0.4]])
req.normalize_batch_and_arguments()
self.assertTrue(req.is_single)
self.assertEqual(req.batch_size, 1)
# Test batch input_embeds
req = GenerateReqInput(input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]])
req.normalize_batch_and_arguments()
self.assertFalse(req.is_single)
self.assertEqual(req.batch_size, 2)
def test_lora_path_normalization(self):
"""Test normalization of lora_path."""
# Test single lora_path with batch input
req = GenerateReqInput(text=["Hello", "World"], lora_path="path/to/lora")
# Define expected lora_paths before normalization
expected_lora_paths = ["path/to/lora", "path/to/lora"]
req.normalize_batch_and_arguments()
self.assertEqual(req.lora_path, expected_lora_paths)
# Test list of lora_paths
req = GenerateReqInput(text=["Hello", "World"], lora_path=["path1", "path2"])
# Define expected lora_paths before normalization
expected_lora_paths = ["path1", "path2"]
req.normalize_batch_and_arguments()
self.assertEqual(req.lora_path, expected_lora_paths)
# Test with parallel sampling
req = GenerateReqInput(
text=["Hello", "World"],
lora_path=["path1", "path2"],
sampling_params={"n": 2},
)
# Define expected lora_paths before normalization
expected_lora_paths = ["path1", "path2"] * 2
req.normalize_batch_and_arguments()
self.assertEqual(req.lora_path, expected_lora_paths)
def test_logprob_parameters_normalization(self):
"""Test normalization of logprob-related parameters."""
# Test single example
req = GenerateReqInput(
text="Hello",
return_logprob=True,
logprob_start_len=10,
top_logprobs_num=5,
token_ids_logprob=[7, 8, 9],
)
req.normalize_batch_and_arguments()
self.assertEqual(req.return_logprob, True)
self.assertEqual(req.logprob_start_len, 10)
self.assertEqual(req.top_logprobs_num, 5)
self.assertEqual(req.token_ids_logprob, [7, 8, 9])
# Test batch with scalar values
req = GenerateReqInput(
text=["Hello", "World"],
return_logprob=True,
logprob_start_len=10,
top_logprobs_num=5,
token_ids_logprob=[7, 8, 9],
)
req.normalize_batch_and_arguments()
self.assertEqual(req.return_logprob, [True, True])
self.assertEqual(req.logprob_start_len, [10, 10])
self.assertEqual(req.top_logprobs_num, [5, 5])
self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [7, 8, 9]])
# Test batch with list values
req = GenerateReqInput(
text=["Hello", "World"],
return_logprob=[True, False],
logprob_start_len=[10, 5],
top_logprobs_num=[5, 3],
token_ids_logprob=[[7, 8, 9], [4, 5, 6]],
return_hidden_states=[False, False, True],
)
req.normalize_batch_and_arguments()
self.assertEqual(req.return_logprob, [True, False])
self.assertEqual(req.logprob_start_len, [10, 5])
self.assertEqual(req.top_logprobs_num, [5, 3])
self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [4, 5, 6]])
self.assertEqual(req.return_hidden_states, [False, False, True])
def test_custom_logit_processor_normalization(self):
"""Test normalization of custom_logit_processor."""
# Test single processor
req = GenerateReqInput(
text=["Hello", "World"], custom_logit_processor="serialized_processor"
)
req.normalize_batch_and_arguments()
self.assertEqual(
req.custom_logit_processor, ["serialized_processor", "serialized_processor"]
)
# Test list of processors
req = GenerateReqInput(
text=["Hello", "World"], custom_logit_processor=["processor1", "processor2"]
)
req.normalize_batch_and_arguments()
self.assertEqual(req.custom_logit_processor, ["processor1", "processor2"])
def test_session_params_handling(self):
"""Test handling of session_params."""
# Test with dict
req = GenerateReqInput(
text=["Hello", "World"], session_params={"id": "session1", "offset": 10}
)
req.normalize_batch_and_arguments()
self.assertEqual(req.session_params, {"id": "session1", "offset": 10})
# Test with list of dicts
req = GenerateReqInput(
text=["Hello", "World"],
session_params=[{"id": "session1"}, {"id": "session2"}],
)
req.normalize_batch_and_arguments()
self.assertEqual(req.session_params, [{"id": "session1"}, {"id": "session2"}])
def test_getitem_method(self):
"""Test the __getitem__ method."""
req = GenerateReqInput(
text=["Hello", "World"],
image_data=[["img1.jpg"], ["img2.jpg"]],
audio_data=["audio1.mp3", "audio2.mp3"],
sampling_params=[{"temp": 0.7}, {"temp": 0.8}],
rid=["id1", "id2"],
return_logprob=[True, False],
logprob_start_len=[10, 5],
top_logprobs_num=[5, 3],
token_ids_logprob=[[7, 8, 9], [4, 5, 6]],
stream=True,
log_metrics=True,
modalities=["image", "image"],
lora_path=["path1", "path2"],
custom_logit_processor=["processor1", "processor2"],
return_hidden_states=True,
)
req.normalize_batch_and_arguments()
# Get the first item
item0 = req[0]
self.assertEqual(item0.text, "Hello")
self.assertEqual(item0.image_data, ["img1.jpg"])
self.assertEqual(item0.audio_data, "audio1.mp3")
self.assertEqual(item0.sampling_params, {"temp": 0.7})
self.assertEqual(item0.rid, "id1")
self.assertEqual(item0.return_logprob, True)
self.assertEqual(item0.logprob_start_len, 10)
self.assertEqual(item0.top_logprobs_num, 5)
self.assertEqual(item0.token_ids_logprob, [7, 8, 9])
self.assertEqual(item0.stream, True)
self.assertEqual(item0.log_metrics, True)
self.assertEqual(item0.modalities, "image")
self.assertEqual(item0.lora_path, "path1")
self.assertEqual(item0.custom_logit_processor, "processor1")
self.assertEqual(item0.return_hidden_states, True)
def test_regenerate_rid(self):
"""Test the regenerate_rid method."""
req = GenerateReqInput(text="Hello")
req.normalize_batch_and_arguments()
original_rid = req.rid
new_rid = req.regenerate_rid()
self.assertNotEqual(original_rid, new_rid)
self.assertEqual(req.rid, new_rid)
def test_error_cases(self):
"""Test various error cases."""
# Test when neither text, input_ids, nor input_embeds is provided
with self.assertRaises(ValueError):
req = GenerateReqInput()
req.normalize_batch_and_arguments()
# Test when all of text, input_ids, and input_embeds are provided
with self.assertRaises(ValueError):
req = GenerateReqInput(
text="Hello", input_ids=[1, 2, 3], input_embeds=[[0.1, 0.2]]
)
req.normalize_batch_and_arguments()
def test_multiple_input_formats(self):
"""Test different combinations of input formats."""
# Test with text only
req = GenerateReqInput(text="Hello")
req.normalize_batch_and_arguments()
self.assertTrue(req.is_single)
# Test with input_ids only
req = GenerateReqInput(input_ids=[1, 2, 3])
req.normalize_batch_and_arguments()
self.assertTrue(req.is_single)
# Test with input_embeds only
req = GenerateReqInput(input_embeds=[[0.1, 0.2]])
req.normalize_batch_and_arguments()
self.assertTrue(req.is_single)
if __name__ == "__main__":
unittest.main()