faiss_rag_enterprise/llama_index/llms/openllm.py

481 lines
17 KiB
Python

import asyncio
import logging
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
)
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks import CallbackManager
from llama_index.core.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseAsyncGen,
ChatResponseGen,
CompletionResponse,
CompletionResponseAsyncGen,
CompletionResponseGen,
LLMMetadata,
)
from llama_index.llms.base import (
llm_chat_callback,
llm_completion_callback,
)
from llama_index.llms.generic_utils import (
completion_response_to_chat_response,
)
from llama_index.llms.generic_utils import (
messages_to_prompt as generic_messages_to_prompt,
)
from llama_index.llms.llm import LLM
from llama_index.types import PydanticProgramMode
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from typing import TypeVar
M = TypeVar("M")
T = TypeVar("T")
Metadata = Any
class OpenLLM(LLM):
"""OpenLLM LLM."""
model_id: str = Field(
description="Given Model ID from HuggingFace Hub. This can be either a pretrained ID or local path. This is synonymous to HuggingFace's '.from_pretrained' first argument"
)
model_version: Optional[str] = Field(
description="Optional model version to save the model as."
)
model_tag: Optional[str] = Field(
description="Optional tag to save to BentoML store."
)
prompt_template: Optional[str] = Field(
description="Optional prompt template to pass for this LLM."
)
backend: Optional[Literal["vllm", "pt"]] = Field(
description="Optional backend to pass for this LLM. By default, it will use vLLM if vLLM is available in local system. Otherwise, it will fallback to PyTorch."
)
quantize: Optional[Literal["awq", "gptq", "int8", "int4", "squeezellm"]] = Field(
description="Optional quantization methods to use with this LLM. See OpenLLM's --quantize options from `openllm start` for more information."
)
serialization: Literal["safetensors", "legacy"] = Field(
description="Optional serialization methods for this LLM to be save as. Default to 'safetensors', but will fallback to PyTorch pickle `.bin` on some models."
)
trust_remote_code: bool = Field(
description="Optional flag to trust remote code. This is synonymous to Transformers' `trust_remote_code`. Default to False."
)
if TYPE_CHECKING:
from typing import Generic
try:
import openllm
_llm: openllm.LLM[Any, Any]
except ImportError:
_llm: Any # type: ignore[no-redef]
else:
_llm: Any = PrivateAttr()
def __init__(
self,
model_id: str,
model_version: Optional[str] = None,
model_tag: Optional[str] = None,
prompt_template: Optional[str] = None,
backend: Optional[Literal["vllm", "pt"]] = None,
*args: Any,
quantize: Optional[Literal["awq", "gptq", "int8", "int4", "squeezellm"]] = None,
serialization: Literal["safetensors", "legacy"] = "safetensors",
trust_remote_code: bool = False,
callback_manager: Optional[CallbackManager] = None,
system_prompt: Optional[str] = None,
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
completion_to_prompt: Optional[Callable[[str], str]] = None,
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
**attrs: Any,
):
try:
import openllm
except ImportError:
raise ImportError(
"OpenLLM is not installed. Please install OpenLLM via `pip install openllm`"
)
self._llm = openllm.LLM[Any, Any](
model_id,
model_version=model_version,
model_tag=model_tag,
prompt_template=prompt_template,
system_message=system_prompt,
backend=backend,
quantize=quantize,
serialisation=serialization,
trust_remote_code=trust_remote_code,
embedded=True,
**attrs,
)
if messages_to_prompt is None:
messages_to_prompt = self._tokenizer_messages_to_prompt
# NOTE: We need to do this here to ensure model is saved and revision is set correctly.
assert self._llm.bentomodel
super().__init__(
model_id=model_id,
model_version=self._llm.revision,
model_tag=str(self._llm.tag),
prompt_template=prompt_template,
backend=self._llm.__llm_backend__,
quantize=self._llm.quantise,
serialization=self._llm._serialisation,
trust_remote_code=self._llm.trust_remote_code,
callback_manager=callback_manager,
system_prompt=system_prompt,
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
pydantic_program_mode=pydantic_program_mode,
)
@classmethod
def class_name(cls) -> str:
return "OpenLLM"
@property
def metadata(self) -> LLMMetadata:
"""LLM metadata."""
return LLMMetadata(
num_output=self._llm.config["max_new_tokens"],
model_name=self.model_id,
)
def _tokenizer_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
"""Use the tokenizer to convert messages to prompt. Fallback to generic."""
if hasattr(self._llm.tokenizer, "apply_chat_template"):
return self._llm.tokenizer.apply_chat_template(
[message.dict() for message in messages],
tokenize=False,
add_generation_prompt=True,
)
return generic_messages_to_prompt(messages)
@llm_completion_callback()
def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
return asyncio.run(self.acomplete(prompt, **kwargs))
@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
return asyncio.run(self.achat(messages, **kwargs))
@property
def _loop(self) -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.get_event_loop()
return loop
@llm_completion_callback()
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseGen:
generator = self.astream_complete(prompt, **kwargs)
# Yield items from the queue synchronously
while True:
try:
yield self._loop.run_until_complete(generator.__anext__())
except StopAsyncIteration:
break
@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
generator = self.astream_chat(messages, **kwargs)
# Yield items from the queue synchronously
while True:
try:
yield self._loop.run_until_complete(generator.__anext__())
except StopAsyncIteration:
break
@llm_chat_callback()
async def achat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any,
) -> ChatResponse:
response = await self.acomplete(self.messages_to_prompt(messages), **kwargs)
return completion_response_to_chat_response(response)
@llm_completion_callback()
async def acomplete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
response = await self._llm.generate(prompt, **kwargs)
return CompletionResponse(
text=response.outputs[0].text,
raw=response.model_dump(),
additional_kwargs={
"prompt_token_ids": response.prompt_token_ids,
"prompt_logprobs": response.prompt_logprobs,
"finished": response.finished,
"outputs": {
"token_ids": response.outputs[0].token_ids,
"cumulative_logprob": response.outputs[0].cumulative_logprob,
"logprobs": response.outputs[0].logprobs,
"finish_reason": response.outputs[0].finish_reason,
},
},
)
@llm_chat_callback()
async def astream_chat(
self,
messages: Sequence[ChatMessage],
**kwargs: Any,
) -> ChatResponseAsyncGen:
async for response_chunk in self.astream_complete(
self.messages_to_prompt(messages), **kwargs
):
yield completion_response_to_chat_response(response_chunk)
@llm_completion_callback()
async def astream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseAsyncGen:
config = self._llm.config.model_construct_env(**kwargs)
if config["n"] > 1:
logger.warning("Currently only support n=1")
texts: List[List[str]] = [[]] * config["n"]
async for response_chunk in self._llm.generate_iterator(prompt, **kwargs):
for output in response_chunk.outputs:
texts[output.index].append(output.text)
yield CompletionResponse(
text=response_chunk.outputs[0].text,
delta=response_chunk.outputs[0].text,
raw=response_chunk.model_dump(),
additional_kwargs={
"prompt_token_ids": response_chunk.prompt_token_ids,
"prompt_logprobs": response_chunk.prompt_logprobs,
"finished": response_chunk.finished,
"outputs": {
"text": response_chunk.outputs[0].text,
"token_ids": response_chunk.outputs[0].token_ids,
"cumulative_logprob": response_chunk.outputs[
0
].cumulative_logprob,
"logprobs": response_chunk.outputs[0].logprobs,
"finish_reason": response_chunk.outputs[0].finish_reason,
},
},
)
class OpenLLMAPI(LLM):
"""OpenLLM Client interface. This is useful when interacting with a remote OpenLLM server."""
address: Optional[str] = Field(
description="OpenLLM server address. This could either be set here or via OPENLLM_ENDPOINT"
)
timeout: int = Field(description="Timeout for sending requests.")
max_retries: int = Field(description="Maximum number of retries.")
api_version: Literal["v1"] = Field(description="OpenLLM Server API version.")
if TYPE_CHECKING:
try:
from openllm_client import AsyncHTTPClient, HTTPClient
_sync_client: HTTPClient
_async_client: AsyncHTTPClient
except ImportError:
_sync_client: Any # type: ignore[no-redef]
_async_client: Any # type: ignore[no-redef]
else:
_sync_client: Any = PrivateAttr()
_async_client: Any = PrivateAttr()
def __init__(
self,
address: Optional[str] = None,
timeout: int = 30,
max_retries: int = 2,
api_version: Literal["v1"] = "v1",
**kwargs: Any,
):
try:
from openllm_client import AsyncHTTPClient, HTTPClient
except ImportError:
raise ImportError(
f'"{type(self).__name__}" requires "openllm-client". Make sure to install with `pip install openllm-client`'
)
super().__init__(
address=address,
timeout=timeout,
max_retries=max_retries,
api_version=api_version,
**kwargs,
)
self._sync_client = HTTPClient(
address=address,
timeout=timeout,
max_retries=max_retries,
api_version=api_version,
)
self._async_client = AsyncHTTPClient(
address=address,
timeout=timeout,
max_retries=max_retries,
api_version=api_version,
)
@classmethod
def class_name(cls) -> str:
return "OpenLLM_Client"
@property
def _server_metadata(self) -> "Metadata":
return self._sync_client._metadata
@property
def _server_config(self) -> Dict[str, Any]:
return self._sync_client._config
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
num_output=self._server_config["max_new_tokens"],
model_name=self._server_metadata.model_id.replace("/", "--"),
)
def _convert_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
return self._sync_client.helpers.messages(
messages=[
{"role": message.role, "content": message.content}
for message in messages
],
add_generation_prompt=True,
)
async def _async_messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
return await self._async_client.helpers.messages(
messages=[
{"role": message.role, "content": message.content}
for message in messages
],
add_generation_prompt=True,
)
@llm_completion_callback()
def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
response = self._sync_client.generate(prompt, **kwargs)
return CompletionResponse(
text=response.outputs[0].text,
raw=response.model_dump(),
additional_kwargs={
"prompt_token_ids": response.prompt_token_ids,
"prompt_logprobs": response.prompt_logprobs,
"finished": response.finished,
"outputs": {
"token_ids": response.outputs[0].token_ids,
"cumulative_logprob": response.outputs[0].cumulative_logprob,
"logprobs": response.outputs[0].logprobs,
"finish_reason": response.outputs[0].finish_reason,
},
},
)
@llm_completion_callback()
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseGen:
for response_chunk in self._sync_client.generate_stream(prompt, **kwargs):
yield CompletionResponse(
text=response_chunk.text,
delta=response_chunk.text,
raw=response_chunk.model_dump(),
additional_kwargs={"token_ids": response_chunk.token_ids},
)
@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
return completion_response_to_chat_response(
self.complete(self._convert_messages_to_prompt(messages), **kwargs)
)
@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
for response_chunk in self.stream_complete(
self._convert_messages_to_prompt(messages), **kwargs
):
yield completion_response_to_chat_response(response_chunk)
@llm_completion_callback()
async def acomplete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
response = await self._async_client.generate(prompt, **kwargs)
return CompletionResponse(
text=response.outputs[0].text,
raw=response.model_dump(),
additional_kwargs={
"prompt_token_ids": response.prompt_token_ids,
"prompt_logprobs": response.prompt_logprobs,
"finished": response.finished,
"outputs": {
"token_ids": response.outputs[0].token_ids,
"cumulative_logprob": response.outputs[0].cumulative_logprob,
"logprobs": response.outputs[0].logprobs,
"finish_reason": response.outputs[0].finish_reason,
},
},
)
@llm_completion_callback()
async def astream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseAsyncGen:
async for response_chunk in self._async_client.generate_stream(
prompt, **kwargs
):
yield CompletionResponse(
text=response_chunk.text,
delta=response_chunk.text,
raw=response_chunk.model_dump(),
additional_kwargs={"token_ids": response_chunk.token_ids},
)
@llm_chat_callback()
async def achat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponse:
return completion_response_to_chat_response(
await self.acomplete(
await self._async_messages_to_prompt(messages), **kwargs
)
)
@llm_chat_callback()
async def astream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseAsyncGen:
async for response_chunk in self.astream_complete(
await self._async_messages_to_prompt(messages), **kwargs
):
yield completion_response_to_chat_response(response_chunk)