evalscope_v0.17.0/evalscope.0.17.0/evalscope/models/adapters/t2i_adapter.py

77 lines
2.6 KiB
Python

import os
import time
import torch
from typing import Any, Dict, List, Optional, Tuple, Union
from evalscope.utils.chat_service import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage
from evalscope.utils.io_utils import OutputsStructure
from evalscope.utils.logger import get_logger
from ..local_model import LocalModel
from .base_adapter import BaseModelAdapter
logger = get_logger()
class T2IModelAdapter(BaseModelAdapter):
"""
Text to image model adapter.
"""
def __init__(self, model: LocalModel, **kwargs):
super().__init__(model)
self.task_config = kwargs.get('task_cfg', None)
assert self.task_config is not None, 'Task config is required for T2I model adapter.'
self.save_path = os.path.join(self.task_config.work_dir, OutputsStructure.PREDICTIONS_DIR,
self.task_config.model_id, 'images')
os.makedirs(self.save_path, exist_ok=True)
def _model_generate(self, prompt, infer_cfg=None) -> List:
"""
Generate images from the model.
Args:
prompt: The input prompt.
infer_cfg: The inference configuration.
Returns:
The generated images.
"""
infer_cfg = infer_cfg or {}
sample = self.model(prompt=prompt, **infer_cfg).images
return sample
@torch.no_grad()
def predict(self, inputs: List[dict], infer_cfg: Optional[dict] = None) -> List[dict]:
"""
Args:
inputs: The input data.
infer_cfg: The inference configuration.
Returns:
The prediction results.
"""
results = []
for input_item in inputs:
prompt = input_item['data'][0]
image_id = input_item.get('id') or input_item.get('index')
samples = self._model_generate(prompt, infer_cfg)
choices_list = []
for index, sample in enumerate(samples):
image_file_path = os.path.join(self.save_path, f'{image_id}_{index}.jpeg')
sample.save(image_file_path)
logger.debug(f'Saved image to {image_file_path}')
choice = ChatCompletionResponseChoice(
index=index, message=ChatMessage(content=image_file_path, role='assistant'), finish_reason='stop')
choices_list.append(choice)
res_d = ChatCompletionResponse(
model=self.model_id, choices=choices_list, object='images.generations',
created=int(time.time())).model_dump(exclude_unset=True)
results.append(res_d)
return results