mysora/tools/caption/caption_llava_next.py

380 lines
15 KiB
Python

# code modified based on https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/playground/demo/video_demo.py
import argparse
import base64
import csv
import math
import os
import warnings
import cv2
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from decord import VideoReader, cpu
from llava.constants import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llava.conversation import conv_templates
from llava.mm_utils import get_model_name_from_path, tokenizer_image_token
from llava.model.builder import load_pretrained_model
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoConfig, BitsAndBytesConfig
warnings.filterwarnings("ignore")
PAD_TOKEN_ID = 151643
# Function to initialize the distributed environment
def setup(rank, world_size):
print(f"Setting up process {rank} of {world_size}")
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.cuda.set_device(rank)
# Initialize the process group for communication
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
# Cleanup after inference is done
def cleanup():
dist.destroy_process_group()
class VideoDataset:
def __init__(self, df, args, rank, world_size, image_processor, model, tokenizer):
self.df = df
self.rank = rank
self.world_size = world_size
self.args = args
self.image_processor = image_processor
self.model = model
self.tokenizer = tokenizer
self.length = len([i for i in range(self.rank, len(self.df), self.world_size)])
def __len__(self):
return self.length
def __getitem__(self, idx):
row = self.df.iloc[self.rank + idx * self.world_size]
video_path = row["path"]
info = row
sample_set = {}
question = self.args.prompt
sample_set["video_name"] = video_path
if os.path.exists(video_path):
video, frame_time, video_time = load_video(video_path, self.args)
video = self.image_processor.preprocess(video, return_tensors="pt")["pixel_values"].half()
sample_set["video"] = video
sample_set["frame_time"] = frame_time
sample_set["video_time"] = video_time
sample_set["info"] = info
if self.args.add_time_instruction:
time_instruction = (
f"The video lasts for {video_time:.2f} seconds, and "
f"{self.args.for_get_frames_num} frames are uniformly sampled from it. "
f"These frames are located at {frame_time}. "
f"Please answer the following questions related to this video."
)
qs = f"{time_instruction}\n{question}"
else:
qs = question
if self.model.config.mm_use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs
else:
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
conv = conv_templates[self.args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(
0
)
# print("input_ids", input_ids)
sample_set["input_ids"] = input_ids
attention_masks = input_ids.ne(self.tokenizer.pad_token_id).long()
sample_set["attention_masks"] = attention_masks
return sample_set
def collate_fn(batch):
# Collate function to handle dynamic padding or combining elements of the batch
videos = [item["video"] for item in batch if "video" in item]
input_ids = [item["input_ids"] for item in batch]
max_len = max([item.shape[1] for item in input_ids])
# pad token id PAD_TOKEN_ID
input_ids = [torch.nn.functional.pad(item, (max_len - item.shape[1], 0), value=PAD_TOKEN_ID) for item in input_ids]
input_ids = torch.cat(input_ids, dim=0)
attention_masks = [item["attention_masks"] for item in batch]
attention_masks = [torch.nn.functional.pad(item, (max_len - item.shape[1], 0), value=0) for item in attention_masks]
attention_masks = torch.cat(attention_masks, dim=0)
video_names = [item["video_name"] for item in batch]
frame_times = [item["frame_time"] for item in batch]
video_times = [item["video_time"] for item in batch]
infos = [item["info"] for item in batch]
return {
"input_ids": input_ids,
"attention_masks": attention_masks,
"videos": videos,
"video_names": video_names,
"frame_times": frame_times,
"video_times": video_times,
"infos": infos,
}
def create_dataloader(df, args, rank, world_size, image_processor, model, tokenizer):
dataset = VideoDataset(df, args, rank, world_size, image_processor, model, tokenizer)
return DataLoader(
dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=args.num_workers
)
def split_list(lst, n):
"""Split a list into n (roughly) equal-sized chunks"""
chunk_size = math.ceil(len(lst) / n) # integer division
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
def get_chunk(lst, n, k):
chunks = split_list(lst, n)
return chunks[k]
def parse_args():
"""
Parse command-line arguments.
"""
parser = argparse.ArgumentParser()
# Define the command-line arguments
parser.add_argument("--data_file", help="Path to the video dataset file.", required=True)
parser.add_argument("--output_folder", help="Path to the output file.", required=True)
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--conv-mode", type=str, default=None)
parser.add_argument("--chunk-idx", type=int, default=0)
parser.add_argument("--mm_resampler_type", type=str, default="spatial_pool")
parser.add_argument("--mm_spatial_pool_stride", type=int, default=4)
parser.add_argument("--mm_spatial_pool_out_channels", type=int, default=1024)
parser.add_argument("--mm_spatial_pool_mode", type=str, default="average")
parser.add_argument("--image_aspect_ratio", type=str, default="anyres")
parser.add_argument(
"--image_grid_pinpoints",
type=str,
default="[(224, 448), (224, 672), (224, 896), (448, 448), (448, 224), (672, 224), (896, 224)]",
)
parser.add_argument("--mm_patch_merge_type", type=str, default="spatial_unpad")
parser.add_argument("--overwrite", type=lambda x: (str(x).lower() == "true"), default=True)
parser.add_argument("--for_get_frames_num", type=int, default=4)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=10)
parser.add_argument("--load_8bit", type=lambda x: (str(x).lower() == "true"), default=False)
parser.add_argument("--prompt", type=str, default=None)
parser.add_argument("--api_key", type=str, help="OpenAI API key")
parser.add_argument("--mm_newline_position", type=str, default="no_token")
parser.add_argument("--force_sample", type=lambda x: (str(x).lower() == "true"), default=False)
parser.add_argument("--add_time_instruction", type=str, default=False)
return parser.parse_args()
def load_video(video_path, args):
if args.for_get_frames_num == 0:
return np.zeros((1, 336, 336, 3))
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
total_frame_num = len(vr)
video_time = total_frame_num / vr.get_avg_fps()
fps = round(vr.get_avg_fps())
frame_idx = [i for i in range(0, len(vr), fps)]
frame_time = [i / fps for i in frame_idx]
if len(frame_idx) > args.for_get_frames_num or args.force_sample:
sample_fps = args.for_get_frames_num
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
frame_time = [i / vr.get_avg_fps() for i in frame_idx]
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
spare_frames = vr.get_batch(frame_idx).asnumpy()
# import pdb;pdb.set_trace()
return spare_frames, frame_time, video_time
def load_video_base64(path):
video = cv2.VideoCapture(path)
base64Frames = []
while video.isOpened():
success, frame = video.read()
if not success:
break
_, buffer = cv2.imencode(".jpg", frame)
base64Frames.append(base64.b64encode(buffer).decode("utf-8"))
video.release()
# print(len(base64Frames), "frames read.")
return base64Frames
def run_inference(rank, world_size, args):
"""
Run inference on ActivityNet QA DataSet using the Video-ChatGPT model.
Args:
args: Command-line arguments.
"""
setup(rank, world_size)
device = torch.device(f"cuda:{rank}")
# Initialize the model
model_name = get_model_name_from_path(args.model_path)
# Set model configuration parameters if they exist
if args.overwrite == True:
overwrite_config = {}
overwrite_config["mm_spatial_pool_mode"] = args.mm_spatial_pool_mode
overwrite_config["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
overwrite_config["mm_newline_position"] = args.mm_newline_position
cfg_pretrained = AutoConfig.from_pretrained(args.model_path)
# import pdb;pdb.set_trace()
if "qwen" not in args.model_path.lower():
if "224" in cfg_pretrained.mm_vision_tower:
# suppose the length of text tokens is around 1000, from bo's report
least_token_number = args.for_get_frames_num * (16 // args.mm_spatial_pool_stride) ** 2 + 1000
else:
least_token_number = args.for_get_frames_num * (24 // args.mm_spatial_pool_stride) ** 2 + 1000
scaling_factor = math.ceil(least_token_number / 4096)
if scaling_factor >= 2:
if "vicuna" in cfg_pretrained._name_or_path.lower():
print(float(scaling_factor))
overwrite_config["rope_scaling"] = {"factor": float(scaling_factor), "type": "linear"}
overwrite_config["max_sequence_length"] = 4096 * scaling_factor
overwrite_config["tokenizer_model_max_length"] = 4096 * scaling_factor
if args.load_8bit:
quantization_config = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.bfloat16)
tokenizer, model, image_processor, context_len = load_pretrained_model(
args.model_path,
args.model_base,
model_name,
device_map=device,
quantization_config=quantization_config,
overwrite_config=overwrite_config,
)
else:
tokenizer, model, image_processor, context_len = load_pretrained_model(
args.model_path, args.model_base, model_name, device_map=device, overwrite_config=overwrite_config
)
else:
tokenizer, model, image_processor, context_len = load_pretrained_model(
args.model_path, args.model_base, model_name, device_map=device
)
if tokenizer.pad_token_id is None:
if "qwen" in tokenizer.name_or_path.lower():
# print("Setting pad token to bos token for qwen model.")
tokenizer.pad_token_id = PAD_TOKEN_ID
if args.batch_size > 1:
tokenizer.padding_side = "left"
model.config.tokenizer_padding_side = "left"
# model = DDP(model, device_ids=[rank])
# import pdb;pdb.set_trace()
if getattr(model.config, "force_sample", None) is not None:
args.force_sample = model.config.force_sample
else:
args.force_sample = False
if getattr(model.config, "add_time_instruction", None) is not None:
args.add_time_instruction = model.config.add_time_instruction
else:
args.add_time_instruction = False
df = pd.read_csv(args.data_file)
data_name = os.path.basename(args.data_file).split(".csv")[0]
column_names = df.columns.to_list()
if "text" not in column_names:
column_names.append("text")
text_column_index = column_names.index("text")
output_file = os.path.join(args.output_folder, f"{data_name}_{rank}.csv")
with open(output_file, "w", newline="") as csvfile:
csvwriter = csv.writer(csvfile)
csvwriter.writerow(column_names)
dataloader = create_dataloader(df, args, rank, world_size, image_processor, model, tokenizer)
for batch in tqdm(dataloader):
videos = [item.to(device) for item in batch["videos"]]
input_ids = batch["input_ids"].to(device)
attention_masks = batch["attention_masks"].to(device)
infos = batch["infos"]
stop_str = "###"
with torch.inference_mode():
modalities = ["video"] * len(videos)
if "mistral" not in cfg_pretrained._name_or_path.lower():
output_ids = model.generate(
inputs=input_ids,
images=videos,
attention_mask=attention_masks,
modalities=modalities,
do_sample=False,
temperature=0.0,
max_new_tokens=1024,
top_p=0.1,
num_beams=1,
use_cache=True,
)
# output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=True, temperature=0.2, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria])
else:
output_ids = model.generate(
inputs=input_ids,
images=videos,
attention_mask=attention_masks,
modalities=modalities,
do_sample=False,
temperature=0.0,
max_new_tokens=1024,
top_p=0.1,
num_beams=1,
use_cache=True,
)
# output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=True, temperature=0.2, max_new_tokens=1024, use_cache=True)
# print("output_ids", output_ids)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
# print("outputs", outputs)
outputs = [output.split(stop_str)[0].strip() for output in outputs]
if len(infos[0]) < len(column_names):
for i in range(len(infos)):
infos[i].append(outputs[i])
else:
for i in range(len(infos)):
infos[i][text_column_index] = outputs[i]
# write to csv
for row in infos:
csvwriter.writerow(row)
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running inference on {world_size} GPUs.")
args = parse_args()
# Spawn one process per GPU
mp.spawn(run_inference, args=(world_size, args), nprocs=world_size, join=True)