96 lines
2.5 KiB
Python
96 lines
2.5 KiB
Python
import time
|
|
import uuid
|
|
from typing import Any, Dict, Optional
|
|
|
|
import numpy as np
|
|
|
|
|
|
def parse_input(
|
|
input_text: str, tokenizer: Any, end_id: int, remove_input_padding: bool
|
|
) -> Any:
|
|
try:
|
|
import torch
|
|
except ImportError:
|
|
raise ImportError("nvidia_tensorrt requires `pip install torch`.")
|
|
|
|
input_tokens = []
|
|
|
|
input_tokens.append(tokenizer.encode(input_text, add_special_tokens=False))
|
|
|
|
input_lengths = torch.tensor(
|
|
[len(x) for x in input_tokens], dtype=torch.int32, device="cuda"
|
|
)
|
|
if remove_input_padding:
|
|
input_ids = np.concatenate(input_tokens)
|
|
input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda").unsqueeze(
|
|
0
|
|
)
|
|
else:
|
|
input_ids = torch.nested.to_padded_tensor(
|
|
torch.nested.nested_tensor(input_tokens, dtype=torch.int32), end_id
|
|
).cuda()
|
|
|
|
return input_ids, input_lengths
|
|
|
|
|
|
def remove_extra_eos_ids(outputs: Any) -> Any:
|
|
outputs.reverse()
|
|
while outputs and outputs[0] == 2:
|
|
outputs.pop(0)
|
|
outputs.reverse()
|
|
outputs.append(2)
|
|
return outputs
|
|
|
|
|
|
def get_output(
|
|
output_ids: Any,
|
|
input_lengths: Any,
|
|
max_output_len: int,
|
|
tokenizer: Any,
|
|
) -> Any:
|
|
num_beams = output_ids.size(1)
|
|
output_text = ""
|
|
outputs = None
|
|
for b in range(input_lengths.size(0)):
|
|
for beam in range(num_beams):
|
|
output_begin = input_lengths[b]
|
|
output_end = input_lengths[b] + max_output_len
|
|
outputs = output_ids[b][beam][output_begin:output_end].tolist()
|
|
outputs = remove_extra_eos_ids(outputs)
|
|
output_text = tokenizer.decode(outputs)
|
|
|
|
return output_text, outputs
|
|
|
|
|
|
def generate_completion_dict(
|
|
text_str: str, model: Any, model_path: Optional[str]
|
|
) -> Dict:
|
|
"""
|
|
Generate a dictionary for text completion details.
|
|
|
|
Returns:
|
|
dict: A dictionary containing completion details.
|
|
"""
|
|
completion_id: str = f"cmpl-{uuid.uuid4()!s}"
|
|
created: int = int(time.time())
|
|
model_name: str = model if model is not None else model_path
|
|
return {
|
|
"id": completion_id,
|
|
"object": "text_completion",
|
|
"created": created,
|
|
"model": model_name,
|
|
"choices": [
|
|
{
|
|
"text": text_str,
|
|
"index": 0,
|
|
"logprobs": None,
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": None,
|
|
"completion_tokens": None,
|
|
"total_tokens": None,
|
|
},
|
|
}
|