embed-bge-m3/FlagEmbedding/research/MLVU/evaluation/models/videochat2/open_bench.py

473 lines
16 KiB
Python

from utils.config import Config
config_file = "configs/config.json"
cfg = Config.from_file(config_file)
import os
import io
import json
from models import VideoChat2_it_vicuna
from utils.easydict import EasyDict
import torch
from transformers import StoppingCriteria, StoppingCriteriaList
from PIL import Image
import numpy as np
import numpy as np
from decord import VideoReader, cpu
import torchvision.transforms as T
from dataset.video_transforms import (
GroupNormalize, GroupScale, GroupCenterCrop,
Stack, ToTorchFormatTensor
)
from torchvision.transforms.functional import InterpolationMode
from torch.utils.data import Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from IPython.display import Video, HTML
from peft import get_peft_model, LoraConfig, TaskType
import copy
# load stage2 model
cfg.model.vision_encoder.num_frames = 4
model = VideoChat2_it_vicuna(config=cfg.model)
model = model.to(torch.device(cfg.device))
model = model.eval().cuda()
# add lora to run stage3 model
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=False,
r=16, lora_alpha=32, lora_dropout=0.
)
model.llama_model = get_peft_model(model.llama_model, peft_config)
state_dict = torch.load("checkpoints/videochat2_7b_stage3.pth", "cpu")
if 'model' in state_dict.keys():
msg = model.load_state_dict(state_dict['model'], strict=False)
else:
msg = model.load_state_dict(state_dict, strict=False)
# print("#########",msg)
model = model.eval().cuda()
def get_prompt(conv):
ret = conv.system + conv.sep
for role, message in conv.messages:
if message:
ret += role + ": " + message + conv.sep
else:
ret += role + ":"
return ret
def get_prompt2(conv):
ret = conv.system + conv.sep
count = 0
for role, message in conv.messages:
count += 1
if count == len(conv.messages):
ret += role + ": " + message
else:
if message:
ret += role + ": " + message + conv.sep
else:
ret += role + ":"
return ret
def get_context_emb(conv, model, img_list, answer_prompt=None, print_res=False):
if answer_prompt:
prompt = get_prompt2(conv)
else:
prompt = get_prompt(conv)
print("#############")
print("prompt********",prompt)
print("##############")
if print_res:
print(prompt)
if '<VideoHere>' in prompt:
prompt_segs = prompt.split('<VideoHere>')
else:
prompt_segs = prompt.split('<ImageHere>')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
with torch.no_grad():
seg_tokens = [
model.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=i == 0).to("cuda:0").input_ids
# only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
seg_embs = [model.llama_model.base_model.model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs
def ask(text, conv):
conv.messages.append([conv.roles[0], text + '\n'])
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all((stop == input_ids[0][-len(stop):])).item():
return True
return False
def answer(conv, model, img_list, do_sample=True, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.0, length_penalty=1, temperature=1.0, answer_prompt=None, print_res=False):
stop_words_ids = [
torch.tensor([835]).to("cuda:0"),
torch.tensor([2277, 29937]).to("cuda:0")] # '###' can be encoded in two different ways.
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
conv.messages.append([conv.roles[1], answer_prompt])
embs = get_context_emb(conv, model, img_list, answer_prompt=answer_prompt, print_res=print_res)
with torch.no_grad():
outputs = model.llama_model.generate(
inputs_embeds=embs,
max_new_tokens=max_new_tokens,
stopping_criteria=stopping_criteria,
num_beams=num_beams,
do_sample=do_sample,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
temperature=temperature,
)
output_token = outputs[0]
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
output_token = output_token[1:]
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
output_token = output_token[1:]
output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)
output_text = output_text.split('###')[0] # remove the stop sign '###'
output_text = output_text.split('Assistant:')[-1].strip()
conv.messages[-1][1] = output_text
return output_text, output_token.cpu().numpy()
def get_index(num_frames, num_segments):
seg_size = float(num_frames - 1) / num_segments
start = int(seg_size / 2)
offsets = np.array([
start + int(np.round(seg_size * idx)) for idx in range(num_segments)
])
return offsets
def load_video(video_path, num_segments=8, return_msg=False, resolution=224):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
num_frames = len(vr)
frame_indices = get_index(num_frames, num_segments)
# transform
crop_size = resolution
scale_size = resolution
input_mean = [0.48145466, 0.4578275, 0.40821073]
input_std = [0.26862954, 0.26130258, 0.27577711]
transform = T.Compose([
GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
GroupCenterCrop(crop_size),
Stack(),
ToTorchFormatTensor(),
GroupNormalize(input_mean, input_std)
])
images_group = list()
for frame_index in frame_indices:
img = Image.fromarray(vr[frame_index].numpy())
images_group.append(img)
torch_imgs = transform(images_group)
if return_msg:
fps = float(vr.get_avg_fps())
sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
# " " should be added in the start and end
msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
return torch_imgs, msg
else:
return torch_imgs
def get_sinusoid_encoding_table(n_position=784, d_hid=1024, cur_frame=8, ckpt_num_frame=4, pre_n_position=784):
''' Sinusoid position encoding table '''
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
# generate checkpoint position embedding
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(pre_n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
print(f"n_position: {n_position}")
print(f"pre_n_position: {pre_n_position}")
if n_position != pre_n_position:
T = ckpt_num_frame # checkpoint frame
P = 14 # checkpoint size
C = d_hid
new_P = int((n_position // cur_frame) ** 0.5) # testing size
if new_P != 14:
print(f'Pretraining uses 14x14, but current version is {new_P}x{new_P}')
print(f'Interpolate the position embedding')
sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
sinusoid_table = sinusoid_table.reshape(-1, P, P, C).permute(0, 3, 1, 2)
sinusoid_table = torch.nn.functional.interpolate(
sinusoid_table, size=(new_P, new_P), mode='bicubic', align_corners=False)
# BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
sinusoid_table = sinusoid_table.permute(0, 2, 3, 1).reshape(-1, T, new_P, new_P, C)
sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C
if cur_frame != ckpt_num_frame:
print(f'Pretraining uses 4 frames, but current frame is {cur_frame}')
print(f'Interpolate the position embedding')
T = ckpt_num_frame # checkpoint frame
new_T = cur_frame # testing frame
# interpolate
P = int((n_position // cur_frame) ** 0.5) # testing size
C = d_hid
sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T) # BHW, C, T
sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear')
sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C
sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C
return sinusoid_table
data_list = {
"subPlot": ("8_sub_scene.json", f"/MLVU_all/video/subPlot", "video", False),
"summary": ("9_summary.json", f"/MLVU_all/video/summary", "video", False)
}
data_dir = f"/MLVU_all/json"
class MLVU(Dataset):
def __init__(self, data_dir, data_list, num_segments=8, resolution=224):
self.data_list = []
for k, v in data_list.items():
with open(os.path.join(data_dir, v[0]), 'r') as f:
json_data = json.load(f)
for data in json_data:
self.data_list.append({
'task_type': k,
'prefix': v[1],
'data_type': v[2],
'data': data
})
self.decord_method = {
'video': self.read_video
}
self.num_segments = num_segments
# transform
crop_size = resolution
scale_size = resolution
input_mean = [0.48145466, 0.4578275, 0.40821073]
input_std = [0.26862954, 0.26130258, 0.27577711]
self.transform = T.Compose([
GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
GroupCenterCrop(crop_size),
Stack(),
ToTorchFormatTensor(),
GroupNormalize(input_mean, input_std)
])
def __str__(self):
len_list = {}
option_list = {}
for data in self.data_list:
if data['task_type'] not in len_list:
len_list[data['task_type']] = 0
len_list[data['task_type']] += 1
if data['task_type'] not in option_list:
option_list[data['task_type']] = 0
option_list[data['task_type']] += len(data['data']['candidates'])
correct = 0
total = 0
res = f"There are {len(self.data_list)} videos as follow:\n"
for k, v in len_list.items():
correct += len_list[k]
total += option_list[k]
res += f"{v} for {k} ({option_list[k]} options => {len_list[k]/option_list[k]*100:.2f}%)\n"
correct = correct + 1 / option_list[k]
res += f"Total random accuracy: {correct/total*100:.2f}%"
return res.rstrip()
def __len__(self):
return len(self.data_list)
def get_index(self, bound, fps, max_frame, first_idx=0):
if bound:
start, end = bound[0], bound[1]
else:
start, end = -100000, 100000
start_idx = max(first_idx, round(start * fps))
end_idx = min(round(end * fps), max_frame)
seg_size = float(end_idx - start_idx) / self.num_segments
frame_indices = np.array([
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
for idx in range(self.num_segments)
])
return frame_indices
def read_video(self, video_path, bound=None):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
max_frame = len(vr) - 1
fps = float(vr.get_avg_fps())
images_group = list()
frame_indices = self.get_index(bound, fps, max_frame, first_idx=0)
for frame_index in frame_indices:
img = Image.fromarray(vr[frame_index].numpy())
images_group.append(img)
torch_imgs = self.transform(images_group)
return torch_imgs
def qa_template(self, data):
question = f"{data['question']}"
answer = data['answer']
return question, answer
def __getitem__(self, idx):
decord_method = self.decord_method[self.data_list[idx]['data_type']]
bound = None
video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video'])
torch_imgs = decord_method(video_path, bound)
question, answer = self.qa_template(self.data_list[idx]['data'])
return {
'video': torch_imgs,
'question': question,
'answer': answer,
'task_type': self.data_list[idx]['task_type'],
'video_path': video_path
}
# position embedding
num_frame = 16
resolution = 224
new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2*num_frame, cur_frame=num_frame)
model.vision_encoder.encoder.pos_embed = new_pos_emb
dataset = MLVU(data_dir, data_list, num_segments=num_frame, resolution=resolution)
def infer_mvbench(
data_sample, system="",
question_prompt='', # add in the end of question
answer_prompt=None, # add in the begining of answer
return_prompt='', # add in the begining of return message
system_q=False, # whether add question in the system prompt for QFormer
print_res=False,
system_llm=False
):
video = data_sample["video"]
TC, H, W = video.shape
video = video.reshape(1, TC//3, 3, H, W).to("cuda:0")
# video = torch.zeros((1, 16, 3, 224, 224)).to("cuda:0")
video_list = []
with torch.no_grad():
video_emb, _ = model.encode_img(video, system + data_sample['question'])
video_list.append(video_emb)
chat = EasyDict({
"system": system,
"roles": ("Human", "Assistant"),
"messages": [],
"sep": "###"
})
chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video>\n"])
if system_llm:
prompt = system + data_sample['question'] + question_prompt
else:
prompt = data_sample['question'] + question_prompt
ask(prompt, chat)
llm_message = answer(
conv=chat, model=model, do_sample=False,
img_list=video_list, max_new_tokens=1000,
answer_prompt=answer_prompt, print_res=print_res
)[0]
# remove potential explanation
# llm_message = llm_message.strip().split('\n')[0]
return llm_message
res_list_subplot = []
res_list_summary = []
for example in tqdm(dataset):
task_type = example['task_type']
pred = infer_mvbench(
example,
system="Carefully watch this video and pay attention to every detail. Based on your observations, answer the given questions.\n",
question_prompt="",
answer_prompt="",
return_prompt='',
system_q=False,
print_res=False,
system_llm=False
)
gt = example['answer']
if task_type=="subPlot":
result={}
result["video_name"]=example['video_path'].split("/")[-1]
result['Q']=example['question']
result['A']=gt
result['pred']=pred
res_list_subplot.append(result)
if task_type=="summary":
result={}
result["video_name"]=example['video_path'].split("/")[-1]
result['Q']=example['question']
result['A']=gt
result['pred']=pred
res_list_summary.append(result)
# print("####################3")
# print("question",result['Q'])
print("pred",pred)
# print("gt",gt)
# print("##################3")
with open(f"subplot_all.json", "w") as f:
json.dump(res_list_subplot, f)
with open(f"summary_all.json", "w") as f:
json.dump(res_list_summary, f)