68 lines
2.5 KiB
Python
68 lines
2.5 KiB
Python
from typing import Any, Callable, Dict, Optional, Sequence
|
|
|
|
from llama_index.callbacks import CallbackManager
|
|
from llama_index.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE
|
|
from llama_index.core.llms.types import ChatMessage, LLMMetadata
|
|
from llama_index.llms.everlyai_utils import everlyai_modelname_to_contextsize
|
|
from llama_index.llms.generic_utils import get_from_param_or_env
|
|
from llama_index.llms.openai import OpenAI
|
|
from llama_index.types import BaseOutputParser, PydanticProgramMode
|
|
|
|
EVERLYAI_API_BASE = "https://everlyai.xyz/hosted"
|
|
DEFAULT_MODEL = "meta-llama/Llama-2-7b-chat-hf"
|
|
|
|
|
|
class EverlyAI(OpenAI):
|
|
def __init__(
|
|
self,
|
|
model: str = DEFAULT_MODEL,
|
|
temperature: float = DEFAULT_TEMPERATURE,
|
|
max_tokens: int = DEFAULT_NUM_OUTPUTS,
|
|
additional_kwargs: Optional[Dict[str, Any]] = None,
|
|
max_retries: int = 10,
|
|
api_key: Optional[str] = None,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
system_prompt: Optional[str] = None,
|
|
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
|
|
completion_to_prompt: Optional[Callable[[str], str]] = None,
|
|
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
|
|
output_parser: Optional[BaseOutputParser] = None,
|
|
) -> None:
|
|
additional_kwargs = additional_kwargs or {}
|
|
callback_manager = callback_manager or CallbackManager([])
|
|
|
|
api_key = get_from_param_or_env("api_key", api_key, "EverlyAI_API_KEY")
|
|
|
|
super().__init__(
|
|
model=model,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
api_base=EVERLYAI_API_BASE,
|
|
api_key=api_key,
|
|
additional_kwargs=additional_kwargs,
|
|
max_retries=max_retries,
|
|
callback_manager=callback_manager,
|
|
system_prompt=system_prompt,
|
|
messages_to_prompt=messages_to_prompt,
|
|
completion_to_prompt=completion_to_prompt,
|
|
pydantic_program_mode=pydantic_program_mode,
|
|
output_parser=output_parser,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "EverlyAI_LLM"
|
|
|
|
@property
|
|
def metadata(self) -> LLMMetadata:
|
|
return LLMMetadata(
|
|
context_window=everlyai_modelname_to_contextsize(self.model),
|
|
num_output=self.max_tokens,
|
|
is_chat_model=True,
|
|
model_name=self.model,
|
|
)
|
|
|
|
@property
|
|
def _is_chat_model(self) -> bool:
|
|
return True
|