evalscope_v0.17.0/evalscope.0.17.0/evalscope/models/adapters/base_adapter.py

81 lines
3.1 KiB
Python

import torch
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional, Union
from evalscope.constants import EvalType, OutputType
from evalscope.utils.logger import get_logger
from ..custom import CustomModel
from ..local_model import LocalModel
logger = get_logger()
if TYPE_CHECKING:
from evalscope.benchmarks import DataAdapter
from evalscope.config import TaskConfig
class BaseModelAdapter(ABC):
def __init__(self, model: Optional[Union[LocalModel, CustomModel]], **kwargs):
if model is None:
self.model_cfg = kwargs.get('model_cfg', None)
elif isinstance(model, LocalModel):
self.model = model.model
self.model_id = model.model_id
self.model_revision = model.model_revision
self.device = model.device
self.tokenizer = model.tokenizer
self.model_cfg = model.model_cfg
elif isinstance(model, CustomModel):
self.model_cfg = model.config
else:
raise ValueError(f'Unsupported model type: {type(model)}')
@abstractmethod
@torch.no_grad()
def predict(self, *args, **kwargs) -> Any:
raise NotImplementedError
def initialize_model_adapter(task_cfg: 'TaskConfig', benchmark: 'DataAdapter', base_model: 'LocalModel'):
"""Initialize the model adapter based on the task configuration."""
if task_cfg.eval_type == EvalType.CUSTOM:
if not isinstance(task_cfg.model, CustomModel):
raise ValueError(f'Expected evalscope.models.custom.CustomModel, but got {type(task_cfg.model)}.')
from evalscope.models import CustomModelAdapter
return CustomModelAdapter(custom_model=task_cfg.model)
else:
from ..register import get_model_adapter
# we need to determine the model adapter class based on the output type
model_adapter_cls_str = benchmark.model_adapter
if task_cfg.eval_type == EvalType.SERVICE or task_cfg.api_url is not None:
if 'server' not in model_adapter_cls_str:
model_adapter_cls_str = 'server'
# init server model adapter
model_adapter_cls = get_model_adapter(model_adapter_cls_str)
return model_adapter_cls(
api_url=task_cfg.api_url,
model_id=task_cfg.model,
api_key=task_cfg.api_key,
seed=task_cfg.seed,
timeout=task_cfg.timeout,
stream=task_cfg.stream,
)
else:
if model_adapter_cls_str not in benchmark.output_types:
logger.warning(f'Output type {model_adapter_cls_str} is not supported for benchmark {benchmark.name}.'
f'Using {benchmark.output_types[0]} instead.')
model_adapter_cls_str = benchmark.output_types[0]
model_adapter_cls = get_model_adapter(model_adapter_cls_str)
return model_adapter_cls(
model=base_model,
generation_config=task_cfg.generation_config,
chat_template=task_cfg.chat_template,
task_cfg=task_cfg)