This commit is contained in:
parent
eaa79e566c
commit
4739fc615d
|
|
@ -232,11 +232,7 @@ class SFTDataCollator:
|
|||
self.pad_to_length = pad_to_length
|
||||
assert self.tok.pad_token_id is not None
|
||||
|
||||
def __call__(self, features):
|
||||
# if not features:
|
||||
# raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. "
|
||||
# f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.")
|
||||
|
||||
def __call__(self, features):
|
||||
if not features:
|
||||
raise RuntimeError(
|
||||
f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
|
||||
|
|
@ -274,14 +270,6 @@ class SFTDataCollator:
|
|||
f"target_len={target_len} first_len={first_len}",
|
||||
flush=True
|
||||
)
|
||||
# 额外严苛校验:防止空 batch 继续往下走
|
||||
# if not features:
|
||||
# raise RuntimeError(
|
||||
# f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
|
||||
# f"Check dataset sharding/streaming."
|
||||
# )
|
||||
# >>> DEBUG END
|
||||
|
||||
return {
|
||||
"input_ids": torch.stack(batch_inp, dim=0),
|
||||
"attention_mask": torch.stack(batch_attn, dim=0),
|
||||
|
|
@ -322,6 +310,7 @@ def parse_args():
|
|||
help="for deepspeed/torchrun launcher; ignored by user code")
|
||||
ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
|
||||
ap.add_argument("--deepspeed", type=str, default=None)
|
||||
|
||||
return ap.parse_args()
|
||||
|
||||
|
||||
|
|
@ -451,31 +440,39 @@ def main():
|
|||
"另外检查 seq_len 是否过小导致全部被裁。"
|
||||
)
|
||||
|
||||
# ====== 正式训练流 ======
|
||||
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||
if world_size > 1 and len(files) >= world_size:
|
||||
# 多文件,按文件连续分片
|
||||
ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
|
||||
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
|
||||
else:
|
||||
# 单文件或文件数不足,按样本取模轮转
|
||||
def ex_iter2():
|
||||
for i, ex in enumerate(ds_stream2):
|
||||
if i % max(world_size, 1) == rank:
|
||||
yield ex
|
||||
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
|
||||
# # ====== 正式训练流 ======
|
||||
# ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||
# if world_size > 1 and len(files) >= world_size:
|
||||
# # 多文件,按文件连续分片
|
||||
# ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
|
||||
# train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
|
||||
# else:
|
||||
# # 单文件或文件数不足,按样本取模轮转
|
||||
# def ex_iter2():
|
||||
# for i, ex in enumerate(ds_stream2):
|
||||
# if i % max(world_size, 1) == rank:
|
||||
# yield ex
|
||||
# train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
|
||||
|
||||
# ====== 一致性探针(与上面保持同逻辑)=====
|
||||
# ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer)======
|
||||
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
|
||||
|
||||
# # ====== 一致性探针(与上面保持同逻辑)=====
|
||||
# ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||
# if world_size > 1 and len(files) >= world_size:
|
||||
# ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True)
|
||||
# probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
|
||||
# else:
|
||||
# def ex_iter2_probe():
|
||||
# for i, ex in enumerate(ds_stream_probe2):
|
||||
# if i % max(world_size, 1) == rank:
|
||||
# yield ex
|
||||
# probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
|
||||
|
||||
# ====== 一致性探针(不分片)======
|
||||
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||
if world_size > 1 and len(files) >= world_size:
|
||||
ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True)
|
||||
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
|
||||
else:
|
||||
def ex_iter2_probe():
|
||||
for i, ex in enumerate(ds_stream_probe2):
|
||||
if i % max(world_size, 1) == rank:
|
||||
yield ex
|
||||
probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
|
||||
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
|
||||
|
||||
def has_at_least(stream, n: int):
|
||||
it = iter(stream)
|
||||
|
|
@ -511,22 +508,6 @@ def main():
|
|||
)
|
||||
sys.exit(2)
|
||||
|
||||
# if dist.is_available() and dist.is_initialized():
|
||||
# # t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
# t = torch.tensor(
|
||||
# local_ok,
|
||||
# device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu")
|
||||
# )
|
||||
# dist.all_reduce(t, op=dist.ReduceOp.MIN)
|
||||
# if t.item() == 0:
|
||||
# if is_main_process():
|
||||
# print("[FATAL] 至少有一个 rank 没有任何样本。请减少 WORLD_SIZE 或修正分片;本次训练不会启动。", flush=True)
|
||||
# dist.barrier()
|
||||
# sys.exit(2)
|
||||
# else:
|
||||
# if local_ok == 0:
|
||||
# print("[FATAL] 本机无样本,退出。", flush=True); sys.exit(2)
|
||||
|
||||
# ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ----
|
||||
eval_dataset: Optional[Dataset] = None
|
||||
|
||||
|
|
@ -595,7 +576,6 @@ def main():
|
|||
f"eval size {len(eval_dataset)} still not divisible by global_bs {global_bs}"
|
||||
|
||||
# 更稳:联调阶段不强行 pad 到 4096
|
||||
# data_collator = SFTDataCollator(tokenizer, pad_to_length=None)
|
||||
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
|
@ -661,14 +641,6 @@ def main():
|
|||
data_collator=data_collator
|
||||
)
|
||||
|
||||
# trainer = Trainer(
|
||||
# model=model,
|
||||
# args=training_args,
|
||||
# train_dataset=train_stream,
|
||||
# eval_dataset=eval_dataset,
|
||||
# processing_class=tokenizer,
|
||||
# data_collator=data_collator
|
||||
# )
|
||||
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
||||
|
||||
# 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*
|
||||
|
|
|
|||
Loading…
Reference in New Issue