import argparse import json import os import sys from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Union from evalscope.constants import DEFAULT_WORK_DIR from evalscope.utils import BaseArgument @dataclass class Arguments(BaseArgument): # Model and API model: str # Model name or path model_id: Optional[str] = None # Model identifier attn_implementation: Optional[str] = None # Attention implementaion, only for local inference api: str = 'openai' # API to be used (default: 'openai') tokenizer_path: Optional[str] = None # Path to the tokenizer port: int = 8877 # Port number for the local API server # Connection settings url: str = 'http://127.0.0.1:8877/v1/chat/completions' # URL for the API connection headers: Dict[str, Any] = field(default_factory=dict) # Custom headers connect_timeout: int = 600 # Connection timeout in seconds read_timeout: int = 600 # Read timeout in seconds api_key: Optional[str] = None no_test_connection: bool = False # Test the connection before starting the benchmark # Performance and parallelism number: Union[int, List[int]] = 1000 # Number of requests to be made parallel: Union[int, List[int]] = 1 # Number of parallel requests rate: int = -1 # Rate limit for requests (default: -1, no limit) # Logging and debugging log_every_n_query: int = 10 # Log every N queries debug: bool = False # Debug mode wandb_api_key: Optional[str] = None # WandB API key for logging swanlab_api_key: Optional[str] = None # SwanLab API key for logging name: Optional[str] = None # Name for the run # Output settings outputs_dir: str = DEFAULT_WORK_DIR # Prompt settings max_prompt_length: int = 131072 # Maximum length of the prompt min_prompt_length: int = 0 # Minimum length of the prompt prefix_length: int = 0 # Length of the prefix, only for random dataset prompt: Optional[str] = None # The prompt text query_template: Optional[str] = None # Template for the query apply_chat_template: Optional[bool] = None # Whether to apply chat template # Dataset settings dataset: str = 'openqa' # Dataset type (default: 'line_by_line') dataset_path: Optional[str] = None # Path to the dataset # Response settings frequency_penalty: Optional[float] = None # Frequency penalty for the response repetition_penalty: Optional[float] = None # Repetition penalty for the response logprobs: Optional[bool] = None # Whether to log probabilities max_tokens: Optional[int] = 2048 # Maximum number of tokens in the response min_tokens: Optional[int] = None # Minimum number of tokens in the response n_choices: Optional[int] = None # Number of response choices seed: Optional[int] = 0 # Random seed for reproducibility stop: Optional[List[str]] = None # Stop sequences for the response stop_token_ids: Optional[List[str]] = None # Stop token IDs for the response stream: Optional[bool] = True # Whether to stream the response temperature: float = 0.0 # Temperature setting for the response top_p: Optional[float] = None # Top-p (nucleus) sampling setting for the response top_k: Optional[int] = None # Top-k sampling setting for the response extra_args: Optional[Dict[str, Any]] = None # Extra arguments def __post_init__(self): # Set the default headers self.headers = self.headers or {} # Default to empty dictionary if self.api_key: # Assuming the API key is used as a Bearer token self.headers['Authorization'] = f'Bearer {self.api_key}' # Set the model ID based on the model name self.model_id = os.path.basename(self.model) # Set the URL based on the dataset type if self.api.startswith('local'): if self.dataset.startswith('speed_benchmark'): self.url = f'http://127.0.0.1:{self.port}/v1/completions' else: self.url = f'http://127.0.0.1:{self.port}/v1/chat/completions' # Set the apply_chat_template flag based on the URL if self.apply_chat_template is None: self.apply_chat_template = self.url.strip('/').endswith('chat/completions') # Set number and parallel to lists if they are integers if isinstance(self.number, int): self.number = [self.number] if isinstance(self.parallel, int): self.parallel = [self.parallel] assert len(self.number) == len( self.parallel ), f'The length of number and parallel should be the same, but got number: {self.number} and parallel: {self.parallel}' # noqa: E501 class ParseKVAction(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): if not values: setattr(namespace, self.dest, {}) else: try: kv_dict = {} for kv in values: parts = kv.split('=', 1) # only split the first '=' if len(parts) != 2: raise ValueError(f'Invalid key-value pair: {kv}') key, value = parts kv_dict[key.strip()] = value.strip() setattr(namespace, self.dest, kv_dict) except ValueError as e: parser.error(f'Error parsing key-value pairs: {e}') def add_argument(parser: argparse.ArgumentParser): # yapf: disable # Model and API parser.add_argument('--model', type=str, required=True, help='The test model name.') parser.add_argument('--attn-implementation', required=False, default=None, help='Attention implementaion') parser.add_argument('--api', type=str, default='openai', help='Specify the service API') parser.add_argument( '--tokenizer-path', type=str, required=False, default=None, help='Specify the tokenizer weight path') # Connection settings parser.add_argument('--url', type=str, default='http://127.0.0.1:8877/v1/chat/completions') parser.add_argument('--port', type=int, default=8877, help='The port for local inference') parser.add_argument('--headers', nargs='+', dest='headers', action=ParseKVAction, help='Extra HTTP headers') parser.add_argument('--api-key', type=str, required=False, default=None, help='The API key for authentication') parser.add_argument('--connect-timeout', type=int, default=600, help='The network connection timeout') parser.add_argument('--read-timeout', type=int, default=600, help='The network read timeout') parser.add_argument('--no-test-connection', action='store_false', default=False, help='Do not test the connection before starting the benchmark') # noqa: E501 # Performance and parallelism parser.add_argument('-n', '--number', type=int, default=1000, nargs='+', help='How many requests to be made') parser.add_argument('--parallel', type=int, default=1, nargs='+', help='Set number of concurrency requests, default 1') # noqa: E501 parser.add_argument('--rate', type=int, default=-1, help='Number of requests per second. default None') # Logging and debugging parser.add_argument('--log-every-n-query', type=int, default=10, help='Logging every n query') parser.add_argument('--debug', action='store_true', default=False, help='Debug request send') parser.add_argument('--wandb-api-key', type=str, default=None, help='The wandb API key') parser.add_argument('--swanlab-api-key', type=str, default=None, help='The swanlab API key') parser.add_argument('--name', type=str, help='The wandb/swanlab db result name and result db name') # Prompt settings parser.add_argument('--max-prompt-length', type=int, default=sys.maxsize, help='Maximum input prompt length') parser.add_argument('--min-prompt-length', type=int, default=0, help='Minimum input prompt length') parser.add_argument('--prefix-length', type=int, default=0, help='The prefix length') parser.add_argument('--prompt', type=str, required=False, default=None, help='Specified the request prompt') parser.add_argument('--query-template', type=str, default=None, help='Specify the query template') parser.add_argument( '--apply-chat-template', type=argparse.BooleanOptionalAction, default=None, help='Apply chat template to the prompt') # noqa: E501 # Output settings parser.add_argument('--outputs-dir', help='Outputs dir.', default='outputs') # Dataset settings parser.add_argument('--dataset', type=str, default='openqa', help='Specify the dataset') parser.add_argument('--dataset-path', type=str, required=False, help='Path to the dataset file') # Response settings parser.add_argument('--frequency-penalty', type=float, help='The frequency_penalty value', default=None) parser.add_argument('--repetition-penalty', type=float, help='The repetition_penalty value', default=None) parser.add_argument('--logprobs', action='store_true', help='The logprobs', default=None) parser.add_argument( '--max-tokens', type=int, help='The maximum number of tokens that can be generated', default=2048) parser.add_argument( '--min-tokens', type=int, help='The minimum number of tokens that can be generated', default=None) parser.add_argument('--n-choices', type=int, help='How many completion choices to generate', default=None) parser.add_argument('--seed', type=int, help='The random seed', default=0) parser.add_argument('--stop', nargs='*', help='The stop tokens', default=None) parser.add_argument('--stop-token-ids', nargs='*', help='Set the stop token IDs', default=None) parser.add_argument('--stream', action=argparse.BooleanOptionalAction, help='Stream output with SSE', default=True) parser.add_argument('--temperature', type=float, help='The sample temperature', default=0.0) parser.add_argument('--top-p', type=float, help='Sampling top p', default=None) parser.add_argument('--top-k', type=int, help='Sampling top k', default=None) parser.add_argument('--extra-args', type=json.loads, default='{}', help='Extra arguments, should in JSON format',) # yapf: enable def parse_args(): parser = argparse.ArgumentParser(description='Benchmark LLM service performance.') add_argument(parser) return parser.parse_args()