This commit is contained in:
parent
9bb6ee9307
commit
746b8471ee
|
|
@ -447,11 +447,15 @@ def main():
|
|||
|
||||
# ====== 正式训练流 + 模数分片(不要求样本数整除 world_size) ======
|
||||
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||
def ex_iter2():
|
||||
for i, ex in enumerate(ds_stream2):
|
||||
if i % max(world_size, 1) == rank:
|
||||
yield ex
|
||||
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
|
||||
if world_size > 1:
|
||||
ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
|
||||
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
|
||||
|
||||
# def ex_iter2():
|
||||
# for i, ex in enumerate(ds_stream2):
|
||||
# if i % max(world_size, 1) == rank:
|
||||
# yield ex
|
||||
# train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
|
||||
|
||||
# ====== 一致性探针:任意 rank 无样本 -> 全体退出 ======
|
||||
def has_one_sample(stream):
|
||||
|
|
@ -462,11 +466,16 @@ def main():
|
|||
return 0
|
||||
|
||||
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||
def ex_iter2_probe():
|
||||
for i, ex in enumerate(ds_stream_probe2):
|
||||
if i % max(world_size, 1) == rank:
|
||||
yield ex
|
||||
probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
|
||||
if world_size > 1:
|
||||
ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True)
|
||||
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
|
||||
|
||||
# def ex_iter2_probe():
|
||||
# for i, ex in enumerate(ds_stream_probe2):
|
||||
# if i % max(world_size, 1) == rank:
|
||||
# yield ex
|
||||
# probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
|
||||
|
||||
local_ok = has_one_sample(probe_stream)
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
|
|
@ -614,8 +623,8 @@ def main():
|
|||
args=training_args,
|
||||
train_dataset=train_stream,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
# processing_class=tokenizer,
|
||||
#tokenizer=tokenizer,
|
||||
processing_class=tokenizer,
|
||||
data_collator=data_collator
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue