sglang_v0.5.2/sglang/sgl-router/py_test/integration/test_pd_routing.py

128 lines
3.8 KiB
Python

import collections
import concurrent.futures
import subprocess
import time
import pytest
import requests
@pytest.mark.integration
def test_pd_power_of_two_decode_attribution(router_manager, mock_workers):
# Start two prefill and three decode mock workers via fixture
_, prefill_urls_raw, prefill_ids = mock_workers(n=2)
_, decode_urls_raw, decode_ids_list = mock_workers(n=3)
prefill_urls = [(u, None) for u in prefill_urls_raw]
decode_urls = list(decode_urls_raw)
decode_ids = set(decode_ids_list)
rh = router_manager.start_router(
policy="power_of_two",
pd_disaggregation=True,
prefill_urls=prefill_urls,
decode_urls=decode_urls,
extra={"worker_startup_check_interval": 1},
)
counts = collections.Counter()
with requests.Session() as s:
for i in range(30):
r = s.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"p{i}",
"max_tokens": 1,
"stream": False,
},
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
assert wid in decode_ids
counts[wid] += 1
assert sum(1 for v in counts.values() if v > 0) >= 2
@pytest.mark.integration
def test_pd_power_of_two_skews_to_faster_decode(router_manager, mock_workers):
# Start two prefill workers (fast)
_, prefill_urls_raw, _ = mock_workers(n=2)
# Start two decode workers: one slow, one fast
_, [decode_slow_url], [slow_id] = mock_workers(
n=1, args=["--latency-ms", "300"]
) # slower decode
_, [decode_fast_url], [fast_id] = mock_workers(n=1)
decode_urls_raw = [decode_slow_url, decode_fast_url]
prefill_urls = [(u, None) for u in prefill_urls_raw]
decode_urls = list(decode_urls_raw)
rh = router_manager.start_router(
policy="power_of_two",
pd_disaggregation=True,
prefill_urls=prefill_urls,
decode_urls=decode_urls,
extra={"worker_startup_check_interval": 1},
)
def _prime_call(i):
try:
requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"warm-{i}",
"max_tokens": 1,
"stream": False,
},
timeout=8,
)
except Exception:
pass
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
list(ex.map(_prime_call, range(128)))
time.sleep(2)
def _direct_decode_load(i):
try:
requests.post(
f"{decode_slow_url}/v1/completions",
json={
"model": "test-model",
"prompt": f"bg-{i}",
"max_tokens": 1,
"stream": False,
},
timeout=8,
)
except Exception:
pass
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
list(ex.map(_direct_decode_load, range(128)))
time.sleep(1)
def call(i):
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"p{i}",
"max_tokens": 1,
"stream": False,
},
timeout=8,
)
assert r.status_code == 200
return r.headers.get("X-Worker-Id") or r.json().get("worker_id")
counts = collections.Counter()
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
for wid in ex.map(call, range(200)):
counts[wid] += 1
assert counts[slow_id] < counts[fast_id], counts