50 lines
1.6 KiB
Python
50 lines
1.6 KiB
Python
import json
|
|
import logging
|
|
import time
|
|
from collections import defaultdict
|
|
from http import HTTPStatus
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def validate_input_length(
|
|
req: Req, max_req_input_len: int, allow_auto_truncate: bool
|
|
) -> Optional[str]:
|
|
"""Validate and potentially truncate input length.
|
|
|
|
Args:
|
|
req: The request containing input_ids to validate
|
|
max_req_input_len: Maximum allowed input length
|
|
allow_auto_truncate: Whether to truncate long inputs
|
|
|
|
Returns:
|
|
Error message if validation fails, None if successful
|
|
"""
|
|
if len(req.origin_input_ids) >= max_req_input_len:
|
|
if allow_auto_truncate:
|
|
logger.warning(
|
|
"Request length is longer than the KV cache pool size or "
|
|
"the max context length. Truncated. "
|
|
f"{len(req.origin_input_ids)=}, {max_req_input_len=}."
|
|
)
|
|
req.origin_input_ids = req.origin_input_ids[:max_req_input_len]
|
|
return None
|
|
else:
|
|
error_msg = (
|
|
f"Input length ({len(req.origin_input_ids)} tokens) exceeds "
|
|
f"the maximum allowed length ({max_req_input_len} tokens). "
|
|
f"Use a shorter input or enable --allow-auto-truncate."
|
|
)
|
|
logger.error(error_msg)
|
|
req.finished_reason = FINISH_ABORT(
|
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
|
)
|
|
return error_msg
|
|
|
|
return None
|