This commit is contained in:
parent
746b8471ee
commit
c9a00b6208
|
|
@ -445,36 +445,39 @@ def main():
|
||||||
"另外检查 seq_len 是否过小导致全部被裁。"
|
"另外检查 seq_len 是否过小导致全部被裁。"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ====== 正式训练流 + 模数分片(不要求样本数整除 world_size) ======
|
# ====== 正式训练流 ======
|
||||||
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||||
if world_size > 1:
|
if world_size > 1 and len(files) >= world_size:
|
||||||
|
# 多文件,按文件连续分片
|
||||||
ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
|
ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
|
||||||
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
|
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
|
||||||
|
else:
|
||||||
|
# 单文件或文件数不足,按样本取模轮转
|
||||||
|
def ex_iter2():
|
||||||
|
for i, ex in enumerate(ds_stream2):
|
||||||
|
if i % max(world_size, 1) == rank:
|
||||||
|
yield ex
|
||||||
|
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
|
||||||
|
|
||||||
# def ex_iter2():
|
|
||||||
# for i, ex in enumerate(ds_stream2):
|
|
||||||
# if i % max(world_size, 1) == rank:
|
|
||||||
# yield ex
|
|
||||||
# train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
|
|
||||||
|
|
||||||
# ====== 一致性探针:任意 rank 无样本 -> 全体退出 ======
|
|
||||||
def has_one_sample(stream):
|
def has_one_sample(stream):
|
||||||
it = iter(stream)
|
it = iter(stream)
|
||||||
try:
|
try:
|
||||||
next(it); return 1
|
next(it); return 1
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
# ====== 一致性探针(与上面保持同逻辑)=====
|
||||||
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||||
if world_size > 1:
|
if world_size > 1 and len(files) >= world_size:
|
||||||
ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True)
|
ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True)
|
||||||
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
|
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
|
||||||
|
else:
|
||||||
|
def ex_iter2_probe():
|
||||||
|
for i, ex in enumerate(ds_stream_probe2):
|
||||||
|
if i % max(world_size, 1) == rank:
|
||||||
|
yield ex
|
||||||
|
probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
|
||||||
|
|
||||||
# def ex_iter2_probe():
|
|
||||||
# for i, ex in enumerate(ds_stream_probe2):
|
|
||||||
# if i % max(world_size, 1) == rank:
|
|
||||||
# yield ex
|
|
||||||
# probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
|
|
||||||
|
|
||||||
local_ok = has_one_sample(probe_stream)
|
local_ok = has_one_sample(probe_stream)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue