evalscope_v0.17.0/evalscope.0.17.0/evalscope/models/model.py

190 lines
6.2 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import time
from abc import ABC, abstractmethod
from typing import Any, List
from evalscope.utils.logger import get_logger
logger = get_logger()
class BaseModel(ABC):
def __init__(self, model_cfg: dict, **kwargs):
"""
Base model class.
Args:
model_cfg (dict): The model configuration. Depending on the specific model. Example:
{'model_id': 'modelscope/Llama-2-7b-chat-ms', 'revision': 'v1.0.0'}
**kwargs: kwargs
"""
self.model_cfg: dict = model_cfg
self.kwargs = kwargs
@abstractmethod
def predict(self, *args, **kwargs) -> Any:
"""
Model prediction func.
"""
raise NotImplementedError
class ChatBaseModel(BaseModel):
def __init__(self, model_cfg: dict, **kwargs):
"""
Chat base model class. Depending on the specific model.
Args:
model_cfg (dict):
{'model_id': 'modelscope/Llama-2-7b-chat-ms', 'revision': 'v1.0.0', 'device_map': 'auto'}
**kwargs: kwargs
"""
super(ChatBaseModel, self).__init__(model_cfg=model_cfg, **kwargs)
@abstractmethod
def predict(self, inputs: dict, **kwargs) -> dict:
"""
Model prediction func. The inputs and outputs are compatible with OpenAI Chat Completions APIs.
Refer to: https://platform.openai.com/docs/guides/gpt/chat-completions-api
# TODO: follow latest OpenAI API
Args:
inputs (dict): The input prompts and history. Input format:
{'messages': [
{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': 'Who won the world series in 2020?'},
{'role': 'assistant', 'content': 'The Los Angeles Dodgers won the World Series in 2020.'},
]
'history': [
{'role': 'system', 'content': 'Hello'},
{'role': 'user', 'content': 'Hi'}]
}
kwargs (dict): Could be inference configuration. Default: None.
cfg format: {'max_length': 1024}
Returns: The result format:
{
'choices': [
{
'index': 0,
'message': {
'content': 'The 2020 World Series was played in Texas at Globe Life Field in Arlington.',
'role': 'assistant'
}
}
],
'created': 1677664795,
# For models on the ModelScope or HuggingFace, concat model_id and revision with "-".
'model': 'gpt-3.5-turbo-0613',
'object': 'chat.completion',
'usage': {
'completion_tokens': 17,
'prompt_tokens': 57,
'total_tokens': 74
}
}
"""
raise NotImplementedError
# TODO: Remove this class after refactoring all models
class OpenAIModel(ChatBaseModel):
"""
APIs of OpenAI models.
Available models: gpt-3.5-turbo, gpt-4
"""
MAX_RETRIES = 3
def __init__(self, model_cfg: dict, **kwargs):
super(OpenAIModel, self).__init__(model_cfg=model_cfg, **kwargs)
openai_api_key = os.environ.get('OPENAI_API_KEY', None)
self.api_key = self.model_cfg.get('api_key', openai_api_key)
if not self.api_key:
logger.error('OpenAI API key is not provided, please set it in environment variable OPENAI_API_KEY')
# raise ValueError(
# 'OpenAI API key is not provided, '
# 'please set it in environment variable OPENAI_API_KEY')
def predict(self, model_id: str, inputs: dict, **kwargs) -> dict:
sys_prompt: str = inputs.get('sys_prompt', '')
user_prompt: str = inputs.get('user_prompt', '')
# model_id: str = kwargs.get('model_id', '')
temperature: float = kwargs.pop('temperature', 0.2)
max_tokens: int = kwargs.pop('max_tokens', 1024)
mode: str = kwargs.pop('mode', 'chat.completion')
logger.info(f'Using OpenAI model_id: {model_id}')
res = self._predict(
model_id=model_id,
sys_prompt=sys_prompt,
user_prompt=user_prompt,
temperature=temperature,
max_tokens=max_tokens,
mode=mode)
return res
def _predict(
self,
model_id,
sys_prompt,
user_prompt,
temperature,
max_tokens,
mode: str = 'chat.completion',
) -> dict:
import openai
res = {}
openai.api_key = self.api_key
for i in range(self.MAX_RETRIES):
try:
if mode == 'chat.completion':
resp = openai.ChatCompletion.create(
model=model_id,
messages=[{
'role': 'system',
'content': sys_prompt
}, {
'role': 'user',
'content': user_prompt
}],
temperature=temperature,
max_tokens=max_tokens)
if resp:
ans_text = resp['choices'][0]['message']['content']
model_id = resp['model']
else:
logger.warning(f'OpenAI GPT API call failed: got empty response '
f'for input {sys_prompt} {user_prompt}')
ans_text = ''
model_id = ''
res['ans_text'] = ans_text
res['model_id'] = model_id
else:
raise ValueError(f'Invalid mode: {mode}')
return res
except Exception as e:
logger.warning(f'OpenAI API call failed: {e}')
time.sleep(3)
logger.error(f'OpenAI API call failed after {self.MAX_RETRIES} retries')
return res