88 lines
2.8 KiB
Python
88 lines
2.8 KiB
Python
import asyncio
|
|
import os
|
|
import re
|
|
import unittest
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from sglang.srt.utils import kill_process_tree
|
|
from sglang.test.test_utils import (
|
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
DEFAULT_URL_FOR_TEST,
|
|
STDERR_FILENAME,
|
|
STDOUT_FILENAME,
|
|
CustomTestCase,
|
|
popen_launch_server,
|
|
send_concurrent_generate_requests,
|
|
send_generate_requests,
|
|
)
|
|
|
|
|
|
class TestMaxQueuedRequests(CustomTestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
|
|
cls.stdout = open(STDOUT_FILENAME, "w")
|
|
cls.stderr = open(STDERR_FILENAME, "w")
|
|
|
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
cls.process = popen_launch_server(
|
|
cls.model,
|
|
cls.base_url,
|
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
other_args=(
|
|
"--max-running-requests", # Enforce max request concurrency is 1
|
|
"1",
|
|
"--max-queued-requests", # Enforce max queued request number is 1
|
|
"1",
|
|
),
|
|
return_stdout_stderr=(cls.stdout, cls.stderr),
|
|
)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
kill_process_tree(cls.process.pid)
|
|
cls.stdout.close()
|
|
cls.stderr.close()
|
|
os.remove(STDOUT_FILENAME)
|
|
os.remove(STDERR_FILENAME)
|
|
|
|
def test_max_queued_requests_validation_with_serial_requests(self):
|
|
"""Verify request is not throttled when the max concurrency is 1."""
|
|
status_codes = send_generate_requests(
|
|
self.base_url,
|
|
num_requests=10,
|
|
)
|
|
|
|
for status_code in status_codes:
|
|
assert status_code == 200 # request shouldn't be throttled
|
|
|
|
def test_max_queued_requests_validation_with_concurrent_requests(self):
|
|
"""Verify request throttling with concurrent requests."""
|
|
status_codes = asyncio.run(
|
|
send_concurrent_generate_requests(self.base_url, num_requests=10)
|
|
)
|
|
|
|
assert 200 in status_codes
|
|
assert 503 in status_codes
|
|
assert all(status_code in [200, 503] for status_code in status_codes)
|
|
|
|
def test_max_running_requests_and_max_queued_request_validation(self):
|
|
"""Verify running request and queued request numbers based on server logs."""
|
|
rr_pattern = re.compile(r"#running-req:\s*(\d+)")
|
|
qr_pattern = re.compile(r"#queue-req:\s*(\d+)")
|
|
|
|
with open(STDERR_FILENAME) as lines:
|
|
for line in lines:
|
|
rr_match, qr_match = rr_pattern.search(line), qr_pattern.search(line)
|
|
if rr_match:
|
|
assert int(rr_match.group(1)) <= 1
|
|
if qr_match:
|
|
assert int(qr_match.group(1)) <= 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|