261 lines
8.2 KiB
Python
261 lines
8.2 KiB
Python
import torch
|
|
from videollava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
|
from videollava.conversation import conv_templates, SeparatorStyle
|
|
from videollava.model.builder import load_pretrained_model
|
|
from videollava.utils import disable_torch_init
|
|
from videollava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
|
from torchvision import transforms
|
|
import json
|
|
from tqdm import tqdm
|
|
import os
|
|
import argparse
|
|
|
|
|
|
|
|
from PIL import Image
|
|
|
|
import random
|
|
import numpy as np
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
def get_prompt2(conv):
|
|
ret = conv.system + conv.sep
|
|
count = 0
|
|
for role, message in conv.messages:
|
|
count += 1
|
|
if count == len(conv.messages):
|
|
ret += role + ": " + message
|
|
else:
|
|
if message:
|
|
ret += role + ": " + message + conv.sep
|
|
else:
|
|
ret += role + ":"
|
|
return ret
|
|
|
|
class MLVU(Dataset):
|
|
def __init__(self, data_dir, data_list):
|
|
self.data_list = []
|
|
for k, v in data_list.items():
|
|
with open(os.path.join(data_dir, v[0]), 'r') as f:
|
|
json_data = json.load(f)
|
|
for data in json_data:
|
|
self.data_list.append({
|
|
'task_type': k,
|
|
'prefix': v[1],
|
|
'data_type': v[2],
|
|
'data': data
|
|
})
|
|
|
|
|
|
def __str__(self):
|
|
len_list = {}
|
|
option_list = {}
|
|
for data in self.data_list:
|
|
if data['task_type'] not in len_list:
|
|
len_list[data['task_type']] = 0
|
|
len_list[data['task_type']] += 1
|
|
if data['task_type'] not in option_list:
|
|
option_list[data['task_type']] = 0
|
|
option_list[data['task_type']] += len(data['data']['candidates'])
|
|
|
|
correct = 0
|
|
total = 0
|
|
res = f"There are {len(self.data_list)} videos as follow:\n"
|
|
for k, v in len_list.items():
|
|
correct += len_list[k]
|
|
total += option_list[k]
|
|
res += f"{v} for {k} ({option_list[k]} options => {len_list[k]/option_list[k]*100:.2f}%)\n"
|
|
correct = correct + 1 / option_list[k]
|
|
res += f"Total random accuracy: {correct/total*100:.2f}%"
|
|
return res.rstrip()
|
|
|
|
def __len__(self):
|
|
return len(self.data_list)
|
|
|
|
def get_index(self, bound, fps, max_frame, first_idx=0):
|
|
start, end = -100000, 100000
|
|
start_idx = max(first_idx, round(start * fps))
|
|
end_idx = min(round(end * fps), max_frame)
|
|
seg_size = float(end_idx - start_idx) / self.num_segments
|
|
frame_indices = np.array([
|
|
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
|
|
for idx in range(self.num_segments)
|
|
])
|
|
return frame_indices
|
|
|
|
|
|
def qa_template(self, data):
|
|
question = f"Question: {data['question']}\n"
|
|
question += "Options:\n"
|
|
answer = data['answer']
|
|
answer_idx = -1
|
|
for idx, c in enumerate(data['candidates']):
|
|
question += f"({chr(ord('A') + idx)}) {c}\n"
|
|
if c == answer:
|
|
answer_idx = idx
|
|
question = question.rstrip()
|
|
answer = f"({chr(ord('A') + answer_idx)}) {answer}"
|
|
return question, answer
|
|
|
|
def __getitem__(self, idx):
|
|
bound = None
|
|
video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video'])
|
|
question, answer = self.qa_template(self.data_list[idx]['data'])
|
|
return {
|
|
'video': video_path,
|
|
'question': question,
|
|
'answer': answer,
|
|
'task_type': self.data_list[idx]['task_type']
|
|
}
|
|
|
|
|
|
|
|
def check_ans(pred, gt):
|
|
flag = False
|
|
|
|
index=gt.index("(")
|
|
index2=gt.index(")")
|
|
gt_option=gt[index+1:index2]
|
|
|
|
if ")" in pred:
|
|
index3=pred.index(")")
|
|
pred=pred[index3-1:index3]
|
|
|
|
print("2222222",pred,gt_option)
|
|
if pred==gt_option:
|
|
print("11111111111111",pred,gt_option)
|
|
flag=True
|
|
|
|
return flag
|
|
|
|
def main():
|
|
disable_torch_init()
|
|
|
|
|
|
data_list = {
|
|
"count": ("4_count.json", f"/MLVU_all/video/count", "video"),
|
|
"ego": ("3_ego.json", f"/MLVU_all/video/ego", "video"),
|
|
"needle": ("2_needle.json", f"/MLVU_all/video/needle", "video"),
|
|
"order": ("5_order.json", f"/MLVU_all/video/order", "video"),
|
|
"plotQA": ("1_plotQA.json", f"/MLVU_all/video/plotQA", "video"),
|
|
"anomaly_reco": ("6_anomaly_reco.json", f"/MLVU_all/video/anomaly_reco", "video"),
|
|
"topic_reasoning": ("7_topic_reasoning.json", f"/MLVU_all/video/topic_reasoning", "video")
|
|
}
|
|
|
|
|
|
data_dir = f"/MLVU_all/json"
|
|
save_path = f"./test_all_choice"
|
|
result_path=f"bench_all.json"
|
|
|
|
dataset = MLVU(data_dir, data_list)
|
|
|
|
model_path = 'LanguageBind/Video-LLaVA-7B'
|
|
cache_dir = 'cache_dir'
|
|
device = 'cuda:6'
|
|
load_4bit, load_8bit = True, False
|
|
model_name = get_model_name_from_path(model_path)
|
|
tokenizer, model, processor, _ = load_pretrained_model(model_path, None, model_name, load_8bit, load_4bit, device=device)
|
|
video_processor = processor['video']
|
|
conv_mode = "llava_v1"
|
|
conv = conv_templates[conv_mode].copy()
|
|
roles = conv.roles
|
|
|
|
|
|
correct = 0
|
|
total = 0
|
|
res_list = []
|
|
acc_dict = {}
|
|
for example in tqdm(dataset):
|
|
conv.messages = list()
|
|
task_type = example['task_type']
|
|
if task_type not in acc_dict:
|
|
acc_dict[task_type] = [0, 0] # correct, total
|
|
acc_dict[task_type][1] += 1
|
|
total += 1
|
|
video_path=example["video"]
|
|
inp=example["question"] + "\nOnly give the best option."
|
|
|
|
|
|
video_tensor = video_processor(video_path, return_tensors='pt')['pixel_values']
|
|
print("##########",video_tensor.shape)
|
|
|
|
if type(video_tensor) is list:
|
|
tensor = [video.to(model.device, dtype=torch.float16) for video in video_tensor]
|
|
else:
|
|
tensor = video_tensor.to(model.device, dtype=torch.float16)
|
|
|
|
inp = ' '.join([DEFAULT_IMAGE_TOKEN] * model.get_video_tower().config.num_frames) + '\n' + inp
|
|
conv.system="Carefully watch this video and pay attention to every detail. Based on your observations, select the best option that accurately addresses the question."
|
|
conv.append_message(conv.roles[0], inp)
|
|
conv.append_message(conv.roles[1], "Best Option: (")
|
|
# prompt = conv.get_prompt()
|
|
prompt=get_prompt2(conv)
|
|
print("*************")
|
|
print("prompt",prompt)
|
|
print("**************")
|
|
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
|
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
|
keywords = [stop_str]
|
|
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
|
|
|
with torch.inference_mode():
|
|
output_ids = model.generate(
|
|
input_ids,
|
|
images=tensor,
|
|
do_sample=True,
|
|
temperature=0.1,
|
|
max_new_tokens=1024,
|
|
use_cache=True,
|
|
stopping_criteria=[stopping_criteria])
|
|
|
|
pred= tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
|
|
|
|
|
gt = example['answer']
|
|
print("##########")
|
|
print("GT",gt)
|
|
print("Pred",pred)
|
|
print("##########")
|
|
|
|
res_list.append({
|
|
'pred': pred,
|
|
'gt': gt,
|
|
'question':example['question'],
|
|
'question_type':example['task_type'],
|
|
'video':example['video']
|
|
})
|
|
if check_ans(pred=pred, gt=gt):
|
|
acc_dict[task_type][0] += 1
|
|
correct += 1
|
|
print(f"Part Acc: {acc_dict[task_type][0] / acc_dict[task_type][1] * 100 :.2f}%")
|
|
print('-' * 30, task_type, '-' * 30)
|
|
|
|
|
|
with open(f"{save_path}.json", "w") as f:
|
|
json.dump({
|
|
"acc_dict": acc_dict,
|
|
"res_list": res_list
|
|
}, f)
|
|
|
|
final_res = dict()
|
|
total=0
|
|
idx=0
|
|
for k, v in acc_dict.items():
|
|
idx+=1
|
|
final_res[k] = v[0] / v[1] * 100
|
|
total+=final_res[k]
|
|
final_res['Avg'] = total /idx
|
|
print(final_res)
|
|
|
|
with open(result_path, "w") as f:
|
|
json.dump(final_res, f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|