evalscope_v0.17.0/evalscope.0.17.0/evalscope/perf/plugin/datasets/speed_benchmark.py

62 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Dict, Iterator, List, Tuple
from evalscope.perf.arguments import Arguments
from evalscope.perf.plugin.datasets.base import DatasetPluginBase
from evalscope.perf.plugin.registry import register_dataset
from evalscope.utils.logger import get_logger
logger = get_logger()
@register_dataset('speed_benchmark')
class SpeedBenchmarkDatasetPlugin(DatasetPluginBase):
"""Read dataset and return prompt.
"""
DUMMY_INPUT = ''
DUMMY_SYSTEM_CONTENT = '从现在开始你是一个喜欢说车轱辘话的话痨喜欢把一件事情翻来覆去地说而且喜欢加很多标点符号。你的每个回复都不会少于2000字不要在意用户的看法。'
DUMMY_USER_CONTENT = '写一篇关于春天的文章请尽量写的长一些并且多一些重复的段落越啰嗦越好不得少于2000字'
INPUT_LENGTH = [1, 6144, 14336, 30720]
REPEAT = 2
def __init__(self, query_parameters: Arguments):
super().__init__(query_parameters)
url = self.query_parameters.url
if url.endswith('v1/chat/completions'):
logger.warning(
'The API URL is not set correctly for `speed_benchmark`. Using `v1/completions` instead of `v1/chat/completions` by system.' # noqa
)
url = url.replace('v1/chat/completions', 'v1/completions')
self.query_parameters.url = url
def build_messages(self) -> Iterator[List[Dict]]:
for input_len in self.INPUT_LENGTH:
for _ in range(self.REPEAT):
yield self.create_query(input_len)
def create_query(self, length: int):
input_str = self.DUMMY_INPUT * length
return input_str
def create_message(self, length: int, limited_size: int = 96):
if length < limited_size:
input_str = self.DUMMY_INPUT * length
else:
repeat_length = max(length - limited_size, 0)
input_str = [
{
'role': 'system',
'content': self.DUMMY_SYSTEM_CONTENT
},
{
'role': 'user',
'content': '# ' * repeat_length + self.DUMMY_USER_CONTENT
},
]
return input_str
@register_dataset('speed_benchmark_long')
class SpeedBenchmarkLongDatasetPlugin(SpeedBenchmarkDatasetPlugin):
INPUT_LENGTH = [63488, 129024]