241 lines
8.9 KiB
Python
241 lines
8.9 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""A tensor parallel worker."""
|
|
|
|
import dataclasses
|
|
import logging
|
|
import signal
|
|
import threading
|
|
from queue import Queue
|
|
from typing import Optional
|
|
|
|
import psutil
|
|
import torch
|
|
|
|
from sglang.srt.managers.io_struct import (
|
|
GetWeightsByNameReqInput,
|
|
InitWeightsUpdateGroupReqInput,
|
|
UpdateWeightFromDiskReqInput,
|
|
UpdateWeightsFromDistributedReqInput,
|
|
UpdateWeightsFromTensorReqInput,
|
|
)
|
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
|
from sglang.srt.server_args import ServerArgs
|
|
from sglang.srt.utils import DynamicGradMode, get_compiler_backend
|
|
from sglang.utils import get_exception_traceback
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
|
def resolve_future_token_ids(input_ids, future_token_ids_map):
|
|
input_ids[:] = torch.where(
|
|
input_ids < 0,
|
|
future_token_ids_map[torch.clamp(-input_ids, min=0)],
|
|
input_ids,
|
|
)
|
|
|
|
|
|
class TpModelWorkerClient:
|
|
"""A tensor parallel model worker."""
|
|
|
|
def __init__(
|
|
self,
|
|
server_args: ServerArgs,
|
|
gpu_id: int,
|
|
tp_rank: int,
|
|
dp_rank: Optional[int],
|
|
nccl_port: int,
|
|
):
|
|
# Load the model
|
|
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
|
|
self.max_running_requests = self.worker.max_running_requests
|
|
self.device = self.worker.device
|
|
self.gpu_id = gpu_id
|
|
|
|
# Init future mappings
|
|
self.future_token_ids_ct = 0
|
|
self.future_token_ids_limit = self.max_running_requests * 3
|
|
self.future_token_ids_map = torch.empty(
|
|
(self.max_running_requests * 5,), dtype=torch.int64, device=self.device
|
|
)
|
|
|
|
# Launch threads
|
|
self.input_queue = Queue()
|
|
self.output_queue = Queue()
|
|
self.forward_stream = torch.get_device_module(self.device).Stream()
|
|
self.forward_thread = threading.Thread(
|
|
target=self.forward_thread_func,
|
|
)
|
|
self.forward_thread.start()
|
|
self.parent_process = psutil.Process().parent()
|
|
self.scheduler_stream = torch.get_device_module(self.device).current_stream()
|
|
if self.device == "cpu":
|
|
self.scheduler_stream.synchronize = lambda: None # No-op for CPU
|
|
|
|
def get_worker_info(self):
|
|
return self.worker.get_worker_info()
|
|
|
|
def get_pad_input_ids_func(self):
|
|
return self.worker.get_pad_input_ids_func()
|
|
|
|
def get_tp_cpu_group(self):
|
|
return self.worker.get_tp_cpu_group()
|
|
|
|
def get_attention_tp_cpu_group(self):
|
|
return self.worker.get_attention_tp_cpu_group()
|
|
|
|
def get_memory_pool(self):
|
|
return (
|
|
self.worker.model_runner.req_to_token_pool,
|
|
self.worker.model_runner.token_to_kv_pool_allocator,
|
|
)
|
|
|
|
def get_kv_cache(self):
|
|
return self.worker.model_runner.token_to_kv_pool
|
|
|
|
def forward_thread_func(self):
|
|
try:
|
|
with torch.get_device_module(self.device).stream(self.forward_stream):
|
|
self.forward_thread_func_()
|
|
except Exception:
|
|
traceback = get_exception_traceback()
|
|
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
|
|
self.parent_process.send_signal(signal.SIGQUIT)
|
|
|
|
@DynamicGradMode()
|
|
def forward_thread_func_(self):
|
|
batch_pt = 0
|
|
batch_lists = [None] * 2
|
|
|
|
while True:
|
|
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
|
if not model_worker_batch:
|
|
break
|
|
|
|
# Keep a reference of model_worker_batch by storing it into a list.
|
|
# Otherwise, the tensor members of model_worker_batch will be released
|
|
# by pytorch and cause CUDA illegal memory access errors.
|
|
batch_lists[batch_pt % 2] = model_worker_batch
|
|
batch_pt += 1
|
|
|
|
# Create event
|
|
self.launch_done = threading.Event()
|
|
copy_done = torch.get_device_module(self.device).Event()
|
|
|
|
# Resolve future tokens in the input
|
|
input_ids = model_worker_batch.input_ids
|
|
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
|
|
|
# Run forward
|
|
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
|
model_worker_batch, self.launch_done
|
|
)
|
|
|
|
# Update the future token ids map
|
|
bs = len(model_worker_batch.seq_lens)
|
|
self.future_token_ids_map[
|
|
future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
|
|
] = next_token_ids
|
|
|
|
# Copy results to the CPU
|
|
if model_worker_batch.return_logprob:
|
|
logits_output.next_token_logprobs = (
|
|
logits_output.next_token_logprobs.to("cpu", non_blocking=True)
|
|
)
|
|
if logits_output.input_token_logprobs is not None:
|
|
logits_output.input_token_logprobs = (
|
|
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
|
)
|
|
if logits_output.hidden_states is not None:
|
|
logits_output.hidden_states = logits_output.hidden_states.to(
|
|
"cpu", non_blocking=True
|
|
)
|
|
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
|
copy_done.record()
|
|
|
|
self.output_queue.put((copy_done, logits_output, next_token_ids))
|
|
|
|
def resolve_batch_result(self, bid: int):
|
|
copy_done, logits_output, next_token_ids = self.output_queue.get()
|
|
copy_done.synchronize()
|
|
self.launch_done.wait()
|
|
|
|
if logits_output.next_token_logprobs is not None:
|
|
logits_output.next_token_logprobs = (
|
|
logits_output.next_token_logprobs.tolist()
|
|
)
|
|
if logits_output.input_token_logprobs is not None:
|
|
logits_output.input_token_logprobs = tuple(
|
|
logits_output.input_token_logprobs.tolist()
|
|
)
|
|
next_token_ids = next_token_ids.tolist()
|
|
return logits_output, next_token_ids
|
|
|
|
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
|
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
|
sampling_info = model_worker_batch.sampling_info
|
|
sampling_info.update_penalties()
|
|
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
|
|
sampling_info,
|
|
sampling_info_done=threading.Event(),
|
|
penalizer_orchestrator=None,
|
|
)
|
|
|
|
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
|
self.scheduler_stream.synchronize()
|
|
|
|
# Push a new batch to the queue
|
|
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
|
|
|
# Allocate output future objects
|
|
bs = len(model_worker_batch.seq_lens)
|
|
future_next_token_ids = torch.arange(
|
|
-(self.future_token_ids_ct + 1),
|
|
-(self.future_token_ids_ct + 1 + bs),
|
|
-1,
|
|
dtype=torch.int64,
|
|
device=self.device,
|
|
)
|
|
self.future_token_ids_ct = (
|
|
self.future_token_ids_ct + bs
|
|
) % self.future_token_ids_limit
|
|
return None, future_next_token_ids
|
|
|
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
|
success, message = self.worker.update_weights_from_disk(recv_req)
|
|
return success, message
|
|
|
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
|
success, message = self.worker.init_weights_update_group(recv_req)
|
|
return success, message
|
|
|
|
def update_weights_from_distributed(
|
|
self, recv_req: UpdateWeightsFromDistributedReqInput
|
|
):
|
|
success, message = self.worker.update_weights_from_distributed(recv_req)
|
|
return success, message
|
|
|
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
|
success, message = self.worker.update_weights_from_tensor(recv_req)
|
|
return success, message
|
|
|
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
|
return self.worker.get_weights_by_name(recv_req)
|
|
|
|
def __delete__(self):
|
|
self.input_queue.put((None, None))
|
|
self.copy_queue.put((None, None, None))
|