faiss_rag_enterprise/llama_index/llms/monsterapi.py

189 lines
6.7 KiB
Python

from typing import Any, Callable, Dict, Optional, Sequence
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks import CallbackManager
from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
from llama_index.core.llms.types import (
ChatMessage,
ChatResponse,
CompletionResponse,
CompletionResponseGen,
LLMMetadata,
)
from llama_index.llms.base import llm_chat_callback, llm_completion_callback
from llama_index.llms.custom import CustomLLM
from llama_index.types import BaseOutputParser, PydanticProgramMode
DEFAULT_MONSTER_TEMP = 0.75
class MonsterLLM(CustomLLM):
model: str = Field(description="The MonsterAPI model to use.")
monster_api_key: Optional[str] = Field(description="The MonsterAPI key to use.")
max_new_tokens: int = Field(
default=DEFAULT_NUM_OUTPUTS,
description="The number of tokens to generate.",
gt=0,
)
temperature: float = Field(
default=DEFAULT_MONSTER_TEMP,
description="The temperature to use for sampling.",
gte=0.0,
lte=1.0,
)
context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description="The number of context tokens available to the LLM.",
gt=0,
)
_client: Any = PrivateAttr()
def __init__(
self,
model: str,
base_url: str = "https://api.monsterapi.ai/v1",
monster_api_key: Optional[str] = None,
max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
temperature: float = DEFAULT_MONSTER_TEMP,
context_window: int = DEFAULT_CONTEXT_WINDOW,
callback_manager: Optional[CallbackManager] = None,
system_prompt: Optional[str] = None,
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
completion_to_prompt: Optional[Callable[[str], str]] = None,
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
output_parser: Optional[BaseOutputParser] = None,
) -> None:
self._client, available_llms = self.initialize_client(monster_api_key, base_url)
# Check if provided model is supported
if model not in available_llms:
error_message = (
f"Model: {model} is not supported. "
f"Supported models are {available_llms}. "
"Please update monsterapiclient to see if any models are added. "
"pip install --upgrade monsterapi"
)
raise RuntimeError(error_message)
super().__init__(
model=model,
monster_api_key=monster_api_key,
max_new_tokens=max_new_tokens,
temperature=temperature,
context_window=context_window,
callback_manager=callback_manager,
system_prompt=system_prompt,
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
pydantic_program_mode=pydantic_program_mode,
output_parser=output_parser,
)
def initialize_client(
self, monster_api_key: Optional[str], base_url: Optional[str]
) -> Any:
try:
from monsterapi import client as MonsterClient
from monsterapi.InputDataModels import MODEL_TYPES
except ImportError:
raise ImportError(
"Could not import Monster API client library."
"Please install it with `pip install monsterapi`"
)
llm_models_enabled = [i for i, j in MODEL_TYPES.items() if j == "LLM"]
return MonsterClient(monster_api_key, base_url), llm_models_enabled
@classmethod
def class_name(cls) -> str:
return "MonsterLLM"
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=self.context_window,
num_output=self.max_new_tokens,
model_name=self.model,
)
def _get_input_dict(self, prompt: str, **kwargs: Any) -> Dict[str, Any]:
return {
"prompt": prompt,
"temperature": self.temperature,
"max_length": self.max_new_tokens,
**kwargs,
}
@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
prompt = self.messages_to_prompt(messages)
return self.complete(prompt, formatted=True, **kwargs)
@llm_completion_callback()
def complete(
self, prompt: str, formatted: bool = False, timeout: int = 100, **kwargs: Any
) -> CompletionResponse:
if not formatted:
prompt = self.completion_to_prompt(prompt)
stream = kwargs.pop("stream", False)
if stream is True:
raise NotImplementedError(
"complete method cannot be used with stream=True, please use stream_complete method"
)
# Validate input args against input Pydantic model
input_dict = self._get_input_dict(prompt, **kwargs)
result = self._client.generate(
model=self.model, data=input_dict, timeout=timeout
)
if isinstance(result, Exception):
raise result
if isinstance(result, dict) and "error" in result:
raise RuntimeError(result["error"])
if isinstance(result, dict) and "text" in result:
if isinstance(result["text"], list):
return CompletionResponse(text=result["text"][0])
elif isinstance(result["text"], str):
return CompletionResponse(text=result["text"])
if isinstance(result, list):
return CompletionResponse(text=result[0]["text"])
raise RuntimeError("Unexpected Return please contact monsterapi support!")
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
if "deploy" not in self.model:
raise NotImplementedError(
"stream_complete method can only be used with deploy models for now. Support for other models will be added soon."
)
# Validate input args against input Pydantic model
input_dict = self._get_input_dict(prompt, **kwargs)
input_dict["stream"] = True
# Starting the stream
result_stream = self._client.generate(model=self.model, data=input_dict)
if isinstance(result_stream, Exception):
raise result_stream
if isinstance(result_stream, dict) and "error" in result_stream:
raise RuntimeError(result_stream["error"])
# Iterating over the generator
try:
for result in result_stream:
yield CompletionResponse(text=result[0])
except StopIteration:
pass