import unittest from types import SimpleNamespace import requests from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, popen_launch_server, ) class TestTorchAO(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=["--torchao-config", "int4wo-128"], ) @classmethod def tearDownClass(cls): kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( base_url=self.base_url, model=self.model, eval_name="mmlu", num_examples=64, num_threads=32, ) metrics = run_eval(args) assert metrics["score"] >= 0.60 def run_decode(self, max_new_tokens): response = requests.post( self.base_url + "/generate", json={ "text": "The capital of France is", "sampling_params": { "temperature": 0, "max_new_tokens": max_new_tokens, }, "ignore_eos": True, }, ) return response.json() def test_throughput(self): import time max_tokens = 256 tic = time.perf_counter() res = self.run_decode(max_tokens) tok = time.perf_counter() print(res["text"]) throughput = max_tokens / (tok - tic) print(f"Throughput: {throughput} tokens/s") assert throughput >= 210 if __name__ == "__main__": unittest.main()