51 lines
1.4 KiB
Python
51 lines
1.4 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import torch
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, List, Union
|
|
|
|
|
|
class CustomModel(ABC):
|
|
|
|
def __init__(self, config: dict, **kwargs):
|
|
self.config = config
|
|
self.kwargs = kwargs
|
|
|
|
@abstractmethod
|
|
@torch.no_grad()
|
|
def predict(self, prompts: List[str], **kwargs) -> List[Dict[str, Any]]:
|
|
"""
|
|
Model prediction function for batch inputs.
|
|
|
|
Args:
|
|
prompts (str): The input batch of prompts to predict.
|
|
|
|
**kwargs: kwargs
|
|
|
|
Returns:
|
|
res (dict): The model prediction results (batch). Format:
|
|
[
|
|
{
|
|
'choices': [
|
|
{
|
|
'index': 0,
|
|
'message': {
|
|
'content': 'xxx',
|
|
'role': 'assistant'
|
|
}
|
|
}
|
|
],
|
|
'created': 1677664795,
|
|
'model': 'gpt-3.5-turbo-0613', # should be model_id
|
|
'object': 'chat.completion',
|
|
'usage': {
|
|
'completion_tokens': 17,
|
|
'prompt_tokens': 57,
|
|
'total_tokens': 74
|
|
}
|
|
}
|
|
,
|
|
...
|
|
]
|
|
"""
|
|
raise NotImplementedError
|