205 lines
7.7 KiB
Python
205 lines
7.7 KiB
Python
import os
|
|
import time
|
|
import torch
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
from evalscope.utils.chat_service import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, Usage
|
|
from evalscope.utils.logger import get_logger
|
|
from evalscope.utils.model_utils import fix_do_sample_warning
|
|
from ..local_model import LocalModel
|
|
from .base_adapter import BaseModelAdapter
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
class ChatGenerationModelAdapter(BaseModelAdapter):
|
|
"""
|
|
Chat generation model adapter.
|
|
"""
|
|
|
|
def __init__(self, model: LocalModel, **kwargs):
|
|
super().__init__(model)
|
|
|
|
self.generation_config = self._parse_generation_config(self.tokenizer, self.model)
|
|
|
|
custom_generation_config = kwargs.pop('generation_config', None)
|
|
custom_chat_template = kwargs.pop('chat_template', None)
|
|
|
|
if custom_generation_config:
|
|
logger.info('Updating generation config ...')
|
|
self.generation_config.update(**custom_generation_config)
|
|
|
|
if custom_chat_template:
|
|
self.tokenizer.chat_template = custom_chat_template
|
|
logger.info(f'Using custom chat template: {custom_chat_template}')
|
|
|
|
def _parse_generation_config(self, tokenizer, model):
|
|
from modelscope import GenerationConfig
|
|
|
|
generation_config = getattr(model, 'generation_config', GenerationConfig(do_sample=False))
|
|
|
|
try:
|
|
remote_config = GenerationConfig.from_pretrained(
|
|
self.model_id, revision=self.model_revision, trust_remote_code=True)
|
|
generation_config.update(**remote_config.to_dict())
|
|
except Exception:
|
|
logger.warning(f'Failed to get generation config of {self.model_id} from model hub, use default.')
|
|
|
|
if isinstance(self.model_id, str) and os.path.exists(self.model_id):
|
|
logger.warning(f'Got local model dir: {self.model_id}')
|
|
|
|
if tokenizer.eos_token_id is not None:
|
|
generation_config.eos_token_id = tokenizer.eos_token_id
|
|
if tokenizer.pad_token_id is not None:
|
|
generation_config.pad_token_id = tokenizer.pad_token_id
|
|
if generation_config.max_new_tokens is None:
|
|
generation_config.max_new_tokens = 2048
|
|
|
|
return generation_config
|
|
|
|
def _model_generate(self,
|
|
formatted_prompts: List[str],
|
|
infer_cfg: Dict[str, Any] = None) -> Tuple[List[List[str]], List[int]]:
|
|
"""
|
|
Args:
|
|
formatted_prompts: The formatted prompts.
|
|
infer_cfg: The inference configuration.
|
|
Returns:
|
|
The prediction results.
|
|
"""
|
|
if infer_cfg is None:
|
|
infer_cfg = {}
|
|
|
|
# Process infer_cfg
|
|
num_return_sequences = infer_cfg.get('num_return_sequences', 1)
|
|
if num_return_sequences > 1:
|
|
infer_cfg['do_sample'] = True
|
|
|
|
# stop settings
|
|
stop = infer_cfg.get('stop', [])
|
|
if stop:
|
|
eos_token_id = self.tokenizer.encode(stop, add_special_tokens=False)[0]
|
|
else:
|
|
eos_token_id = self.tokenizer.eos_token_id
|
|
|
|
if eos_token_id is not None:
|
|
infer_cfg['eos_token_id'] = eos_token_id
|
|
|
|
self.generation_config.update(**infer_cfg)
|
|
fix_do_sample_warning(self.generation_config)
|
|
|
|
# Get input ids
|
|
inputs = self.tokenizer(
|
|
formatted_prompts, return_tensors='pt', padding=True, truncation=True,
|
|
padding_side='left').to(self.model.device) # padding_side='left' is important for chat model
|
|
input_ids = inputs['input_ids']
|
|
|
|
# Run inference
|
|
output_ids = self.model.generate(**inputs, generation_config=self.generation_config)
|
|
|
|
# Decode output
|
|
responses = []
|
|
input_lengths = [len(self.tokenizer.encode(prompt)) for prompt in formatted_prompts]
|
|
for i in range(0, len(output_ids), num_return_sequences):
|
|
query_responses = []
|
|
for j in range(num_return_sequences):
|
|
output = output_ids[i + j]
|
|
response = self.tokenizer.decode(
|
|
output[len(input_ids[i // num_return_sequences]):], skip_special_tokens=True)
|
|
query_responses.append(response)
|
|
responses.append(query_responses)
|
|
|
|
return responses, input_lengths
|
|
|
|
def _prepare_inputs(self, inputs: List[dict], infer_cfg: dict = {}) -> List[str]:
|
|
"""
|
|
Prepare the inputs for the model.
|
|
Args:
|
|
inputs: The input data.
|
|
infer_cfg: The inference configuration.
|
|
Returns:
|
|
The prepared inputs and system prompts.
|
|
"""
|
|
queries = []
|
|
system_prompts = []
|
|
message_list = []
|
|
|
|
for input_item in inputs:
|
|
queries.append(input_item['data'][0])
|
|
system_prompts.append(input_item.get('system_prompt', None))
|
|
if input_item.get('messages', None):
|
|
message_list.append(input_item.get('messages', None))
|
|
|
|
# For non chat model, use the original queries as the input
|
|
if self.tokenizer.chat_template is None:
|
|
return queries
|
|
|
|
# For chat model, use the messages as the input
|
|
# if message_list is None, use the queries as the input
|
|
if len(message_list) == 0:
|
|
for i, query in enumerate(queries):
|
|
messages = [ChatMessage(role='user', content=query)]
|
|
if i < len(system_prompts) and system_prompts[i]:
|
|
messages = [ChatMessage(role='system', content=system_prompts[i])] + messages
|
|
message_list.append(messages)
|
|
|
|
# Format the messages
|
|
formatted_prompts = []
|
|
for messages in message_list:
|
|
# apply chat template
|
|
chat_template_kwargs = infer_cfg.get('chat_template_kwargs', None)
|
|
if chat_template_kwargs is not None:
|
|
prompts = self.tokenizer.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True, **chat_template_kwargs)
|
|
else:
|
|
prompts = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
formatted_prompts.append(prompts)
|
|
|
|
logger.debug(f'formatted_prompts: {formatted_prompts}')
|
|
return formatted_prompts
|
|
|
|
@torch.no_grad()
|
|
def predict(self, inputs: List[dict], infer_cfg: Optional[dict] = {}) -> List[dict]:
|
|
"""
|
|
Args:
|
|
inputs: The input data.
|
|
infer_cfg: The inference configuration.
|
|
Returns:
|
|
The prediction results.
|
|
"""
|
|
|
|
# Process inputs
|
|
formatted_prompts = self._prepare_inputs(inputs, infer_cfg)
|
|
|
|
# Run inference
|
|
responses, input_lengths = self._model_generate(formatted_prompts, infer_cfg)
|
|
|
|
# Process outputs
|
|
results = []
|
|
for response, input_length in zip(responses, input_lengths):
|
|
choices_list = []
|
|
completion_tokens = 0
|
|
|
|
for index, one_response in enumerate(response):
|
|
choice = ChatCompletionResponseChoice(
|
|
index=index, message=ChatMessage(content=one_response, role='assistant'), finish_reason='stop')
|
|
choices_list.append(choice)
|
|
|
|
completion_tokens += len(self.tokenizer.encode(one_response))
|
|
|
|
usage = Usage(
|
|
prompt_tokens=input_length,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=input_length + completion_tokens)
|
|
|
|
res_d = ChatCompletionResponse(
|
|
model=self.model_id,
|
|
choices=choices_list,
|
|
object='chat.completion',
|
|
created=int(time.time()),
|
|
usage=usage).model_dump(exclude_unset=True)
|
|
|
|
results.append(res_d)
|
|
|
|
return results
|