sglang0.4.5.post1/python/sglang/srt/managers/data_parallel_controller.py

281 lines
10 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A controller that dispatches requests to multiple data parallel workers."""
import logging
import multiprocessing as mp
import signal
import threading
from enum import Enum, auto
import psutil
import setproctitle
import zmq
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.managers.io_struct import (
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
class LoadBalanceMethod(Enum):
"""Load balance method."""
ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()
@classmethod
def from_str(cls, method: str):
method = method.upper()
try:
return cls[method]
except KeyError as exc:
raise ValueError(f"Invalid load balance method: {method}") from exc
class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers."""
def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
# Parse args
self.max_total_num_tokens = None
self.server_args = server_args
self.port_args = port_args
self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method
)
# Init inter-process communication
self.context = zmq.Context(1 + server_args.dp_size)
if server_args.node_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
self.context, zmq.PULL, port_args.scheduler_input_ipc_name, False
)
# Dispatch method
self.round_robin_counter = 0
dispatch_lookup = {
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
}
self.dispatching = dispatch_lookup[self.load_balance_method]
# Launch data parallel workers
self.scheduler_procs = []
self.workers = [None] * server_args.dp_size
if server_args.enable_dp_attention:
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
self.control_message_step = server_args.tp_size
else:
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
self.control_message_step = 1
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
if server_args.node_rank == 0:
for dp_rank in range(server_args.dp_size):
self.workers[dp_rank] = get_zmq_socket(
self.context,
zmq.PUSH,
dp_port_args[dp_rank].scheduler_input_ipc_name,
True,
)
self.max_req_input_len = None
def launch_dp_schedulers(self, server_args, port_args):
base_gpu_id = 0
threads = []
sockets = []
dp_port_args = []
ready_events = []
for dp_rank in range(server_args.dp_size):
tmp_port_args = PortArgs.init_new(server_args)
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
dp_port_args.append(tmp_port_args)
# This port is checked free in PortArgs.init_new.
# We hold it first so that the next dp worker gets a different port
sockets.append(bind_port(tmp_port_args.nccl_port))
ready_event = threading.Event()
ready_events.append(ready_event)
# Create a thread for each worker
thread = threading.Thread(
target=self.launch_tensor_parallel_group_thread,
args=(server_args, tmp_port_args, base_gpu_id, dp_rank, ready_event),
)
threads.append(thread)
base_gpu_id += server_args.tp_size * server_args.gpu_id_step
# Free all sockets before starting the threads to launch TP workers
for sock in sockets:
sock.close()
# Start all threads
for thread in threads:
thread.start()
for event in ready_events:
event.wait()
return dp_port_args
def launch_tensor_parallel_group_thread(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
ready_event: threading.Event,
):
self.launch_tensor_parallel_group(server_args, port_args, base_gpu_id, dp_rank)
ready_event.set()
# This thread cannot be closed because otherwise the `kill_itself_when_parent_died`
# function in scheduler.py will kill the scheduler.
while True:
pass
def launch_dp_attention_schedulers(self, server_args, port_args):
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
dp_port_args = []
for dp_rank in range(server_args.dp_size):
dp_port_args.append(PortArgs.init_new(server_args, dp_rank))
return dp_port_args
def launch_tensor_parallel_group(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
):
if not server_args.enable_dp_attention:
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
# Launch tensor parallel scheduler processes
scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes
tp_rank_range = range(
tp_size_per_node * server_args.node_rank,
tp_size_per_node * (server_args.node_rank + 1),
)
for tp_rank in tp_rank_range:
rank_port_args = port_args
if server_args.enable_dp_attention:
# dp attention has different sharding logic
_, _, dp_rank = compute_dp_attention_world_info(
server_args.enable_dp_attention,
tp_rank,
server_args.tp_size,
server_args.dp_size,
)
# compute zmq ports for this dp rank
rank_port_args = PortArgs.init_new(server_args, dp_rank)
# Data parallelism resues the tensor parallelism group,
# so all dp ranks should use the same nccl port.
rank_port_args.nccl_port = port_args.nccl_port
reader, writer = mp.Pipe(duplex=False)
gpu_id = (
server_args.base_gpu_id
+ base_gpu_id
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
)
proc.start()
self.scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
# Wait for model to finish loading
scheduler_info = []
for i in range(len(scheduler_pipe_readers)):
scheduler_info.append(scheduler_pipe_readers[i].recv())
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
def shortest_queue_scheduler(self, input_requests):
raise NotImplementedError()
def event_loop(self):
while True:
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
if isinstance(
recv_req,
(
TokenizedGenerateReqInput,
TokenizedEmbeddingReqInput,
),
):
self.dispatching(recv_req)
else:
# Send other control messages to first worker of tp group
for worker in self.workers[:: self.control_message_step]:
worker.send_pyobj(recv_req)
def run_data_parallel_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
):
setproctitle.setproctitle("sglang::data_parallel_controller")
configure_logger(server_args)
parent_process = psutil.Process().parent()
try:
controller = DataParallelController(server_args, port_args)
pipe_writer.send(
{
"status": "ready",
"max_total_num_tokens": controller.max_total_num_tokens,
"max_req_input_len": controller.max_req_input_len,
}
)
if server_args.node_rank == 0:
controller.event_loop()
for proc in controller.scheduler_procs:
proc.join()
logger.error(
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
)
except Exception:
traceback = get_exception_traceback()
logger.error(f"DataParallelController hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT)