97 lines
3.1 KiB
Python
97 lines
3.1 KiB
Python
import argparse
|
|
import os
|
|
|
|
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
|
|
|
|
def read_file(input_path):
|
|
if input_path.endswith(".csv"):
|
|
return pd.read_csv(input_path)
|
|
elif input_path.endswith(".parquet"):
|
|
return pd.read_parquet(input_path)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported file format: {input_path}")
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("input", type=str, help="Path to the input dataset")
|
|
parser.add_argument("--save-img", type=str, default="samples/infos/", help="Path to save the image")
|
|
return parser.parse_args()
|
|
|
|
|
|
def plot_data(data, column, bins, name):
|
|
plt.clf()
|
|
data.hist(column=column, bins=bins)
|
|
os.makedirs(os.path.dirname(name), exist_ok=True)
|
|
plt.savefig(name)
|
|
print(f"Saved {name}")
|
|
|
|
|
|
def plot_categorical_data(data, column, name):
|
|
plt.clf()
|
|
data[column].value_counts().plot(kind="bar")
|
|
os.makedirs(os.path.dirname(name), exist_ok=True)
|
|
plt.savefig(name)
|
|
print(f"Saved {name}")
|
|
|
|
|
|
COLUMNS = {
|
|
"num_frames": 100,
|
|
"resolution": 100,
|
|
"text_len": 100,
|
|
"aes": 100,
|
|
"match": 100,
|
|
"flow": 100,
|
|
"cmotion": None,
|
|
}
|
|
|
|
|
|
def main(args):
|
|
data = read_file(args.input)
|
|
|
|
# === Image Data Info ===
|
|
image_index = data["num_frames"] == 1
|
|
if image_index.sum() > 0:
|
|
print("=== Image Data Info ===")
|
|
img_data = data[image_index]
|
|
print(f"Number of images: {len(img_data)}")
|
|
print(img_data.head())
|
|
print(img_data.describe())
|
|
if args.save_img:
|
|
for column in COLUMNS:
|
|
if column in img_data.columns and column not in ["num_frames", "cmotion"]:
|
|
if COLUMNS[column] is None:
|
|
plot_categorical_data(img_data, column, os.path.join(args.save_img, f"image_{column}.png"))
|
|
else:
|
|
plot_data(img_data, column, COLUMNS[column], os.path.join(args.save_img, f"image_{column}.png"))
|
|
|
|
# === Video Data Info ===
|
|
if not image_index.all():
|
|
print("=== Video Data Info ===")
|
|
video_data = data[~image_index]
|
|
print(f"Number of videos: {len(video_data)}")
|
|
if "num_frames" in video_data.columns:
|
|
total_num_frames = video_data["num_frames"].sum()
|
|
print(f"Number of frames: {total_num_frames}")
|
|
DEFAULT_FPS = 30
|
|
total_hours = total_num_frames / DEFAULT_FPS / 3600
|
|
print(f"Total hours (30 FPS): {int(total_hours)}")
|
|
print(video_data.head())
|
|
print(video_data.describe())
|
|
if args.save_img:
|
|
for column in COLUMNS:
|
|
if column in video_data.columns:
|
|
if COLUMNS[column] is None:
|
|
plot_categorical_data(video_data, column, os.path.join(args.save_img, f"video_{column}.png"))
|
|
else:
|
|
plot_data(
|
|
video_data, column, COLUMNS[column], os.path.join(args.save_img, f"video_{column}.png")
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
main(args)
|