mysora/tools/scoring/ocr/inference.py

159 lines
4.9 KiB
Python

import argparse
import os
import colossalai
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
from mmengine import Config
from mmengine.dataset import Compose, default_collate
from mmengine.registry import DefaultScope
from mmocr.datasets import PackTextDetInputs
from mmocr.registry import MODELS
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.datasets.folder import pil_loader
from torchvision.transforms import CenterCrop, Compose, Resize
from tqdm import tqdm
from tools.datasets.utils import extract_frames, is_video
def merge_scores(gathered_list: list, meta: pd.DataFrame):
# reorder
indices_list = list(map(lambda x: x[0], gathered_list))
scores_list = list(map(lambda x: x[1], gathered_list))
flat_indices = []
for x in zip(*indices_list):
flat_indices.extend(x)
flat_scores = []
for x in zip(*scores_list):
flat_scores.extend(x)
flat_indices = np.array(flat_indices)
flat_scores = np.array(flat_scores)
# filter duplicates
unique_indices, unique_indices_idx = np.unique(flat_indices, return_index=True)
meta.loc[unique_indices, "ocr"] = flat_scores[unique_indices_idx]
class VideoTextDataset(torch.utils.data.Dataset):
def __init__(self, meta_path, transform):
self.meta_path = meta_path
self.meta = pd.read_csv(meta_path)
self.transform = transform
self.transform = Compose(
[
Resize(1024),
CenterCrop(1024),
]
)
self.formatting = PackTextDetInputs(meta_keys=["scale_factor"])
def __getitem__(self, index):
row = self.meta.iloc[index]
path = row["path"]
if is_video(path):
img = extract_frames(path, frame_inds=[10], backend="opencv")[0]
else:
img = pil_loader(path)
img = self.transform(img)
img_array = np.array(img)[:, :, ::-1].copy() # bgr
results = {
"img": img_array,
"scale_factor": 1.0,
# 'img_shape': img_array.shape[-2],
# 'ori_shape': img_array.shape[-2],
}
results = self.formatting(results)
results["index"] = index
return results
def __len__(self):
return len(self.meta)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("meta_path", type=str, help="Path to the input CSV file")
parser.add_argument("--bs", type=int, default=16, help="Batch size")
parser.add_argument("--num_workers", type=int, default=16, help="Number of workers")
parser.add_argument("--skip_if_existing", action="store_true")
args = parser.parse_args()
return args
def main():
args = parse_args()
meta_path = args.meta_path
if not os.path.exists(meta_path):
print(f"Meta file '{meta_path}' not found. Exit.")
exit()
wo_ext, ext = os.path.splitext(meta_path)
out_path = f"{wo_ext}_ocr{ext}"
if args.skip_if_existing and os.path.exists(out_path):
print(f"Output meta file '{out_path}' already exists. Exit.")
exit()
cfg = Config.fromfile("./tools/scoring/ocr/dbnetpp.py")
colossalai.launch_from_torch({})
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
DefaultScope.get_instance("ocr", scope_name="mmocr") # use mmocr Registry as default
# build model
model = MODELS.build(cfg.model)
model.init_weights()
model.to(device) # set data_preprocessor._device
print("==> Model built.")
# build dataset
transform = Compose(cfg.test_pipeline)
dataset = VideoTextDataset(meta_path=meta_path, transform=transform)
dataloader = DataLoader(
dataset,
batch_size=args.bs,
num_workers=args.num_workers,
sampler=DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
shuffle=False,
drop_last=False,
),
collate_fn=default_collate,
)
print("==> Dataloader built.")
# compute scores
dataset.meta["ocr"] = np.nan
indices_list = []
scores_list = []
model.eval()
for data in tqdm(dataloader, disable=dist.get_rank() != 0):
indices_i = data["index"]
indices_list.extend(indices_i.tolist())
del data["index"]
pred = model.test_step(data) # this line will cast data to device
num_texts_i = [(x.pred_instances.scores > 0.3).sum().item() for x in pred]
scores_list.extend(num_texts_i)
gathered_list = [None] * dist.get_world_size()
dist.all_gather_object(gathered_list, (indices_list, scores_list))
if dist.get_rank() == 0:
merge_scores(gathered_list, dataset.meta)
dataset.meta.to_csv(out_path, index=False)
print(f"New meta (shape={dataset.meta.shape}) with ocr results saved to '{out_path}'.")
if __name__ == "__main__":
main()