249 lines
8.1 KiB
Python
249 lines
8.1 KiB
Python
import random
|
|
from typing import (
|
|
Any,
|
|
Dict,
|
|
Optional,
|
|
Sequence,
|
|
)
|
|
|
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
|
from llama_index.callbacks import CallbackManager
|
|
from llama_index.llms.base import (
|
|
ChatMessage,
|
|
ChatResponse,
|
|
ChatResponseAsyncGen,
|
|
ChatResponseGen,
|
|
CompletionResponse,
|
|
CompletionResponseAsyncGen,
|
|
CompletionResponseGen,
|
|
LLMMetadata,
|
|
llm_chat_callback,
|
|
)
|
|
from llama_index.llms.generic_utils import (
|
|
completion_to_chat_decorator,
|
|
)
|
|
from llama_index.llms.llm import LLM
|
|
from llama_index.llms.nvidia_triton_utils import GrpcTritonClient
|
|
|
|
DEFAULT_SERVER_URL = "localhost:8001"
|
|
DEFAULT_MAX_RETRIES = 3
|
|
DEFAULT_TIMEOUT = 60.0
|
|
DEFAULT_MODEL = "ensemble"
|
|
DEFAULT_TEMPERATURE = 1.0
|
|
DEFAULT_TOP_P = 0
|
|
DEFAULT_TOP_K = 1.0
|
|
DEFAULT_MAX_TOKENS = 100
|
|
DEFAULT_BEAM_WIDTH = 1
|
|
DEFAULT_REPTITION_PENALTY = 1.0
|
|
DEFAULT_LENGTH_PENALTY = 1.0
|
|
DEFAULT_REUSE_CLIENT = True
|
|
DEFAULT_TRITON_LOAD_MODEL = True
|
|
|
|
|
|
class NvidiaTriton(LLM):
|
|
server_url: str = Field(
|
|
default=DEFAULT_SERVER_URL,
|
|
description="The URL of the Triton inference server to use.",
|
|
)
|
|
model_name: str = Field(
|
|
default=DEFAULT_MODEL,
|
|
description="The name of the Triton hosted model this client should use",
|
|
)
|
|
temperature: Optional[float] = Field(
|
|
default=DEFAULT_TEMPERATURE, description="Temperature to use for sampling"
|
|
)
|
|
top_p: Optional[float] = Field(
|
|
default=DEFAULT_TOP_P, description="The top-p value to use for sampling"
|
|
)
|
|
top_k: Optional[float] = Field(
|
|
default=DEFAULT_TOP_K, description="The top k value to use for sampling"
|
|
)
|
|
tokens: Optional[int] = Field(
|
|
default=DEFAULT_MAX_TOKENS,
|
|
description="The maximum number of tokens to generate.",
|
|
)
|
|
beam_width: Optional[int] = Field(
|
|
default=DEFAULT_BEAM_WIDTH, description="Last n number of tokens to penalize"
|
|
)
|
|
repetition_penalty: Optional[float] = Field(
|
|
default=DEFAULT_REPTITION_PENALTY,
|
|
description="Last n number of tokens to penalize",
|
|
)
|
|
length_penalty: Optional[float] = Field(
|
|
default=DEFAULT_LENGTH_PENALTY,
|
|
description="The penalty to apply repeated tokens",
|
|
)
|
|
max_retries: Optional[int] = Field(
|
|
default=DEFAULT_MAX_RETRIES,
|
|
description="Maximum number of attempts to retry Triton client invocation before erroring",
|
|
)
|
|
timeout: Optional[float] = Field(
|
|
default=DEFAULT_TIMEOUT,
|
|
description="Maximum time (seconds) allowed for a Triton client call before erroring",
|
|
)
|
|
reuse_client: Optional[bool] = Field(
|
|
default=DEFAULT_REUSE_CLIENT,
|
|
description="True for reusing the same client instance between invocations",
|
|
)
|
|
triton_load_model_call: Optional[bool] = Field(
|
|
default=DEFAULT_TRITON_LOAD_MODEL,
|
|
description="True if a Triton load model API call should be made before using the client",
|
|
)
|
|
|
|
_client: Optional[GrpcTritonClient] = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
server_url: str = DEFAULT_SERVER_URL,
|
|
model: str = DEFAULT_MODEL,
|
|
temperature: float = DEFAULT_TEMPERATURE,
|
|
top_p: float = DEFAULT_TOP_P,
|
|
top_k: float = DEFAULT_TOP_K,
|
|
tokens: Optional[int] = DEFAULT_MAX_TOKENS,
|
|
beam_width: int = DEFAULT_BEAM_WIDTH,
|
|
repetition_penalty: float = DEFAULT_REPTITION_PENALTY,
|
|
length_penalty: float = DEFAULT_LENGTH_PENALTY,
|
|
max_retries: int = DEFAULT_MAX_RETRIES,
|
|
timeout: float = DEFAULT_TIMEOUT,
|
|
reuse_client: bool = DEFAULT_REUSE_CLIENT,
|
|
triton_load_model_call: bool = DEFAULT_TRITON_LOAD_MODEL,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
additional_kwargs: Optional[Dict[str, Any]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
additional_kwargs = additional_kwargs or {}
|
|
|
|
super().__init__(
|
|
server_url=server_url,
|
|
model=model,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
tokens=tokens,
|
|
beam_width=beam_width,
|
|
repetition_penalty=repetition_penalty,
|
|
length_penalty=length_penalty,
|
|
max_retries=max_retries,
|
|
timeout=timeout,
|
|
reuse_client=reuse_client,
|
|
triton_load_model_call=triton_load_model_call,
|
|
callback_manager=callback_manager,
|
|
additional_kwargs=additional_kwargs,
|
|
**kwargs,
|
|
)
|
|
|
|
try:
|
|
self._client = GrpcTritonClient(server_url)
|
|
except ImportError as err:
|
|
raise ImportError(
|
|
"Could not import triton client python package. "
|
|
"Please install it with `pip install tritonclient`."
|
|
) from err
|
|
|
|
@property
|
|
def _get_model_default_parameters(self) -> Dict[str, Any]:
|
|
return {
|
|
"tokens": self.tokens,
|
|
"top_k": self.top_k,
|
|
"top_p": self.top_p,
|
|
"temperature": self.temperature,
|
|
"repetition_penalty": self.repetition_penalty,
|
|
"length_penalty": self.length_penalty,
|
|
"beam_width": self.beam_width,
|
|
}
|
|
|
|
@property
|
|
def _invocation_params(self, **kwargs: Any) -> Dict[str, Any]:
|
|
return {**self._get_model_default_parameters, **kwargs}
|
|
|
|
@property
|
|
def _identifying_params(self) -> Dict[str, Any]:
|
|
"""Get all the identifying parameters."""
|
|
return {
|
|
"server_url": self.server_url,
|
|
"model_name": self.model_name,
|
|
}
|
|
|
|
def _get_client(self) -> Any:
|
|
"""Create or reuse a Triton client connection."""
|
|
if not self.reuse_client:
|
|
return GrpcTritonClient(self.server_url)
|
|
|
|
if self._client is None:
|
|
self._client = GrpcTritonClient(self.server_url)
|
|
return self._client
|
|
|
|
@property
|
|
def metadata(self) -> LLMMetadata:
|
|
"""Gather and return metadata about the user Triton configured LLM model."""
|
|
return LLMMetadata(
|
|
model_name=self.model_name,
|
|
)
|
|
|
|
@llm_chat_callback()
|
|
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
|
|
chat_fn = completion_to_chat_decorator(self.complete)
|
|
return chat_fn(messages, **kwargs)
|
|
|
|
def stream_chat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponseGen:
|
|
raise NotImplementedError
|
|
|
|
def complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponse:
|
|
from tritonclient.utils import InferenceServerException
|
|
|
|
client = self._get_client()
|
|
|
|
invocation_params = self._get_model_default_parameters
|
|
invocation_params.update(kwargs)
|
|
invocation_params["prompt"] = [[prompt]]
|
|
model_params = self._identifying_params
|
|
model_params.update(kwargs)
|
|
request_id = str(random.randint(1, 9999999)) # nosec
|
|
|
|
if self.triton_load_model_call:
|
|
client.load_model(model_params["model_name"])
|
|
|
|
result_queue = client.request_streaming(
|
|
model_params["model_name"], request_id, **invocation_params
|
|
)
|
|
|
|
response = ""
|
|
for token in result_queue:
|
|
if isinstance(token, InferenceServerException):
|
|
client.stop_stream(model_params["model_name"], request_id)
|
|
raise token
|
|
response = response + token
|
|
|
|
return CompletionResponse(
|
|
text=response,
|
|
)
|
|
|
|
def stream_complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponseGen:
|
|
raise NotImplementedError
|
|
|
|
async def achat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponse:
|
|
raise NotImplementedError
|
|
|
|
async def acomplete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponse:
|
|
raise NotImplementedError
|
|
|
|
async def astream_chat(
|
|
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
) -> ChatResponseAsyncGen:
|
|
raise NotImplementedError
|
|
|
|
async def astream_complete(
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
) -> CompletionResponseAsyncGen:
|
|
raise NotImplementedError
|