mysora/tools/scoring/aesthetic/inference.py

214 lines
6.5 KiB
Python

# adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py
import cv2 # isort:skip
import argparse
import gc
import os
from datetime import timedelta
import clip
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.datasets.folder import pil_loader
from tqdm import tqdm
from tools.datasets.utils import extract_frames, is_video
NUM_FRAMES_POINTS = {
1: (0.5,),
2: (0.25, 0.5),
3: (0.1, 0.5, 0.9),
}
def merge_scores(gathered_list: list, meta: pd.DataFrame, column):
# reorder
indices_list = list(map(lambda x: x[0], gathered_list))
scores_list = list(map(lambda x: x[1], gathered_list))
flat_indices = []
for x in zip(*indices_list):
flat_indices.extend(x)
flat_scores = []
for x in zip(*scores_list):
flat_scores.extend(x)
flat_indices = np.array(flat_indices)
flat_scores = np.array(flat_scores)
# filter duplicates
unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True)
meta.loc[unique_indices, column] = flat_scores[unique_indices_idx]
# drop indices in meta not in unique_indices
meta = meta.loc[unique_indices]
return meta
class VideoTextDataset(torch.utils.data.Dataset):
def __init__(self, meta_path, transform=None, num_frames=3):
self.meta_path = meta_path
self.meta = pd.read_csv(meta_path)
self.transform = transform
self.points = NUM_FRAMES_POINTS[num_frames]
def __getitem__(self, index):
sample = self.meta.iloc[index]
path = sample["path"]
# extract frames
if not is_video(path):
images = [pil_loader(path)]
else:
images = extract_frames(sample["path"], points=self.points, backend="opencv")
# transform
images = [self.transform(img) for img in images]
# stack
images = torch.stack(images)
ret = dict(index=index, images=images)
return ret
def __len__(self):
return len(self.meta)
class MLP(nn.Module):
def __init__(self, input_size):
super().__init__()
self.input_size = input_size
self.layers = nn.Sequential(
nn.Linear(self.input_size, 1024),
nn.Dropout(0.2),
nn.Linear(1024, 128),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.Dropout(0.1),
nn.Linear(64, 16),
nn.Linear(16, 1),
)
def forward(self, x):
return self.layers(x)
class AestheticScorer(nn.Module):
def __init__(self, input_size, device):
super().__init__()
self.mlp = MLP(input_size)
self.clip, self.preprocess = clip.load("ViT-L/14", device=device)
self.eval()
self.to(device)
def forward(self, x):
image_features = self.clip.encode_image(x)
image_features = F.normalize(image_features, p=2, dim=-1).float()
return self.mlp(image_features)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("meta_path", type=str, help="Path to the input CSV file")
parser.add_argument("--bs", type=int, default=1024, help="Batch size")
parser.add_argument("--num_workers", type=int, default=16, help="Number of workers")
parser.add_argument("--prefetch_factor", type=int, default=3, help="Prefetch factor")
parser.add_argument("--num_frames", type=int, default=3, help="Number of frames to extract")
parser.add_argument("--skip_if_existing", action="store_true")
args = parser.parse_args()
return args
def main():
args = parse_args()
meta_path = args.meta_path
if not os.path.exists(meta_path):
print(f"Meta file '{meta_path}' not found. Exit.")
exit()
wo_ext, ext = os.path.splitext(meta_path)
out_path = f"{wo_ext}_aes{ext}"
if args.skip_if_existing and os.path.exists(out_path):
print(f"Output meta file '{out_path}' already exists. Exit.")
exit()
dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
# build model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AestheticScorer(768, device)
model.mlp.load_state_dict(torch.load("pretrained_models/aesthetic.pth", map_location=device))
preprocess = model.preprocess
# build dataset
dataset = VideoTextDataset(args.meta_path, transform=preprocess, num_frames=args.num_frames)
dataloader = DataLoader(
dataset,
batch_size=args.bs,
num_workers=args.num_workers,
sampler=DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=False,
drop_last=False,
),
)
# compute aesthetic scores
indices_list = []
scores_list = []
model.eval()
for batch in tqdm(dataloader, disable=dist.get_rank() != 0):
indices = batch["index"]
images = batch["images"].to(device, non_blocking=True)
B = images.shape[0]
images = rearrange(images, "B N C H W -> (B N) C H W")
# compute score
with torch.no_grad():
scores = model(images)
scores = rearrange(scores, "(B N) 1 -> B N", B=B)
scores = scores.mean(dim=1)
scores_np = scores.to(torch.float32).cpu().numpy()
indices_list.extend(indices.tolist())
scores_list.extend(scores_np.tolist())
# save local results
meta_local = merge_scores([(indices_list, scores_list)], dataset.meta, column="aes")
save_dir_local = os.path.join(os.path.dirname(out_path), "parts")
os.makedirs(save_dir_local, exist_ok=True)
out_path_local = os.path.join(
save_dir_local, os.path.basename(out_path).replace(".csv", f"_part_{dist.get_rank()}.csv")
)
meta_local.to_csv(out_path_local, index=False)
# wait for all ranks to finish data processing
dist.barrier()
torch.cuda.empty_cache()
gc.collect()
gathered_list = [None] * dist.get_world_size()
dist.all_gather_object(gathered_list, (indices_list, scores_list))
if dist.get_rank() == 0:
meta_new = merge_scores(gathered_list, dataset.meta, column="aes")
meta_new.to_csv(out_path, index=False)
print(f"New meta with aesthetic scores saved to '{out_path}'.")
if __name__ == "__main__":
main()