faiss_rag_enterprise/llama_index/llms/nvidia_tensorrt_utils.py

96 lines
2.5 KiB
Python

import time
import uuid
from typing import Any, Dict, Optional
import numpy as np
def parse_input(
input_text: str, tokenizer: Any, end_id: int, remove_input_padding: bool
) -> Any:
try:
import torch
except ImportError:
raise ImportError("nvidia_tensorrt requires `pip install torch`.")
input_tokens = []
input_tokens.append(tokenizer.encode(input_text, add_special_tokens=False))
input_lengths = torch.tensor(
[len(x) for x in input_tokens], dtype=torch.int32, device="cuda"
)
if remove_input_padding:
input_ids = np.concatenate(input_tokens)
input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda").unsqueeze(
0
)
else:
input_ids = torch.nested.to_padded_tensor(
torch.nested.nested_tensor(input_tokens, dtype=torch.int32), end_id
).cuda()
return input_ids, input_lengths
def remove_extra_eos_ids(outputs: Any) -> Any:
outputs.reverse()
while outputs and outputs[0] == 2:
outputs.pop(0)
outputs.reverse()
outputs.append(2)
return outputs
def get_output(
output_ids: Any,
input_lengths: Any,
max_output_len: int,
tokenizer: Any,
) -> Any:
num_beams = output_ids.size(1)
output_text = ""
outputs = None
for b in range(input_lengths.size(0)):
for beam in range(num_beams):
output_begin = input_lengths[b]
output_end = input_lengths[b] + max_output_len
outputs = output_ids[b][beam][output_begin:output_end].tolist()
outputs = remove_extra_eos_ids(outputs)
output_text = tokenizer.decode(outputs)
return output_text, outputs
def generate_completion_dict(
text_str: str, model: Any, model_path: Optional[str]
) -> Dict:
"""
Generate a dictionary for text completion details.
Returns:
dict: A dictionary containing completion details.
"""
completion_id: str = f"cmpl-{uuid.uuid4()!s}"
created: int = int(time.time())
model_name: str = model if model is not None else model_path
return {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"choices": [
{
"text": text_str,
"index": 0,
"logprobs": None,
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": None,
"completion_tokens": None,
"total_tokens": None,
},
}