473 lines
16 KiB
Python
473 lines
16 KiB
Python
from utils.config import Config
|
|
config_file = "configs/config.json"
|
|
cfg = Config.from_file(config_file)
|
|
|
|
import os
|
|
|
|
import io
|
|
import json
|
|
from models import VideoChat2_it_vicuna
|
|
from utils.easydict import EasyDict
|
|
import torch
|
|
|
|
from transformers import StoppingCriteria, StoppingCriteriaList
|
|
|
|
from PIL import Image
|
|
import numpy as np
|
|
import numpy as np
|
|
from decord import VideoReader, cpu
|
|
import torchvision.transforms as T
|
|
from dataset.video_transforms import (
|
|
GroupNormalize, GroupScale, GroupCenterCrop,
|
|
Stack, ToTorchFormatTensor
|
|
)
|
|
from torchvision.transforms.functional import InterpolationMode
|
|
from torch.utils.data import Dataset
|
|
|
|
from torchvision import transforms
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
from tqdm import tqdm
|
|
|
|
from IPython.display import Video, HTML
|
|
|
|
from peft import get_peft_model, LoraConfig, TaskType
|
|
import copy
|
|
|
|
# load stage2 model
|
|
cfg.model.vision_encoder.num_frames = 4
|
|
model = VideoChat2_it_vicuna(config=cfg.model)
|
|
|
|
model = model.to(torch.device(cfg.device))
|
|
model = model.eval().cuda()
|
|
|
|
# add lora to run stage3 model
|
|
peft_config = LoraConfig(
|
|
task_type=TaskType.CAUSAL_LM, inference_mode=False,
|
|
r=16, lora_alpha=32, lora_dropout=0.
|
|
)
|
|
model.llama_model = get_peft_model(model.llama_model, peft_config)
|
|
|
|
state_dict = torch.load("checkpoints/videochat2_7b_stage3.pth", "cpu")
|
|
|
|
if 'model' in state_dict.keys():
|
|
msg = model.load_state_dict(state_dict['model'], strict=False)
|
|
else:
|
|
msg = model.load_state_dict(state_dict, strict=False)
|
|
# print("#########",msg)
|
|
|
|
model = model.eval().cuda()
|
|
|
|
def get_prompt(conv):
|
|
ret = conv.system + conv.sep
|
|
for role, message in conv.messages:
|
|
if message:
|
|
ret += role + ": " + message + conv.sep
|
|
else:
|
|
ret += role + ":"
|
|
return ret
|
|
|
|
|
|
def get_prompt2(conv):
|
|
ret = conv.system + conv.sep
|
|
count = 0
|
|
for role, message in conv.messages:
|
|
count += 1
|
|
if count == len(conv.messages):
|
|
ret += role + ": " + message
|
|
else:
|
|
if message:
|
|
ret += role + ": " + message + conv.sep
|
|
else:
|
|
ret += role + ":"
|
|
return ret
|
|
|
|
|
|
def get_context_emb(conv, model, img_list, answer_prompt=None, print_res=False):
|
|
if answer_prompt:
|
|
prompt = get_prompt2(conv)
|
|
else:
|
|
prompt = get_prompt(conv)
|
|
print("#############")
|
|
print("prompt********",prompt)
|
|
print("##############")
|
|
if print_res:
|
|
print(prompt)
|
|
if '<VideoHere>' in prompt:
|
|
prompt_segs = prompt.split('<VideoHere>')
|
|
else:
|
|
prompt_segs = prompt.split('<ImageHere>')
|
|
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
|
|
with torch.no_grad():
|
|
seg_tokens = [
|
|
model.llama_tokenizer(
|
|
seg, return_tensors="pt", add_special_tokens=i == 0).to("cuda:0").input_ids
|
|
# only add bos to the first seg
|
|
for i, seg in enumerate(prompt_segs)
|
|
]
|
|
seg_embs = [model.llama_model.base_model.model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
|
|
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
|
|
mixed_embs = torch.cat(mixed_embs, dim=1)
|
|
return mixed_embs
|
|
|
|
|
|
def ask(text, conv):
|
|
conv.messages.append([conv.roles[0], text + '\n'])
|
|
|
|
|
|
class StoppingCriteriaSub(StoppingCriteria):
|
|
def __init__(self, stops=[], encounters=1):
|
|
super().__init__()
|
|
self.stops = stops
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
|
for stop in self.stops:
|
|
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
|
return True
|
|
return False
|
|
|
|
|
|
def answer(conv, model, img_list, do_sample=True, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
|
|
repetition_penalty=1.0, length_penalty=1, temperature=1.0, answer_prompt=None, print_res=False):
|
|
stop_words_ids = [
|
|
torch.tensor([835]).to("cuda:0"),
|
|
torch.tensor([2277, 29937]).to("cuda:0")] # '###' can be encoded in two different ways.
|
|
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
|
|
|
conv.messages.append([conv.roles[1], answer_prompt])
|
|
embs = get_context_emb(conv, model, img_list, answer_prompt=answer_prompt, print_res=print_res)
|
|
with torch.no_grad():
|
|
outputs = model.llama_model.generate(
|
|
inputs_embeds=embs,
|
|
max_new_tokens=max_new_tokens,
|
|
stopping_criteria=stopping_criteria,
|
|
num_beams=num_beams,
|
|
do_sample=do_sample,
|
|
min_length=min_length,
|
|
top_p=top_p,
|
|
repetition_penalty=repetition_penalty,
|
|
length_penalty=length_penalty,
|
|
temperature=temperature,
|
|
)
|
|
output_token = outputs[0]
|
|
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
|
|
output_token = output_token[1:]
|
|
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
|
|
output_token = output_token[1:]
|
|
output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)
|
|
output_text = output_text.split('###')[0] # remove the stop sign '###'
|
|
output_text = output_text.split('Assistant:')[-1].strip()
|
|
conv.messages[-1][1] = output_text
|
|
return output_text, output_token.cpu().numpy()
|
|
|
|
def get_index(num_frames, num_segments):
|
|
seg_size = float(num_frames - 1) / num_segments
|
|
start = int(seg_size / 2)
|
|
offsets = np.array([
|
|
start + int(np.round(seg_size * idx)) for idx in range(num_segments)
|
|
])
|
|
return offsets
|
|
|
|
|
|
def load_video(video_path, num_segments=8, return_msg=False, resolution=224):
|
|
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
|
num_frames = len(vr)
|
|
frame_indices = get_index(num_frames, num_segments)
|
|
|
|
# transform
|
|
crop_size = resolution
|
|
scale_size = resolution
|
|
input_mean = [0.48145466, 0.4578275, 0.40821073]
|
|
input_std = [0.26862954, 0.26130258, 0.27577711]
|
|
|
|
transform = T.Compose([
|
|
GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
|
|
GroupCenterCrop(crop_size),
|
|
Stack(),
|
|
ToTorchFormatTensor(),
|
|
GroupNormalize(input_mean, input_std)
|
|
])
|
|
|
|
images_group = list()
|
|
for frame_index in frame_indices:
|
|
img = Image.fromarray(vr[frame_index].numpy())
|
|
images_group.append(img)
|
|
torch_imgs = transform(images_group)
|
|
if return_msg:
|
|
fps = float(vr.get_avg_fps())
|
|
sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
|
|
# " " should be added in the start and end
|
|
msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
|
|
return torch_imgs, msg
|
|
else:
|
|
return torch_imgs
|
|
|
|
def get_sinusoid_encoding_table(n_position=784, d_hid=1024, cur_frame=8, ckpt_num_frame=4, pre_n_position=784):
|
|
''' Sinusoid position encoding table '''
|
|
# TODO: make it with torch instead of numpy
|
|
def get_position_angle_vec(position):
|
|
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
|
|
|
# generate checkpoint position embedding
|
|
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(pre_n_position)])
|
|
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
|
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
|
sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
|
|
|
|
print(f"n_position: {n_position}")
|
|
print(f"pre_n_position: {pre_n_position}")
|
|
|
|
if n_position != pre_n_position:
|
|
T = ckpt_num_frame # checkpoint frame
|
|
P = 14 # checkpoint size
|
|
C = d_hid
|
|
new_P = int((n_position // cur_frame) ** 0.5) # testing size
|
|
if new_P != 14:
|
|
print(f'Pretraining uses 14x14, but current version is {new_P}x{new_P}')
|
|
print(f'Interpolate the position embedding')
|
|
sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
|
|
sinusoid_table = sinusoid_table.reshape(-1, P, P, C).permute(0, 3, 1, 2)
|
|
sinusoid_table = torch.nn.functional.interpolate(
|
|
sinusoid_table, size=(new_P, new_P), mode='bicubic', align_corners=False)
|
|
# BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
|
|
sinusoid_table = sinusoid_table.permute(0, 2, 3, 1).reshape(-1, T, new_P, new_P, C)
|
|
sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C
|
|
|
|
if cur_frame != ckpt_num_frame:
|
|
print(f'Pretraining uses 4 frames, but current frame is {cur_frame}')
|
|
print(f'Interpolate the position embedding')
|
|
T = ckpt_num_frame # checkpoint frame
|
|
new_T = cur_frame # testing frame
|
|
# interpolate
|
|
P = int((n_position // cur_frame) ** 0.5) # testing size
|
|
C = d_hid
|
|
sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
|
|
sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T) # BHW, C, T
|
|
sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear')
|
|
sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C
|
|
sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C
|
|
|
|
return sinusoid_table
|
|
|
|
|
|
|
|
|
|
|
|
data_list = {
|
|
"subPlot": ("8_sub_scene.json", f"/MLVU_all/video/subPlot", "video", False),
|
|
"summary": ("9_summary.json", f"/MLVU_all/video/summary", "video", False)
|
|
}
|
|
|
|
data_dir = f"/MLVU_all/json"
|
|
|
|
|
|
class MLVU(Dataset):
|
|
def __init__(self, data_dir, data_list, num_segments=8, resolution=224):
|
|
self.data_list = []
|
|
for k, v in data_list.items():
|
|
with open(os.path.join(data_dir, v[0]), 'r') as f:
|
|
json_data = json.load(f)
|
|
for data in json_data:
|
|
self.data_list.append({
|
|
'task_type': k,
|
|
'prefix': v[1],
|
|
'data_type': v[2],
|
|
'data': data
|
|
})
|
|
|
|
self.decord_method = {
|
|
'video': self.read_video
|
|
}
|
|
self.num_segments = num_segments
|
|
|
|
# transform
|
|
crop_size = resolution
|
|
scale_size = resolution
|
|
input_mean = [0.48145466, 0.4578275, 0.40821073]
|
|
input_std = [0.26862954, 0.26130258, 0.27577711]
|
|
self.transform = T.Compose([
|
|
GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
|
|
GroupCenterCrop(crop_size),
|
|
Stack(),
|
|
ToTorchFormatTensor(),
|
|
GroupNormalize(input_mean, input_std)
|
|
])
|
|
|
|
def __str__(self):
|
|
len_list = {}
|
|
option_list = {}
|
|
for data in self.data_list:
|
|
if data['task_type'] not in len_list:
|
|
len_list[data['task_type']] = 0
|
|
len_list[data['task_type']] += 1
|
|
if data['task_type'] not in option_list:
|
|
option_list[data['task_type']] = 0
|
|
option_list[data['task_type']] += len(data['data']['candidates'])
|
|
|
|
correct = 0
|
|
total = 0
|
|
res = f"There are {len(self.data_list)} videos as follow:\n"
|
|
for k, v in len_list.items():
|
|
correct += len_list[k]
|
|
total += option_list[k]
|
|
res += f"{v} for {k} ({option_list[k]} options => {len_list[k]/option_list[k]*100:.2f}%)\n"
|
|
correct = correct + 1 / option_list[k]
|
|
res += f"Total random accuracy: {correct/total*100:.2f}%"
|
|
return res.rstrip()
|
|
|
|
def __len__(self):
|
|
return len(self.data_list)
|
|
|
|
def get_index(self, bound, fps, max_frame, first_idx=0):
|
|
if bound:
|
|
start, end = bound[0], bound[1]
|
|
else:
|
|
start, end = -100000, 100000
|
|
start_idx = max(first_idx, round(start * fps))
|
|
end_idx = min(round(end * fps), max_frame)
|
|
seg_size = float(end_idx - start_idx) / self.num_segments
|
|
frame_indices = np.array([
|
|
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
|
|
for idx in range(self.num_segments)
|
|
])
|
|
return frame_indices
|
|
|
|
def read_video(self, video_path, bound=None):
|
|
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
|
max_frame = len(vr) - 1
|
|
fps = float(vr.get_avg_fps())
|
|
|
|
images_group = list()
|
|
frame_indices = self.get_index(bound, fps, max_frame, first_idx=0)
|
|
for frame_index in frame_indices:
|
|
img = Image.fromarray(vr[frame_index].numpy())
|
|
images_group.append(img)
|
|
torch_imgs = self.transform(images_group)
|
|
return torch_imgs
|
|
|
|
|
|
def qa_template(self, data):
|
|
question = f"{data['question']}"
|
|
answer = data['answer']
|
|
return question, answer
|
|
|
|
def __getitem__(self, idx):
|
|
decord_method = self.decord_method[self.data_list[idx]['data_type']]
|
|
bound = None
|
|
video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video'])
|
|
torch_imgs = decord_method(video_path, bound)
|
|
question, answer = self.qa_template(self.data_list[idx]['data'])
|
|
|
|
return {
|
|
'video': torch_imgs,
|
|
'question': question,
|
|
'answer': answer,
|
|
'task_type': self.data_list[idx]['task_type'],
|
|
'video_path': video_path
|
|
}
|
|
|
|
# position embedding
|
|
num_frame = 16
|
|
resolution = 224
|
|
new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2*num_frame, cur_frame=num_frame)
|
|
model.vision_encoder.encoder.pos_embed = new_pos_emb
|
|
|
|
dataset = MLVU(data_dir, data_list, num_segments=num_frame, resolution=resolution)
|
|
|
|
def infer_mvbench(
|
|
data_sample, system="",
|
|
question_prompt='', # add in the end of question
|
|
answer_prompt=None, # add in the begining of answer
|
|
return_prompt='', # add in the begining of return message
|
|
system_q=False, # whether add question in the system prompt for QFormer
|
|
print_res=False,
|
|
system_llm=False
|
|
):
|
|
video = data_sample["video"]
|
|
TC, H, W = video.shape
|
|
video = video.reshape(1, TC//3, 3, H, W).to("cuda:0")
|
|
# video = torch.zeros((1, 16, 3, 224, 224)).to("cuda:0")
|
|
|
|
|
|
video_list = []
|
|
with torch.no_grad():
|
|
video_emb, _ = model.encode_img(video, system + data_sample['question'])
|
|
|
|
video_list.append(video_emb)
|
|
|
|
|
|
chat = EasyDict({
|
|
"system": system,
|
|
"roles": ("Human", "Assistant"),
|
|
"messages": [],
|
|
"sep": "###"
|
|
})
|
|
|
|
chat.messages.append([chat.roles[0], f"<Video><VideoHere></Video>\n"])
|
|
|
|
|
|
if system_llm:
|
|
prompt = system + data_sample['question'] + question_prompt
|
|
else:
|
|
prompt = data_sample['question'] + question_prompt
|
|
|
|
ask(prompt, chat)
|
|
|
|
llm_message = answer(
|
|
conv=chat, model=model, do_sample=False,
|
|
img_list=video_list, max_new_tokens=1000,
|
|
answer_prompt=answer_prompt, print_res=print_res
|
|
)[0]
|
|
# remove potential explanation
|
|
# llm_message = llm_message.strip().split('\n')[0]
|
|
|
|
return llm_message
|
|
|
|
|
|
|
|
res_list_subplot = []
|
|
res_list_summary = []
|
|
for example in tqdm(dataset):
|
|
task_type = example['task_type']
|
|
pred = infer_mvbench(
|
|
example,
|
|
system="Carefully watch this video and pay attention to every detail. Based on your observations, answer the given questions.\n",
|
|
question_prompt="",
|
|
answer_prompt="",
|
|
return_prompt='',
|
|
system_q=False,
|
|
print_res=False,
|
|
system_llm=False
|
|
)
|
|
gt = example['answer']
|
|
|
|
if task_type=="subPlot":
|
|
result={}
|
|
result["video_name"]=example['video_path'].split("/")[-1]
|
|
result['Q']=example['question']
|
|
result['A']=gt
|
|
result['pred']=pred
|
|
res_list_subplot.append(result)
|
|
|
|
if task_type=="summary":
|
|
result={}
|
|
result["video_name"]=example['video_path'].split("/")[-1]
|
|
result['Q']=example['question']
|
|
result['A']=gt
|
|
result['pred']=pred
|
|
res_list_summary.append(result)
|
|
|
|
# print("####################3")
|
|
# print("question",result['Q'])
|
|
print("pred",pred)
|
|
# print("gt",gt)
|
|
# print("##################3")
|
|
|
|
|
|
|
|
with open(f"subplot_all.json", "w") as f:
|
|
json.dump(res_list_subplot, f)
|
|
|
|
with open(f"summary_all.json", "w") as f:
|
|
json.dump(res_list_summary, f)
|