embed-bge-m3/FlagEmbedding/research/Reinforced_IR/inference/multi.py

167 lines
5.2 KiB
Python

import os
import json
import shutil
import multiprocessing
from dataclasses import dataclass, field
from transformers import HfArgumentParser
@dataclass
class Args():
generate_model_path: str = field(
default='Meta-Llama-3-8B',
metadata={"help": "Generate model path"}
)
generate_model_lora_path: str = field(
default=None,
metadata={"help": "Generate model path"}
)
temperature: float = field(
default=0.8,
metadata={"help": "Temperature for generation"}
)
gpu_memory_utilization: float = field(
default=0.8,
metadata={"help": "GPU memory utilization"}
)
top_p: float = field(
default=1.0,
metadata={"help": "Top p for generation"}
)
max_tokens: int = field(
default=300,
metadata={"help": "Max tokens for generation"}
)
model_type: str = field(
default='llm_instruct',
metadata={"help": "LLM model type"}
)
input_dir: str = field(
default=None,
metadata={"help": "Input directory", "required": True}
)
output_dir: str = field(
default=None,
metadata={"help": "Output directory", "required": True}
)
num_gpus: int = field(
default=8,
metadata={"help": "Number of GPUs"}
)
rm_tmp: bool = field(
default=True,
metadata={"help": "Remove temporary files"}
)
start_gpu: int = field(
default=0
)
def worker_function(device):
parser = HfArgumentParser([Args])
args: Args = parser.parse_args_into_dataclasses()[0]
os.environ['CUDA_VISIBLE_DEVICES'] = f'{device}'
import torch
from agent import LLMInstructAgent, LLMAgent
print("Available GPUs:", torch.cuda.device_count())
print("Device:", device)
num_splits = args.num_gpus
input_dir = args.input_dir
output_split_dir = os.path.join(args.output_dir, f'tmp_split_{device}')
os.makedirs(output_split_dir, exist_ok=True)
if args.model_type == 'llm':
llm = LLMAgent(generate_model_path=args.generate_model_path,
gpu_memory_utilization=args.gpu_memory_utilization)
else:
llm = LLMInstructAgent(generate_model_path=args.generate_model_path,
gpu_memory_utilization=args.gpu_memory_utilization)
print("========================================")
for file in os.listdir(input_dir):
if not file.endswith('.json'):
continue
input_path = os.path.join(input_dir, file)
output_path = os.path.join(output_split_dir, file)
prompt_list = json.load(open(input_path, 'r', encoding='utf-8'))
length = len(prompt_list)
start = int(device) * length // num_splits
if int(device) == num_splits - 1:
end = length
else:
end = (int(device) + 1) * length // num_splits
split_prompt_list = prompt_list[start : end]
print('----------------------------------------')
print(f"Processing {input_path} on device {device}, {args.model_type}")
output_list = llm.generate(split_prompt_list,
temperature=args.temperature,
top_p=args.top_p,
max_tokens=args.max_tokens,
stop=[],
# lora_path=args.generate_model_lora_path
)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(output_list, f, indent=4)
print(f"Finished {input_path} on device {device}")
def merge(args: Args):
num_splits = args.num_gpus
input_dir = args.input_dir
output_dir = args.output_dir
output_split_dir_list = [os.path.join(output_dir, f'tmp_split_{i}') for i in range(num_splits)]
for file in os.listdir(input_dir):
if not file.endswith('.json'):
continue
merged_output_list = []
for output_split_dir in output_split_dir_list:
# output_split_dir = output_split_dir_list[i]
output_path = os.path.join(output_split_dir, file)
merged_output_list.extend(json.load(open(output_path, 'r', encoding='utf-8')))
merged_output_path = os.path.join(output_dir, file)
if os.path.exists(merged_output_path):
merged_output_list.extend(json.load(open(merged_output_path)))
with open(merged_output_path, 'w', encoding='utf-8') as f:
json.dump(merged_output_list, f, indent=4)
if args.rm_tmp:
for output_split_dir in output_split_dir_list:
shutil.rmtree(output_split_dir)
print(f"Removed {output_split_dir}")
print("Finished merging")
if __name__ == "__main__":
processes = []
multiprocessing.set_start_method('spawn')
parser = HfArgumentParser([Args])
args: Args = parser.parse_args_into_dataclasses()[0]
for i in range(args.num_gpus):
process = multiprocessing.Process(target=worker_function, args=(i,))
processes.append(process)
process.start()
for process in processes:
process.join()
merge(args=args)