import time import unittest from types import SimpleNamespace import requests from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, popen_launch_server, ) class TestTorchCompileMoe(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=["--enable-torch-compile", "--torch-compile-max-bs", "8"], ) @classmethod def tearDownClass(cls): kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( base_url=self.base_url, model=self.model, eval_name="mmlu", num_examples=64, num_threads=32, ) metrics = run_eval(args) self.assertGreaterEqual(metrics["score"], 0.50) def run_decode(self, max_new_tokens): response = requests.post( self.base_url + "/generate", json={ "text": "The capital of France is", "sampling_params": { "temperature": 0, "max_new_tokens": max_new_tokens, "ignore_eos": True, }, }, ) return response.json() def test_throughput(self): # Warmup res = self.run_decode(16) max_tokens = 256 tic = time.time() res = self.run_decode(max_tokens) tok = time.time() print(f"{res=}") throughput = max_tokens / (tok - tic) self.assertGreaterEqual(throughput, 285) if __name__ == "__main__": unittest.main()