sglang0.4.5.post1/python/sglang/srt/managers/utils.py

50 lines
1.6 KiB
Python

import json
import logging
import time
from collections import defaultdict
from http import HTTPStatus
from typing import Dict, List, Optional, Tuple
import torch
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
logger = logging.getLogger(__name__)
def validate_input_length(
req: Req, max_req_input_len: int, allow_auto_truncate: bool
) -> Optional[str]:
"""Validate and potentially truncate input length.
Args:
req: The request containing input_ids to validate
max_req_input_len: Maximum allowed input length
allow_auto_truncate: Whether to truncate long inputs
Returns:
Error message if validation fails, None if successful
"""
if len(req.origin_input_ids) >= max_req_input_len:
if allow_auto_truncate:
logger.warning(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated. "
f"{len(req.origin_input_ids)=}, {max_req_input_len=}."
)
req.origin_input_ids = req.origin_input_ids[:max_req_input_len]
return None
else:
error_msg = (
f"Input length ({len(req.origin_input_ids)} tokens) exceeds "
f"the maximum allowed length ({max_req_input_len} tokens). "
f"Use a shorter input or enable --allow-auto-truncate."
)
logger.error(error_msg)
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)
return error_msg
return None