import json import logging import time from collections import defaultdict from http import HTTPStatus from typing import Dict, List, Optional, Tuple import torch from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req logger = logging.getLogger(__name__) def validate_input_length( req: Req, max_req_input_len: int, allow_auto_truncate: bool ) -> Optional[str]: """Validate and potentially truncate input length. Args: req: The request containing input_ids to validate max_req_input_len: Maximum allowed input length allow_auto_truncate: Whether to truncate long inputs Returns: Error message if validation fails, None if successful """ if len(req.origin_input_ids) >= max_req_input_len: if allow_auto_truncate: logger.warning( "Request length is longer than the KV cache pool size or " "the max context length. Truncated. " f"{len(req.origin_input_ids)=}, {max_req_input_len=}." ) req.origin_input_ids = req.origin_input_ids[:max_req_input_len] return None else: error_msg = ( f"Input length ({len(req.origin_input_ids)} tokens) exceeds " f"the maximum allowed length ({max_req_input_len} tokens). " f"Use a shorter input or enable --allow-auto-truncate." ) logger.error(error_msg) req.finished_reason = FINISH_ABORT( error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError" ) return error_msg return None