evalscope_v0.17.0/evalscope.0.17.0/evalscope/perf/benchmark.py

208 lines
7.7 KiB
Python

import asyncio
import json
import numpy as np
import platform
import sqlite3
import time
from http import HTTPStatus
from tqdm import tqdm
from typing import AsyncGenerator, Dict, List, Tuple
from evalscope.perf.arguments import Arguments
from evalscope.perf.http_client import AioHttpClient, test_connection
from evalscope.perf.plugin.registry import ApiRegistry, DatasetRegistry
from evalscope.perf.utils.benchmark_util import BenchmarkData, BenchmarkMetrics
from evalscope.perf.utils.db_util import create_result_table, get_result_db_path, insert_benchmark_data, summary_result
from evalscope.perf.utils.handler import add_signal_handlers, exception_handler
from evalscope.utils.logger import get_logger
logger = get_logger()
data_process_completed_event = asyncio.Event()
@exception_handler
async def get_requests(args: Arguments) -> AsyncGenerator[dict, None]:
query_generator_class = ApiRegistry(args.api)
query_generator = query_generator_class(args.tokenizer_path)
def load_prompt(prompt_path_or_text):
if prompt_path_or_text.startswith('@'):
with open(prompt_path_or_text[1:], 'r', encoding='utf-8') as file:
return file.read()
return prompt_path_or_text
async def generate_requests_from_prompt(messages):
request = query_generator.build_request(messages, args)
for _ in range(args.number):
yield request
async def generate_requests_from_dataset():
message_generator_class = DatasetRegistry(args.dataset)
message_generator = message_generator_class(args)
dataset_messages = []
try:
for messages in message_generator:
dataset_messages.append(messages)
except StopIteration:
pass
if not dataset_messages:
raise Exception('Dataset is empty!')
count = 0
dataset_index = 0
while count < args.number:
messages = dataset_messages[dataset_index]
request = query_generator.build_request(messages, args)
if request is not None:
yield request
count += 1
dataset_index = (dataset_index + 1) % len(dataset_messages)
if args.prompt:
prompt = load_prompt(args.prompt)
messages = [{'role': 'user', 'content': prompt}] if args.apply_chat_template else prompt
generator = generate_requests_from_prompt(messages)
elif args.dataset:
generator = generate_requests_from_dataset()
else:
raise Exception('Either prompt or dataset is required!')
async for request in generator:
yield request
if args.rate != -1:
interval = np.random.exponential(1.0 / args.rate)
await asyncio.sleep(interval)
@exception_handler
async def send_request(
semaphore: asyncio.Semaphore,
request: dict,
benchmark_data_queue: asyncio.Queue,
args: Arguments,
):
async with semaphore:
client = AioHttpClient(args)
async with client:
benchmark_data = BenchmarkData(request=request)
benchmark_data.start_time = time.perf_counter()
collected_messages = []
try:
async for is_error, state_code, response_data in client.post(request):
if is_error or state_code != HTTPStatus.OK:
logger.error(f'Request: {request} failed, state_code: {state_code}, data: {response_data}')
benchmark_data.success = False
break
if response_data:
collected_messages.append(response_data)
benchmark_data.chunk_times.append(time.perf_counter())
benchmark_data.success = True
benchmark_data.update_gpu_usage()
except Exception as e:
if response_data:
collected_messages.append(response_data)
benchmark_data.success = False
logger.exception(e)
logger.error(f'Request query: {request} exception')
finally:
benchmark_data.completed_time = time.perf_counter()
benchmark_data.response_messages = collected_messages
await benchmark_data_queue.put(benchmark_data)
@exception_handler
async def statistic_benchmark_metric(benchmark_data_queue: asyncio.Queue, args: Arguments):
metrics = BenchmarkMetrics(concurrency=args.parallel)
api_plugin_class = ApiRegistry(args.api)
api_plugin = api_plugin_class(args.tokenizer_path)
result_db_path = get_result_db_path(args)
collected_benchmark_data = []
with tqdm(desc='Processing', total=args.number) as pbar:
while not (data_process_completed_event.is_set() and benchmark_data_queue.empty()):
try:
# Attempt to get benchmark data from the queue with a timeout
benchmark_data = await asyncio.wait_for(benchmark_data_queue.get(), timeout=0.01)
benchmark_data_queue.task_done()
except asyncio.TimeoutError:
# If timeout, continue to the next iteration
continue
# Update metrics based on the benchmark data
metrics.update_metrics(benchmark_data, api_plugin)
# Collect benchmark data for later database insertion
collected_benchmark_data.append(benchmark_data)
# Create a message with the updated metrics
message = metrics.create_message()
# Log the message to wandb\swanlab if the api key is provided
if args.wandb_api_key:
import wandb
wandb.log(message)
if args.swanlab_api_key:
import swanlab
swanlab.log(message)
# Log the message to the logger every n queries
if int(metrics.n_total_queries) % args.log_every_n_query == 0:
msg = json.dumps(message, ensure_ascii=False, indent=2)
logger.info(msg)
pbar.update(1) # Update the progress bar
# Now perform database operations after all benchmark data has been processed
with sqlite3.connect(result_db_path) as con:
cursor = con.cursor()
create_result_table(cursor)
for benchmark_data in collected_benchmark_data:
insert_benchmark_data(cursor, benchmark_data)
con.commit()
return metrics, result_db_path
@exception_handler
async def connect_test(args: Arguments) -> bool:
if (not args.no_test_connection) and (not await test_connection(args)):
raise TimeoutError('Test connection failed')
@exception_handler
async def benchmark(args: Arguments) -> Tuple[Dict, Dict]:
if platform.system() != 'Windows':
loop = asyncio.get_running_loop()
add_signal_handlers(loop)
# init queue
benchmark_data_queue = asyncio.Queue()
# reset event
data_process_completed_event.clear()
# test connection
await connect_test(args)
# start statistic benchmark metric
statistic_benchmark_metric_task = asyncio.create_task(statistic_benchmark_metric(benchmark_data_queue, args))
# start send request
semaphore = asyncio.Semaphore(args.parallel)
send_request_tasks: List[asyncio.Task] = []
async for request in get_requests(args):
task = asyncio.create_task(send_request(semaphore, request, benchmark_data_queue, args))
send_request_tasks.append(task)
await asyncio.gather(*send_request_tasks, return_exceptions=True)
await benchmark_data_queue.join()
data_process_completed_event.set()
metrics, result_db_path = await statistic_benchmark_metric_task
metrics_result, percentile_result = summary_result(args, metrics, result_db_path)
return metrics_result, percentile_result