293 lines
12 KiB
Python
293 lines
12 KiB
Python
import argparse
|
|
import csv
|
|
import os
|
|
import warnings
|
|
from datetime import timedelta
|
|
|
|
import pandas as pd
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.utils.data import Dataset
|
|
from tqdm import tqdm
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from utils import read_file
|
|
|
|
os.system(f"cp {__file__} ~/backup/") # optionally backup the script
|
|
warnings.filterwarnings("ignore")
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
from torch.distributed.elastic.multiprocessing.errors import record
|
|
|
|
|
|
class CSVTextDataset(Dataset):
|
|
def __init__(self, csv_path):
|
|
self.df = pd.read_csv(csv_path)
|
|
# assert text is in the columns
|
|
assert "text" in self.df.columns, "text column not found in the csv file"
|
|
|
|
def __len__(self):
|
|
return len(self.df)
|
|
|
|
def __getitem__(self, idx):
|
|
if idx < 0 or idx >= len(self.df):
|
|
raise IndexError
|
|
return self.df.iloc[idx]
|
|
|
|
def set_rank_and_world_size(self, rank, world_size):
|
|
self.rank = rank
|
|
self.world_size = world_size
|
|
self.data_per_gpu = len(self) // world_size
|
|
self.start_index = rank * self.data_per_gpu
|
|
self.end_index = (rank + 1) * self.data_per_gpu if rank != world_size - 1 else len(self)
|
|
self.df = self.df.iloc[self.start_index : self.end_index]
|
|
|
|
def write_to_csv(self, output_file, data, new_key):
|
|
"""write the part of the df to a csv file corresponding to the rank and write self.data_list as a new column"""
|
|
writer = csv.writer(open(output_file, "w"))
|
|
columns = self.df.columns + [new_key]
|
|
writer.writerow(columns)
|
|
for index, row in self.df.iterrows():
|
|
if index < self.start_index or index >= self.end_index:
|
|
continue
|
|
writer.writerow([*row, data[index - self.start_index]])
|
|
writer.close()
|
|
|
|
|
|
def pad_left(sequences, padding_value=0):
|
|
# Determine the maximum length of the sequences
|
|
max_len = max([s.size(0) for s in sequences])
|
|
# Create a list to hold the padded sequences
|
|
padded_sequences = []
|
|
for sequence in sequences:
|
|
# Calculate the number of padding elements needed for this sequence
|
|
num_padding = max_len - sequence.size(0)
|
|
# Create a tensor of padding values
|
|
padding = torch.full((num_padding,), padding_value, dtype=sequence.dtype).to(sequence.device)
|
|
# Concatenate the padding and the sequence to pad on the left
|
|
padded_sequence = torch.cat([padding, sequence], dim=0)
|
|
padded_sequences.append(padded_sequence)
|
|
# Stack the padded sequences into a batch
|
|
batch = torch.stack(padded_sequences)
|
|
return batch
|
|
|
|
|
|
@record
|
|
def main(args):
|
|
# ======================================================
|
|
# 1. init environment
|
|
# ======================================================
|
|
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
|
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
|
|
|
# ======================================================
|
|
# 2. Prep rank-wise dataloader
|
|
# ======================================================
|
|
dataframe = read_file(args.input)
|
|
print("read data from {}".format(args.input))
|
|
dataset = CSVTextDataset(args.input)
|
|
dataset.set_rank_and_world_size(dist.get_rank(), dist.get_world_size())
|
|
|
|
import os
|
|
|
|
if os.getenv("DEBUG_ADDRESS") != None and dist.get_rank() == 2:
|
|
import ptvsd
|
|
|
|
print("waiting for debugger attachment")
|
|
ptvsd.enable_attach(address=("localhost", int(os.getenv("DEBUG_ADDRESS"))), redirect_output=True)
|
|
ptvsd.wait_for_attach()
|
|
|
|
# output_file = output_prefix + f"_rank{dist.get_rank()}" + f"_{args.key}.csv"
|
|
# 假设你已经有了 output_prefix 和 parts
|
|
output_prefix = os.path.splitext(args.input)[0]
|
|
parts = "parts"
|
|
|
|
# 找到最后一个 '/'
|
|
last_slash_index = output_prefix.rfind("/")
|
|
|
|
# 如果找到了 '/', 在其后插入 parts
|
|
if last_slash_index != -1:
|
|
output_prefix_total = output_prefix[: last_slash_index + 1] + "/" + output_prefix[last_slash_index + 1 :]
|
|
output_prefix = output_prefix[: last_slash_index + 1] + parts + "/" + output_prefix[last_slash_index + 1 :]
|
|
|
|
output_file = output_prefix + f"_rank{dist.get_rank()}" + f"_objects_actions.csv"
|
|
# 确保目录存在
|
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
|
output_file_handle = open(output_file, "w")
|
|
writer = csv.writer(output_file_handle)
|
|
# columns = list(dataframe.columns) + [args.key]
|
|
columns = list(dataframe.columns) + ["objects"] + ["actions"]
|
|
|
|
writer.writerow(columns)
|
|
|
|
# add a new key named summary, write in csv file
|
|
print("the processed data saved on this rank will be saved to {}".format(output_file))
|
|
|
|
def collate_fn(batch):
|
|
return batch
|
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
# num_workers=2,
|
|
batch_size=args.batch_size,
|
|
collate_fn=collate_fn,
|
|
shuffle=False,
|
|
)
|
|
|
|
# ======================================================
|
|
# 2. process using llama3 and prompt
|
|
# ======================================================
|
|
|
|
print("Using model with the id {}".format(args.model_id))
|
|
model_id = args.model_id
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_id,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map=dist.get_rank() % torch.cuda.device_count(),
|
|
)
|
|
# .to(dist.get_rank() % torch.cuda.device_count())
|
|
dist.barrier()
|
|
print("======== Process data using LLAMA3 ========")
|
|
|
|
def extract_batch(texts, prompt):
|
|
input_ids_list = [
|
|
tokenizer.apply_chat_template(
|
|
[{"role": "system", "content": prompt}, {"role": "user", "content": text}],
|
|
add_generation_prompt=True,
|
|
return_tensors="pt",
|
|
).to(model.device)[0]
|
|
for text in texts
|
|
]
|
|
|
|
attention_mask_list = [
|
|
torch.ones(input_ids.shape, dtype=torch.long, device=model.device) for input_ids in input_ids_list
|
|
]
|
|
|
|
# input_ids_batch = pad_left(
|
|
# input_ids_list, padding_value=tokenizer.eos_token_id
|
|
# )
|
|
|
|
input_ids_batch = torch.nn.utils.rnn.pad_sequence(
|
|
input_ids_list, batch_first=True, padding_value=tokenizer.eos_token_id
|
|
)
|
|
|
|
attention_mask_batch = torch.nn.utils.rnn.pad_sequence(attention_mask_list, batch_first=True, padding_value=0)
|
|
|
|
# attention_mask_batch = pad_left(
|
|
# attention_mask_list, padding_value=0
|
|
# )
|
|
|
|
terminators = [
|
|
tokenizer.eos_token_id,
|
|
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
|
|
]
|
|
outputs = model.generate(
|
|
input_ids_batch,
|
|
max_new_tokens=512,
|
|
attention_mask=attention_mask_batch,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
eos_token_id=terminators,
|
|
# do_sample=True,
|
|
# temperature=0.6,
|
|
# top_p=0.9,
|
|
)
|
|
|
|
responses = []
|
|
for i in range(len(texts)):
|
|
response = outputs[i][input_ids_list[i].shape[-1] :]
|
|
response = tokenizer.decode(response, skip_special_tokens=True)
|
|
responses.append(response)
|
|
|
|
return responses
|
|
|
|
print("Processing starting...")
|
|
# prompt_objects = args.prompt_objects
|
|
# prompt_actions = args.prompt_actions
|
|
# if args.prompt_objects == "":
|
|
prompt_objects = (
|
|
"You are a AI assistant to extract objects from user's text. "
|
|
"For example: user: 'In this video a dog is running around. In addition, a person is laughing at the dog.', you produce a list of objects separated by ',' and wrapped by '[' and ']': '[dog, person]' "
|
|
)
|
|
# if args.prompt_actions == "":
|
|
prompt_actions = (
|
|
"You are a AI assistant to extract actions from user's text. "
|
|
"For example: user: 'In this video a dog is running around. In addition, a person is laughing at the dog.', you produce a list of actions separated by ',' and wrapped by '[' and ']': '[run, laugh]' "
|
|
)
|
|
|
|
print("prompt_objects: {}".format(prompt_objects))
|
|
print("prompt_actions: {}".format(prompt_actions))
|
|
|
|
args.batch_size
|
|
# for i in tqdm(range(0, len(dataframe), batch_size)):
|
|
for _, batch in enumerate(tqdm(dataloader)):
|
|
# get the text column from the batch
|
|
texts = [batch[i]["text"] for i in range(len(batch))]
|
|
# Extract keywords using both prompts
|
|
list_actions = extract_batch(texts, prompt_actions)
|
|
list_objects = extract_batch(texts, prompt_objects)
|
|
|
|
for idx, (actions, objects) in enumerate(zip(list_actions, list_objects)):
|
|
try:
|
|
# Process actions
|
|
if actions.startswith("'") and actions.endswith("'"):
|
|
actions = actions[1:-1]
|
|
actions_start = actions.find("[")
|
|
actions_end = actions.find("]")
|
|
actions = actions[actions_start + 1 : actions_end]
|
|
actions = [act.strip() for act in actions.split(",") if act.strip()]
|
|
|
|
# Process objects
|
|
if objects.startswith("'") and objects.endswith("'"):
|
|
objects = objects[1:-1]
|
|
objects_start = objects.find("[")
|
|
objects_end = objects.find("]")
|
|
objects = objects[objects_start + 1 : objects_end]
|
|
objects = [obj.strip() for obj in objects.split(",") if obj.strip()]
|
|
except:
|
|
actions = "NONE_FOUND"
|
|
objects = "NONE_FOUND"
|
|
|
|
row = batch[idx]
|
|
writer.writerow([*row, objects, actions]) # Adjusted to include both objects and actions
|
|
|
|
output_file_handle.close()
|
|
dist.barrier()
|
|
|
|
if dist.get_rank() == 0:
|
|
# collated_file = output_prefix + f"_{args.key}.csv"
|
|
collated_file = output_prefix_total + f"_objects_actions.csv"
|
|
print("All ranks are finished. Collating the processed data to {}".format(collated_file))
|
|
import pandas as pd
|
|
|
|
csv_files = [output_prefix + f"_rank{i}" + f"_objects_actions.csv" for i in range(dist.get_world_size())]
|
|
# List to hold DataFrames
|
|
dataframes = []
|
|
# Read each CSV into a DataFrame and append to list
|
|
for file in csv_files:
|
|
df = pd.read_csv(file)
|
|
# scan each line in the df, if the ``key`` column is NaN, replace it with "NONE_FOUND"
|
|
# df[args.key] = df[args.key].fillna("NONE_FOUND")
|
|
df["objects"] = df["objects"].fillna("NONE_FOUND")
|
|
df["actions"] = df["actions"].fillna("NONE_FOUND")
|
|
dataframes.append(df)
|
|
# Concatenate all DataFrames
|
|
combined_df = pd.concat(dataframes, ignore_index=True)
|
|
# Save the combined DataFrame to a new CSV file
|
|
combined_df.to_csv(collated_file, index=False)
|
|
print("Collated data saved to {}".format(collated_file))
|
|
# terminate distributed env
|
|
dist.destroy_process_group()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model-id", default="meta-llama/Meta-Llama-3-8B-Instruct")
|
|
parser.add_argument("input", type=str, help="Path to the input CSV file")
|
|
# parser.add_argument("--output_prefix", type=str, help="Path to the output CSV file")
|
|
# parser.add_argument("--prompt_objects", type=str, default="")
|
|
# parser.add_argument("--prompt_actions", type=str, default="")
|
|
parser.add_argument("--batch_size", type=int, default=32)
|
|
# parser.add_argument("--key", type=str)
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|