380 lines
15 KiB
Python
380 lines
15 KiB
Python
# code modified based on https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/playground/demo/video_demo.py
|
|
|
|
import argparse
|
|
import base64
|
|
import csv
|
|
import math
|
|
import os
|
|
import warnings
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
from decord import VideoReader, cpu
|
|
from llava.constants import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
|
|
from llava.conversation import conv_templates
|
|
from llava.mm_utils import get_model_name_from_path, tokenizer_image_token
|
|
from llava.model.builder import load_pretrained_model
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
from transformers import AutoConfig, BitsAndBytesConfig
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
PAD_TOKEN_ID = 151643
|
|
|
|
|
|
# Function to initialize the distributed environment
|
|
def setup(rank, world_size):
|
|
print(f"Setting up process {rank} of {world_size}")
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
os.environ["MASTER_PORT"] = "12355"
|
|
torch.cuda.set_device(rank)
|
|
# Initialize the process group for communication
|
|
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
|
|
|
|
|
# Cleanup after inference is done
|
|
def cleanup():
|
|
dist.destroy_process_group()
|
|
|
|
|
|
class VideoDataset:
|
|
def __init__(self, df, args, rank, world_size, image_processor, model, tokenizer):
|
|
self.df = df
|
|
self.rank = rank
|
|
self.world_size = world_size
|
|
self.args = args
|
|
self.image_processor = image_processor
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.length = len([i for i in range(self.rank, len(self.df), self.world_size)])
|
|
|
|
def __len__(self):
|
|
return self.length
|
|
|
|
def __getitem__(self, idx):
|
|
row = self.df.iloc[self.rank + idx * self.world_size]
|
|
video_path = row["path"]
|
|
info = row
|
|
sample_set = {}
|
|
question = self.args.prompt
|
|
sample_set["video_name"] = video_path
|
|
|
|
if os.path.exists(video_path):
|
|
video, frame_time, video_time = load_video(video_path, self.args)
|
|
video = self.image_processor.preprocess(video, return_tensors="pt")["pixel_values"].half()
|
|
sample_set["video"] = video
|
|
sample_set["frame_time"] = frame_time
|
|
sample_set["video_time"] = video_time
|
|
sample_set["info"] = info
|
|
|
|
if self.args.add_time_instruction:
|
|
time_instruction = (
|
|
f"The video lasts for {video_time:.2f} seconds, and "
|
|
f"{self.args.for_get_frames_num} frames are uniformly sampled from it. "
|
|
f"These frames are located at {frame_time}. "
|
|
f"Please answer the following questions related to this video."
|
|
)
|
|
qs = f"{time_instruction}\n{question}"
|
|
else:
|
|
qs = question
|
|
if self.model.config.mm_use_im_start_end:
|
|
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs
|
|
else:
|
|
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
|
|
|
|
conv = conv_templates[self.args.conv_mode].copy()
|
|
conv.append_message(conv.roles[0], qs)
|
|
conv.append_message(conv.roles[1], None)
|
|
prompt = conv.get_prompt()
|
|
|
|
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(
|
|
0
|
|
)
|
|
# print("input_ids", input_ids)
|
|
sample_set["input_ids"] = input_ids
|
|
attention_masks = input_ids.ne(self.tokenizer.pad_token_id).long()
|
|
sample_set["attention_masks"] = attention_masks
|
|
|
|
return sample_set
|
|
|
|
|
|
def collate_fn(batch):
|
|
# Collate function to handle dynamic padding or combining elements of the batch
|
|
videos = [item["video"] for item in batch if "video" in item]
|
|
input_ids = [item["input_ids"] for item in batch]
|
|
max_len = max([item.shape[1] for item in input_ids])
|
|
# pad token id PAD_TOKEN_ID
|
|
input_ids = [torch.nn.functional.pad(item, (max_len - item.shape[1], 0), value=PAD_TOKEN_ID) for item in input_ids]
|
|
input_ids = torch.cat(input_ids, dim=0)
|
|
attention_masks = [item["attention_masks"] for item in batch]
|
|
attention_masks = [torch.nn.functional.pad(item, (max_len - item.shape[1], 0), value=0) for item in attention_masks]
|
|
attention_masks = torch.cat(attention_masks, dim=0)
|
|
|
|
video_names = [item["video_name"] for item in batch]
|
|
frame_times = [item["frame_time"] for item in batch]
|
|
video_times = [item["video_time"] for item in batch]
|
|
infos = [item["info"] for item in batch]
|
|
|
|
return {
|
|
"input_ids": input_ids,
|
|
"attention_masks": attention_masks,
|
|
"videos": videos,
|
|
"video_names": video_names,
|
|
"frame_times": frame_times,
|
|
"video_times": video_times,
|
|
"infos": infos,
|
|
}
|
|
|
|
|
|
def create_dataloader(df, args, rank, world_size, image_processor, model, tokenizer):
|
|
dataset = VideoDataset(df, args, rank, world_size, image_processor, model, tokenizer)
|
|
return DataLoader(
|
|
dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=args.num_workers
|
|
)
|
|
|
|
|
|
def split_list(lst, n):
|
|
"""Split a list into n (roughly) equal-sized chunks"""
|
|
chunk_size = math.ceil(len(lst) / n) # integer division
|
|
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
|
|
|
|
|
def get_chunk(lst, n, k):
|
|
chunks = split_list(lst, n)
|
|
return chunks[k]
|
|
|
|
|
|
def parse_args():
|
|
"""
|
|
Parse command-line arguments.
|
|
"""
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# Define the command-line arguments
|
|
parser.add_argument("--data_file", help="Path to the video dataset file.", required=True)
|
|
parser.add_argument("--output_folder", help="Path to the output file.", required=True)
|
|
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
|
parser.add_argument("--model-base", type=str, default=None)
|
|
parser.add_argument("--conv-mode", type=str, default=None)
|
|
parser.add_argument("--chunk-idx", type=int, default=0)
|
|
parser.add_argument("--mm_resampler_type", type=str, default="spatial_pool")
|
|
parser.add_argument("--mm_spatial_pool_stride", type=int, default=4)
|
|
parser.add_argument("--mm_spatial_pool_out_channels", type=int, default=1024)
|
|
parser.add_argument("--mm_spatial_pool_mode", type=str, default="average")
|
|
parser.add_argument("--image_aspect_ratio", type=str, default="anyres")
|
|
parser.add_argument(
|
|
"--image_grid_pinpoints",
|
|
type=str,
|
|
default="[(224, 448), (224, 672), (224, 896), (448, 448), (448, 224), (672, 224), (896, 224)]",
|
|
)
|
|
parser.add_argument("--mm_patch_merge_type", type=str, default="spatial_unpad")
|
|
parser.add_argument("--overwrite", type=lambda x: (str(x).lower() == "true"), default=True)
|
|
parser.add_argument("--for_get_frames_num", type=int, default=4)
|
|
parser.add_argument("--batch_size", type=int, default=1)
|
|
parser.add_argument("--num_workers", type=int, default=10)
|
|
parser.add_argument("--load_8bit", type=lambda x: (str(x).lower() == "true"), default=False)
|
|
parser.add_argument("--prompt", type=str, default=None)
|
|
parser.add_argument("--api_key", type=str, help="OpenAI API key")
|
|
parser.add_argument("--mm_newline_position", type=str, default="no_token")
|
|
parser.add_argument("--force_sample", type=lambda x: (str(x).lower() == "true"), default=False)
|
|
parser.add_argument("--add_time_instruction", type=str, default=False)
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_video(video_path, args):
|
|
if args.for_get_frames_num == 0:
|
|
return np.zeros((1, 336, 336, 3))
|
|
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
|
total_frame_num = len(vr)
|
|
video_time = total_frame_num / vr.get_avg_fps()
|
|
fps = round(vr.get_avg_fps())
|
|
frame_idx = [i for i in range(0, len(vr), fps)]
|
|
frame_time = [i / fps for i in frame_idx]
|
|
if len(frame_idx) > args.for_get_frames_num or args.force_sample:
|
|
sample_fps = args.for_get_frames_num
|
|
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
|
|
frame_idx = uniform_sampled_frames.tolist()
|
|
frame_time = [i / vr.get_avg_fps() for i in frame_idx]
|
|
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
|
|
spare_frames = vr.get_batch(frame_idx).asnumpy()
|
|
# import pdb;pdb.set_trace()
|
|
|
|
return spare_frames, frame_time, video_time
|
|
|
|
|
|
def load_video_base64(path):
|
|
video = cv2.VideoCapture(path)
|
|
|
|
base64Frames = []
|
|
while video.isOpened():
|
|
success, frame = video.read()
|
|
if not success:
|
|
break
|
|
_, buffer = cv2.imencode(".jpg", frame)
|
|
base64Frames.append(base64.b64encode(buffer).decode("utf-8"))
|
|
|
|
video.release()
|
|
# print(len(base64Frames), "frames read.")
|
|
return base64Frames
|
|
|
|
|
|
def run_inference(rank, world_size, args):
|
|
"""
|
|
Run inference on ActivityNet QA DataSet using the Video-ChatGPT model.
|
|
|
|
Args:
|
|
args: Command-line arguments.
|
|
"""
|
|
setup(rank, world_size)
|
|
|
|
device = torch.device(f"cuda:{rank}")
|
|
# Initialize the model
|
|
model_name = get_model_name_from_path(args.model_path)
|
|
# Set model configuration parameters if they exist
|
|
if args.overwrite == True:
|
|
overwrite_config = {}
|
|
overwrite_config["mm_spatial_pool_mode"] = args.mm_spatial_pool_mode
|
|
overwrite_config["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
|
|
overwrite_config["mm_newline_position"] = args.mm_newline_position
|
|
|
|
cfg_pretrained = AutoConfig.from_pretrained(args.model_path)
|
|
|
|
# import pdb;pdb.set_trace()
|
|
if "qwen" not in args.model_path.lower():
|
|
if "224" in cfg_pretrained.mm_vision_tower:
|
|
# suppose the length of text tokens is around 1000, from bo's report
|
|
least_token_number = args.for_get_frames_num * (16 // args.mm_spatial_pool_stride) ** 2 + 1000
|
|
else:
|
|
least_token_number = args.for_get_frames_num * (24 // args.mm_spatial_pool_stride) ** 2 + 1000
|
|
|
|
scaling_factor = math.ceil(least_token_number / 4096)
|
|
if scaling_factor >= 2:
|
|
if "vicuna" in cfg_pretrained._name_or_path.lower():
|
|
print(float(scaling_factor))
|
|
overwrite_config["rope_scaling"] = {"factor": float(scaling_factor), "type": "linear"}
|
|
overwrite_config["max_sequence_length"] = 4096 * scaling_factor
|
|
overwrite_config["tokenizer_model_max_length"] = 4096 * scaling_factor
|
|
if args.load_8bit:
|
|
quantization_config = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.bfloat16)
|
|
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
|
args.model_path,
|
|
args.model_base,
|
|
model_name,
|
|
device_map=device,
|
|
quantization_config=quantization_config,
|
|
overwrite_config=overwrite_config,
|
|
)
|
|
else:
|
|
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
|
args.model_path, args.model_base, model_name, device_map=device, overwrite_config=overwrite_config
|
|
)
|
|
else:
|
|
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
|
args.model_path, args.model_base, model_name, device_map=device
|
|
)
|
|
|
|
if tokenizer.pad_token_id is None:
|
|
if "qwen" in tokenizer.name_or_path.lower():
|
|
# print("Setting pad token to bos token for qwen model.")
|
|
tokenizer.pad_token_id = PAD_TOKEN_ID
|
|
|
|
if args.batch_size > 1:
|
|
tokenizer.padding_side = "left"
|
|
model.config.tokenizer_padding_side = "left"
|
|
|
|
# model = DDP(model, device_ids=[rank])
|
|
|
|
# import pdb;pdb.set_trace()
|
|
if getattr(model.config, "force_sample", None) is not None:
|
|
args.force_sample = model.config.force_sample
|
|
else:
|
|
args.force_sample = False
|
|
|
|
if getattr(model.config, "add_time_instruction", None) is not None:
|
|
args.add_time_instruction = model.config.add_time_instruction
|
|
else:
|
|
args.add_time_instruction = False
|
|
|
|
df = pd.read_csv(args.data_file)
|
|
data_name = os.path.basename(args.data_file).split(".csv")[0]
|
|
column_names = df.columns.to_list()
|
|
if "text" not in column_names:
|
|
column_names.append("text")
|
|
text_column_index = column_names.index("text")
|
|
|
|
output_file = os.path.join(args.output_folder, f"{data_name}_{rank}.csv")
|
|
with open(output_file, "w", newline="") as csvfile:
|
|
csvwriter = csv.writer(csvfile)
|
|
csvwriter.writerow(column_names)
|
|
|
|
dataloader = create_dataloader(df, args, rank, world_size, image_processor, model, tokenizer)
|
|
|
|
for batch in tqdm(dataloader):
|
|
videos = [item.to(device) for item in batch["videos"]]
|
|
input_ids = batch["input_ids"].to(device)
|
|
attention_masks = batch["attention_masks"].to(device)
|
|
infos = batch["infos"]
|
|
stop_str = "###"
|
|
with torch.inference_mode():
|
|
modalities = ["video"] * len(videos)
|
|
if "mistral" not in cfg_pretrained._name_or_path.lower():
|
|
output_ids = model.generate(
|
|
inputs=input_ids,
|
|
images=videos,
|
|
attention_mask=attention_masks,
|
|
modalities=modalities,
|
|
do_sample=False,
|
|
temperature=0.0,
|
|
max_new_tokens=1024,
|
|
top_p=0.1,
|
|
num_beams=1,
|
|
use_cache=True,
|
|
)
|
|
# output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=True, temperature=0.2, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria])
|
|
else:
|
|
output_ids = model.generate(
|
|
inputs=input_ids,
|
|
images=videos,
|
|
attention_mask=attention_masks,
|
|
modalities=modalities,
|
|
do_sample=False,
|
|
temperature=0.0,
|
|
max_new_tokens=1024,
|
|
top_p=0.1,
|
|
num_beams=1,
|
|
use_cache=True,
|
|
)
|
|
# output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=True, temperature=0.2, max_new_tokens=1024, use_cache=True)
|
|
# print("output_ids", output_ids)
|
|
|
|
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
|
|
|
# print("outputs", outputs)
|
|
outputs = [output.split(stop_str)[0].strip() for output in outputs]
|
|
if len(infos[0]) < len(column_names):
|
|
for i in range(len(infos)):
|
|
infos[i].append(outputs[i])
|
|
else:
|
|
for i in range(len(infos)):
|
|
infos[i][text_column_index] = outputs[i]
|
|
|
|
# write to csv
|
|
for row in infos:
|
|
csvwriter.writerow(row)
|
|
|
|
cleanup()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
world_size = torch.cuda.device_count()
|
|
print(f"Running inference on {world_size} GPUs.")
|
|
args = parse_args()
|
|
|
|
# Spawn one process per GPU
|
|
mp.spawn(run_inference, args=(world_size, args), nprocs=world_size, join=True)
|