1716 lines
69 KiB
Python
1716 lines
69 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""TokenizerManager is a process that tokenizes the text."""
|
|
|
|
import asyncio
|
|
import copy
|
|
import dataclasses
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import pickle
|
|
import signal
|
|
import sys
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from collections import deque
|
|
from contextlib import nullcontext
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from http import HTTPStatus
|
|
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
|
|
|
import fastapi
|
|
import torch
|
|
import uvloop
|
|
import zmq
|
|
import zmq.asyncio
|
|
from fastapi import BackgroundTasks
|
|
|
|
from sglang.srt.aio_rwlock import RWLock
|
|
from sglang.srt.configs.model_config import ModelConfig
|
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
|
from sglang.srt.hf_transformers_utils import (
|
|
get_processor,
|
|
get_tokenizer,
|
|
get_tokenizer_from_processor,
|
|
)
|
|
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
|
from sglang.srt.managers.disagg_service import start_disagg_service
|
|
from sglang.srt.managers.io_struct import (
|
|
AbortReq,
|
|
BatchEmbeddingOut,
|
|
BatchMultimodalOut,
|
|
BatchStrOut,
|
|
BatchTokenIDOut,
|
|
BatchTokenizedEmbeddingReqInput,
|
|
BatchTokenizedGenerateReqInput,
|
|
CloseSessionReqInput,
|
|
ConfigureLoggingReq,
|
|
EmbeddingReqInput,
|
|
FreezeGCReq,
|
|
GenerateReqInput,
|
|
HealthCheckOutput,
|
|
MultiTokenizerWrapper,
|
|
OpenSessionReqInput,
|
|
OpenSessionReqOutput,
|
|
SessionParams,
|
|
TokenizedEmbeddingReqInput,
|
|
TokenizedGenerateReqInput,
|
|
UpdateWeightFromDiskReqInput,
|
|
UpdateWeightFromDiskReqOutput,
|
|
)
|
|
from sglang.srt.managers.mm_utils import TensorTransportMode
|
|
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
|
from sglang.srt.managers.scheduler import is_health_check_generate_req
|
|
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
|
|
from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
|
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
|
from sglang.srt.utils import (
|
|
configure_gc_warning,
|
|
dataclass_to_string_truncated,
|
|
freeze_gc,
|
|
get_bool_env_var,
|
|
get_origin_rid,
|
|
get_zmq_socket,
|
|
kill_process_tree,
|
|
)
|
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
|
|
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ReqState:
|
|
"""Store the state a request."""
|
|
|
|
out_list: List[Dict[Any, Any]]
|
|
finished: bool
|
|
event: asyncio.Event
|
|
obj: Union[GenerateReqInput, EmbeddingReqInput]
|
|
|
|
# For metrics
|
|
created_time: float
|
|
finished_time: float = 0.0
|
|
first_token_time: float = 0.0
|
|
last_time: float = 0.0
|
|
last_completion_tokens: int = 1
|
|
|
|
# For streaming output
|
|
last_output_offset: int = 0
|
|
|
|
# For incremental state update.
|
|
# TODO(lianmin): do not initialize some lists if not needed.
|
|
text: str = ""
|
|
output_ids: List[int] = dataclasses.field(default_factory=list)
|
|
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
|
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
|
output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
|
output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
|
input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
|
input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
|
output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
|
output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
|
input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
|
|
input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
|
output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
|
|
output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
|
|
|
|
|
|
class TokenizerManager(TokenizerCommunicatorMixin):
|
|
"""TokenizerManager is a process that tokenizes the text."""
|
|
|
|
def __init__(
|
|
self,
|
|
server_args: ServerArgs,
|
|
port_args: PortArgs,
|
|
):
|
|
# Parse args
|
|
self.server_args = server_args
|
|
self.enable_metrics = server_args.enable_metrics
|
|
self.log_requests = server_args.log_requests
|
|
self.log_requests_level = server_args.log_requests_level
|
|
self.preferred_sampling_params = (
|
|
json.loads(server_args.preferred_sampling_params)
|
|
if server_args.preferred_sampling_params
|
|
else None
|
|
)
|
|
self.crash_dump_folder = server_args.crash_dump_folder
|
|
|
|
# Read model args
|
|
self.model_path = server_args.model_path
|
|
self.served_model_name = server_args.served_model_name
|
|
self.model_config = ModelConfig.from_server_args(server_args)
|
|
self.is_generation = self.model_config.is_generation
|
|
self.is_image_gen = self.model_config.is_image_gen
|
|
self.context_len = self.model_config.context_len
|
|
self.image_token_id = self.model_config.image_token_id
|
|
self.max_req_input_len = None # Will be set later in engine.py
|
|
|
|
if self.model_config.is_multimodal:
|
|
import_processors()
|
|
try:
|
|
_processor = get_processor(
|
|
server_args.tokenizer_path,
|
|
tokenizer_mode=server_args.tokenizer_mode,
|
|
trust_remote_code=server_args.trust_remote_code,
|
|
revision=server_args.revision,
|
|
use_fast=not server_args.disable_fast_image_processor,
|
|
)
|
|
except ValueError as e:
|
|
error_message = str(e)
|
|
if "does not have a slow version" in error_message:
|
|
logger.info(
|
|
f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version"
|
|
)
|
|
_processor = get_processor(
|
|
server_args.tokenizer_path,
|
|
tokenizer_mode=server_args.tokenizer_mode,
|
|
trust_remote_code=server_args.trust_remote_code,
|
|
revision=server_args.revision,
|
|
use_fast=True,
|
|
)
|
|
else:
|
|
raise e
|
|
transport_mode = _determine_tensor_transport_mode(self.server_args)
|
|
|
|
# We want to parallelize the image pre-processing so we create an executor for it
|
|
# We create mm_processor for any skip_tokenizer_init to make sure we still encode
|
|
# images even with skip_tokenizer_init=False.
|
|
self.mm_processor = get_mm_processor(
|
|
self.model_config.hf_config, server_args, _processor, transport_mode
|
|
)
|
|
|
|
if server_args.skip_tokenizer_init:
|
|
self.tokenizer = self.processor = None
|
|
else:
|
|
self.processor = _processor
|
|
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
else:
|
|
self.mm_processor = self.processor = None
|
|
|
|
if server_args.skip_tokenizer_init:
|
|
self.tokenizer = None
|
|
else:
|
|
self.tokenizer = get_tokenizer(
|
|
server_args.tokenizer_path,
|
|
tokenizer_mode=server_args.tokenizer_mode,
|
|
trust_remote_code=server_args.trust_remote_code,
|
|
revision=server_args.revision,
|
|
)
|
|
|
|
# Init inter-process communication
|
|
context = zmq.asyncio.Context(2)
|
|
self.recv_from_detokenizer = get_zmq_socket(
|
|
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
|
)
|
|
if self.server_args.tokenizer_worker_num > 1:
|
|
# Use tokenizer_worker_ipc_name in multi-tokenizer mode
|
|
self.send_to_scheduler = get_zmq_socket(
|
|
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
|
|
)
|
|
else:
|
|
self.send_to_scheduler = get_zmq_socket(
|
|
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
|
)
|
|
|
|
# Request states
|
|
self.no_create_loop = False
|
|
self.rid_to_state: Dict[str, ReqState] = {}
|
|
self.asyncio_tasks = set()
|
|
|
|
# Health check
|
|
self.server_status = ServerStatus.Starting
|
|
self.gracefully_exit = False
|
|
self.last_receive_tstamp = 0
|
|
|
|
# Dumping
|
|
self.dump_requests_folder = "" # By default do not dump
|
|
self.dump_requests_threshold = 1000
|
|
self.dump_request_list: List[Tuple] = []
|
|
self.log_request_metadata = self.get_log_request_metadata()
|
|
self.crash_dump_request_list: deque[Tuple] = deque()
|
|
self.crash_dump_performed = False # Flag to ensure dump is only called once
|
|
|
|
# Session
|
|
self.session_futures = {} # session_id -> asyncio event
|
|
|
|
# Weight updates
|
|
# The event to notify the weight sync is finished.
|
|
self.model_update_lock = RWLock()
|
|
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
|
None
|
|
)
|
|
self.is_pause = False
|
|
self.is_pause_cond = asyncio.Condition()
|
|
|
|
# LoRA
|
|
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
|
|
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
|
|
# serves as the source of truth for available adapters and maps user-friendly LoRA names
|
|
# to internally used unique LoRA IDs.
|
|
self.lora_registry = LoRARegistry(self.server_args.lora_paths)
|
|
# Lock to serialize LoRA update operations.
|
|
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
|
|
# LoRA updates and inference to overlap.
|
|
self.lora_update_lock = asyncio.Lock()
|
|
|
|
self.disaggregation_mode = DisaggregationMode(
|
|
self.server_args.disaggregation_mode
|
|
)
|
|
self.bootstrap_server = start_disagg_service(self.server_args)
|
|
|
|
# For load balancing
|
|
self.current_load = 0
|
|
self.current_load_lock = asyncio.Lock()
|
|
|
|
# Metrics
|
|
if self.enable_metrics:
|
|
self.metrics_collector = TokenizerMetricsCollector(
|
|
server_args=server_args,
|
|
labels={
|
|
"model_name": self.server_args.served_model_name,
|
|
# TODO: Add lora name/path in the future,
|
|
},
|
|
bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
|
|
bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
|
|
bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
|
|
collect_tokens_histogram=self.server_args.collect_tokens_histogram,
|
|
)
|
|
|
|
# Configure GC warning
|
|
if self.server_args.gc_warning_threshold_secs > 0.0:
|
|
configure_gc_warning(self.server_args.gc_warning_threshold_secs)
|
|
|
|
self._result_dispatcher = TypeBasedDispatcher(
|
|
[
|
|
(
|
|
(
|
|
BatchStrOut,
|
|
BatchEmbeddingOut,
|
|
BatchTokenIDOut,
|
|
BatchMultimodalOut,
|
|
),
|
|
self._handle_batch_output,
|
|
),
|
|
(AbortReq, self._handle_abort_req),
|
|
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
|
(
|
|
UpdateWeightFromDiskReqOutput,
|
|
self._handle_update_weights_from_disk_req_output,
|
|
),
|
|
(
|
|
FreezeGCReq,
|
|
lambda x: None,
|
|
), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
|
|
(HealthCheckOutput, lambda x: None),
|
|
]
|
|
)
|
|
|
|
self.init_communicators(server_args)
|
|
|
|
async def generate_request(
|
|
self,
|
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
|
request: Optional[fastapi.Request] = None,
|
|
):
|
|
created_time = time.time()
|
|
self.auto_create_handle_loop()
|
|
obj.normalize_batch_and_arguments()
|
|
|
|
if self.server_args.tokenizer_worker_num > 1:
|
|
# Modify rid, add worker_id
|
|
if isinstance(obj.rid, list):
|
|
# If it's an array, add worker_id prefix to each element
|
|
obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
|
|
else:
|
|
# If it's a single value, add worker_id prefix
|
|
obj.rid = f"{self.worker_id}_{obj.rid}"
|
|
|
|
if self.log_requests:
|
|
max_length, skip_names, _ = self.log_request_metadata
|
|
logger.info(
|
|
f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
|
|
)
|
|
|
|
async with self.is_pause_cond:
|
|
await self.is_pause_cond.wait_for(lambda: not self.is_pause)
|
|
|
|
async with self.model_update_lock.reader_lock:
|
|
if self.server_args.enable_lora and obj.lora_path:
|
|
# Look up the LoRA ID from the registry and start tracking ongoing LoRA requests.
|
|
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
|
|
|
|
if obj.is_single:
|
|
tokenized_obj = await self._tokenize_one_request(obj)
|
|
state = self._send_one_request(obj, tokenized_obj, created_time)
|
|
async for response in self._wait_one_response(obj, state, request):
|
|
yield response
|
|
else:
|
|
async for response in self._handle_batch_request(
|
|
obj, request, created_time
|
|
):
|
|
yield response
|
|
|
|
async def _tokenize_one_request(
|
|
self,
|
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
|
):
|
|
"""Tokenize one request."""
|
|
# Tokenize
|
|
input_embeds = None
|
|
input_text = obj.text
|
|
token_type_ids = None
|
|
is_cross_encoder_request = (
|
|
isinstance(obj, EmbeddingReqInput) and obj.is_cross_encoder_request
|
|
)
|
|
if obj.input_embeds is not None:
|
|
if not self.server_args.disable_radix_cache:
|
|
raise ValueError(
|
|
"input_embeds is provided while disable_radix_cache is False. "
|
|
"Please add `--disable-radix-cache` when you launch the server "
|
|
"if you want to use input_embeds as inputs."
|
|
)
|
|
input_embeds = obj.input_embeds
|
|
input_ids = obj.input_ids
|
|
elif obj.input_ids is not None:
|
|
input_ids = obj.input_ids
|
|
else:
|
|
if self.tokenizer is None:
|
|
raise ValueError(
|
|
"The engine initialized with skip_tokenizer_init=True cannot "
|
|
"accept text prompts. Please provide input_ids or re-initialize "
|
|
"the engine with skip_tokenizer_init=False."
|
|
)
|
|
encoded = self.tokenizer(
|
|
input_text, return_token_type_ids=is_cross_encoder_request
|
|
)
|
|
|
|
input_ids = encoded["input_ids"]
|
|
if is_cross_encoder_request:
|
|
input_ids = encoded["input_ids"][0]
|
|
token_type_ids = encoded.get("token_type_ids", [None])[0]
|
|
|
|
if self.mm_processor and obj.contains_mm_input():
|
|
if not isinstance(obj.image_data, list):
|
|
obj.image_data = [obj.image_data]
|
|
if not isinstance(obj.audio_data, list):
|
|
obj.audio_data = [obj.audio_data]
|
|
mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
|
image_data=obj.image_data,
|
|
audio_data=obj.audio_data,
|
|
input_text=input_text or input_ids,
|
|
request_obj=obj,
|
|
max_req_input_len=self.max_req_input_len,
|
|
)
|
|
if mm_inputs and "input_ids" in mm_inputs:
|
|
input_ids = mm_inputs["input_ids"]
|
|
else:
|
|
mm_inputs = None
|
|
|
|
self._validate_one_request(obj, input_ids)
|
|
return self._create_tokenized_object(
|
|
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
|
)
|
|
|
|
def _validate_one_request(
|
|
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
|
|
) -> None:
|
|
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
|
|
# FIXME: unify the length validation logic with the one in the scheduler.
|
|
_max_req_len = self.context_len
|
|
|
|
input_token_num = len(input_ids) if input_ids is not None else 0
|
|
if input_token_num >= self.context_len:
|
|
if self.server_args.allow_auto_truncate:
|
|
logger.warning(
|
|
f"The input ({input_token_num} tokens) is longer than the "
|
|
f"model's context length ({self.context_len} tokens). "
|
|
"Truncating the input."
|
|
)
|
|
del input_ids[_max_req_len:]
|
|
input_token_num = len(input_ids)
|
|
else:
|
|
raise ValueError(
|
|
f"The input ({input_token_num} tokens) is longer than the "
|
|
f"model's context length ({self.context_len} tokens)."
|
|
)
|
|
|
|
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
|
raise ValueError(
|
|
"This model does not appear to be an embedding model by default. "
|
|
"Please add `--is-embedding` when launching the server or try another model."
|
|
)
|
|
|
|
# Check total tokens (input + max_new_tokens)
|
|
max_new_tokens = obj.sampling_params.get("max_new_tokens")
|
|
if (
|
|
max_new_tokens is not None
|
|
and (max_new_tokens + input_token_num) >= _max_req_len
|
|
):
|
|
if self.server_args.allow_auto_truncate:
|
|
logger.warning(
|
|
f"Requested token count ({input_token_num} input + {max_new_tokens} new) "
|
|
f"exceeds the model's context length ({self.context_len} tokens). "
|
|
"Truncating max_new_tokens."
|
|
)
|
|
obj.sampling_params["max_new_tokens"] = max(
|
|
0, _max_req_len - input_token_num
|
|
)
|
|
else:
|
|
total_tokens = max_new_tokens + input_token_num
|
|
error_msg = (
|
|
f"Requested token count exceeds the model's maximum context length "
|
|
f"of {self.context_len} tokens. You requested a total of {total_tokens} "
|
|
f"tokens: {input_token_num} tokens from the input messages and "
|
|
f"{max_new_tokens} tokens for the completion. Please reduce the number "
|
|
f"of tokens in the input messages or the completion to fit within the limit."
|
|
)
|
|
raise ValueError(error_msg)
|
|
|
|
if isinstance(obj, GenerateReqInput):
|
|
if (
|
|
obj.return_hidden_states
|
|
and not self.server_args.enable_return_hidden_states
|
|
):
|
|
raise ValueError(
|
|
"The server is not configured to return the hidden states. "
|
|
"Please set `--enable-return-hidden-states` to enable this feature."
|
|
)
|
|
if (
|
|
obj.custom_logit_processor
|
|
and not self.server_args.enable_custom_logit_processor
|
|
):
|
|
raise ValueError(
|
|
"The server is not configured to enable custom logit processor. "
|
|
"Please set `--enable-custom-logits-processor` to enable this feature."
|
|
)
|
|
|
|
def _validate_input_ids_in_vocab(
|
|
self, input_ids: List[int], vocab_size: int
|
|
) -> None:
|
|
if any(id >= vocab_size for id in input_ids):
|
|
raise ValueError(
|
|
f"The input_ids {input_ids} contains values greater than the vocab size ({vocab_size})."
|
|
)
|
|
|
|
def _create_tokenized_object(
|
|
self,
|
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
|
input_text: str,
|
|
input_ids: List[int],
|
|
input_embeds: Optional[Union[List[float], None]] = None,
|
|
mm_inputs: Optional[Dict] = None,
|
|
token_type_ids: Optional[List[int]] = None,
|
|
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
|
|
"""Create a tokenized request object from common parameters."""
|
|
# Parse sampling parameters
|
|
# Note: if there are preferred sampling params, we use them if they are not
|
|
# explicitly passed in sampling_params
|
|
if self.preferred_sampling_params:
|
|
sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params}
|
|
else:
|
|
sampling_kwargs = obj.sampling_params
|
|
sampling_params = SamplingParams(**sampling_kwargs)
|
|
sampling_params.normalize(self.tokenizer)
|
|
sampling_params.verify(self.model_config.vocab_size)
|
|
|
|
# Build return object
|
|
if isinstance(obj, GenerateReqInput):
|
|
session_params = (
|
|
SessionParams(**obj.session_params) if obj.session_params else None
|
|
)
|
|
|
|
tokenized_obj = TokenizedGenerateReqInput(
|
|
obj.rid,
|
|
input_text,
|
|
input_ids,
|
|
mm_inputs,
|
|
sampling_params,
|
|
obj.return_logprob,
|
|
obj.logprob_start_len,
|
|
obj.top_logprobs_num,
|
|
obj.token_ids_logprob,
|
|
obj.stream,
|
|
bootstrap_host=obj.bootstrap_host,
|
|
bootstrap_port=obj.bootstrap_port,
|
|
bootstrap_room=obj.bootstrap_room,
|
|
lora_id=obj.lora_id,
|
|
input_embeds=input_embeds,
|
|
session_params=session_params,
|
|
custom_logit_processor=obj.custom_logit_processor,
|
|
return_hidden_states=obj.return_hidden_states,
|
|
data_parallel_rank=obj.data_parallel_rank,
|
|
)
|
|
elif isinstance(obj, EmbeddingReqInput):
|
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
|
obj.rid,
|
|
input_text,
|
|
input_ids,
|
|
mm_inputs,
|
|
token_type_ids,
|
|
sampling_params,
|
|
)
|
|
|
|
return tokenized_obj
|
|
|
|
async def _batch_tokenize_and_process(
|
|
self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
|
|
) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
|
|
"""Handle batch tokenization for text inputs only."""
|
|
logger.debug(f"Starting batch tokenization for {batch_size} text requests")
|
|
|
|
# Collect requests and texts
|
|
requests = [obj[i] for i in range(batch_size)]
|
|
texts = [req.text for req in requests]
|
|
|
|
# Batch tokenize all texts
|
|
encoded = self.tokenizer(texts)
|
|
input_ids_list = encoded["input_ids"]
|
|
|
|
# Process all requests
|
|
tokenized_objs = []
|
|
for i, req in enumerate(requests):
|
|
self._validate_one_request(obj[i], input_ids_list[i])
|
|
tokenized_objs.append(
|
|
self._create_tokenized_object(
|
|
req, req.text, input_ids_list[i], None, None
|
|
)
|
|
)
|
|
logger.debug(f"Completed batch processing for {batch_size} requests")
|
|
return tokenized_objs
|
|
|
|
def _validate_batch_tokenization_constraints(
|
|
self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
|
|
) -> None:
|
|
"""Validate constraints for batch tokenization processing."""
|
|
for i in range(batch_size):
|
|
if self.is_generation and obj[i].contains_mm_input():
|
|
raise ValueError(
|
|
"For multimodal input processing do not set `enable_tokenizer_batch_encode`."
|
|
)
|
|
if obj[i].input_ids is not None:
|
|
raise ValueError(
|
|
"Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
|
|
)
|
|
if obj[i].input_embeds is not None:
|
|
raise ValueError(
|
|
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
|
|
)
|
|
|
|
def _send_one_request(
|
|
self,
|
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
|
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
|
created_time: Optional[float] = None,
|
|
):
|
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
|
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
|
self.rid_to_state[obj.rid] = state
|
|
return state
|
|
|
|
def _send_batch_request(
|
|
self,
|
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
|
tokenized_objs: List[
|
|
Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
|
|
],
|
|
created_time: Optional[float] = None,
|
|
):
|
|
"""Send a batch of tokenized requests as a single batched request to the scheduler."""
|
|
if isinstance(tokenized_objs[0], TokenizedGenerateReqInput):
|
|
batch_req = BatchTokenizedGenerateReqInput(batch=tokenized_objs)
|
|
else:
|
|
batch_req = BatchTokenizedEmbeddingReqInput(batch=tokenized_objs)
|
|
|
|
self.send_to_scheduler.send_pyobj(batch_req)
|
|
|
|
# Create states for each individual request in the batch
|
|
for i, tokenized_obj in enumerate(tokenized_objs):
|
|
tmp_obj = obj[i]
|
|
state = ReqState(
|
|
[], False, asyncio.Event(), tmp_obj, created_time=created_time
|
|
)
|
|
self.rid_to_state[tmp_obj.rid] = state
|
|
|
|
async def _wait_one_response(
|
|
self,
|
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
|
state: ReqState,
|
|
request: Optional[fastapi.Request] = None,
|
|
):
|
|
"""Wait for the response of one request."""
|
|
while True:
|
|
try:
|
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
|
except asyncio.TimeoutError:
|
|
if (
|
|
request is not None
|
|
and not obj.background
|
|
and await request.is_disconnected()
|
|
):
|
|
# Abort the request for disconnected requests (non-streaming, waiting queue)
|
|
self.abort_request(obj.rid)
|
|
# Use exception to kill the whole call stack and asyncio task
|
|
raise ValueError(
|
|
f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
|
|
)
|
|
continue
|
|
|
|
out = state.out_list[-1]
|
|
|
|
state.out_list = []
|
|
if state.finished:
|
|
if self.log_requests:
|
|
max_length, skip_names, out_skip_names = self.log_request_metadata
|
|
if self.model_config.is_multimodal_gen:
|
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
|
|
else:
|
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
|
logger.info(msg)
|
|
|
|
# Check if this was an abort/error created by scheduler
|
|
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
|
finish_reason = out["meta_info"]["finish_reason"]
|
|
if (
|
|
finish_reason.get("type") == "abort"
|
|
and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
|
|
):
|
|
raise ValueError(finish_reason["message"])
|
|
|
|
if finish_reason.get("type") == "abort" and finish_reason.get(
|
|
"status_code"
|
|
) in (
|
|
HTTPStatus.SERVICE_UNAVAILABLE,
|
|
HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
):
|
|
# This is an abort request initiated by scheduler.
|
|
# Delete the key to prevent resending abort request to the scheduler and
|
|
# to ensure aborted request state is cleaned up.
|
|
if state.obj.rid in self.rid_to_state:
|
|
del self.rid_to_state[state.obj.rid]
|
|
|
|
# Mark ongoing LoRA request as finished.
|
|
if self.server_args.enable_lora and state.obj.lora_path:
|
|
await self.lora_registry.release(state.obj.lora_id)
|
|
|
|
raise fastapi.HTTPException(
|
|
status_code=finish_reason["status_code"],
|
|
detail=finish_reason["message"],
|
|
)
|
|
yield out
|
|
break
|
|
|
|
state.event.clear()
|
|
|
|
if obj.stream:
|
|
yield out
|
|
else:
|
|
if (
|
|
request is not None
|
|
and not obj.background
|
|
and await request.is_disconnected()
|
|
):
|
|
# Abort the request for disconnected requests (non-streaming, running)
|
|
self.abort_request(obj.rid)
|
|
# Use exception to kill the whole call stack and asyncio task
|
|
raise ValueError(
|
|
f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
|
|
)
|
|
|
|
async def _handle_batch_request(
|
|
self,
|
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
|
request: Optional[fastapi.Request] = None,
|
|
created_time: Optional[float] = None,
|
|
):
|
|
batch_size = obj.batch_size
|
|
|
|
generators = []
|
|
rids = []
|
|
if getattr(obj, "parallel_sample_num", 1) == 1:
|
|
if self.server_args.enable_tokenizer_batch_encode:
|
|
# Validate batch tokenization constraints
|
|
self._validate_batch_tokenization_constraints(batch_size, obj)
|
|
|
|
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
|
|
|
|
# Send as a single batched request
|
|
self._send_batch_request(obj, tokenized_objs, created_time)
|
|
|
|
# Set up generators for each request in the batch
|
|
for i in range(batch_size):
|
|
tmp_obj = obj[i]
|
|
generators.append(
|
|
self._wait_one_response(
|
|
tmp_obj, self.rid_to_state[tmp_obj.rid], request
|
|
)
|
|
)
|
|
rids.append(tmp_obj.rid)
|
|
else:
|
|
# Sequential tokenization and processing
|
|
with (
|
|
input_blocker_guard_region(send_to_scheduler=self.send_to_scheduler)
|
|
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
|
|
else nullcontext()
|
|
):
|
|
for i in range(batch_size):
|
|
tmp_obj = obj[i]
|
|
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
|
state = self._send_one_request(
|
|
tmp_obj, tokenized_obj, created_time
|
|
)
|
|
generators.append(
|
|
self._wait_one_response(tmp_obj, state, request)
|
|
)
|
|
rids.append(tmp_obj.rid)
|
|
else:
|
|
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
|
if batch_size > 128:
|
|
logger.warning(
|
|
"Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
|
|
"The performance might be better if you just duplicate the requests n times or use "
|
|
"many threads to send them one by one with parallel sampling (n > 1)."
|
|
)
|
|
|
|
# Tokenize all requests
|
|
objs = [obj[i] for i in range(batch_size)]
|
|
tokenized_objs = await asyncio.gather(
|
|
*(self._tokenize_one_request(obj) for obj in objs)
|
|
)
|
|
|
|
# Cache the common prefix for parallel sampling
|
|
for i in range(batch_size):
|
|
tmp_obj = copy.copy(objs[i])
|
|
tokenized_obj = copy.copy(tokenized_objs[i])
|
|
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
|
tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
|
|
tokenized_obj.sampling_params.max_new_tokens = 0
|
|
tokenized_obj.stream = False
|
|
state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
|
await self._wait_one_response(tmp_obj, state, request).__anext__()
|
|
|
|
# Expand requests, assign new rids for them, and send them
|
|
for i in range(batch_size):
|
|
for _ in range(obj.parallel_sample_num):
|
|
tmp_obj = copy.copy(objs[i])
|
|
tokenized_obj = copy.copy(tokenized_objs[i])
|
|
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
|
state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
|
generators.append(self._wait_one_response(tmp_obj, state, request))
|
|
rids.append(tmp_obj.rid)
|
|
|
|
# Wait for all requests
|
|
is_stream = hasattr(obj, "stream") and obj.stream
|
|
if not is_stream:
|
|
outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
|
|
yield outputs
|
|
else:
|
|
rid_to_index = {rid: i for i, rid in enumerate(rids)}
|
|
task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
|
|
while task_map:
|
|
done, _ = await asyncio.wait(
|
|
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
|
|
)
|
|
|
|
for task in done:
|
|
gen = task_map.pop(task)
|
|
try:
|
|
result = task.result()
|
|
result["index"] = rid_to_index[result["meta_info"]["id"]]
|
|
yield result
|
|
new_task = asyncio.create_task(gen.__anext__())
|
|
task_map[new_task] = gen
|
|
except StopAsyncIteration:
|
|
pass
|
|
|
|
def abort_request(self, rid: str = "", abort_all: bool = False):
|
|
if not abort_all and rid not in self.rid_to_state:
|
|
return
|
|
req = AbortReq(rid, abort_all)
|
|
self.send_to_scheduler.send_pyobj(req)
|
|
|
|
if self.enable_metrics:
|
|
self.metrics_collector.observe_one_aborted_request()
|
|
|
|
async def pause_generation(self):
|
|
async with self.is_pause_cond:
|
|
self.is_pause = True
|
|
self.abort_request(abort_all=True)
|
|
|
|
async def continue_generation(self):
|
|
async with self.is_pause_cond:
|
|
self.is_pause = False
|
|
self.is_pause_cond.notify_all()
|
|
|
|
async def update_weights_from_disk(
|
|
self,
|
|
obj: UpdateWeightFromDiskReqInput,
|
|
request: Optional[fastapi.Request] = None,
|
|
) -> Tuple[bool, str]:
|
|
self.auto_create_handle_loop()
|
|
|
|
# default the load format to the server_args
|
|
if obj.load_format is None:
|
|
obj.load_format = self.server_args.load_format
|
|
logger.info("Start update_weights. Load format=%s", obj.load_format)
|
|
|
|
if obj.abort_all_requests:
|
|
self.abort_request(abort_all=True)
|
|
|
|
if True: # Keep this redundant check to simplify some internal code sync
|
|
# Hold the lock if it is not async. This means that weight sync
|
|
# cannot run while requests are in progress.
|
|
async with self.model_update_lock.writer_lock:
|
|
return await self._wait_for_model_update_from_disk(obj)
|
|
|
|
async def _wait_for_model_update_from_disk(
|
|
self, obj: UpdateWeightFromDiskReqInput
|
|
) -> Tuple[bool, str]:
|
|
if self.server_args.tokenizer_worker_num > 1:
|
|
obj = MultiTokenizerWrapper(self.worker_id, obj)
|
|
self.send_to_scheduler.send_pyobj(obj)
|
|
self.model_update_result = asyncio.Future()
|
|
if self.server_args.dp_size == 1:
|
|
result = await self.model_update_result
|
|
if result.success:
|
|
self.served_model_name = obj.model_path
|
|
self.server_args.model_path = obj.model_path
|
|
self.server_args.load_format = obj.load_format
|
|
self.model_path = obj.model_path
|
|
return result.success, result.message, result.num_paused_requests
|
|
else: # self.server_args.dp_size > 1
|
|
self.model_update_tmp = []
|
|
result = await self.model_update_result
|
|
|
|
all_success = all([r.success for r in result])
|
|
if all_success is True:
|
|
self.server_args.model_path = obj.model_path
|
|
self.server_args.load_format = obj.load_format
|
|
self.model_path = obj.model_path
|
|
all_message = [r.message for r in result]
|
|
all_message = " | ".join(all_message)
|
|
all_paused_requests = [r.num_paused_requests for r in result]
|
|
return all_success, all_message, all_paused_requests
|
|
|
|
async def open_session(
|
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
|
):
|
|
self.auto_create_handle_loop()
|
|
|
|
if obj.session_id is None:
|
|
obj.session_id = uuid.uuid4().hex
|
|
elif obj.session_id in self.session_futures:
|
|
return None
|
|
|
|
if self.server_args.tokenizer_worker_num > 1:
|
|
obj = MultiTokenizerWrapper(self.worker_id, obj)
|
|
self.send_to_scheduler.send_pyobj(obj)
|
|
|
|
self.session_futures[obj.session_id] = asyncio.Future()
|
|
session_id = await self.session_futures[obj.session_id]
|
|
del self.session_futures[obj.session_id]
|
|
return session_id
|
|
|
|
async def close_session(
|
|
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
|
):
|
|
await self.send_to_scheduler.send_pyobj(obj)
|
|
|
|
def get_log_request_metadata(self):
|
|
max_length = None
|
|
skip_names = None
|
|
out_skip_names = None
|
|
if self.log_requests:
|
|
if self.log_requests_level == 0:
|
|
max_length = 1 << 30
|
|
skip_names = set(
|
|
[
|
|
"text",
|
|
"input_ids",
|
|
"input_embeds",
|
|
"image_data",
|
|
"audio_data",
|
|
"lora_path",
|
|
"sampling_params",
|
|
]
|
|
)
|
|
out_skip_names = set(
|
|
[
|
|
"text",
|
|
"output_ids",
|
|
"embedding",
|
|
]
|
|
)
|
|
elif self.log_requests_level == 1:
|
|
max_length = 1 << 30
|
|
skip_names = set(
|
|
[
|
|
"text",
|
|
"input_ids",
|
|
"input_embeds",
|
|
"image_data",
|
|
"audio_data",
|
|
"lora_path",
|
|
]
|
|
)
|
|
out_skip_names = set(
|
|
[
|
|
"text",
|
|
"output_ids",
|
|
"embedding",
|
|
]
|
|
)
|
|
elif self.log_requests_level == 2:
|
|
max_length = 2048
|
|
elif self.log_requests_level == 3:
|
|
max_length = 1 << 30
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid --log-requests-level: {self.log_requests_level=}"
|
|
)
|
|
return max_length, skip_names, out_skip_names
|
|
|
|
def configure_logging(self, obj: ConfigureLoggingReq):
|
|
if obj.log_requests is not None:
|
|
self.log_requests = obj.log_requests
|
|
if obj.log_requests_level is not None:
|
|
self.log_requests_level = obj.log_requests_level
|
|
if obj.dump_requests_folder is not None:
|
|
self.dump_requests_folder = obj.dump_requests_folder
|
|
if obj.dump_requests_threshold is not None:
|
|
self.dump_requests_threshold = obj.dump_requests_threshold
|
|
if obj.crash_dump_folder is not None:
|
|
self.crash_dump_folder = obj.crash_dump_folder
|
|
logging.info(f"Config logging: {obj=}")
|
|
self.log_request_metadata = self.get_log_request_metadata()
|
|
|
|
async def freeze_gc(self):
|
|
"""Send a freeze_gc message to the scheduler first, then freeze locally."""
|
|
self.send_to_scheduler.send_pyobj(FreezeGCReq())
|
|
freeze_gc("Tokenizer Manager")
|
|
return None
|
|
|
|
def create_abort_task(self, obj: GenerateReqInput):
|
|
# Abort the request if the client is disconnected.
|
|
async def abort_request():
|
|
await asyncio.sleep(2)
|
|
if obj.is_single:
|
|
self.abort_request(obj.rid)
|
|
else:
|
|
for rid in obj.rid:
|
|
self.abort_request(rid)
|
|
|
|
background_tasks = BackgroundTasks()
|
|
background_tasks.add_task(abort_request)
|
|
return background_tasks
|
|
|
|
def auto_create_handle_loop(self):
|
|
if self.no_create_loop:
|
|
return
|
|
|
|
self.no_create_loop = True
|
|
loop = asyncio.get_event_loop()
|
|
self.asyncio_tasks.add(
|
|
loop.create_task(print_exception_wrapper(self.handle_loop))
|
|
)
|
|
|
|
self.event_loop = loop
|
|
|
|
# We cannot add signal handler when the tokenizer manager is not in
|
|
# the main thread due to the CPython limitation.
|
|
if threading.current_thread() is threading.main_thread():
|
|
signal_handler = SignalHandler(self)
|
|
loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
|
|
# Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
|
|
loop.add_signal_handler(
|
|
signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Signal handler is not added because the tokenizer manager is "
|
|
"not in the main thread. This disables graceful shutdown of the "
|
|
"tokenizer manager when SIGTERM is received."
|
|
)
|
|
self.asyncio_tasks.add(
|
|
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
|
)
|
|
|
|
def dump_requests_before_crash(self):
|
|
if self.crash_dump_performed:
|
|
logger.info(
|
|
"SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
|
|
)
|
|
return
|
|
|
|
if not self.crash_dump_folder:
|
|
return
|
|
|
|
logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
|
|
self.crash_dump_performed = True
|
|
|
|
# Check if NFS directory is available
|
|
# expected_nfs_dir = "/" + self.crash_dump_folder.lstrip("/").split("/")[0]
|
|
# use_nfs_dir = os.path.isdir(expected_nfs_dir) and os.access(
|
|
# expected_nfs_dir, os.W_OK
|
|
# )
|
|
use_nfs_dir = False
|
|
if not use_nfs_dir:
|
|
logger.error(
|
|
f"Expected NFS directory is not available or writable. Uploading to GCS."
|
|
)
|
|
|
|
data_to_dump = []
|
|
if self.crash_dump_request_list:
|
|
data_to_dump.extend(self.crash_dump_request_list)
|
|
|
|
# Add unfinished requests from rid_to_state
|
|
unfinished_requests = []
|
|
for rid, state in self.rid_to_state.items():
|
|
if not state.finished:
|
|
unfinished_requests.append(
|
|
(
|
|
state.obj,
|
|
state.out_list[-1] if state.out_list else {},
|
|
state.created_time,
|
|
time.time(),
|
|
)
|
|
)
|
|
if unfinished_requests:
|
|
data_to_dump.extend(unfinished_requests)
|
|
|
|
if not data_to_dump:
|
|
return
|
|
|
|
object_name = f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl'
|
|
filename = os.path.join(
|
|
self.crash_dump_folder,
|
|
os.getenv("HOSTNAME", None),
|
|
object_name,
|
|
)
|
|
|
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
# Include server_args in the dump
|
|
data_to_dump_with_server_args = {
|
|
"server_args": self.server_args,
|
|
"requests": data_to_dump,
|
|
}
|
|
with open(filename, "wb") as f:
|
|
pickle.dump(data_to_dump_with_server_args, f)
|
|
logger.error(
|
|
f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
|
|
)
|
|
|
|
def _upload_file_to_gcs(bucket_name, source_file_path, object_name):
|
|
from google.cloud import storage
|
|
|
|
client = storage.Client()
|
|
bucket = client.bucket(bucket_name)
|
|
blob = bucket.blob(object_name)
|
|
blob.upload_from_filename(source_file_path, if_generation_match=0)
|
|
logger.error(
|
|
f"Successfully uploaded {source_file_path} to gs://{bucket_name}/{object_name}"
|
|
)
|
|
|
|
if not use_nfs_dir:
|
|
_upload_file_to_gcs(
|
|
"sglang_crash_dump",
|
|
filename,
|
|
os.getenv("HOSTNAME", None) + "/" + object_name,
|
|
)
|
|
|
|
async def sigterm_watchdog(self):
|
|
while not self.gracefully_exit:
|
|
await asyncio.sleep(5)
|
|
|
|
# Drain requests
|
|
while True:
|
|
remain_num_req = len(self.rid_to_state)
|
|
|
|
if self.server_status == ServerStatus.UnHealthy:
|
|
# if health check failed, we should exit immediately
|
|
logger.error(
|
|
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
|
remain_num_req,
|
|
)
|
|
self.dump_requests_before_crash()
|
|
break
|
|
|
|
elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
|
|
# if force shutdown flag set, exit immediately
|
|
logger.error(
|
|
"Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d",
|
|
remain_num_req,
|
|
)
|
|
break
|
|
|
|
logger.info(
|
|
f"Gracefully exiting... remaining number of requests {remain_num_req}"
|
|
)
|
|
if remain_num_req > 0:
|
|
await asyncio.sleep(5)
|
|
else:
|
|
self.dump_requests_before_crash()
|
|
break
|
|
|
|
kill_process_tree(os.getpid(), include_parent=True)
|
|
sys.exit(0)
|
|
|
|
async def handle_loop(self):
|
|
"""The event loop that handles requests"""
|
|
while True:
|
|
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
|
self._result_dispatcher(recv_obj)
|
|
self.last_receive_tstamp = time.time()
|
|
|
|
def _handle_batch_output(
|
|
self,
|
|
recv_obj: Union[
|
|
BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
|
|
],
|
|
):
|
|
for i, rid in enumerate(recv_obj.rids):
|
|
state = self.rid_to_state.get(rid, None)
|
|
if state is None:
|
|
logger.error(
|
|
f"Received output for {rid=} but the state was deleted in TokenizerManager."
|
|
)
|
|
continue
|
|
|
|
origin_rid = rid
|
|
if self.server_args.tokenizer_worker_num > 1:
|
|
origin_rid = get_origin_rid(rid)
|
|
# Build meta_info and return value
|
|
meta_info = {
|
|
"id": origin_rid,
|
|
"finish_reason": recv_obj.finished_reasons[i],
|
|
"prompt_tokens": recv_obj.prompt_tokens[i],
|
|
"weight_version": self.server_args.weight_version,
|
|
}
|
|
|
|
if getattr(state.obj, "return_logprob", False):
|
|
self.convert_logprob_style(
|
|
meta_info,
|
|
state,
|
|
state.obj.top_logprobs_num,
|
|
state.obj.token_ids_logprob,
|
|
state.obj.return_text_in_logprobs
|
|
and not self.server_args.skip_tokenizer_init,
|
|
recv_obj,
|
|
i,
|
|
)
|
|
|
|
if not isinstance(recv_obj, BatchEmbeddingOut):
|
|
meta_info.update(
|
|
{
|
|
"completion_tokens": recv_obj.completion_tokens[i],
|
|
"cached_tokens": recv_obj.cached_tokens[i],
|
|
}
|
|
)
|
|
|
|
if getattr(recv_obj, "output_hidden_states", None):
|
|
meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
|
|
|
|
if isinstance(recv_obj, BatchStrOut):
|
|
state.text += recv_obj.output_strs[i]
|
|
if state.obj.stream:
|
|
state.output_ids.extend(recv_obj.output_ids[i])
|
|
output_token_ids = state.output_ids[state.last_output_offset :]
|
|
state.last_output_offset = len(state.output_ids)
|
|
else:
|
|
state.output_ids.extend(recv_obj.output_ids[i])
|
|
output_token_ids = state.output_ids.copy()
|
|
|
|
out_dict = {
|
|
"text": state.text,
|
|
"output_ids": output_token_ids,
|
|
"meta_info": meta_info,
|
|
}
|
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
|
if self.server_args.stream_output and state.obj.stream:
|
|
state.output_ids.extend(recv_obj.output_ids[i])
|
|
output_token_ids = state.output_ids[state.last_output_offset :]
|
|
state.last_output_offset = len(state.output_ids)
|
|
else:
|
|
state.output_ids.extend(recv_obj.output_ids[i])
|
|
output_token_ids = state.output_ids.copy()
|
|
|
|
out_dict = {
|
|
"output_ids": output_token_ids,
|
|
"meta_info": meta_info,
|
|
}
|
|
elif isinstance(recv_obj, BatchMultimodalOut):
|
|
raise NotImplementedError("BatchMultimodalOut not implemented")
|
|
else:
|
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
|
out_dict = {
|
|
"embedding": recv_obj.embeddings[i],
|
|
"meta_info": meta_info,
|
|
}
|
|
|
|
state.finished = recv_obj.finished_reasons[i] is not None
|
|
if state.finished:
|
|
if self.server_args.speculative_algorithm:
|
|
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
|
state.finished_time = time.time()
|
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
|
del self.rid_to_state[rid]
|
|
|
|
# Mark ongoing LoRA request as finished.
|
|
if self.server_args.enable_lora and state.obj.lora_path:
|
|
asyncio.create_task(self.lora_registry.release(state.obj.lora_id))
|
|
|
|
state.out_list.append(out_dict)
|
|
state.event.set()
|
|
|
|
# Log metrics and dump
|
|
if self.enable_metrics and state.obj.log_metrics:
|
|
self.collect_metrics(state, recv_obj, i)
|
|
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
|
|
self.dump_requests(state, out_dict)
|
|
if self.crash_dump_folder and state.finished and state.obj.log_metrics:
|
|
self.record_request_for_crash_dump(state, out_dict)
|
|
|
|
def convert_logprob_style(
|
|
self,
|
|
meta_info: dict,
|
|
state: ReqState,
|
|
top_logprobs_num: int,
|
|
token_ids_logprob: List[int],
|
|
return_text_in_logprobs: bool,
|
|
recv_obj: BatchStrOut,
|
|
recv_obj_index: int,
|
|
):
|
|
if recv_obj.input_token_logprobs_val is None:
|
|
return
|
|
|
|
if len(recv_obj.input_token_logprobs_val) > 0:
|
|
state.input_token_logprobs_val.extend(
|
|
recv_obj.input_token_logprobs_val[recv_obj_index]
|
|
)
|
|
state.input_token_logprobs_idx.extend(
|
|
recv_obj.input_token_logprobs_idx[recv_obj_index]
|
|
)
|
|
state.output_token_logprobs_val.extend(
|
|
recv_obj.output_token_logprobs_val[recv_obj_index]
|
|
)
|
|
state.output_token_logprobs_idx.extend(
|
|
recv_obj.output_token_logprobs_idx[recv_obj_index]
|
|
)
|
|
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
|
state.input_token_logprobs_val,
|
|
state.input_token_logprobs_idx,
|
|
return_text_in_logprobs,
|
|
)
|
|
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
|
state.output_token_logprobs_val,
|
|
state.output_token_logprobs_idx,
|
|
return_text_in_logprobs,
|
|
)
|
|
|
|
if top_logprobs_num > 0:
|
|
if len(recv_obj.input_top_logprobs_val) > 0:
|
|
state.input_top_logprobs_val.extend(
|
|
recv_obj.input_top_logprobs_val[recv_obj_index]
|
|
)
|
|
state.input_top_logprobs_idx.extend(
|
|
recv_obj.input_top_logprobs_idx[recv_obj_index]
|
|
)
|
|
state.output_top_logprobs_val.extend(
|
|
recv_obj.output_top_logprobs_val[recv_obj_index]
|
|
)
|
|
state.output_top_logprobs_idx.extend(
|
|
recv_obj.output_top_logprobs_idx[recv_obj_index]
|
|
)
|
|
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
|
state.input_top_logprobs_val,
|
|
state.input_top_logprobs_idx,
|
|
return_text_in_logprobs,
|
|
)
|
|
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
|
state.output_top_logprobs_val,
|
|
state.output_top_logprobs_idx,
|
|
return_text_in_logprobs,
|
|
)
|
|
|
|
if token_ids_logprob is not None:
|
|
if len(recv_obj.input_token_ids_logprobs_val) > 0:
|
|
state.input_token_ids_logprobs_val.extend(
|
|
recv_obj.input_token_ids_logprobs_val[recv_obj_index]
|
|
)
|
|
state.input_token_ids_logprobs_idx.extend(
|
|
recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
|
|
)
|
|
state.output_token_ids_logprobs_val.extend(
|
|
recv_obj.output_token_ids_logprobs_val[recv_obj_index]
|
|
)
|
|
state.output_token_ids_logprobs_idx.extend(
|
|
recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
|
|
)
|
|
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
|
|
state.input_token_ids_logprobs_val,
|
|
state.input_token_ids_logprobs_idx,
|
|
return_text_in_logprobs,
|
|
)
|
|
meta_info["output_token_ids_logprobs"] = (
|
|
self.detokenize_top_logprobs_tokens(
|
|
state.output_token_ids_logprobs_val,
|
|
state.output_token_ids_logprobs_idx,
|
|
return_text_in_logprobs,
|
|
)
|
|
)
|
|
|
|
def detokenize_logprob_tokens(
|
|
self,
|
|
token_logprobs_val: List[float],
|
|
token_logprobs_idx: List[int],
|
|
decode_to_text: bool,
|
|
):
|
|
if not decode_to_text:
|
|
return [
|
|
(logprob, token_id, None)
|
|
for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
|
|
]
|
|
else:
|
|
assert self.tokenizer is not None
|
|
token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
|
|
return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
|
|
|
|
def detokenize_top_logprobs_tokens(
|
|
self,
|
|
token_logprobs_val: List[float],
|
|
token_logprobs_idx: List[int],
|
|
decode_to_text: bool,
|
|
):
|
|
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
|
|
# We should batch all top-k tokens in all positions.
|
|
ret = []
|
|
for i in range(len(token_logprobs_val)):
|
|
if token_logprobs_val[i]:
|
|
ret.append(
|
|
self.detokenize_logprob_tokens(
|
|
token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
|
|
)
|
|
)
|
|
else:
|
|
ret.append(None)
|
|
return ret
|
|
|
|
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
|
|
completion_tokens = (
|
|
recv_obj.completion_tokens[i]
|
|
if getattr(recv_obj, "completion_tokens", None)
|
|
else 0
|
|
)
|
|
|
|
if (
|
|
state.first_token_time == 0.0
|
|
and self.disaggregation_mode != DisaggregationMode.PREFILL
|
|
):
|
|
state.first_token_time = state.last_time = time.time()
|
|
state.last_completion_tokens = completion_tokens
|
|
self.metrics_collector.observe_time_to_first_token(
|
|
state.first_token_time - state.created_time
|
|
)
|
|
else:
|
|
num_new_tokens = completion_tokens - state.last_completion_tokens
|
|
if num_new_tokens:
|
|
new_time = time.time()
|
|
interval = new_time - state.last_time
|
|
self.metrics_collector.observe_inter_token_latency(
|
|
interval,
|
|
num_new_tokens,
|
|
)
|
|
state.last_time = new_time
|
|
state.last_completion_tokens = completion_tokens
|
|
|
|
if state.finished:
|
|
has_grammar = (
|
|
state.obj.sampling_params.get("json_schema", None)
|
|
or state.obj.sampling_params.get("regex", None)
|
|
or state.obj.sampling_params.get("ebnf", None)
|
|
or state.obj.sampling_params.get("structural_tag", None)
|
|
)
|
|
self.metrics_collector.observe_one_finished_request(
|
|
recv_obj.prompt_tokens[i],
|
|
completion_tokens,
|
|
recv_obj.cached_tokens[i],
|
|
state.finished_time - state.created_time,
|
|
has_grammar,
|
|
)
|
|
|
|
def dump_requests(self, state: ReqState, out_dict: dict):
|
|
self.dump_request_list.append(
|
|
(state.obj, out_dict, state.created_time, time.time())
|
|
)
|
|
|
|
if len(self.dump_request_list) >= self.dump_requests_threshold:
|
|
filename = os.path.join(
|
|
self.dump_requests_folder,
|
|
datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
|
|
)
|
|
self._dump_data_to_file(
|
|
data_list=self.dump_request_list,
|
|
filename=filename,
|
|
log_message=f"Dump {len(self.dump_request_list)} requests to {filename}",
|
|
)
|
|
self.dump_request_list = []
|
|
|
|
def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
|
|
current_time = time.time()
|
|
self.crash_dump_request_list.append(
|
|
(state.obj, out_dict, state.created_time, current_time)
|
|
)
|
|
# Remove requests older than 5 minutes based on finish time
|
|
while (
|
|
self.crash_dump_request_list
|
|
and current_time - self.crash_dump_request_list[0][3] >= 300
|
|
):
|
|
self.crash_dump_request_list.popleft()
|
|
|
|
def _dump_data_to_file(
|
|
self, data_list: List[Tuple], filename: str, log_message: str
|
|
):
|
|
logger.info(log_message)
|
|
to_dump_with_server_args = {
|
|
"server_args": self.server_args,
|
|
"requests": data_list.copy(),
|
|
}
|
|
|
|
def background_task():
|
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
with open(filename, "wb") as f:
|
|
pickle.dump(to_dump_with_server_args, f)
|
|
|
|
asyncio.create_task(asyncio.to_thread(background_task))
|
|
|
|
def _handle_abort_req(self, recv_obj):
|
|
if is_health_check_generate_req(recv_obj):
|
|
return
|
|
state = self.rid_to_state[recv_obj.rid]
|
|
origin_rid = recv_obj.rid
|
|
if self.server_args.tokenizer_worker_num > 1:
|
|
origin_rid = get_origin_rid(origin_rid)
|
|
state.finished = True
|
|
if recv_obj.finished_reason:
|
|
out = {
|
|
"meta_info": {
|
|
"id": recv_obj.rid,
|
|
"finish_reason": recv_obj.finished_reason,
|
|
},
|
|
}
|
|
else:
|
|
out = {
|
|
"text": "",
|
|
"meta_info": {
|
|
"id": origin_rid,
|
|
"finish_reason": {
|
|
"type": "abort",
|
|
"message": "Abort before prefill",
|
|
},
|
|
"prompt_tokens": 0,
|
|
"completion_tokens": 0,
|
|
},
|
|
}
|
|
state.out_list.append(out)
|
|
state.event.set()
|
|
|
|
def _handle_open_session_req_output(self, recv_obj):
|
|
self.session_futures[recv_obj.session_id].set_result(
|
|
recv_obj.session_id if recv_obj.success else None
|
|
)
|
|
|
|
def _handle_update_weights_from_disk_req_output(self, recv_obj):
|
|
if self.server_args.dp_size == 1:
|
|
self.model_update_result.set_result(recv_obj)
|
|
else: # self.server_args.dp_size > 1
|
|
self.model_update_tmp.append(recv_obj)
|
|
# set future if the all results are received
|
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
|
self.model_update_result.set_result(self.model_update_tmp)
|
|
|
|
async def score_request(
|
|
self,
|
|
query: Optional[Union[str, List[int]]] = None,
|
|
items: Optional[Union[str, List[str], List[List[int]]]] = None,
|
|
label_token_ids: Optional[List[int]] = None,
|
|
apply_softmax: bool = False,
|
|
item_first: bool = False,
|
|
request: Optional[Any] = None,
|
|
) -> List[List[float]]:
|
|
"""
|
|
See Engine.score() for more details.
|
|
"""
|
|
if label_token_ids is None:
|
|
raise ValueError("label_token_ids must be provided")
|
|
|
|
if self.tokenizer is not None:
|
|
vocab_size = self.tokenizer.vocab_size
|
|
for token_id in label_token_ids:
|
|
if token_id >= vocab_size:
|
|
raise ValueError(
|
|
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
|
|
)
|
|
|
|
batch_request = GenerateReqInput(
|
|
token_ids_logprob=label_token_ids,
|
|
return_logprob=True,
|
|
stream=False,
|
|
sampling_params={"max_new_tokens": 0},
|
|
)
|
|
|
|
# Handle string or tokenized query/items
|
|
if isinstance(query, str) and (
|
|
isinstance(items, str)
|
|
or (isinstance(items, list) and (not items or isinstance(items[0], str)))
|
|
):
|
|
# Both query and items are text
|
|
items_list = [items] if isinstance(items, str) else items
|
|
if item_first:
|
|
prompts = [f"{item}{query}" for item in items_list]
|
|
else:
|
|
prompts = [f"{query}{item}" for item in items_list]
|
|
|
|
batch_request.text = prompts
|
|
|
|
elif (
|
|
isinstance(query, list)
|
|
and isinstance(items, list)
|
|
and items
|
|
and isinstance(items[0], list)
|
|
):
|
|
# Both query and items are token IDs
|
|
if item_first:
|
|
input_ids_list = [item + query for item in items]
|
|
else:
|
|
input_ids_list = [query + item for item in items]
|
|
|
|
batch_request.input_ids = input_ids_list
|
|
else:
|
|
raise ValueError(
|
|
"Invalid combination of query/items types for score_request."
|
|
)
|
|
|
|
results = await self.generate_request(batch_request, request).__anext__()
|
|
scores = []
|
|
|
|
for result in results:
|
|
# Get logprobs for each token
|
|
logprobs = {}
|
|
|
|
# For scoring requests, we read from output_token_ids_logprobs since we want
|
|
# the logprobs for specific tokens mentioned in the label_token_ids at
|
|
# the next position after the last token in the prompt
|
|
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
|
|
|
# Throw an error here if output_logprobs is None
|
|
if output_logprobs is None:
|
|
raise RuntimeError(
|
|
f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
|
|
"This usually indicates a problem with the scoring request or the backend output."
|
|
)
|
|
|
|
for logprob, token_id, _ in output_logprobs[0]:
|
|
if token_id in label_token_ids:
|
|
logprobs[token_id] = logprob
|
|
|
|
# Get scores in order of label_token_ids
|
|
score_list = [
|
|
logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
|
|
]
|
|
|
|
# Apply softmax to logprobs if needed
|
|
if apply_softmax:
|
|
score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
|
|
else:
|
|
# Convert logprobs to probabilities if not using softmax
|
|
score_list = [
|
|
math.exp(x) if x != float("-inf") else 0.0 for x in score_list
|
|
]
|
|
|
|
scores.append(score_list)
|
|
|
|
return scores
|
|
|
|
|
|
class ServerStatus(Enum):
|
|
Up = "Up"
|
|
Starting = "Starting"
|
|
UnHealthy = "UnHealthy"
|
|
|
|
|
|
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
|
is_cross_node = server_args.dist_init_addr
|
|
|
|
if is_cross_node:
|
|
# Fallback to default CPU transport for multi-node
|
|
return "default"
|
|
else:
|
|
return "cuda_ipc"
|
|
|
|
|
|
async def print_exception_wrapper(func):
|
|
"""
|
|
Sometimes an asyncio function does not print exception.
|
|
We do another wrapper to handle the exception.
|
|
"""
|
|
try:
|
|
await func()
|
|
except Exception:
|
|
traceback = get_exception_traceback()
|
|
logger.error(f"TokenizerManager hit an exception: {traceback}")
|
|
if hasattr(func, "__self__") and isinstance(func.__self__, TokenizerManager):
|
|
func.__self__.dump_requests_before_crash()
|
|
kill_process_tree(os.getpid(), include_parent=True)
|
|
sys.exit(1)
|
|
|
|
|
|
class SignalHandler:
|
|
def __init__(self, tokenizer_manager: TokenizerManager):
|
|
self.tokenizer_manager = tokenizer_manager
|
|
|
|
def sigterm_handler(self, signum=None, frame=None):
|
|
logger.warning(
|
|
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
|
|
)
|
|
self.tokenizer_manager.gracefully_exit = True
|
|
|
|
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
|
logger.error(
|
|
"Received sigquit from a child process. It usually means the child failed."
|
|
)
|
|
self.tokenizer_manager.dump_requests_before_crash()
|
|
kill_process_tree(os.getpid())
|
|
|
|
|
|
# Note: request abort handling logic
|
|
# We should handle all of the following cases correctly.
|
|
#
|
|
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
|
|
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
|
|
# | http | yes | validation | background task | fast api | del in _handle_abort_req |
|
|
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
|
|
# | http | yes | running | background task | fast api | del in _handle_batch_output |
|
|
# | http | no | validation | http exception | http exception | del in _handle_abort_req |
|
|
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
|
|
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
|
|
#
|