214 lines
6.5 KiB
Python
214 lines
6.5 KiB
Python
# adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py
|
|
import cv2 # isort:skip
|
|
|
|
import argparse
|
|
import gc
|
|
import os
|
|
from datetime import timedelta
|
|
|
|
import clip
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from torch.utils.data import DataLoader, DistributedSampler
|
|
from torchvision.datasets.folder import pil_loader
|
|
from tqdm import tqdm
|
|
|
|
from tools.datasets.utils import extract_frames, is_video
|
|
|
|
NUM_FRAMES_POINTS = {
|
|
1: (0.5,),
|
|
2: (0.25, 0.5),
|
|
3: (0.1, 0.5, 0.9),
|
|
}
|
|
|
|
|
|
def merge_scores(gathered_list: list, meta: pd.DataFrame, column):
|
|
# reorder
|
|
indices_list = list(map(lambda x: x[0], gathered_list))
|
|
scores_list = list(map(lambda x: x[1], gathered_list))
|
|
|
|
flat_indices = []
|
|
for x in zip(*indices_list):
|
|
flat_indices.extend(x)
|
|
flat_scores = []
|
|
for x in zip(*scores_list):
|
|
flat_scores.extend(x)
|
|
flat_indices = np.array(flat_indices)
|
|
flat_scores = np.array(flat_scores)
|
|
|
|
# filter duplicates
|
|
unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True)
|
|
meta.loc[unique_indices, column] = flat_scores[unique_indices_idx]
|
|
|
|
# drop indices in meta not in unique_indices
|
|
meta = meta.loc[unique_indices]
|
|
return meta
|
|
|
|
|
|
class VideoTextDataset(torch.utils.data.Dataset):
|
|
def __init__(self, meta_path, transform=None, num_frames=3):
|
|
self.meta_path = meta_path
|
|
self.meta = pd.read_csv(meta_path)
|
|
self.transform = transform
|
|
self.points = NUM_FRAMES_POINTS[num_frames]
|
|
|
|
def __getitem__(self, index):
|
|
sample = self.meta.iloc[index]
|
|
path = sample["path"]
|
|
|
|
# extract frames
|
|
if not is_video(path):
|
|
images = [pil_loader(path)]
|
|
else:
|
|
images = extract_frames(sample["path"], points=self.points, backend="opencv")
|
|
|
|
# transform
|
|
images = [self.transform(img) for img in images]
|
|
|
|
# stack
|
|
images = torch.stack(images)
|
|
|
|
ret = dict(index=index, images=images)
|
|
return ret
|
|
|
|
def __len__(self):
|
|
return len(self.meta)
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, input_size):
|
|
super().__init__()
|
|
self.input_size = input_size
|
|
self.layers = nn.Sequential(
|
|
nn.Linear(self.input_size, 1024),
|
|
nn.Dropout(0.2),
|
|
nn.Linear(1024, 128),
|
|
nn.Dropout(0.2),
|
|
nn.Linear(128, 64),
|
|
nn.Dropout(0.1),
|
|
nn.Linear(64, 16),
|
|
nn.Linear(16, 1),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.layers(x)
|
|
|
|
|
|
class AestheticScorer(nn.Module):
|
|
def __init__(self, input_size, device):
|
|
super().__init__()
|
|
self.mlp = MLP(input_size)
|
|
self.clip, self.preprocess = clip.load("ViT-L/14", device=device)
|
|
|
|
self.eval()
|
|
self.to(device)
|
|
|
|
def forward(self, x):
|
|
image_features = self.clip.encode_image(x)
|
|
image_features = F.normalize(image_features, p=2, dim=-1).float()
|
|
return self.mlp(image_features)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("meta_path", type=str, help="Path to the input CSV file")
|
|
parser.add_argument("--bs", type=int, default=1024, help="Batch size")
|
|
parser.add_argument("--num_workers", type=int, default=16, help="Number of workers")
|
|
parser.add_argument("--prefetch_factor", type=int, default=3, help="Prefetch factor")
|
|
parser.add_argument("--num_frames", type=int, default=3, help="Number of frames to extract")
|
|
parser.add_argument("--skip_if_existing", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
meta_path = args.meta_path
|
|
if not os.path.exists(meta_path):
|
|
print(f"Meta file '{meta_path}' not found. Exit.")
|
|
exit()
|
|
|
|
wo_ext, ext = os.path.splitext(meta_path)
|
|
out_path = f"{wo_ext}_aes{ext}"
|
|
if args.skip_if_existing and os.path.exists(out_path):
|
|
print(f"Output meta file '{out_path}' already exists. Exit.")
|
|
exit()
|
|
|
|
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
|
|
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
|
|
|
|
# build model
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = AestheticScorer(768, device)
|
|
model.mlp.load_state_dict(torch.load("pretrained_models/aesthetic.pth", map_location=device))
|
|
preprocess = model.preprocess
|
|
|
|
# build dataset
|
|
dataset = VideoTextDataset(args.meta_path, transform=preprocess, num_frames=args.num_frames)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=args.bs,
|
|
num_workers=args.num_workers,
|
|
sampler=DistributedSampler(
|
|
dataset,
|
|
num_replicas=dist.get_world_size(),
|
|
rank=dist.get_rank(),
|
|
shuffle=False,
|
|
drop_last=False,
|
|
),
|
|
)
|
|
|
|
# compute aesthetic scores
|
|
indices_list = []
|
|
scores_list = []
|
|
model.eval()
|
|
for batch in tqdm(dataloader, disable=dist.get_rank() != 0):
|
|
indices = batch["index"]
|
|
images = batch["images"].to(device, non_blocking=True)
|
|
|
|
B = images.shape[0]
|
|
images = rearrange(images, "B N C H W -> (B N) C H W")
|
|
|
|
# compute score
|
|
with torch.no_grad():
|
|
scores = model(images)
|
|
|
|
scores = rearrange(scores, "(B N) 1 -> B N", B=B)
|
|
scores = scores.mean(dim=1)
|
|
scores_np = scores.to(torch.float32).cpu().numpy()
|
|
|
|
indices_list.extend(indices.tolist())
|
|
scores_list.extend(scores_np.tolist())
|
|
|
|
# save local results
|
|
meta_local = merge_scores([(indices_list, scores_list)], dataset.meta, column="aes")
|
|
save_dir_local = os.path.join(os.path.dirname(out_path), "parts")
|
|
os.makedirs(save_dir_local, exist_ok=True)
|
|
out_path_local = os.path.join(
|
|
save_dir_local, os.path.basename(out_path).replace(".csv", f"_part_{dist.get_rank()}.csv")
|
|
)
|
|
meta_local.to_csv(out_path_local, index=False)
|
|
|
|
# wait for all ranks to finish data processing
|
|
dist.barrier()
|
|
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
gathered_list = [None] * dist.get_world_size()
|
|
dist.all_gather_object(gathered_list, (indices_list, scores_list))
|
|
if dist.get_rank() == 0:
|
|
meta_new = merge_scores(gathered_list, dataset.meta, column="aes")
|
|
meta_new.to_csv(out_path, index=False)
|
|
print(f"New meta with aesthetic scores saved to '{out_path}'.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|