evalscope_v0.17.0/evalscope.0.17.0/evalscope/third_party/toolbench_static/toolbench_static.py

53 lines
2.1 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from copy import deepcopy
from typing import Union
from evalscope.third_party.toolbench_static.eval import EvalArgs, run_eval
from evalscope.third_party.toolbench_static.infer import InferArgs, run_infer
from evalscope.utils import get_logger
from evalscope.utils.deprecation_utils import deprecated
from evalscope.utils.io_utils import json_to_dict, yaml_to_dict
logger = get_logger()
@deprecated(since='0.15.1', remove_in='0.18.0', alternative='Native implementation of ToolBench')
def run_task(task_cfg: Union[str, dict]):
if isinstance(task_cfg, str):
if task_cfg.endswith('.yaml'):
task_cfg: dict = yaml_to_dict(task_cfg)
elif task_cfg.endswith('.json'):
task_cfg: dict = json_to_dict(task_cfg)
else:
raise ValueError(f'Unsupported file format: {task_cfg}, should be yaml or json file.')
# Run inference for each domain
infer_args: dict = task_cfg['infer_args']
for domain in ['in_domain', 'out_of_domain']:
domain_infer_args = deepcopy(infer_args)
domain_infer_args.update({'data_path': os.path.join(infer_args['data_path'], f'{domain}.json')})
domain_infer_args.update({'output_dir': os.path.join(infer_args['output_dir'], domain)})
task_infer_args = InferArgs(**domain_infer_args)
print(f'**Run infer config: {task_infer_args}')
run_infer(task_infer_args)
# Run evaluation for each domain
eval_args: dict = task_cfg['eval_args']
for domain in ['in_domain', 'out_of_domain']:
domain_eval_args = deepcopy(eval_args)
domain_eval_args.update({'input_path': os.path.join(eval_args['input_path'], domain)})
domain_eval_args.update({'output_path': os.path.join(eval_args['output_path'], domain)})
task_eval_args = EvalArgs(**domain_eval_args)
print(f'**Run eval config: {task_eval_args}')
run_eval(task_eval_args)
if __name__ == '__main__':
# task_cfg_file = 'config_default.yaml'
task_cfg_file = 'config_default.json'
run_task(task_cfg=task_cfg_file)