diff --git a/scripts/diffusion/gen_smoke_data.py b/scripts/diffusion/gen_smoke_data.py index 6a9634e..679ce99 100644 --- a/scripts/diffusion/gen_smoke_data.py +++ b/scripts/diffusion/gen_smoke_data.py @@ -24,7 +24,7 @@ import csv import torch from PIL import Image -N_SAMPLES = 4 # 生成 4 条假样本 +N_SAMPLES = 16 # 8 GPU × 2 batch/GPU,保证每卡至少有数据 BASE_DIR = "/data/train-input/smoke_test" NULL_DIR = "/mnt/ddn/sora/tmp_load"