import time import unittest from types import SimpleNamespace from urllib.parse import urlparse from sglang.test.few_shot_gsm8k import run_eval from sglang.test.test_disaggregation_utils import TestDisaggregationBase from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_pd_server, ) class TestDisaggregationPPAccuracy(TestDisaggregationBase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST parsed_url = urlparse(DEFAULT_URL_FOR_TEST) cls.base_host = parsed_url.hostname base_port = str(parsed_url.port) cls.lb_port = base_port cls.prefill_port = f"{int(base_port) + 100}" cls.decode_port = f"{int(base_port) + 200}" cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") # Non blocking start servers cls.start_prefill() cls.start_decode() # Block until both cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.decode_url + "/health") cls.launch_lb() @classmethod def start_prefill(cls): prefill_args = [ "--trust-remote-code", "--disaggregation-mode", "prefill", "--tp-size", "1", "--pp-size", "2", "--disaggregation-ib-device", "mlx5_roce0,mlx5_roce1", "--disable-overlap-schedule", ] cls.process_prefill = popen_launch_pd_server( cls.model, cls.prefill_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=prefill_args, ) @classmethod def start_decode(cls): decode_args = [ "--trust-remote-code", "--disaggregation-mode", "decode", "--tp", "1", "--base-gpu-id", "2", "--disaggregation-ib-device", "mlx5_roce2", ] cls.process_decode = popen_launch_pd_server( cls.model, cls.decode_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=decode_args, ) def test_gsm8k(self): args = SimpleNamespace( num_shots=5, data_path=None, num_questions=200, max_new_tokens=512, parallel=128, host=f"http://{self.base_host}", port=int(self.lb_port), ) metrics = run_eval(args) print(f"{metrics=}") self.assertGreater(metrics["accuracy"], 0.24) # Wait a little bit so that the memory check happens. time.sleep(5) if __name__ == "__main__": unittest.main()