""" Usage: python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens python3 -m unittest test_srt_backend.TestSRTBackend.test_hellaswag_select """ import unittest import sglang as sgl from sglang.test.test_programs import ( test_decode_int, test_decode_json_regex, test_dtype_gen, test_expert_answer, test_few_shot_qa, test_gen_min_new_tokens, test_hellaswag_select, test_mt_bench, test_parallel_decoding, test_regex, test_select, test_stream, test_tool_use, ) from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, CustomTestCase class TestSRTBackend(CustomTestCase): backend = None @classmethod def setUpClass(cls): cls.backend = sgl.Runtime( model_path=DEFAULT_MODEL_NAME_FOR_TEST, cuda_graph_max_bs=4 ) sgl.set_default_backend(cls.backend) @classmethod def tearDownClass(cls): cls.backend.shutdown() def test_few_shot_qa(self): test_few_shot_qa() def test_mt_bench(self): test_mt_bench() def test_select(self): test_select(check_answer=False) def test_decode_int(self): test_decode_int() def test_decode_json_regex(self): test_decode_json_regex() def test_expert_answer(self): test_expert_answer() def test_tool_use(self): test_tool_use() def test_parallel_decoding(self): test_parallel_decoding() def test_stream(self): test_stream() def test_regex(self): test_regex() def test_dtype_gen(self): test_dtype_gen() def test_hellaswag_select(self): # Run twice to capture more bugs for _ in range(2): accuracy, latency = test_hellaswag_select() self.assertGreater(accuracy, 0.60) def test_gen_min_new_tokens(self): test_gen_min_new_tokens() if __name__ == "__main__": unittest.main()