mysora/tools/scoring/deduplicate/inference.py

177 lines
6.1 KiB
Python

import argparse
import os
from datetime import timedelta
import clip
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.datasets.folder import pil_loader
from tqdm import tqdm
from tools.datasets.utils import extract_frames, is_video
class VideoDataset(torch.utils.data.Dataset):
def __init__(self, meta_path, transform):
self.meta_path = meta_path
self.meta = pd.read_csv(meta_path)
self.transform = transform
def __getitem__(self, index):
row = self.meta.iloc[index]
path = row["path"]
if is_video(path):
img = extract_frames(path, points=[0.5], backend="opencv")[0]
else:
img = pil_loader(path)
img_size = img.size # W, H
img = self.transform(img)
ret = dict(index=index, images=img, img_size=str(img_size))
return ret
def __len__(self):
return len(self.meta)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("meta_path", type=str, help="Path to the input CSV file")
parser.add_argument("--bs", type=int, default=128, help="Batch size")
parser.add_argument("--thresh", type=float, default=0.98, help="similarity thresh")
parser.add_argument("--num_workers", type=int, default=16, help="Number of workers")
args = parser.parse_args()
return args
def main():
args = parse_args()
meta_path = args.meta_path
if not os.path.exists(meta_path):
print(f"Meta file '{meta_path}' not found. Exit.")
exit()
wo_ext, ext = os.path.splitext(meta_path)
out_path_dedup = f"{wo_ext}_dedup{ext}"
out_path_dup = f"{wo_ext}_dup{ext}"
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
# build model
device = "cuda" if torch.cuda.is_available() else "cpu" # Note: do not use torch.device('cuda')!!!
model, preprocess = clip.load("ViT-L/14", device=device)
# preprocess: resize shorter size to 224 by keeping ar, then center crop
# build dataset
dataset = VideoDataset(meta_path=meta_path, transform=preprocess)
dataloader = DataLoader(
dataset,
batch_size=args.bs,
num_workers=args.num_workers,
sampler=DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=False,
drop_last=False,
),
)
# encode images and store the feature
print("Begin to generate features")
model.eval()
feat_list = [] # store the feature
indices_list = [] # store the indices
img_size_list = []
for batch in tqdm(dataloader, disable=dist.get_rank() != 0):
indices = batch["index"]
imgs = batch["images"].to(device, non_blocking=True)
img_size = batch["img_size"]
with torch.no_grad():
feat_img = model.encode_image(imgs)
feat_img = F.normalize(feat_img, dim=1)
feat_list.append(feat_img)
indices_list.extend(indices.tolist())
img_size_list.extend(img_size)
feats = torch.cat(feat_list, dim=0)
# all_gather
feats_gathered = [torch.zeros_like(feats, device=device) for _ in range(dist.get_world_size())]
dist.all_gather(feats_gathered, feats)
feats_all = torch.cat(feats_gathered)
indices_gathered = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(indices_gathered, indices_list)
indices_all = np.array([x for sub in indices_gathered for x in sub])
img_size_gathered = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(img_size_gathered, img_size_list)
img_size_all = [x for sub in img_size_gathered for x in sub]
indices_unique, indices_unique_idx = np.unique(indices_all, return_index=True)
feats_unique = feats_all[torch.from_numpy(indices_unique_idx)]
img_size_unique = [img_size_all[x] for x in indices_unique_idx]
if dist.get_rank() == 0:
# get similarity scores
non_zero_list = []
sim_scores_list = []
chunk = 1000
for idx in tqdm(range(0, feats_unique.shape[0], chunk)):
sim_mat = torch.matmul(feats_unique[idx : idx + chunk], feats_unique[idx:].T).cpu().numpy()
sim_mat_upper = np.triu(sim_mat, k=1)
non_zero_i = np.nonzero(sim_mat_upper >= args.thresh)
sim_scores_i = sim_mat[non_zero_i]
non_zero_np = np.stack(non_zero_i) + idx # [2, N]
non_zero_list.append(non_zero_np)
sim_scores_list.append(sim_scores_i)
non_zero_indices = np.concatenate(non_zero_list, axis=1)
sim_scores = np.concatenate(sim_scores_list)
dup_dict = {}
for x, y, s in zip(non_zero_indices[0].tolist(), non_zero_indices[1].tolist(), sim_scores.tolist()):
# only count pairs with same the same size
if img_size_unique[x] != img_size_unique[y]:
continue
if y not in dup_dict:
dup_dict[y] = (x, s)
elif dup_dict[y][1] < s:
dup_dict[y] = (x, s)
dup_list = [(k, v) for k, v in dup_dict.items()]
dup_list = sorted(dup_list, key=lambda x: x[1][1], reverse=True)
dup_inds = [x[0] for x in dup_list]
sim_inds = [x[1][0] for x in dup_list]
sim_scores_dup = [x[1][1] for x in dup_list]
remain_inds = sorted(list(set(indices_unique.tolist()) - set(dup_inds)))
# save
meta_unique = dataset.meta.iloc[remain_inds]
meta_unique.to_csv(out_path_dedup, index=False)
print(f"New meta without duplication saved to '{out_path_dedup}'.")
meta_dup = dataset.meta.iloc[dup_inds].copy().reset_index(drop=True)
path_dup = dataset.meta.iloc[sim_inds]["path"].copy().reset_index(drop=True)
meta_dup["path_dup"] = path_dup
meta_dup["sim"] = sim_scores_dup
meta_dup.to_csv(out_path_dup, index=False)
print(f"New meta with duplicated samples saved to '{out_path_dup}'.")
if __name__ == "__main__":
main()