from utils.config import Config config_file = "configs/config.json" cfg = Config.from_file(config_file) import os import io import json from models import VideoChat2_it_vicuna from utils.easydict import EasyDict import torch from transformers import StoppingCriteria, StoppingCriteriaList from PIL import Image import numpy as np import numpy as np from decord import VideoReader, cpu import torchvision.transforms as T from dataset.video_transforms import ( GroupNormalize, GroupScale, GroupCenterCrop, Stack, ToTorchFormatTensor ) from torchvision.transforms.functional import InterpolationMode from torch.utils.data import Dataset from torchvision import transforms import matplotlib.pyplot as plt from tqdm import tqdm from IPython.display import Video, HTML from peft import get_peft_model, LoraConfig, TaskType import copy # load stage2 model cfg.model.vision_encoder.num_frames = 4 model = VideoChat2_it_vicuna(config=cfg.model) model = model.to(torch.device(cfg.device)) model = model.eval().cuda() # add lora to run stage3 model peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=16, lora_alpha=32, lora_dropout=0. ) model.llama_model = get_peft_model(model.llama_model, peft_config) state_dict = torch.load("checkpoints/videochat2_7b_stage3.pth", "cpu") if 'model' in state_dict.keys(): msg = model.load_state_dict(state_dict['model'], strict=False) else: msg = model.load_state_dict(state_dict, strict=False) # print("#########",msg) model = model.eval().cuda() def get_prompt(conv): ret = conv.system + conv.sep for role, message in conv.messages: if message: ret += role + ": " + message + conv.sep else: ret += role + ":" return ret def get_prompt2(conv): ret = conv.system + conv.sep count = 0 for role, message in conv.messages: count += 1 if count == len(conv.messages): ret += role + ": " + message else: if message: ret += role + ": " + message + conv.sep else: ret += role + ":" return ret def get_context_emb(conv, model, img_list, answer_prompt=None, print_res=False): if answer_prompt: prompt = get_prompt2(conv) else: prompt = get_prompt(conv) print("#############") print("prompt********",prompt) print("##############") if print_res: print(prompt) if '' in prompt: prompt_segs = prompt.split('') else: prompt_segs = prompt.split('') assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." with torch.no_grad(): seg_tokens = [ model.llama_tokenizer( seg, return_tensors="pt", add_special_tokens=i == 0).to("cuda:0").input_ids # only add bos to the first seg for i, seg in enumerate(prompt_segs) ] seg_embs = [model.llama_model.base_model.model.model.embed_tokens(seg_t) for seg_t in seg_tokens] mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] mixed_embs = torch.cat(mixed_embs, dim=1) return mixed_embs def ask(text, conv): conv.messages.append([conv.roles[0], text + '\n']) class StoppingCriteriaSub(StoppingCriteria): def __init__(self, stops=[], encounters=1): super().__init__() self.stops = stops def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): for stop in self.stops: if torch.all((stop == input_ids[0][-len(stop):])).item(): return True return False def answer(conv, model, img_list, do_sample=True, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9, repetition_penalty=1.0, length_penalty=1, temperature=1.0, answer_prompt=None, print_res=False): stop_words_ids = [ torch.tensor([835]).to("cuda:0"), torch.tensor([2277, 29937]).to("cuda:0")] # '###' can be encoded in two different ways. stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) conv.messages.append([conv.roles[1], answer_prompt]) embs = get_context_emb(conv, model, img_list, answer_prompt=answer_prompt, print_res=print_res) with torch.no_grad(): outputs = model.llama_model.generate( inputs_embeds=embs, max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria, num_beams=num_beams, do_sample=do_sample, min_length=min_length, top_p=top_p, repetition_penalty=repetition_penalty, length_penalty=length_penalty, temperature=temperature, ) output_token = outputs[0] if output_token[0] == 0: # the model might output a unknow token at the beginning. remove it output_token = output_token[1:] if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it output_token = output_token[1:] output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False) output_text = output_text.split('###')[0] # remove the stop sign '###' output_text = output_text.split('Assistant:')[-1].strip() conv.messages[-1][1] = output_text return output_text, output_token.cpu().numpy() def get_index(num_frames, num_segments): seg_size = float(num_frames - 1) / num_segments start = int(seg_size / 2) offsets = np.array([ start + int(np.round(seg_size * idx)) for idx in range(num_segments) ]) return offsets def load_video(video_path, num_segments=8, return_msg=False, resolution=224): vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) num_frames = len(vr) frame_indices = get_index(num_frames, num_segments) # transform crop_size = resolution scale_size = resolution input_mean = [0.48145466, 0.4578275, 0.40821073] input_std = [0.26862954, 0.26130258, 0.27577711] transform = T.Compose([ GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC), GroupCenterCrop(crop_size), Stack(), ToTorchFormatTensor(), GroupNormalize(input_mean, input_std) ]) images_group = list() for frame_index in frame_indices: img = Image.fromarray(vr[frame_index].numpy()) images_group.append(img) torch_imgs = transform(images_group) if return_msg: fps = float(vr.get_avg_fps()) sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) # " " should be added in the start and end msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." return torch_imgs, msg else: return torch_imgs def get_sinusoid_encoding_table(n_position=784, d_hid=1024, cur_frame=8, ckpt_num_frame=4, pre_n_position=784): ''' Sinusoid position encoding table ''' # TODO: make it with torch instead of numpy def get_position_angle_vec(position): return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] # generate checkpoint position embedding sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(pre_n_position)]) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0) print(f"n_position: {n_position}") print(f"pre_n_position: {pre_n_position}") if n_position != pre_n_position: T = ckpt_num_frame # checkpoint frame P = 14 # checkpoint size C = d_hid new_P = int((n_position // cur_frame) ** 0.5) # testing size if new_P != 14: print(f'Pretraining uses 14x14, but current version is {new_P}x{new_P}') print(f'Interpolate the position embedding') sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C) sinusoid_table = sinusoid_table.reshape(-1, P, P, C).permute(0, 3, 1, 2) sinusoid_table = torch.nn.functional.interpolate( sinusoid_table, size=(new_P, new_P), mode='bicubic', align_corners=False) # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C sinusoid_table = sinusoid_table.permute(0, 2, 3, 1).reshape(-1, T, new_P, new_P, C) sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C if cur_frame != ckpt_num_frame: print(f'Pretraining uses 4 frames, but current frame is {cur_frame}') print(f'Interpolate the position embedding') T = ckpt_num_frame # checkpoint frame new_T = cur_frame # testing frame # interpolate P = int((n_position // cur_frame) ** 0.5) # testing size C = d_hid sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C) sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T) # BHW, C, T sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear') sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C return sinusoid_table data_list = { "subPlot": ("8_sub_scene.json", f"/MLVU_all/video/subPlot", "video", False), "summary": ("9_summary.json", f"/MLVU_all/video/summary", "video", False) } data_dir = f"/MLVU_all/json" class MLVU(Dataset): def __init__(self, data_dir, data_list, num_segments=8, resolution=224): self.data_list = [] for k, v in data_list.items(): with open(os.path.join(data_dir, v[0]), 'r') as f: json_data = json.load(f) for data in json_data: self.data_list.append({ 'task_type': k, 'prefix': v[1], 'data_type': v[2], 'data': data }) self.decord_method = { 'video': self.read_video } self.num_segments = num_segments # transform crop_size = resolution scale_size = resolution input_mean = [0.48145466, 0.4578275, 0.40821073] input_std = [0.26862954, 0.26130258, 0.27577711] self.transform = T.Compose([ GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC), GroupCenterCrop(crop_size), Stack(), ToTorchFormatTensor(), GroupNormalize(input_mean, input_std) ]) def __str__(self): len_list = {} option_list = {} for data in self.data_list: if data['task_type'] not in len_list: len_list[data['task_type']] = 0 len_list[data['task_type']] += 1 if data['task_type'] not in option_list: option_list[data['task_type']] = 0 option_list[data['task_type']] += len(data['data']['candidates']) correct = 0 total = 0 res = f"There are {len(self.data_list)} videos as follow:\n" for k, v in len_list.items(): correct += len_list[k] total += option_list[k] res += f"{v} for {k} ({option_list[k]} options => {len_list[k]/option_list[k]*100:.2f}%)\n" correct = correct + 1 / option_list[k] res += f"Total random accuracy: {correct/total*100:.2f}%" return res.rstrip() def __len__(self): return len(self.data_list) def get_index(self, bound, fps, max_frame, first_idx=0): if bound: start, end = bound[0], bound[1] else: start, end = -100000, 100000 start_idx = max(first_idx, round(start * fps)) end_idx = min(round(end * fps), max_frame) seg_size = float(end_idx - start_idx) / self.num_segments frame_indices = np.array([ int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(self.num_segments) ]) return frame_indices def read_video(self, video_path, bound=None): vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) max_frame = len(vr) - 1 fps = float(vr.get_avg_fps()) images_group = list() frame_indices = self.get_index(bound, fps, max_frame, first_idx=0) for frame_index in frame_indices: img = Image.fromarray(vr[frame_index].numpy()) images_group.append(img) torch_imgs = self.transform(images_group) return torch_imgs def qa_template(self, data): question = f"{data['question']}" answer = data['answer'] return question, answer def __getitem__(self, idx): decord_method = self.decord_method[self.data_list[idx]['data_type']] bound = None video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video']) torch_imgs = decord_method(video_path, bound) question, answer = self.qa_template(self.data_list[idx]['data']) return { 'video': torch_imgs, 'question': question, 'answer': answer, 'task_type': self.data_list[idx]['task_type'], 'video_path': video_path } # position embedding num_frame = 16 resolution = 224 new_pos_emb = get_sinusoid_encoding_table(n_position=(resolution//16)**2*num_frame, cur_frame=num_frame) model.vision_encoder.encoder.pos_embed = new_pos_emb dataset = MLVU(data_dir, data_list, num_segments=num_frame, resolution=resolution) def infer_mvbench( data_sample, system="", question_prompt='', # add in the end of question answer_prompt=None, # add in the begining of answer return_prompt='', # add in the begining of return message system_q=False, # whether add question in the system prompt for QFormer print_res=False, system_llm=False ): video = data_sample["video"] TC, H, W = video.shape video = video.reshape(1, TC//3, 3, H, W).to("cuda:0") # video = torch.zeros((1, 16, 3, 224, 224)).to("cuda:0") video_list = [] with torch.no_grad(): video_emb, _ = model.encode_img(video, system + data_sample['question']) video_list.append(video_emb) chat = EasyDict({ "system": system, "roles": ("Human", "Assistant"), "messages": [], "sep": "###" }) chat.messages.append([chat.roles[0], f"\n"]) if system_llm: prompt = system + data_sample['question'] + question_prompt else: prompt = data_sample['question'] + question_prompt ask(prompt, chat) llm_message = answer( conv=chat, model=model, do_sample=False, img_list=video_list, max_new_tokens=1000, answer_prompt=answer_prompt, print_res=print_res )[0] # remove potential explanation # llm_message = llm_message.strip().split('\n')[0] return llm_message res_list_subplot = [] res_list_summary = [] for example in tqdm(dataset): task_type = example['task_type'] pred = infer_mvbench( example, system="Carefully watch this video and pay attention to every detail. Based on your observations, answer the given questions.\n", question_prompt="", answer_prompt="", return_prompt='', system_q=False, print_res=False, system_llm=False ) gt = example['answer'] if task_type=="subPlot": result={} result["video_name"]=example['video_path'].split("/")[-1] result['Q']=example['question'] result['A']=gt result['pred']=pred res_list_subplot.append(result) if task_type=="summary": result={} result["video_name"]=example['video_path'].split("/")[-1] result['Q']=example['question'] result['A']=gt result['pred']=pred res_list_summary.append(result) # print("####################3") # print("question",result['Q']) print("pred",pred) # print("gt",gt) # print("##################3") with open(f"subplot_all.json", "w") as f: json.dump(res_list_subplot, f) with open(f"summary_all.json", "w") as f: json.dump(res_list_summary, f)