156 lines
5.0 KiB
Python
Executable File
156 lines
5.0 KiB
Python
Executable File
import os
|
|
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
|
|
import sglang as sgl
|
|
from sglang.srt.utils import kill_process_tree
|
|
from sglang.test.run_eval import run_eval
|
|
from sglang.test.test_utils import (
|
|
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
DEFAULT_URL_FOR_TEST,
|
|
CustomTestCase,
|
|
popen_launch_server,
|
|
)
|
|
|
|
|
|
class _BaseTestDynamicEPLB(CustomTestCase):
|
|
extra_args = []
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
|
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
cls.process = popen_launch_server(
|
|
cls.model,
|
|
cls.base_url,
|
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
other_args=[
|
|
"--trust-remote-code",
|
|
"--tp",
|
|
"2",
|
|
"--dp",
|
|
"2",
|
|
"--enable-dp-attention",
|
|
"--enable-deepep-moe",
|
|
"--deepep-mode",
|
|
"normal",
|
|
"--disable-cuda-graph",
|
|
"--enable-eplb",
|
|
"--ep-num-redundant-experts",
|
|
"4",
|
|
"--eplb-rebalance-num-iterations",
|
|
"50",
|
|
"--expert-distribution-recorder-buffer-size",
|
|
"50",
|
|
# TODO pr-chain: enable later
|
|
# "--enable-expert-distribution-metrics",
|
|
# TODO auto determine these flags
|
|
"--expert-distribution-recorder-mode",
|
|
"stat",
|
|
"--ep-dispatch-algorithm",
|
|
"static",
|
|
*cls.extra_args,
|
|
],
|
|
env={
|
|
"SGL_ENABLE_JIT_DEEPGEMM": "0",
|
|
"SGLANG_EXPERT_LOCATION_UPDATER_CANARY": "1",
|
|
**os.environ,
|
|
},
|
|
)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
kill_process_tree(cls.process.pid)
|
|
|
|
def test_mmlu(self):
|
|
args = SimpleNamespace(
|
|
base_url=self.base_url,
|
|
model=self.model,
|
|
eval_name="mmlu",
|
|
num_examples=64,
|
|
num_threads=32,
|
|
)
|
|
|
|
metrics = run_eval(args)
|
|
self.assertGreater(metrics["score"], 0.5)
|
|
|
|
|
|
class TestDynamicEPLBSimple(_BaseTestDynamicEPLB):
|
|
pass
|
|
|
|
|
|
class TestDynamicEPLBMultiChunk(_BaseTestDynamicEPLB):
|
|
extra_args = ["--eplb-rebalance-layers-per-chunk", "1"]
|
|
|
|
|
|
class TestStaticEPLB(CustomTestCase):
|
|
def test_save_expert_distribution_and_init_expert_location(self):
|
|
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "0"
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
engine_kwargs = dict(
|
|
model_path=DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
|
trust_remote_code=True,
|
|
ep_num_redundant_experts=4,
|
|
enable_dp_attention=True,
|
|
enable_deepep_moe=True,
|
|
deepep_mode="normal",
|
|
disable_cuda_graph=True,
|
|
expert_distribution_recorder_mode="stat",
|
|
tp_size=2,
|
|
dp_size=2,
|
|
log_level="info",
|
|
# TODO pr-chain: enable later
|
|
# enable_expert_distribution_metrics=True,
|
|
)
|
|
|
|
print(f"Action: start engine")
|
|
os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir
|
|
engine = sgl.Engine(
|
|
**engine_kwargs,
|
|
disable_overlap_schedule=True,
|
|
)
|
|
engine.start_expert_distribution_record()
|
|
self._assert_engine_generate_correct(engine)
|
|
|
|
print(f"Action: dump_expert_distribution_record")
|
|
engine.dump_expert_distribution_record()
|
|
snapshot_path = list(Path(tmp_dir).glob("*.pt"))[0]
|
|
assert snapshot_path is not None
|
|
print(f"{snapshot_path=}")
|
|
|
|
print(f"Action: shutdown engine")
|
|
engine.shutdown()
|
|
del engine
|
|
|
|
print(f"Action: start engine with init_expert_location")
|
|
engine = sgl.Engine(
|
|
**engine_kwargs,
|
|
init_expert_location=str(snapshot_path),
|
|
port=21000,
|
|
# TODO auto determine these flags
|
|
ep_dispatch_algorithm="static",
|
|
)
|
|
self._assert_engine_generate_correct(engine)
|
|
print(f"Action: shutdown engine")
|
|
engine.shutdown()
|
|
del engine
|
|
|
|
def _assert_engine_generate_correct(self, engine: sgl.Engine):
|
|
output = engine.generate(
|
|
prompt=["1+1=2, 2+2=4", "One plus one is two, two plus two is four"],
|
|
sampling_params=dict(max_new_tokens=8, temperature=0.0),
|
|
)
|
|
print(f"engine.generate {output=}")
|
|
self.assertEqual(
|
|
[x["text"] for x in output],
|
|
[", 4+4=8,", ", four plus four is eight, eight"],
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|