167 lines
5.2 KiB
Python
167 lines
5.2 KiB
Python
import os
|
|
import json
|
|
import shutil
|
|
import multiprocessing
|
|
from dataclasses import dataclass, field
|
|
from transformers import HfArgumentParser
|
|
|
|
|
|
@dataclass
|
|
class Args():
|
|
generate_model_path: str = field(
|
|
default='Meta-Llama-3-8B',
|
|
metadata={"help": "Generate model path"}
|
|
)
|
|
generate_model_lora_path: str = field(
|
|
default=None,
|
|
metadata={"help": "Generate model path"}
|
|
)
|
|
temperature: float = field(
|
|
default=0.8,
|
|
metadata={"help": "Temperature for generation"}
|
|
)
|
|
gpu_memory_utilization: float = field(
|
|
default=0.8,
|
|
metadata={"help": "GPU memory utilization"}
|
|
)
|
|
top_p: float = field(
|
|
default=1.0,
|
|
metadata={"help": "Top p for generation"}
|
|
)
|
|
max_tokens: int = field(
|
|
default=300,
|
|
metadata={"help": "Max tokens for generation"}
|
|
)
|
|
model_type: str = field(
|
|
default='llm_instruct',
|
|
metadata={"help": "LLM model type"}
|
|
)
|
|
input_dir: str = field(
|
|
default=None,
|
|
metadata={"help": "Input directory", "required": True}
|
|
)
|
|
output_dir: str = field(
|
|
default=None,
|
|
metadata={"help": "Output directory", "required": True}
|
|
)
|
|
num_gpus: int = field(
|
|
default=8,
|
|
metadata={"help": "Number of GPUs"}
|
|
)
|
|
rm_tmp: bool = field(
|
|
default=True,
|
|
metadata={"help": "Remove temporary files"}
|
|
)
|
|
start_gpu: int = field(
|
|
default=0
|
|
)
|
|
|
|
|
|
def worker_function(device):
|
|
parser = HfArgumentParser([Args])
|
|
args: Args = parser.parse_args_into_dataclasses()[0]
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = f'{device}'
|
|
|
|
import torch
|
|
from agent import LLMInstructAgent, LLMAgent
|
|
|
|
print("Available GPUs:", torch.cuda.device_count())
|
|
print("Device:", device)
|
|
|
|
num_splits = args.num_gpus
|
|
|
|
input_dir = args.input_dir
|
|
output_split_dir = os.path.join(args.output_dir, f'tmp_split_{device}')
|
|
|
|
os.makedirs(output_split_dir, exist_ok=True)
|
|
|
|
if args.model_type == 'llm':
|
|
llm = LLMAgent(generate_model_path=args.generate_model_path,
|
|
gpu_memory_utilization=args.gpu_memory_utilization)
|
|
else:
|
|
llm = LLMInstructAgent(generate_model_path=args.generate_model_path,
|
|
gpu_memory_utilization=args.gpu_memory_utilization)
|
|
|
|
print("========================================")
|
|
for file in os.listdir(input_dir):
|
|
if not file.endswith('.json'):
|
|
continue
|
|
|
|
input_path = os.path.join(input_dir, file)
|
|
output_path = os.path.join(output_split_dir, file)
|
|
|
|
prompt_list = json.load(open(input_path, 'r', encoding='utf-8'))
|
|
|
|
length = len(prompt_list)
|
|
start = int(device) * length // num_splits
|
|
if int(device) == num_splits - 1:
|
|
end = length
|
|
else:
|
|
end = (int(device) + 1) * length // num_splits
|
|
|
|
split_prompt_list = prompt_list[start : end]
|
|
|
|
print('----------------------------------------')
|
|
print(f"Processing {input_path} on device {device}, {args.model_type}")
|
|
|
|
output_list = llm.generate(split_prompt_list,
|
|
temperature=args.temperature,
|
|
top_p=args.top_p,
|
|
max_tokens=args.max_tokens,
|
|
stop=[],
|
|
# lora_path=args.generate_model_lora_path
|
|
)
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
json.dump(output_list, f, indent=4)
|
|
|
|
print(f"Finished {input_path} on device {device}")
|
|
|
|
|
|
def merge(args: Args):
|
|
num_splits = args.num_gpus
|
|
|
|
input_dir = args.input_dir
|
|
output_dir = args.output_dir
|
|
output_split_dir_list = [os.path.join(output_dir, f'tmp_split_{i}') for i in range(num_splits)]
|
|
|
|
for file in os.listdir(input_dir):
|
|
if not file.endswith('.json'):
|
|
continue
|
|
|
|
merged_output_list = []
|
|
for output_split_dir in output_split_dir_list:
|
|
# output_split_dir = output_split_dir_list[i]
|
|
output_path = os.path.join(output_split_dir, file)
|
|
merged_output_list.extend(json.load(open(output_path, 'r', encoding='utf-8')))
|
|
|
|
merged_output_path = os.path.join(output_dir, file)
|
|
if os.path.exists(merged_output_path):
|
|
merged_output_list.extend(json.load(open(merged_output_path)))
|
|
with open(merged_output_path, 'w', encoding='utf-8') as f:
|
|
json.dump(merged_output_list, f, indent=4)
|
|
|
|
if args.rm_tmp:
|
|
for output_split_dir in output_split_dir_list:
|
|
shutil.rmtree(output_split_dir)
|
|
print(f"Removed {output_split_dir}")
|
|
|
|
print("Finished merging")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
processes = []
|
|
multiprocessing.set_start_method('spawn')
|
|
parser = HfArgumentParser([Args])
|
|
args: Args = parser.parse_args_into_dataclasses()[0]
|
|
|
|
for i in range(args.num_gpus):
|
|
process = multiprocessing.Process(target=worker_function, args=(i,))
|
|
processes.append(process)
|
|
process.start()
|
|
|
|
for process in processes:
|
|
process.join()
|
|
|
|
merge(args=args) |