This commit is contained in:
parent
5a633d4b1c
commit
43111064cc
|
|
@ -853,15 +853,17 @@ def main():
|
||||||
|
|
||||||
|
|
||||||
# ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer)======
|
# ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer)======
|
||||||
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)\
|
# ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)\
|
||||||
.shuffle(buffer_size=50000, seed=args.seed)
|
# .shuffle(buffer_size=50000, seed=args.seed)
|
||||||
|
|
||||||
|
ex_iter = endless_examples(files, args.seed, buf=50000)
|
||||||
|
|
||||||
# 先尝试 datasets 的无限流;没有就用我们自己的无限生成器
|
# 先尝试 datasets 的无限流;没有就用我们自己的无限生成器
|
||||||
try:
|
# try:
|
||||||
ds_stream2 = ds_stream2.repeat() # ★ 若可用:官方无限流
|
# ds_stream2 = ds_stream2.repeat() # ★ 若可用:官方无限流
|
||||||
ex_iter = (ex for ex in ds_stream2) # ★ 统一用 ex_iter 作为上游
|
# ex_iter = (ex for ex in ds_stream2) # ★ 统一用 ex_iter 作为上游
|
||||||
except AttributeError:
|
# except AttributeError:
|
||||||
ex_iter = endless_examples(files, args.seed, buf=50000) # ★ 兜底:自制无限流
|
# ex_iter = endless_examples(files, args.seed, buf=50000) # ★ 兜底:自制无限流
|
||||||
|
|
||||||
# 关键:这里一定要用 ex_iter,而不是重新从 ds_stream2 取一次
|
# 关键:这里一定要用 ex_iter,而不是重新从 ds_stream2 取一次
|
||||||
train_stream = QwenChatSFTDataset(ex_iter, tokenizer, seq_len=args.seq_len)
|
train_stream = QwenChatSFTDataset(ex_iter, tokenizer, seq_len=args.seq_len)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue