512 lines
19 KiB
Python
512 lines
19 KiB
Python
import copy
|
|
import unittest
|
|
|
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
|
from sglang.test.test_utils import (
|
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
|
DEFAULT_URL_FOR_TEST,
|
|
CustomTestCase,
|
|
)
|
|
|
|
|
|
class TestGenerateReqInputNormalization(CustomTestCase):
|
|
"""Test the normalization of GenerateReqInput for batch processing and different input formats."""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
|
|
def setUp(self):
|
|
# Common setup for all tests
|
|
self.base_req = GenerateReqInput(
|
|
text=["Hello", "World"],
|
|
sampling_params=[{}, {}],
|
|
rid=["id1", "id2"],
|
|
)
|
|
|
|
def test_single_image_to_list_of_lists(self):
|
|
"""Test that a single image is converted to a list of single-image lists."""
|
|
req = copy.deepcopy(self.base_req)
|
|
req.image_data = "single_image.jpg" # A single image (non-list)
|
|
|
|
req.normalize_batch_and_arguments()
|
|
|
|
# Should be converted to [[image], [image]]
|
|
self.assertEqual(len(req.image_data), 2)
|
|
self.assertEqual(len(req.image_data[0]), 1)
|
|
self.assertEqual(len(req.image_data[1]), 1)
|
|
self.assertEqual(req.image_data[0][0], "single_image.jpg")
|
|
self.assertEqual(req.image_data[1][0], "single_image.jpg")
|
|
|
|
# Check modalities
|
|
self.assertEqual(req.modalities, ["image", "image"])
|
|
|
|
def test_list_of_images_to_list_of_lists(self):
|
|
"""Test that a list of images is converted to a list of single-image lists."""
|
|
req = copy.deepcopy(self.base_req)
|
|
req.image_data = ["image1.jpg", "image2.jpg"] # List of images
|
|
|
|
req.normalize_batch_and_arguments()
|
|
|
|
# Should be converted to [[image1], [image2]]
|
|
self.assertEqual(len(req.image_data), 2)
|
|
self.assertEqual(len(req.image_data[0]), 1)
|
|
self.assertEqual(len(req.image_data[1]), 1)
|
|
self.assertEqual(req.image_data[0][0], "image1.jpg")
|
|
self.assertEqual(req.image_data[1][0], "image2.jpg")
|
|
|
|
# Check modalities
|
|
self.assertEqual(req.modalities, ["image", "image"])
|
|
|
|
def test_list_of_lists_with_different_modalities(self):
|
|
"""Test handling of list of lists of images with different modalities."""
|
|
req = copy.deepcopy(self.base_req)
|
|
req.image_data = [
|
|
["image1.jpg"], # Single image (image modality)
|
|
["image2.jpg", "image3.jpg"], # Multiple images (multi-images modality)
|
|
]
|
|
|
|
req.normalize_batch_and_arguments()
|
|
|
|
# Structure should remain the same
|
|
self.assertEqual(len(req.image_data), 2)
|
|
self.assertEqual(len(req.image_data[0]), 1)
|
|
self.assertEqual(len(req.image_data[1]), 2)
|
|
|
|
# Check modalities
|
|
self.assertEqual(req.modalities, ["image", "multi-images"])
|
|
|
|
def test_list_of_lists_with_none_values(self):
|
|
"""Test handling of list of lists with None values."""
|
|
req = copy.deepcopy(self.base_req)
|
|
req.image_data = [
|
|
[None], # None value
|
|
["image.jpg"], # Single image
|
|
]
|
|
|
|
req.normalize_batch_and_arguments()
|
|
|
|
# Structure should remain the same
|
|
self.assertEqual(len(req.image_data), 2)
|
|
self.assertEqual(len(req.image_data[0]), 1)
|
|
self.assertEqual(len(req.image_data[1]), 1)
|
|
|
|
# Check modalities
|
|
self.assertEqual(req.modalities, [None, "image"])
|
|
|
|
def test_expanding_parallel_sample_correlation(self):
|
|
"""Test that when expanding with parallel samples, prompts, images and modalities are properly correlated."""
|
|
req = copy.deepcopy(self.base_req)
|
|
req.text = ["Prompt 1", "Prompt 2"]
|
|
req.image_data = [
|
|
["image1.jpg"],
|
|
["image2.jpg", "image3.jpg"],
|
|
]
|
|
req.sampling_params = {"n": 3} # All prompts get 3 samples
|
|
|
|
# Define expected values before normalization
|
|
expected_text = req.text * 3
|
|
expected_images = req.image_data * 3
|
|
expected_modalities = ["image", "multi-images"] * 3
|
|
|
|
req.normalize_batch_and_arguments()
|
|
|
|
# Should be expanded to 6 items (2 original * 3 parallel)
|
|
self.assertEqual(len(req.image_data), 6)
|
|
|
|
# Check that images are properly expanded
|
|
self.assertEqual(req.image_data, expected_images)
|
|
|
|
# Check modalities
|
|
self.assertEqual(req.modalities, expected_modalities)
|
|
|
|
# Ensure that text items are properly duplicated too
|
|
self.assertEqual(req.text, expected_text)
|
|
|
|
def test_specific_parallel_n_per_sample(self):
|
|
"""Test parallel expansion when different samples have different n values."""
|
|
req = copy.deepcopy(self.base_req)
|
|
req.text = ["Prompt 1", "Prompt 2"]
|
|
req.image_data = [
|
|
["image1.jpg"],
|
|
["image2.jpg", "image3.jpg"],
|
|
]
|
|
req.sampling_params = [
|
|
{"n": 2},
|
|
{"n": 2},
|
|
] # First prompt gets 2 samples, second prompt gets 2 samples
|
|
|
|
expected_images = req.image_data * 2
|
|
expected_modalities = ["image", "multi-images"] * 2
|
|
expected_text = req.text * 2
|
|
|
|
req.normalize_batch_and_arguments()
|
|
|
|
# Should be expanded to 4 items (2 original * 2 parallel)
|
|
self.assertEqual(len(req.image_data), 4)
|
|
|
|
# Check that the first 2 are copies for the first prompt
|
|
self.assertEqual(req.image_data, expected_images)
|
|
|
|
# Check modalities
|
|
self.assertEqual(req.modalities, expected_modalities)
|
|
|
|
# Check text expansion
|
|
self.assertEqual(req.text, expected_text)
|
|
|
|
def test_mixed_none_and_images_with_parallel_samples(self):
|
|
"""Test that when some batch items have images and others None, parallel expansion works correctly."""
|
|
req = copy.deepcopy(self.base_req)
|
|
req.text = ["Prompt 1", "Prompt 2", "Prompt 3"]
|
|
req.image_data = [
|
|
["image1.jpg"],
|
|
None,
|
|
["image3_1.jpg", "image3_2.jpg"],
|
|
]
|
|
req.sampling_params = {"n": 2} # All prompts get 2 samples
|
|
|
|
expected_images = req.image_data * 2
|
|
expected_modalities = ["image", None, "multi-images"] * 2
|
|
expected_text = req.text * 2
|
|
|
|
req.normalize_batch_and_arguments()
|
|
|
|
# Should be expanded to 6 items (3 original * 2 parallel)
|
|
self.assertEqual(len(req.image_data), 6)
|
|
|
|
# Check image data
|
|
self.assertEqual(req.image_data, expected_images)
|
|
|
|
# Check modalities
|
|
self.assertEqual(req.modalities, expected_modalities)
|
|
|
|
# Check text expansion
|
|
self.assertEqual(req.text, expected_text)
|
|
|
|
def test_correlation_with_sampling_params(self):
|
|
"""Test that sampling parameters are correctly correlated with prompts during expansion."""
|
|
req = copy.deepcopy(self.base_req)
|
|
req.text = ["Prompt 1", "Prompt 2"]
|
|
req.image_data = [
|
|
["image1.jpg"],
|
|
["image2.jpg"],
|
|
]
|
|
req.sampling_params = [
|
|
{"temperature": 0.7, "n": 2},
|
|
{"temperature": 0.9, "n": 2},
|
|
]
|
|
|
|
req.normalize_batch_and_arguments()
|
|
|
|
# Check sampling params expansion
|
|
self.assertEqual(len(req.sampling_params), 4)
|
|
self.assertEqual(req.sampling_params[0]["temperature"], 0.7)
|
|
self.assertEqual(req.sampling_params[1]["temperature"], 0.9)
|
|
self.assertEqual(req.sampling_params[2]["temperature"], 0.7)
|
|
self.assertEqual(req.sampling_params[3]["temperature"], 0.9)
|
|
|
|
# Should be expanded to 4 items (2 original * 2 parallel)
|
|
self.assertEqual(len(req.image_data), 4)
|
|
|
|
# Check correlation with images
|
|
self.assertEqual(req.image_data[0], ["image1.jpg"])
|
|
self.assertEqual(req.image_data[1], ["image2.jpg"])
|
|
self.assertEqual(req.image_data[2], ["image1.jpg"])
|
|
self.assertEqual(req.image_data[3], ["image2.jpg"])
|
|
|
|
def test_single_example_with_image(self):
|
|
"""Test handling of single example with image."""
|
|
req = GenerateReqInput(
|
|
text="Hello",
|
|
image_data="single_image.jpg",
|
|
)
|
|
|
|
req.normalize_batch_and_arguments()
|
|
|
|
# For single examples, image_data doesn't get processed into lists
|
|
self.assertEqual(req.image_data, "single_image.jpg")
|
|
self.assertIsNone(req.modalities) # Modalities isn't set for single examples
|
|
|
|
def test_single_to_batch_with_parallel_sampling(self):
|
|
"""Test single example converted to batch with parallel sampling."""
|
|
req = GenerateReqInput(
|
|
text="Hello",
|
|
image_data="single_image.jpg",
|
|
sampling_params={"n": 3}, # parallel_sample_num = 3
|
|
)
|
|
|
|
# Define expected values before normalization
|
|
expected_text = ["Hello"] * 3
|
|
|
|
req.normalize_batch_and_arguments()
|
|
|
|
# Should be converted to batch with text=["Hello"]
|
|
self.assertEqual(req.text, expected_text)
|
|
|
|
# Image should be automatically wrapped to list of lists with length 1*3=3
|
|
self.assertEqual(len(req.image_data), 3)
|
|
self.assertEqual(req.image_data[0][0], "single_image.jpg")
|
|
self.assertEqual(req.image_data[1][0], "single_image.jpg")
|
|
self.assertEqual(req.image_data[2][0], "single_image.jpg")
|
|
|
|
# Modalities should be set for all 3 examples
|
|
self.assertEqual(req.modalities, ["image", "image", "image"])
|
|
|
|
def test_audio_data_handling(self):
|
|
"""Test handling of audio_data."""
|
|
req = copy.deepcopy(self.base_req)
|
|
req.audio_data = "audio.mp3" # Single audio
|
|
|
|
req.normalize_batch_and_arguments()
|
|
|
|
# Should be converted to ["audio.mp3", "audio.mp3"]
|
|
self.assertEqual(len(req.audio_data), 2)
|
|
self.assertEqual(req.audio_data[0], "audio.mp3")
|
|
self.assertEqual(req.audio_data[1], "audio.mp3")
|
|
|
|
# Test with list
|
|
req = copy.deepcopy(self.base_req)
|
|
req.audio_data = ["audio1.mp3", "audio2.mp3"]
|
|
|
|
req.normalize_batch_and_arguments()
|
|
|
|
# Should remain the same
|
|
self.assertEqual(len(req.audio_data), 2)
|
|
self.assertEqual(req.audio_data[0], "audio1.mp3")
|
|
self.assertEqual(req.audio_data[1], "audio2.mp3")
|
|
|
|
def test_input_ids_normalization(self):
|
|
"""Test normalization of input_ids instead of text."""
|
|
# Test single input_ids
|
|
req = GenerateReqInput(input_ids=[1, 2, 3])
|
|
req.normalize_batch_and_arguments()
|
|
self.assertTrue(req.is_single)
|
|
self.assertEqual(req.batch_size, 1)
|
|
|
|
# Test batch input_ids
|
|
req = GenerateReqInput(input_ids=[[1, 2, 3], [4, 5, 6]])
|
|
req.normalize_batch_and_arguments()
|
|
self.assertFalse(req.is_single)
|
|
self.assertEqual(req.batch_size, 2)
|
|
|
|
# Test with parallel sampling
|
|
req = GenerateReqInput(
|
|
input_ids=[[1, 2, 3], [4, 5, 6]], sampling_params={"n": 2}
|
|
)
|
|
req.normalize_batch_and_arguments()
|
|
self.assertEqual(len(req.input_ids), 4) # 2 original * 2 parallel
|
|
|
|
def test_input_embeds_normalization(self):
|
|
"""Test normalization of input_embeds."""
|
|
# Test single input_embeds
|
|
req = GenerateReqInput(input_embeds=[[0.1, 0.2], [0.3, 0.4]])
|
|
req.normalize_batch_and_arguments()
|
|
self.assertTrue(req.is_single)
|
|
self.assertEqual(req.batch_size, 1)
|
|
|
|
# Test batch input_embeds
|
|
req = GenerateReqInput(input_embeds=[[[0.1, 0.2]], [[0.3, 0.4]]])
|
|
req.normalize_batch_and_arguments()
|
|
self.assertFalse(req.is_single)
|
|
self.assertEqual(req.batch_size, 2)
|
|
|
|
def test_lora_path_normalization(self):
|
|
"""Test normalization of lora_path."""
|
|
# Test single lora_path with batch input
|
|
req = GenerateReqInput(text=["Hello", "World"], lora_path="path/to/lora")
|
|
|
|
# Define expected lora_paths before normalization
|
|
expected_lora_paths = ["path/to/lora", "path/to/lora"]
|
|
|
|
req.normalize_batch_and_arguments()
|
|
self.assertEqual(req.lora_path, expected_lora_paths)
|
|
|
|
# Test list of lora_paths
|
|
req = GenerateReqInput(text=["Hello", "World"], lora_path=["path1", "path2"])
|
|
|
|
# Define expected lora_paths before normalization
|
|
expected_lora_paths = ["path1", "path2"]
|
|
|
|
req.normalize_batch_and_arguments()
|
|
self.assertEqual(req.lora_path, expected_lora_paths)
|
|
|
|
# Test with parallel sampling
|
|
req = GenerateReqInput(
|
|
text=["Hello", "World"],
|
|
lora_path=["path1", "path2"],
|
|
sampling_params={"n": 2},
|
|
)
|
|
|
|
# Define expected lora_paths before normalization
|
|
expected_lora_paths = ["path1", "path2"] * 2
|
|
|
|
req.normalize_batch_and_arguments()
|
|
self.assertEqual(req.lora_path, expected_lora_paths)
|
|
|
|
def test_logprob_parameters_normalization(self):
|
|
"""Test normalization of logprob-related parameters."""
|
|
# Test single example
|
|
req = GenerateReqInput(
|
|
text="Hello",
|
|
return_logprob=True,
|
|
logprob_start_len=10,
|
|
top_logprobs_num=5,
|
|
token_ids_logprob=[7, 8, 9],
|
|
)
|
|
req.normalize_batch_and_arguments()
|
|
self.assertEqual(req.return_logprob, True)
|
|
self.assertEqual(req.logprob_start_len, 10)
|
|
self.assertEqual(req.top_logprobs_num, 5)
|
|
self.assertEqual(req.token_ids_logprob, [7, 8, 9])
|
|
|
|
# Test batch with scalar values
|
|
req = GenerateReqInput(
|
|
text=["Hello", "World"],
|
|
return_logprob=True,
|
|
logprob_start_len=10,
|
|
top_logprobs_num=5,
|
|
token_ids_logprob=[7, 8, 9],
|
|
)
|
|
req.normalize_batch_and_arguments()
|
|
self.assertEqual(req.return_logprob, [True, True])
|
|
self.assertEqual(req.logprob_start_len, [10, 10])
|
|
self.assertEqual(req.top_logprobs_num, [5, 5])
|
|
self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [7, 8, 9]])
|
|
|
|
# Test batch with list values
|
|
req = GenerateReqInput(
|
|
text=["Hello", "World"],
|
|
return_logprob=[True, False],
|
|
logprob_start_len=[10, 5],
|
|
top_logprobs_num=[5, 3],
|
|
token_ids_logprob=[[7, 8, 9], [4, 5, 6]],
|
|
return_hidden_states=[False, False, True],
|
|
)
|
|
req.normalize_batch_and_arguments()
|
|
self.assertEqual(req.return_logprob, [True, False])
|
|
self.assertEqual(req.logprob_start_len, [10, 5])
|
|
self.assertEqual(req.top_logprobs_num, [5, 3])
|
|
self.assertEqual(req.token_ids_logprob, [[7, 8, 9], [4, 5, 6]])
|
|
self.assertEqual(req.return_hidden_states, [False, False, True])
|
|
|
|
def test_custom_logit_processor_normalization(self):
|
|
"""Test normalization of custom_logit_processor."""
|
|
# Test single processor
|
|
req = GenerateReqInput(
|
|
text=["Hello", "World"], custom_logit_processor="serialized_processor"
|
|
)
|
|
req.normalize_batch_and_arguments()
|
|
self.assertEqual(
|
|
req.custom_logit_processor, ["serialized_processor", "serialized_processor"]
|
|
)
|
|
|
|
# Test list of processors
|
|
req = GenerateReqInput(
|
|
text=["Hello", "World"], custom_logit_processor=["processor1", "processor2"]
|
|
)
|
|
req.normalize_batch_and_arguments()
|
|
self.assertEqual(req.custom_logit_processor, ["processor1", "processor2"])
|
|
|
|
def test_session_params_handling(self):
|
|
"""Test handling of session_params."""
|
|
# Test with dict
|
|
req = GenerateReqInput(
|
|
text=["Hello", "World"], session_params={"id": "session1", "offset": 10}
|
|
)
|
|
req.normalize_batch_and_arguments()
|
|
self.assertEqual(req.session_params, {"id": "session1", "offset": 10})
|
|
|
|
# Test with list of dicts
|
|
req = GenerateReqInput(
|
|
text=["Hello", "World"],
|
|
session_params=[{"id": "session1"}, {"id": "session2"}],
|
|
)
|
|
req.normalize_batch_and_arguments()
|
|
self.assertEqual(req.session_params, [{"id": "session1"}, {"id": "session2"}])
|
|
|
|
def test_getitem_method(self):
|
|
"""Test the __getitem__ method."""
|
|
req = GenerateReqInput(
|
|
text=["Hello", "World"],
|
|
image_data=[["img1.jpg"], ["img2.jpg"]],
|
|
audio_data=["audio1.mp3", "audio2.mp3"],
|
|
sampling_params=[{"temp": 0.7}, {"temp": 0.8}],
|
|
rid=["id1", "id2"],
|
|
return_logprob=[True, False],
|
|
logprob_start_len=[10, 5],
|
|
top_logprobs_num=[5, 3],
|
|
token_ids_logprob=[[7, 8, 9], [4, 5, 6]],
|
|
stream=True,
|
|
log_metrics=True,
|
|
modalities=["image", "image"],
|
|
lora_path=["path1", "path2"],
|
|
custom_logit_processor=["processor1", "processor2"],
|
|
return_hidden_states=True,
|
|
)
|
|
req.normalize_batch_and_arguments()
|
|
|
|
# Get the first item
|
|
item0 = req[0]
|
|
self.assertEqual(item0.text, "Hello")
|
|
self.assertEqual(item0.image_data, ["img1.jpg"])
|
|
self.assertEqual(item0.audio_data, "audio1.mp3")
|
|
self.assertEqual(item0.sampling_params, {"temp": 0.7})
|
|
self.assertEqual(item0.rid, "id1")
|
|
self.assertEqual(item0.return_logprob, True)
|
|
self.assertEqual(item0.logprob_start_len, 10)
|
|
self.assertEqual(item0.top_logprobs_num, 5)
|
|
self.assertEqual(item0.token_ids_logprob, [7, 8, 9])
|
|
self.assertEqual(item0.stream, True)
|
|
self.assertEqual(item0.log_metrics, True)
|
|
self.assertEqual(item0.modalities, "image")
|
|
self.assertEqual(item0.lora_path, "path1")
|
|
self.assertEqual(item0.custom_logit_processor, "processor1")
|
|
self.assertEqual(item0.return_hidden_states, True)
|
|
|
|
def test_regenerate_rid(self):
|
|
"""Test the regenerate_rid method."""
|
|
req = GenerateReqInput(text="Hello")
|
|
req.normalize_batch_and_arguments()
|
|
|
|
original_rid = req.rid
|
|
new_rid = req.regenerate_rid()
|
|
|
|
self.assertNotEqual(original_rid, new_rid)
|
|
self.assertEqual(req.rid, new_rid)
|
|
|
|
def test_error_cases(self):
|
|
"""Test various error cases."""
|
|
# Test when neither text, input_ids, nor input_embeds is provided
|
|
with self.assertRaises(ValueError):
|
|
req = GenerateReqInput()
|
|
req.normalize_batch_and_arguments()
|
|
|
|
# Test when all of text, input_ids, and input_embeds are provided
|
|
with self.assertRaises(ValueError):
|
|
req = GenerateReqInput(
|
|
text="Hello", input_ids=[1, 2, 3], input_embeds=[[0.1, 0.2]]
|
|
)
|
|
req.normalize_batch_and_arguments()
|
|
|
|
def test_multiple_input_formats(self):
|
|
"""Test different combinations of input formats."""
|
|
# Test with text only
|
|
req = GenerateReqInput(text="Hello")
|
|
req.normalize_batch_and_arguments()
|
|
self.assertTrue(req.is_single)
|
|
|
|
# Test with input_ids only
|
|
req = GenerateReqInput(input_ids=[1, 2, 3])
|
|
req.normalize_batch_and_arguments()
|
|
self.assertTrue(req.is_single)
|
|
|
|
# Test with input_embeds only
|
|
req = GenerateReqInput(input_embeds=[[0.1, 0.2]])
|
|
req.normalize_batch_and_arguments()
|
|
self.assertTrue(req.is_single)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|