import gc import json import os import time from pathlib import Path from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS from llama_index.llms.base import ( ChatMessage, ChatResponse, CompletionResponse, LLMMetadata, llm_chat_callback, llm_completion_callback, ) from llama_index.llms.custom import CustomLLM from llama_index.llms.generic_utils import completion_response_to_chat_response from llama_index.llms.nvidia_tensorrt_utils import ( generate_completion_dict, get_output, parse_input, ) EOS_TOKEN = 2 PAD_TOKEN = 2 class LocalTensorRTLLM(CustomLLM): model_path: Optional[str] = Field(description="The path to the trt engine.") temperature: float = Field(description="The temperature to use for sampling.") max_new_tokens: int = Field(description="The maximum number of tokens to generate.") context_window: int = Field( description="The maximum number of context tokens for the model." ) messages_to_prompt: Callable = Field( description="The function to convert messages to a prompt.", exclude=True ) completion_to_prompt: Callable = Field( description="The function to convert a completion to a prompt.", exclude=True ) generate_kwargs: Dict[str, Any] = Field( default_factory=dict, description="Kwargs used for generation." ) model_kwargs: Dict[str, Any] = Field( default_factory=dict, description="Kwargs used for model initialization." ) verbose: bool = Field(description="Whether to print verbose output.") _model: Any = PrivateAttr() _model_config: Any = PrivateAttr() _tokenizer: Any = PrivateAttr() _max_new_tokens = PrivateAttr() _sampling_config = PrivateAttr() _verbose = PrivateAttr() def __init__( self, model_path: Optional[str] = None, engine_name: Optional[str] = None, tokenizer_dir: Optional[str] = None, temperature: float = 0.1, max_new_tokens: int = DEFAULT_NUM_OUTPUTS, context_window: int = DEFAULT_CONTEXT_WINDOW, messages_to_prompt: Optional[Callable] = None, completion_to_prompt: Optional[Callable] = None, callback_manager: Optional[CallbackManager] = None, generate_kwargs: Optional[Dict[str, Any]] = None, model_kwargs: Optional[Dict[str, Any]] = None, verbose: bool = False, ) -> None: try: import torch from transformers import AutoTokenizer except ImportError: raise ImportError( "nvidia_tensorrt requires `pip install torch` and `pip install transformers`." ) try: import tensorrt_llm from tensorrt_llm.runtime import ModelConfig, SamplingConfig except ImportError: print( "Unable to import `tensorrt_llm` module. Please ensure you have\ `tensorrt_llm` installed in your environment. You can run\ `pip3 install tensorrt_llm -U --extra-index-url https://pypi.nvidia.com` to install." ) model_kwargs = model_kwargs or {} model_kwargs.update({"n_ctx": context_window, "verbose": verbose}) self._max_new_tokens = max_new_tokens self._verbose = verbose # check if model is cached if model_path is not None: if not os.path.exists(model_path): raise ValueError( "Provided model path does not exist. " "Please check the path or provide a model_url to download." ) else: engine_dir = model_path engine_dir_path = Path(engine_dir) config_path = engine_dir_path / "config.json" # config function with open(config_path) as f: config = json.load(f) use_gpt_attention_plugin = config["plugin_config"][ "gpt_attention_plugin" ] remove_input_padding = config["plugin_config"]["remove_input_padding"] tp_size = config["builder_config"]["tensor_parallel"] pp_size = config["builder_config"]["pipeline_parallel"] world_size = tp_size * pp_size assert ( world_size == tensorrt_llm.mpi_world_size() ), f"Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})" num_heads = config["builder_config"]["num_heads"] // tp_size hidden_size = config["builder_config"]["hidden_size"] // tp_size vocab_size = config["builder_config"]["vocab_size"] num_layers = config["builder_config"]["num_layers"] num_kv_heads = config["builder_config"].get("num_kv_heads", num_heads) paged_kv_cache = config["plugin_config"]["paged_kv_cache"] if config["builder_config"].get("multi_query_mode", False): tensorrt_llm.logger.warning( "`multi_query_mode` config is deprecated. Please rebuild the engine." ) num_kv_heads = 1 num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size self._model_config = ModelConfig( num_heads=num_heads, num_kv_heads=num_kv_heads, hidden_size=hidden_size, vocab_size=vocab_size, num_layers=num_layers, gpt_attention_plugin=use_gpt_attention_plugin, paged_kv_cache=paged_kv_cache, remove_input_padding=remove_input_padding, ) assert ( pp_size == 1 ), "Python runtime does not support pipeline parallelism" world_size = tp_size * pp_size runtime_rank = tensorrt_llm.mpi_rank() runtime_mapping = tensorrt_llm.Mapping( world_size, runtime_rank, tp_size=tp_size, pp_size=pp_size ) # TensorRT-LLM must run on a GPU. assert ( torch.cuda.is_available() ), "LocalTensorRTLLM requires a Nvidia CUDA enabled GPU to operate" torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node) self._tokenizer = AutoTokenizer.from_pretrained( tokenizer_dir, legacy=False ) self._sampling_config = SamplingConfig( end_id=EOS_TOKEN, pad_id=PAD_TOKEN, num_beams=1, temperature=temperature, ) serialize_path = engine_dir_path / (engine_name if engine_name else "") with open(serialize_path, "rb") as f: engine_buffer = f.read() decoder = tensorrt_llm.runtime.GenerationSession( self._model_config, engine_buffer, runtime_mapping, debug_mode=False ) self._model = decoder generate_kwargs = generate_kwargs or {} generate_kwargs.update( {"temperature": temperature, "max_tokens": max_new_tokens} ) super().__init__( model_path=model_path, temperature=temperature, context_window=context_window, max_new_tokens=max_new_tokens, messages_to_prompt=messages_to_prompt, completion_to_prompt=completion_to_prompt, callback_manager=callback_manager, generate_kwargs=generate_kwargs, model_kwargs=model_kwargs, verbose=verbose, ) @classmethod def class_name(cls) -> str: """Get class name.""" return "LocalTensorRTLLM" @property def metadata(self) -> LLMMetadata: """LLM metadata.""" return LLMMetadata( context_window=self.context_window, num_output=self.max_new_tokens, model_name=self.model_path, ) @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: prompt = self.messages_to_prompt(messages) completion_response = self.complete(prompt, formatted=True, **kwargs) return completion_response_to_chat_response(completion_response) @llm_completion_callback() def complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponse: try: import torch except ImportError: raise ImportError("nvidia_tensorrt requires `pip install torch`.") self.generate_kwargs.update({"stream": False}) if not formatted: prompt = self.completion_to_prompt(prompt) input_text = prompt input_ids, input_lengths = parse_input( input_text, self._tokenizer, EOS_TOKEN, self._model_config ) max_input_length = torch.max(input_lengths).item() self._model.setup( input_lengths.size(0), max_input_length, self._max_new_tokens, 1 ) # beam size is set to 1 if self._verbose: start_time = time.time() output_ids = self._model.decode(input_ids, input_lengths, self._sampling_config) torch.cuda.synchronize() elapsed_time = -1.0 if self._verbose: end_time = time.time() elapsed_time = end_time - start_time output_txt, output_token_ids = get_output( output_ids, input_lengths, self._max_new_tokens, self._tokenizer ) if self._verbose: print(f"Input context length : {input_ids.shape[1]}") print(f"Inference time : {elapsed_time:.2f} seconds") print(f"Output context length : {len(output_token_ids)} ") print( f"Inference token/sec : {(len(output_token_ids) / elapsed_time):2f}" ) # call garbage collected after inference torch.cuda.empty_cache() gc.collect() return CompletionResponse( text=output_txt, raw=generate_completion_dict(output_txt, self._model, self.model_path), ) @llm_completion_callback() def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: raise NotImplementedError( "Nvidia TensorRT-LLM does not currently support streaming completion." )