378 lines
13 KiB
Python
378 lines
13 KiB
Python
import time
|
|
import unittest
|
|
|
|
import msgspec
|
|
import requests
|
|
import zmq
|
|
from msgspec.msgpack import Decoder
|
|
|
|
from sglang.srt.disaggregation.kv_events import (
|
|
AllBlocksCleared,
|
|
BlockRemoved,
|
|
BlockStored,
|
|
EventBatch,
|
|
KVCacheEvent,
|
|
KVEventBatch,
|
|
)
|
|
from sglang.srt.utils import kill_process_tree
|
|
from sglang.test.test_utils import (
|
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
DEFAULT_URL_FOR_TEST,
|
|
CustomTestCase,
|
|
popen_launch_server,
|
|
)
|
|
|
|
|
|
class TestKvEvents(CustomTestCase):
|
|
def test_kv_events_enabled(self):
|
|
"""Test that kv events are sent and received by subscriber data when enabled"""
|
|
|
|
# Launch kv events subscriber
|
|
decoder = Decoder(type=KVEventBatch)
|
|
context = zmq.Context()
|
|
sub = context.socket(zmq.SUB)
|
|
sub.connect("tcp://localhost:5557")
|
|
topic = "kv-events"
|
|
sub.setsockopt_string(zmq.SUBSCRIBE, topic)
|
|
|
|
# Launch sglang server
|
|
process = popen_launch_server(
|
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
|
DEFAULT_URL_FOR_TEST,
|
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
other_args=[
|
|
"--kv-events-config",
|
|
'{"publisher": "zmq", "topic": "kv-events"}',
|
|
"--max-total-tokens",
|
|
32,
|
|
"--cuda-graph-max-bs",
|
|
2,
|
|
"--enable-dp-attention",
|
|
"--dp-size",
|
|
1,
|
|
],
|
|
)
|
|
|
|
try:
|
|
# Make some requests to generate some metrics
|
|
response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate")
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
response = requests.post(
|
|
f"{DEFAULT_URL_FOR_TEST}/generate",
|
|
json={
|
|
"text": "The capital of France is",
|
|
"sampling_params": {
|
|
"temperature": 0,
|
|
"max_new_tokens": 32,
|
|
},
|
|
},
|
|
)
|
|
response = requests.post(
|
|
f"{DEFAULT_URL_FOR_TEST}/generate",
|
|
json={
|
|
"text": "The capital of Spain is",
|
|
"sampling_params": {
|
|
"temperature": 0,
|
|
"max_new_tokens": 32,
|
|
},
|
|
},
|
|
)
|
|
|
|
# Expected events. These may be dependent on model used (meta-llama/Llama-3.2-1B-Instruct)
|
|
expected_events = [
|
|
# <begin> The capital city of France is
|
|
BlockStored(
|
|
block_hashes=[-6650323075460941099],
|
|
parent_block_hash=5740354900026072187,
|
|
token_ids=[128000, 791, 6864, 3363, 315, 9822, 374],
|
|
block_size=7,
|
|
lora_id=None,
|
|
),
|
|
# Paris. The Eiffel Tower
|
|
BlockStored(
|
|
block_hashes=[-7584018293207282755],
|
|
parent_block_hash=-6650323075460941099,
|
|
token_ids=[12366, 13, 578, 469, 3168, 301, 22703],
|
|
block_size=7,
|
|
lora_id=None,
|
|
),
|
|
BlockStored(
|
|
block_hashes=[-8753497827991233192],
|
|
parent_block_hash=5740354900026072187,
|
|
token_ids=[0],
|
|
block_size=1,
|
|
lora_id=None,
|
|
),
|
|
BlockRemoved(block_hashes=[-6650323075460941099]),
|
|
# <begin> The capital
|
|
BlockStored(
|
|
block_hashes=[-2697055055087824455],
|
|
parent_block_hash=5740354900026072187,
|
|
token_ids=[128000, 791, 6864],
|
|
block_size=3,
|
|
lora_id=None,
|
|
),
|
|
# city of France is
|
|
BlockStored(
|
|
block_hashes=[-7505627135785778022],
|
|
parent_block_hash=-2697055055087824455,
|
|
token_ids=[3363, 315, 9822, 374],
|
|
block_size=4,
|
|
lora_id=None,
|
|
),
|
|
# of France is
|
|
BlockStored(
|
|
block_hashes=[-3861108700662737012],
|
|
parent_block_hash=-2697055055087824455,
|
|
token_ids=[315, 9822, 374],
|
|
block_size=3,
|
|
lora_id=None,
|
|
),
|
|
BlockRemoved(block_hashes=[-7584018293207282755]),
|
|
BlockRemoved(block_hashes=[-8753497827991233192]),
|
|
BlockRemoved(block_hashes=[-7505627135785778022]),
|
|
# Paris. The Eiffel Tower is located in Paris. The Eiffel Tower is a famous landmark in Paris
|
|
BlockStored(
|
|
block_hashes=[-3064341286825792715],
|
|
parent_block_hash=-3861108700662737012,
|
|
token_ids=[
|
|
12366,
|
|
13,
|
|
578,
|
|
469,
|
|
3168,
|
|
301,
|
|
22703,
|
|
374,
|
|
7559,
|
|
304,
|
|
12366,
|
|
13,
|
|
578,
|
|
469,
|
|
3168,
|
|
301,
|
|
22703,
|
|
374,
|
|
264,
|
|
11495,
|
|
38350,
|
|
304,
|
|
12366,
|
|
],
|
|
block_size=23,
|
|
lora_id=None,
|
|
),
|
|
BlockRemoved(block_hashes=[-3861108700662737012]),
|
|
# of
|
|
BlockStored(
|
|
block_hashes=[6115672085296369592],
|
|
parent_block_hash=-2697055055087824455,
|
|
token_ids=[315],
|
|
block_size=1,
|
|
lora_id=None,
|
|
),
|
|
# France is
|
|
BlockStored(
|
|
block_hashes=[4208810872343132234],
|
|
parent_block_hash=6115672085296369592,
|
|
token_ids=[9822, 374],
|
|
block_size=2,
|
|
lora_id=None,
|
|
),
|
|
# Spain is
|
|
BlockStored(
|
|
block_hashes=[1675819893649989955],
|
|
parent_block_hash=6115672085296369592,
|
|
token_ids=[18157, 374],
|
|
block_size=2,
|
|
lora_id=None,
|
|
),
|
|
BlockRemoved(block_hashes=[-3064341286825792715]),
|
|
# Madrid. The capital of France is Paris. The capital of Italy is Rome. The capital of Spain is Madrid.
|
|
BlockStored(
|
|
block_hashes=[-8505834929190027295],
|
|
parent_block_hash=1675819893649989955,
|
|
token_ids=[
|
|
25048,
|
|
13,
|
|
578,
|
|
6864,
|
|
315,
|
|
9822,
|
|
374,
|
|
12366,
|
|
13,
|
|
578,
|
|
6864,
|
|
315,
|
|
15704,
|
|
374,
|
|
22463,
|
|
13,
|
|
578,
|
|
6864,
|
|
315,
|
|
18157,
|
|
374,
|
|
25048,
|
|
13,
|
|
],
|
|
block_size=23,
|
|
lora_id=None,
|
|
),
|
|
]
|
|
|
|
# Get events
|
|
events = []
|
|
start = time.time()
|
|
max_wait_s = 5
|
|
while (
|
|
len(events) < len(expected_events)
|
|
and (time.time() - start) < max_wait_s
|
|
):
|
|
_, seq_bytes, payload = sub.recv_multipart()
|
|
event_batch = decoder.decode(payload)
|
|
for event in event_batch.events:
|
|
events.append(event)
|
|
|
|
for expected in expected_events:
|
|
self.assertIn(expected, events)
|
|
|
|
finally:
|
|
kill_process_tree(process.pid)
|
|
|
|
def test_kv_events_attn_dp(self):
|
|
"""Test that kv events are properly tagged with DP rank in attention DP mode"""
|
|
|
|
# Launch multiple subscribers for different DP ranks
|
|
decoder = Decoder(type=KVEventBatch)
|
|
context = zmq.Context()
|
|
|
|
# Subscribe to both DP rank endpoints
|
|
sub_dp0 = context.socket(zmq.SUB)
|
|
sub_dp0.connect("tcp://localhost:5557") # DP rank 0
|
|
topic = "kv-events"
|
|
sub_dp0.setsockopt_string(zmq.SUBSCRIBE, topic)
|
|
|
|
sub_dp1 = context.socket(zmq.SUB)
|
|
sub_dp1.connect("tcp://localhost:5558") # DP rank 1 (offset by rank)
|
|
sub_dp1.setsockopt_string(zmq.SUBSCRIBE, topic)
|
|
|
|
# Launch sglang server with DP attention enabled
|
|
process = popen_launch_server(
|
|
"silence09/DeepSeek-R1-Small-2layers",
|
|
DEFAULT_URL_FOR_TEST,
|
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
other_args=[
|
|
"--kv-events-config",
|
|
'{"publisher": "zmq", "topic": "kv-events"}',
|
|
"--max-total-tokens",
|
|
64,
|
|
"--cuda-graph-max-bs",
|
|
4,
|
|
"--enable-dp-attention",
|
|
"--dp-size",
|
|
2,
|
|
"--tp-size",
|
|
2,
|
|
],
|
|
)
|
|
|
|
try:
|
|
# Make requests to generate events
|
|
response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate")
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
# Send multiple requests to trigger events from both DP ranks
|
|
for i in range(4):
|
|
response = requests.post(
|
|
f"{DEFAULT_URL_FOR_TEST}/generate",
|
|
json={
|
|
"text": f"Request {i}: The capital of country {i} is",
|
|
"sampling_params": {
|
|
"temperature": 0,
|
|
"max_new_tokens": 16,
|
|
},
|
|
},
|
|
)
|
|
|
|
# Collect events from both DP ranks
|
|
events_dp0 = []
|
|
events_dp1 = []
|
|
start = time.time()
|
|
max_wait_s = 10
|
|
min_events_per_rank = 3 # Expect at least a few events from each rank
|
|
|
|
while (time.time() - start) < max_wait_s and (
|
|
len(events_dp0) < min_events_per_rank
|
|
or len(events_dp1) < min_events_per_rank
|
|
):
|
|
# Check DP rank 0
|
|
if sub_dp0.poll(timeout=100): # 100ms timeout
|
|
_, seq_bytes, payload = sub_dp0.recv_multipart()
|
|
event_batch = decoder.decode(payload)
|
|
print(
|
|
f"DP Rank 0 - EventBatch: ts={event_batch.ts}, attn_dp_rank={event_batch.attn_dp_rank}"
|
|
)
|
|
self.assertEqual(
|
|
event_batch.attn_dp_rank,
|
|
0,
|
|
"DP rank 0 events should have attn_dp_rank=0",
|
|
)
|
|
for event in event_batch.events:
|
|
print(f" DP0 - {event}")
|
|
events_dp0.append(event)
|
|
|
|
# Check DP rank 1
|
|
if sub_dp1.poll(timeout=100): # 100ms timeout
|
|
_, seq_bytes, payload = sub_dp1.recv_multipart()
|
|
event_batch = decoder.decode(payload)
|
|
print(
|
|
f"DP Rank 1 - EventBatch: ts={event_batch.ts}, attn_dp_rank={event_batch.attn_dp_rank}"
|
|
)
|
|
self.assertEqual(
|
|
event_batch.attn_dp_rank,
|
|
1,
|
|
"DP rank 1 events should have attn_dp_rank=1",
|
|
)
|
|
for event in event_batch.events:
|
|
print(f" DP1 - {event}")
|
|
events_dp1.append(event)
|
|
|
|
# Verify we got events from both DP ranks
|
|
print(f"Collected {len(events_dp0)} events from DP rank 0")
|
|
print(f"Collected {len(events_dp1)} events from DP rank 1")
|
|
|
|
self.assertGreaterEqual(
|
|
len(events_dp0),
|
|
min_events_per_rank,
|
|
f"Expected at least {min_events_per_rank} events from DP rank 0",
|
|
)
|
|
self.assertGreaterEqual(
|
|
len(events_dp1),
|
|
min_events_per_rank,
|
|
f"Expected at least {min_events_per_rank} events from DP rank 1",
|
|
)
|
|
|
|
# Verify event types are as expected
|
|
for events in [events_dp0, events_dp1]:
|
|
for event in events:
|
|
self.assertIsInstance(
|
|
event,
|
|
(BlockStored, BlockRemoved, AllBlocksCleared),
|
|
f"Event should be a KV cache event, got {type(event)}",
|
|
)
|
|
|
|
finally:
|
|
sub_dp0.close()
|
|
sub_dp1.close()
|
|
context.term()
|
|
kill_process_tree(process.pid)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|