evalscope_v0.17.0/evalscope.0.17.0/evalscope/config.py

174 lines
6.6 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import os
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
from evalscope.constants import (DEFAULT_DATASET_CACHE_DIR, DEFAULT_WORK_DIR, EvalBackend, EvalStage, EvalType, HubType,
JudgeStrategy, ModelTask, OutputType)
from evalscope.models import CustomModel, DummyCustomModel
from evalscope.utils.argument_utils import BaseArgument, parse_int_or_float
from evalscope.utils.io_utils import dict_to_yaml, gen_hash
from evalscope.utils.logger import get_logger
logger = get_logger()
@dataclass
class TaskConfig(BaseArgument):
# Model-related arguments
model: Union[str, 'CustomModel', None] = None
model_id: Optional[str] = None
model_args: Dict = field(default_factory=dict)
model_task: str = ModelTask.TEXT_GENERATION
# Template-related arguments
template_type: Optional[str] = None # Deprecated, will be removed in v1.0.0.
chat_template: Optional[str] = None
# Dataset-related arguments
datasets: List[str] = field(default_factory=list)
dataset_args: Dict = field(default_factory=dict)
dataset_dir: str = DEFAULT_DATASET_CACHE_DIR
dataset_hub: str = HubType.MODELSCOPE
# Generation configuration arguments
generation_config: Dict = field(default_factory=dict)
# Evaluation-related arguments
eval_type: str = EvalType.CHECKPOINT
eval_backend: str = EvalBackend.NATIVE
eval_config: Union[str, Dict, None] = None
stage: str = EvalStage.ALL
limit: Optional[Union[int, float]] = None
eval_batch_size: Optional[int] = None
# Cache and working directory arguments
mem_cache: bool = False # Deprecated, will be removed in v1.0.0.
use_cache: Optional[str] = None
work_dir: str = DEFAULT_WORK_DIR
outputs: Optional[str] = None # Deprecated, will be removed in v1.0.0.
# Debug and runtime mode arguments
ignore_errors: bool = False
debug: bool = False
dry_run: bool = False
seed: Optional[int] = 42
api_url: Optional[str] = None # Only used for server model
api_key: Optional[str] = 'EMPTY' # Only used for server model
timeout: Optional[float] = None # Only used for server model
stream: bool = False # Only used for server model
# LLMJudge arguments
judge_strategy: str = JudgeStrategy.AUTO
judge_worker_num: int = 1
judge_model_args: Optional[Dict] = field(default_factory=dict)
analysis_report: bool = False
def __post_init__(self):
if self.model is None:
self.model = DummyCustomModel()
self.eval_type = EvalType.CUSTOM
if (not self.model_id) and self.model:
if isinstance(self.model, CustomModel):
self.model_id = self.model.config.get('model_id', 'custom_model')
else:
self.model_id = os.path.basename(self.model).rstrip(os.sep)
# fix path error, see http://github.com/modelscope/evalscope/issues/377
self.model_id = self.model_id.replace(':', '-')
# Set default eval_batch_size based on eval_type
if self.eval_batch_size is None:
self.eval_batch_size = 8 if self.eval_type == EvalType.SERVICE else 1
# Post process limit
if self.limit is not None:
self.limit = parse_int_or_float(self.limit)
# Set default generation_config and model_args
self.__init_default_generation_config()
self.__init_default_model_args()
def __init_default_generation_config(self):
if self.generation_config:
return
if self.model_task == ModelTask.IMAGE_GENERATION:
self.generation_config = {
'height': 1024,
'width': 1024,
'num_inference_steps': 50,
'guidance_scale': 9.0,
}
elif self.model_task == ModelTask.TEXT_GENERATION:
if self.eval_type == EvalType.CHECKPOINT:
self.generation_config = {
'max_length': 2048,
'max_new_tokens': 512,
'do_sample': False,
'top_k': 50,
'top_p': 1.0,
'temperature': 1.0,
}
elif self.eval_type == EvalType.SERVICE:
self.generation_config = {
'max_tokens': 2048,
'temperature': 0.0,
}
def __init_default_model_args(self):
if self.model_args:
return
if self.model_task == ModelTask.TEXT_GENERATION:
if self.eval_type == EvalType.CHECKPOINT:
self.model_args = {
'revision': 'master',
'precision': 'torch.float16',
}
def update(self, other: Union['TaskConfig', dict]):
if isinstance(other, TaskConfig):
other = other.to_dict()
self.__dict__.update(other)
def dump_yaml(self, output_dir: str):
"""Dump the task configuration to a YAML file."""
task_cfg_file = os.path.join(output_dir, f'task_config_{gen_hash(str(self), bits=6)}.yaml')
try:
logger.info(f'Dump task config to {task_cfg_file}')
dict_to_yaml(self.to_dict(), task_cfg_file)
except Exception as e:
logger.warning(f'Failed to dump overall task config: {e}')
def to_dict(self):
result = self.__dict__.copy()
if isinstance(self.model, CustomModel):
result['model'] = self.model.__class__.__name__
return result
def parse_task_config(task_cfg) -> TaskConfig:
"""Parse task configuration from various formats into a TaskConfig object."""
if isinstance(task_cfg, TaskConfig):
logger.info('Args: Task config is provided with TaskConfig type.')
elif isinstance(task_cfg, dict):
logger.info('Args: Task config is provided with dictionary type.')
task_cfg = TaskConfig.from_dict(task_cfg)
elif isinstance(task_cfg, Namespace):
logger.info('Args: Task config is provided with CommandLine type.')
task_cfg = TaskConfig.from_args(task_cfg)
elif isinstance(task_cfg, str):
extension = os.path.splitext(task_cfg)[-1]
logger.info(f'Args: Task config is provided with {extension} file type.')
if extension in ['.yaml', '.yml']:
task_cfg = TaskConfig.from_yaml(task_cfg)
elif extension == '.json':
task_cfg = TaskConfig.from_json(task_cfg)
else:
raise ValueError('Args: Unsupported file extension.')
else:
raise ValueError('Args: Please provide a valid task config.')
return task_cfg