82 lines
3.1 KiB
Python
82 lines
3.1 KiB
Python
import json
|
|
import logging
|
|
import time
|
|
from collections import defaultdict
|
|
from typing import Dict, List, Tuple
|
|
|
|
import torch
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# global expert distribution recording
|
|
class ExpertDistributionRecorder:
|
|
# This class is a singleton class
|
|
def __new__(cls):
|
|
if not hasattr(cls, "instance"):
|
|
cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
|
|
return cls.instance
|
|
|
|
def __init__(self):
|
|
# the length of the dictionary is the number of layers
|
|
# the length of the list is the number of tokens
|
|
# the length of the tuple is topk's k value
|
|
self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
|
|
list
|
|
)
|
|
self._record = False
|
|
self._current_layer_id = "UNKNOWN"
|
|
|
|
def set_current_layer(self, layer_idx):
|
|
self._current_layer_id = layer_idx
|
|
|
|
def record_new_token(self, topk_ids):
|
|
if not self._record:
|
|
return
|
|
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
|
|
torch.cuda.synchronize()
|
|
for i in topk_ids_list:
|
|
self._expert_distribution_record[self._current_layer_id].append(tuple(i))
|
|
|
|
def reset(self):
|
|
"""Reset the expert distribution recorder."""
|
|
logger.info("Resetting expert distribution record...")
|
|
self._record = False
|
|
self._expert_distribution_record.clear()
|
|
self._current_layer_id = "UNKNOWN"
|
|
|
|
def start_record(self):
|
|
"""Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
|
|
if self._record == True:
|
|
logger.warning(
|
|
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
|
|
)
|
|
self.reset()
|
|
self._record = True
|
|
|
|
def stop_record(self):
|
|
"""Stop recording the expert distribution. Set the recording flag to False."""
|
|
if self._record == False:
|
|
logger.warning(
|
|
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
|
|
)
|
|
self._record = False
|
|
|
|
def dump_record(self):
|
|
"""Dump the expert distribution record to a file. Reset the recorder after dumping."""
|
|
results = {}
|
|
for layer_idx, layer_record in self._expert_distribution_record.items():
|
|
results[layer_idx] = defaultdict(int)
|
|
for token_record in layer_record:
|
|
for expert_idx in token_record:
|
|
results[layer_idx][expert_idx] += 1
|
|
with open(
|
|
f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
|
|
"w",
|
|
) as fd:
|
|
fd.write("layer_id,expert_id,count\n")
|
|
for layer_idx, layer_results in results.items():
|
|
for expert_idx, count in layer_results.items():
|
|
fd.write(f"{layer_idx},{expert_idx},{count}\n")
|
|
self.reset()
|