faiss_rag_enterprise/llama_index/llms/nvidia_triton.py

249 lines
8.1 KiB
Python

import random
from typing import (
Any,
Dict,
Optional,
Sequence,
)
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks import CallbackManager
from llama_index.llms.base import (
ChatMessage,
ChatResponse,
ChatResponseAsyncGen,
ChatResponseGen,
CompletionResponse,
CompletionResponseAsyncGen,
CompletionResponseGen,
LLMMetadata,
llm_chat_callback,
)
from llama_index.llms.generic_utils import (
completion_to_chat_decorator,
)
from llama_index.llms.llm import LLM
from llama_index.llms.nvidia_triton_utils import GrpcTritonClient
DEFAULT_SERVER_URL = "localhost:8001"
DEFAULT_MAX_RETRIES = 3
DEFAULT_TIMEOUT = 60.0
DEFAULT_MODEL = "ensemble"
DEFAULT_TEMPERATURE = 1.0
DEFAULT_TOP_P = 0
DEFAULT_TOP_K = 1.0
DEFAULT_MAX_TOKENS = 100
DEFAULT_BEAM_WIDTH = 1
DEFAULT_REPTITION_PENALTY = 1.0
DEFAULT_LENGTH_PENALTY = 1.0
DEFAULT_REUSE_CLIENT = True
DEFAULT_TRITON_LOAD_MODEL = True
class NvidiaTriton(LLM):
server_url: str = Field(
default=DEFAULT_SERVER_URL,
description="The URL of the Triton inference server to use.",
)
model_name: str = Field(
default=DEFAULT_MODEL,
description="The name of the Triton hosted model this client should use",
)
temperature: Optional[float] = Field(
default=DEFAULT_TEMPERATURE, description="Temperature to use for sampling"
)
top_p: Optional[float] = Field(
default=DEFAULT_TOP_P, description="The top-p value to use for sampling"
)
top_k: Optional[float] = Field(
default=DEFAULT_TOP_K, description="The top k value to use for sampling"
)
tokens: Optional[int] = Field(
default=DEFAULT_MAX_TOKENS,
description="The maximum number of tokens to generate.",
)
beam_width: Optional[int] = Field(
default=DEFAULT_BEAM_WIDTH, description="Last n number of tokens to penalize"
)
repetition_penalty: Optional[float] = Field(
default=DEFAULT_REPTITION_PENALTY,
description="Last n number of tokens to penalize",
)
length_penalty: Optional[float] = Field(
default=DEFAULT_LENGTH_PENALTY,
description="The penalty to apply repeated tokens",
)
max_retries: Optional[int] = Field(
default=DEFAULT_MAX_RETRIES,
description="Maximum number of attempts to retry Triton client invocation before erroring",
)
timeout: Optional[float] = Field(
default=DEFAULT_TIMEOUT,
description="Maximum time (seconds) allowed for a Triton client call before erroring",
)
reuse_client: Optional[bool] = Field(
default=DEFAULT_REUSE_CLIENT,
description="True for reusing the same client instance between invocations",
)
triton_load_model_call: Optional[bool] = Field(
default=DEFAULT_TRITON_LOAD_MODEL,
description="True if a Triton load model API call should be made before using the client",
)
_client: Optional[GrpcTritonClient] = PrivateAttr()
def __init__(
self,
server_url: str = DEFAULT_SERVER_URL,
model: str = DEFAULT_MODEL,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: float = DEFAULT_TOP_K,
tokens: Optional[int] = DEFAULT_MAX_TOKENS,
beam_width: int = DEFAULT_BEAM_WIDTH,
repetition_penalty: float = DEFAULT_REPTITION_PENALTY,
length_penalty: float = DEFAULT_LENGTH_PENALTY,
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: float = DEFAULT_TIMEOUT,
reuse_client: bool = DEFAULT_REUSE_CLIENT,
triton_load_model_call: bool = DEFAULT_TRITON_LOAD_MODEL,
callback_manager: Optional[CallbackManager] = None,
additional_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
additional_kwargs = additional_kwargs or {}
super().__init__(
server_url=server_url,
model=model,
temperature=temperature,
top_p=top_p,
top_k=top_k,
tokens=tokens,
beam_width=beam_width,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
max_retries=max_retries,
timeout=timeout,
reuse_client=reuse_client,
triton_load_model_call=triton_load_model_call,
callback_manager=callback_manager,
additional_kwargs=additional_kwargs,
**kwargs,
)
try:
self._client = GrpcTritonClient(server_url)
except ImportError as err:
raise ImportError(
"Could not import triton client python package. "
"Please install it with `pip install tritonclient`."
) from err
@property
def _get_model_default_parameters(self) -> Dict[str, Any]:
return {
"tokens": self.tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"temperature": self.temperature,
"repetition_penalty": self.repetition_penalty,
"length_penalty": self.length_penalty,
"beam_width": self.beam_width,
}
@property
def _invocation_params(self, **kwargs: Any) -> Dict[str, Any]:
return {**self._get_model_default_parameters, **kwargs}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get all the identifying parameters."""
return {
"server_url": self.server_url,
"model_name": self.model_name,
}
def _get_client(self) -> Any:
"""Create or reuse a Triton client connection."""
if not self.reuse_client:
return GrpcTritonClient(self.server_url)
if self._client is None:
self._client = GrpcTritonClient(self.server_url)
return self._client
@property
def metadata(self) -> LLMMetadata:
"""Gather and return metadata about the user Triton configured LLM model."""
return LLMMetadata(
model_name=self.model_name,
)
@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
chat_fn = completion_to_chat_decorator(self.complete)
return chat_fn(messages, **kwargs)
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
raise NotImplementedError
def complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
from tritonclient.utils import InferenceServerException
client = self._get_client()
invocation_params = self._get_model_default_parameters
invocation_params.update(kwargs)
invocation_params["prompt"] = [[prompt]]
model_params = self._identifying_params
model_params.update(kwargs)
request_id = str(random.randint(1, 9999999)) # nosec
if self.triton_load_model_call:
client.load_model(model_params["model_name"])
result_queue = client.request_streaming(
model_params["model_name"], request_id, **invocation_params
)
response = ""
for token in result_queue:
if isinstance(token, InferenceServerException):
client.stop_stream(model_params["model_name"], request_id)
raise token
response = response + token
return CompletionResponse(
text=response,
)
def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseGen:
raise NotImplementedError
async def achat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponse:
raise NotImplementedError
async def acomplete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
raise NotImplementedError
async def astream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseAsyncGen:
raise NotImplementedError
async def astream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseAsyncGen:
raise NotImplementedError