40 lines
1.0 KiB
Python
40 lines
1.0 KiB
Python
from typing import Optional
|
|
|
|
from typing_extensions import NotRequired, TypedDict
|
|
|
|
from llama_index.core.llms.types import ChatMessage
|
|
|
|
XINFERENCE_MODEL_SIZES = {
|
|
"baichuan": 2048,
|
|
"baichuan-chat": 2048,
|
|
"wizardlm-v1.0": 2048,
|
|
"vicuna-v1.3": 2048,
|
|
"orca": 2048,
|
|
"chatglm": 2048,
|
|
"chatglm2": 8192,
|
|
"llama-2-chat": 4096,
|
|
"llama-2": 4096,
|
|
}
|
|
|
|
|
|
class ChatCompletionMessage(TypedDict):
|
|
role: str
|
|
content: Optional[str]
|
|
user: NotRequired[str]
|
|
|
|
|
|
def xinference_message_to_history(message: ChatMessage) -> ChatCompletionMessage:
|
|
return ChatCompletionMessage(role=message.role, content=message.content)
|
|
|
|
|
|
def xinference_modelname_to_contextsize(modelname: str) -> int:
|
|
context_size = XINFERENCE_MODEL_SIZES.get(modelname, None)
|
|
|
|
if context_size is None:
|
|
raise ValueError(
|
|
f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
|
|
"Known models are: " + ", ".join(XINFERENCE_MODEL_SIZES.keys())
|
|
)
|
|
|
|
return context_size
|