inference/sglang/test/lang/test_srt_backend.py

87 lines
1.9 KiB
Python

"""
Usage:
python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens
python3 -m unittest test_srt_backend.TestSRTBackend.test_hellaswag_select
"""
import unittest
import sglang as sgl
from sglang.test.test_programs import (
test_decode_int,
test_decode_json_regex,
test_dtype_gen,
test_expert_answer,
test_few_shot_qa,
test_gen_min_new_tokens,
test_hellaswag_select,
test_mt_bench,
test_parallel_decoding,
test_regex,
test_select,
test_stream,
test_tool_use,
)
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, CustomTestCase
class TestSRTBackend(CustomTestCase):
backend = None
@classmethod
def setUpClass(cls):
cls.backend = sgl.Runtime(
model_path=DEFAULT_MODEL_NAME_FOR_TEST, cuda_graph_max_bs=4
)
sgl.set_default_backend(cls.backend)
@classmethod
def tearDownClass(cls):
cls.backend.shutdown()
def test_few_shot_qa(self):
test_few_shot_qa()
def test_mt_bench(self):
test_mt_bench()
def test_select(self):
test_select(check_answer=False)
def test_decode_int(self):
test_decode_int()
def test_decode_json_regex(self):
test_decode_json_regex()
def test_expert_answer(self):
test_expert_answer()
def test_tool_use(self):
test_tool_use()
def test_parallel_decoding(self):
test_parallel_decoding()
def test_stream(self):
test_stream()
def test_regex(self):
test_regex()
def test_dtype_gen(self):
test_dtype_gen()
def test_hellaswag_select(self):
# Run twice to capture more bugs
for _ in range(2):
accuracy, latency = test_hellaswag_select()
self.assertGreater(accuracy, 0.60)
def test_gen_min_new_tokens(self):
test_gen_min_new_tokens()
if __name__ == "__main__":
unittest.main()