evalscope_v0.17.0/evalscope.0.17.0/evalscope/models/adapters/chat_adapter.py

205 lines
7.7 KiB
Python

import os
import time
import torch
from typing import Any, Dict, List, Optional, Tuple, Union
from evalscope.utils.chat_service import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, Usage
from evalscope.utils.logger import get_logger
from evalscope.utils.model_utils import fix_do_sample_warning
from ..local_model import LocalModel
from .base_adapter import BaseModelAdapter
logger = get_logger()
class ChatGenerationModelAdapter(BaseModelAdapter):
"""
Chat generation model adapter.
"""
def __init__(self, model: LocalModel, **kwargs):
super().__init__(model)
self.generation_config = self._parse_generation_config(self.tokenizer, self.model)
custom_generation_config = kwargs.pop('generation_config', None)
custom_chat_template = kwargs.pop('chat_template', None)
if custom_generation_config:
logger.info('Updating generation config ...')
self.generation_config.update(**custom_generation_config)
if custom_chat_template:
self.tokenizer.chat_template = custom_chat_template
logger.info(f'Using custom chat template: {custom_chat_template}')
def _parse_generation_config(self, tokenizer, model):
from modelscope import GenerationConfig
generation_config = getattr(model, 'generation_config', GenerationConfig(do_sample=False))
try:
remote_config = GenerationConfig.from_pretrained(
self.model_id, revision=self.model_revision, trust_remote_code=True)
generation_config.update(**remote_config.to_dict())
except Exception:
logger.warning(f'Failed to get generation config of {self.model_id} from model hub, use default.')
if isinstance(self.model_id, str) and os.path.exists(self.model_id):
logger.warning(f'Got local model dir: {self.model_id}')
if tokenizer.eos_token_id is not None:
generation_config.eos_token_id = tokenizer.eos_token_id
if tokenizer.pad_token_id is not None:
generation_config.pad_token_id = tokenizer.pad_token_id
if generation_config.max_new_tokens is None:
generation_config.max_new_tokens = 2048
return generation_config
def _model_generate(self,
formatted_prompts: List[str],
infer_cfg: Dict[str, Any] = None) -> Tuple[List[List[str]], List[int]]:
"""
Args:
formatted_prompts: The formatted prompts.
infer_cfg: The inference configuration.
Returns:
The prediction results.
"""
if infer_cfg is None:
infer_cfg = {}
# Process infer_cfg
num_return_sequences = infer_cfg.get('num_return_sequences', 1)
if num_return_sequences > 1:
infer_cfg['do_sample'] = True
# stop settings
stop = infer_cfg.get('stop', [])
if stop:
eos_token_id = self.tokenizer.encode(stop, add_special_tokens=False)[0]
else:
eos_token_id = self.tokenizer.eos_token_id
if eos_token_id is not None:
infer_cfg['eos_token_id'] = eos_token_id
self.generation_config.update(**infer_cfg)
fix_do_sample_warning(self.generation_config)
# Get input ids
inputs = self.tokenizer(
formatted_prompts, return_tensors='pt', padding=True, truncation=True,
padding_side='left').to(self.model.device) # padding_side='left' is important for chat model
input_ids = inputs['input_ids']
# Run inference
output_ids = self.model.generate(**inputs, generation_config=self.generation_config)
# Decode output
responses = []
input_lengths = [len(self.tokenizer.encode(prompt)) for prompt in formatted_prompts]
for i in range(0, len(output_ids), num_return_sequences):
query_responses = []
for j in range(num_return_sequences):
output = output_ids[i + j]
response = self.tokenizer.decode(
output[len(input_ids[i // num_return_sequences]):], skip_special_tokens=True)
query_responses.append(response)
responses.append(query_responses)
return responses, input_lengths
def _prepare_inputs(self, inputs: List[dict], infer_cfg: dict = {}) -> List[str]:
"""
Prepare the inputs for the model.
Args:
inputs: The input data.
infer_cfg: The inference configuration.
Returns:
The prepared inputs and system prompts.
"""
queries = []
system_prompts = []
message_list = []
for input_item in inputs:
queries.append(input_item['data'][0])
system_prompts.append(input_item.get('system_prompt', None))
if input_item.get('messages', None):
message_list.append(input_item.get('messages', None))
# For non chat model, use the original queries as the input
if self.tokenizer.chat_template is None:
return queries
# For chat model, use the messages as the input
# if message_list is None, use the queries as the input
if len(message_list) == 0:
for i, query in enumerate(queries):
messages = [ChatMessage(role='user', content=query)]
if i < len(system_prompts) and system_prompts[i]:
messages = [ChatMessage(role='system', content=system_prompts[i])] + messages
message_list.append(messages)
# Format the messages
formatted_prompts = []
for messages in message_list:
# apply chat template
chat_template_kwargs = infer_cfg.get('chat_template_kwargs', None)
if chat_template_kwargs is not None:
prompts = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, **chat_template_kwargs)
else:
prompts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
formatted_prompts.append(prompts)
logger.debug(f'formatted_prompts: {formatted_prompts}')
return formatted_prompts
@torch.no_grad()
def predict(self, inputs: List[dict], infer_cfg: Optional[dict] = {}) -> List[dict]:
"""
Args:
inputs: The input data.
infer_cfg: The inference configuration.
Returns:
The prediction results.
"""
# Process inputs
formatted_prompts = self._prepare_inputs(inputs, infer_cfg)
# Run inference
responses, input_lengths = self._model_generate(formatted_prompts, infer_cfg)
# Process outputs
results = []
for response, input_length in zip(responses, input_lengths):
choices_list = []
completion_tokens = 0
for index, one_response in enumerate(response):
choice = ChatCompletionResponseChoice(
index=index, message=ChatMessage(content=one_response, role='assistant'), finish_reason='stop')
choices_list.append(choice)
completion_tokens += len(self.tokenizer.encode(one_response))
usage = Usage(
prompt_tokens=input_length,
completion_tokens=completion_tokens,
total_tokens=input_length + completion_tokens)
res_d = ChatCompletionResponse(
model=self.model_id,
choices=choices_list,
object='chat.completion',
created=int(time.time()),
usage=usage).model_dump(exclude_unset=True)
results.append(res_d)
return results