evalscope_v0.17.0/evalscope.0.17.0/evalscope/third_party/longbench_write/infer.py

130 lines
4.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright (c) ZhipuAI, Inc. and its affiliates.
import json
import numpy as np
import os
import random
import torch
from typing import List
from evalscope.third_party.longbench_write.tools.openai_api import OpenaiApi
from evalscope.third_party.longbench_write.utils import count_words
from evalscope.utils import get_logger
logger = get_logger()
DEFAULT_PROC_NUM = 8
"""
This script is used to generate predictions for the LongWriter model.
Refer to https://github.com/THUDM/LongWriter for more details.
"""
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.cuda.manual_seed_all(seed)
def run_infer(model: str,
data_path: str,
output_dir: str,
api_config: dict,
generation_kwargs: dict = None,
enable: bool = True,
proc_num: int = DEFAULT_PROC_NUM):
"""
Process inference for LongWriter model.
Args:
model: The model id of the LongWriter model on ModelScope, or local model path.
data_path: The path to the data file.
output_dir: The output directory for the predictions.
api_config: The configuration for the OpenAI API inference.
Attributes:
`openai_api_key`: The OpenAI API key. Default is None for custom model serving.
`openai_api_base`: The OpenAI API base URL.
`is_chat`: Whether to chat. Default is True.
`verbose`: Whether to print verbose information. Default is False.
generation_kwargs: The generation arguments for the model.
Attributes: `max_new_tokens`: The maximum number of tokens to generate. `temperature`: The temperature
enable: Whether to run infer process.
proc_num: calling OpenAI api service with proc_num
"""
model_id_path: str = os.path.join(output_dir, model.strip(os.sep).replace(os.sep, '__'))
if not enable:
logger.warning('*** Skip `infer` stage ***')
return f'{model_id_path}/pred.jsonl'
seed_everything(42)
if generation_kwargs is None:
generation_kwargs = dict({
'max_new_tokens': 32768,
'temperature': 0.5,
'repetition_penalty': 1.0,
})
# Prepare inputs
logger.info(f'>>Input data path: {data_path}')
# TODO: add load data from ms
with open(data_path, encoding='utf-8') as f:
data_list = [json.loads(line) for line in f]
logger.info(f'Input example: {data_list[0]}')
api_client = OpenaiApi(model=model,
openai_api_key=None,
openai_api_base=api_config.get('openai_api_base',
'http://127.0.0.1:8000/v1/chat/completions'),
max_new_tokens=generation_kwargs.get('max_new_tokens', 4096),
temperature=generation_kwargs.get('temperature', 0.0),
repetition_penalty=generation_kwargs.get('repetition_penalty', 1.0),
is_chat=api_config.get('is_chat', True),
verbose=api_config.get('verbose', False),
)
# TODO: refine generate_simple
results: List[str] = api_client.generate_simple(inputs=[example['prompt'] for example in data_list],
num_proc=proc_num)
assert len(results) == len(data_list), \
f'Error: The number of predictions {len(results)} is not equal to the number of inputs {len(data_list)}.'
logger.info(f'Finish generating predictions with {len(data_list)} samples for {model}')
# Outputs
os.makedirs(model_id_path, exist_ok=True)
output_pred_file: str = f'{model_id_path}/pred.jsonl'
with open(output_pred_file, 'w', encoding='utf-8') as f:
for dt, res in zip(data_list, results):
dt['response_length'], _ = count_words(res)
dt['response'] = res
f.write(json.dumps(dt, ensure_ascii=False) + '\n')
logger.info(f'Predictions are saved in {output_pred_file}')
return output_pred_file
if __name__ == '__main__':
# ZhipuAI/LongWriter-glm4-9b, ZhipuAI/LongWriter-llama3.1-8b
api_config = dict(openai_api_key=None,
openai_api_base='http://127.0.0.1:8000/v1/chat/completions',
is_chat=True,
verbose=True,)
run_infer(model='ZhipuAI/LongWriter-glm4-9b',
data_path='resources/longbench_write.jsonl',
output_dir='outputs',
api_config=api_config,
generation_kwargs=dict({
'max_new_tokens': 32768,
'temperature': 0.5})
)