128 lines
3.8 KiB
Python
128 lines
3.8 KiB
Python
import collections
|
|
import concurrent.futures
|
|
import subprocess
|
|
import time
|
|
|
|
import pytest
|
|
import requests
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_pd_power_of_two_decode_attribution(router_manager, mock_workers):
|
|
# Start two prefill and three decode mock workers via fixture
|
|
_, prefill_urls_raw, prefill_ids = mock_workers(n=2)
|
|
_, decode_urls_raw, decode_ids_list = mock_workers(n=3)
|
|
prefill_urls = [(u, None) for u in prefill_urls_raw]
|
|
decode_urls = list(decode_urls_raw)
|
|
decode_ids = set(decode_ids_list)
|
|
|
|
rh = router_manager.start_router(
|
|
policy="power_of_two",
|
|
pd_disaggregation=True,
|
|
prefill_urls=prefill_urls,
|
|
decode_urls=decode_urls,
|
|
extra={"worker_startup_check_interval": 1},
|
|
)
|
|
|
|
counts = collections.Counter()
|
|
with requests.Session() as s:
|
|
for i in range(30):
|
|
r = s.post(
|
|
f"{rh.url}/v1/completions",
|
|
json={
|
|
"model": "test-model",
|
|
"prompt": f"p{i}",
|
|
"max_tokens": 1,
|
|
"stream": False,
|
|
},
|
|
)
|
|
assert r.status_code == 200
|
|
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
|
assert wid in decode_ids
|
|
counts[wid] += 1
|
|
|
|
assert sum(1 for v in counts.values() if v > 0) >= 2
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_pd_power_of_two_skews_to_faster_decode(router_manager, mock_workers):
|
|
# Start two prefill workers (fast)
|
|
_, prefill_urls_raw, _ = mock_workers(n=2)
|
|
|
|
# Start two decode workers: one slow, one fast
|
|
_, [decode_slow_url], [slow_id] = mock_workers(
|
|
n=1, args=["--latency-ms", "300"]
|
|
) # slower decode
|
|
_, [decode_fast_url], [fast_id] = mock_workers(n=1)
|
|
decode_urls_raw = [decode_slow_url, decode_fast_url]
|
|
|
|
prefill_urls = [(u, None) for u in prefill_urls_raw]
|
|
decode_urls = list(decode_urls_raw)
|
|
|
|
rh = router_manager.start_router(
|
|
policy="power_of_two",
|
|
pd_disaggregation=True,
|
|
prefill_urls=prefill_urls,
|
|
decode_urls=decode_urls,
|
|
extra={"worker_startup_check_interval": 1},
|
|
)
|
|
|
|
def _prime_call(i):
|
|
try:
|
|
requests.post(
|
|
f"{rh.url}/v1/completions",
|
|
json={
|
|
"model": "test-model",
|
|
"prompt": f"warm-{i}",
|
|
"max_tokens": 1,
|
|
"stream": False,
|
|
},
|
|
timeout=8,
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
|
|
list(ex.map(_prime_call, range(128)))
|
|
time.sleep(2)
|
|
|
|
def _direct_decode_load(i):
|
|
try:
|
|
requests.post(
|
|
f"{decode_slow_url}/v1/completions",
|
|
json={
|
|
"model": "test-model",
|
|
"prompt": f"bg-{i}",
|
|
"max_tokens": 1,
|
|
"stream": False,
|
|
},
|
|
timeout=8,
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
|
|
list(ex.map(_direct_decode_load, range(128)))
|
|
time.sleep(1)
|
|
|
|
def call(i):
|
|
r = requests.post(
|
|
f"{rh.url}/v1/completions",
|
|
json={
|
|
"model": "test-model",
|
|
"prompt": f"p{i}",
|
|
"max_tokens": 1,
|
|
"stream": False,
|
|
},
|
|
timeout=8,
|
|
)
|
|
assert r.status_code == 200
|
|
return r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
|
|
|
counts = collections.Counter()
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
|
|
for wid in ex.map(call, range(200)):
|
|
counts[wid] += 1
|
|
|
|
assert counts[slow_id] < counts[fast_id], counts
|