embed-bge-m3/FlagEmbedding/research/Reinforced_IR/inference/agent/vllm_instruct.py

91 lines
3.8 KiB
Python

import torch
from typing import List, Union
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from vllm.lora.request import LoRARequest
class LLMInstructAgent():
def __init__(
self,
generate_model_path: str = None,
gpu_memory_utilization: float = 0.8,
tensor_parallel_size: int = None
):
self.tokenizer = AutoTokenizer.from_pretrained(generate_model_path, trust_remote_code=True)
self.llm = LLM(model=generate_model_path, gpu_memory_utilization=gpu_memory_utilization,
trust_remote_code=True,
tensor_parallel_size=torch.cuda.device_count() if tensor_parallel_size is None else tensor_parallel_size,
enable_lora=True, max_lora_rank=64)
self.model_name = generate_model_path.lower()
def generate(
self,
prompts: Union[List[str], str] = None,
temperature: float = 0,
top_p: float = 1.0,
max_tokens: int = 300,
stop: List[str] = [],
repetition_penalty: float = 1.1,
lora_path: str = None,
n: int = 1
):
sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_tokens,
stop=stop, n=n,
repetition_penalty=repetition_penalty,
stop_token_ids=[128001, 128009] if 'llama' in self.model_name else None)
lora_request = None
if lora_path is not None:
lora_request = LoRARequest("lora_adapter", 1, lora_path)
if isinstance(prompts, str):
prompts = [prompts]
prompts = [self.tokenizer.decode(
self.tokenizer.apply_chat_template([
{
"role": "user", "content": t[:5120]
},
], add_generation_prompt=True, )[1:]) for t in prompts]
outputs = self.llm.generate(prompts, sampling_params, lora_request=lora_request)
if n > 1:
results = []
for output in outputs:
results.append([output.outputs[i].text.strip() for i in range(len(output.outputs))])
return [output.outputs[0].text.strip() for output in outputs]
# return [output.outputs[0].text.replace('<|start_header_id|>assistant<|end_header_id|>\n\n', '').replace('assistant<|end_header_id|>\n\n', '').replace('assistant<|end_header_id|>', '').strip() for output in outputs]
def generate_direct(
self,
prompts: List = None,
temperature: float = 0,
top_p: float = 1.0,
max_tokens: int = 300,
stop: List[str] = [],
repetition_penalty: float = 1.1,
lora_path: str = None,
n: int = 1
):
sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_tokens,
stop=stop, n=n,
repetition_penalty=repetition_penalty,
stop_token_ids=[128001, 128009] if 'llama' in self.model_name else None)
lora_request = None
if lora_path is not None:
lora_request = LoRARequest("lora_adapter", 1, lora_path)
if isinstance(prompts, str):
prompts = [prompts]
prompts = [self.tokenizer.decode(
self.tokenizer.apply_chat_template(t, add_generation_prompt=True, )[1:]) for t in prompts]
outputs = self.llm.generate(prompts, sampling_params, lora_request=lora_request)
if n > 1:
results = []
for output in outputs:
results.append([output.outputs[i].text.strip() for i in range(len(output.outputs))])
return [output.outputs[0].text.strip() for output in outputs]