77 lines
2.6 KiB
Python
77 lines
2.6 KiB
Python
import os
|
|
import time
|
|
import torch
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
from evalscope.utils.chat_service import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage
|
|
from evalscope.utils.io_utils import OutputsStructure
|
|
from evalscope.utils.logger import get_logger
|
|
from ..local_model import LocalModel
|
|
from .base_adapter import BaseModelAdapter
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
class T2IModelAdapter(BaseModelAdapter):
|
|
"""
|
|
Text to image model adapter.
|
|
"""
|
|
|
|
def __init__(self, model: LocalModel, **kwargs):
|
|
super().__init__(model)
|
|
|
|
self.task_config = kwargs.get('task_cfg', None)
|
|
assert self.task_config is not None, 'Task config is required for T2I model adapter.'
|
|
|
|
self.save_path = os.path.join(self.task_config.work_dir, OutputsStructure.PREDICTIONS_DIR,
|
|
self.task_config.model_id, 'images')
|
|
os.makedirs(self.save_path, exist_ok=True)
|
|
|
|
def _model_generate(self, prompt, infer_cfg=None) -> List:
|
|
"""
|
|
Generate images from the model.
|
|
Args:
|
|
prompt: The input prompt.
|
|
infer_cfg: The inference configuration.
|
|
Returns:
|
|
The generated images.
|
|
"""
|
|
infer_cfg = infer_cfg or {}
|
|
|
|
sample = self.model(prompt=prompt, **infer_cfg).images
|
|
return sample
|
|
|
|
@torch.no_grad()
|
|
def predict(self, inputs: List[dict], infer_cfg: Optional[dict] = None) -> List[dict]:
|
|
"""
|
|
Args:
|
|
inputs: The input data.
|
|
infer_cfg: The inference configuration.
|
|
Returns:
|
|
The prediction results.
|
|
"""
|
|
results = []
|
|
for input_item in inputs:
|
|
prompt = input_item['data'][0]
|
|
image_id = input_item.get('id') or input_item.get('index')
|
|
|
|
samples = self._model_generate(prompt, infer_cfg)
|
|
|
|
choices_list = []
|
|
for index, sample in enumerate(samples):
|
|
image_file_path = os.path.join(self.save_path, f'{image_id}_{index}.jpeg')
|
|
sample.save(image_file_path)
|
|
logger.debug(f'Saved image to {image_file_path}')
|
|
|
|
choice = ChatCompletionResponseChoice(
|
|
index=index, message=ChatMessage(content=image_file_path, role='assistant'), finish_reason='stop')
|
|
choices_list.append(choice)
|
|
|
|
res_d = ChatCompletionResponse(
|
|
model=self.model_id, choices=choices_list, object='images.generations',
|
|
created=int(time.time())).model_dump(exclude_unset=True)
|
|
|
|
results.append(res_d)
|
|
|
|
return results
|