This commit is contained in:
parent
eaa79e566c
commit
4739fc615d
|
|
@ -232,11 +232,7 @@ class SFTDataCollator:
|
||||||
self.pad_to_length = pad_to_length
|
self.pad_to_length = pad_to_length
|
||||||
assert self.tok.pad_token_id is not None
|
assert self.tok.pad_token_id is not None
|
||||||
|
|
||||||
def __call__(self, features):
|
def __call__(self, features):
|
||||||
# if not features:
|
|
||||||
# raise RuntimeError(f"EMPTY BATCH in collator on rank={os.environ.get('RANK','0')}. "
|
|
||||||
# f"Check sampler/sharding & make eval size >= world_size * per_device_eval_batch_size.")
|
|
||||||
|
|
||||||
if not features:
|
if not features:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
|
f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
|
||||||
|
|
@ -274,14 +270,6 @@ class SFTDataCollator:
|
||||||
f"target_len={target_len} first_len={first_len}",
|
f"target_len={target_len} first_len={first_len}",
|
||||||
flush=True
|
flush=True
|
||||||
)
|
)
|
||||||
# 额外严苛校验:防止空 batch 继续往下走
|
|
||||||
# if not features:
|
|
||||||
# raise RuntimeError(
|
|
||||||
# f"[FATAL][RANK={os.environ.get('RANK','?')}] Empty batch reached collator. "
|
|
||||||
# f"Check dataset sharding/streaming."
|
|
||||||
# )
|
|
||||||
# >>> DEBUG END
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": torch.stack(batch_inp, dim=0),
|
"input_ids": torch.stack(batch_inp, dim=0),
|
||||||
"attention_mask": torch.stack(batch_attn, dim=0),
|
"attention_mask": torch.stack(batch_attn, dim=0),
|
||||||
|
|
@ -322,6 +310,7 @@ def parse_args():
|
||||||
help="for deepspeed/torchrun launcher; ignored by user code")
|
help="for deepspeed/torchrun launcher; ignored by user code")
|
||||||
ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
|
ap.add_argument("--per_device_eval_batch_size", type=int, default=1)
|
||||||
ap.add_argument("--deepspeed", type=str, default=None)
|
ap.add_argument("--deepspeed", type=str, default=None)
|
||||||
|
|
||||||
return ap.parse_args()
|
return ap.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -451,31 +440,39 @@ def main():
|
||||||
"另外检查 seq_len 是否过小导致全部被裁。"
|
"另外检查 seq_len 是否过小导致全部被裁。"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ====== 正式训练流 ======
|
# # ====== 正式训练流 ======
|
||||||
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
# ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||||
if world_size > 1 and len(files) >= world_size:
|
# if world_size > 1 and len(files) >= world_size:
|
||||||
# 多文件,按文件连续分片
|
# # 多文件,按文件连续分片
|
||||||
ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
|
# ds_stream2 = ds_stream2.shard(num_shards=world_size, index=rank, contiguous=True)
|
||||||
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
|
# train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
|
||||||
else:
|
# else:
|
||||||
# 单文件或文件数不足,按样本取模轮转
|
# # 单文件或文件数不足,按样本取模轮转
|
||||||
def ex_iter2():
|
# def ex_iter2():
|
||||||
for i, ex in enumerate(ds_stream2):
|
# for i, ex in enumerate(ds_stream2):
|
||||||
if i % max(world_size, 1) == rank:
|
# if i % max(world_size, 1) == rank:
|
||||||
yield ex
|
# yield ex
|
||||||
train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
|
# train_stream = QwenChatSFTDataset(ex_iter2(), tokenizer, seq_len=args.seq_len)
|
||||||
|
|
||||||
# ====== 一致性探针(与上面保持同逻辑)=====
|
# ====== 正式训练流(不做任何手动分片,交给 Accelerate/Trainer)======
|
||||||
|
ds_stream2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||||
|
train_stream = QwenChatSFTDataset((ex for ex in ds_stream2), tokenizer, seq_len=args.seq_len)
|
||||||
|
|
||||||
|
# # ====== 一致性探针(与上面保持同逻辑)=====
|
||||||
|
# ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||||
|
# if world_size > 1 and len(files) >= world_size:
|
||||||
|
# ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True)
|
||||||
|
# probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
|
||||||
|
# else:
|
||||||
|
# def ex_iter2_probe():
|
||||||
|
# for i, ex in enumerate(ds_stream_probe2):
|
||||||
|
# if i % max(world_size, 1) == rank:
|
||||||
|
# yield ex
|
||||||
|
# probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
|
||||||
|
|
||||||
|
# ====== 一致性探针(不分片)======
|
||||||
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
ds_stream_probe2 = load_dataset("json", data_files={"train": files}, split="train", streaming=True)
|
||||||
if world_size > 1 and len(files) >= world_size:
|
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
|
||||||
ds_stream_probe2 = ds_stream_probe2.shard(num_shards=world_size, index=rank, contiguous=True)
|
|
||||||
probe_stream = QwenChatSFTDataset((ex for ex in ds_stream_probe2), tokenizer, seq_len=args.seq_len)
|
|
||||||
else:
|
|
||||||
def ex_iter2_probe():
|
|
||||||
for i, ex in enumerate(ds_stream_probe2):
|
|
||||||
if i % max(world_size, 1) == rank:
|
|
||||||
yield ex
|
|
||||||
probe_stream = QwenChatSFTDataset(ex_iter2_probe(), tokenizer, seq_len=args.seq_len)
|
|
||||||
|
|
||||||
def has_at_least(stream, n: int):
|
def has_at_least(stream, n: int):
|
||||||
it = iter(stream)
|
it = iter(stream)
|
||||||
|
|
@ -511,22 +508,6 @@ def main():
|
||||||
)
|
)
|
||||||
sys.exit(2)
|
sys.exit(2)
|
||||||
|
|
||||||
# if dist.is_available() and dist.is_initialized():
|
|
||||||
# # t = torch.tensor(local_ok, device=("cuda" if torch.cuda.is_available() else "cpu"))
|
|
||||||
# t = torch.tensor(
|
|
||||||
# local_ok,
|
|
||||||
# device=(f"cuda:{local_rank}" if torch.cuda.is_available() and local_rank >= 0 else "cpu")
|
|
||||||
# )
|
|
||||||
# dist.all_reduce(t, op=dist.ReduceOp.MIN)
|
|
||||||
# if t.item() == 0:
|
|
||||||
# if is_main_process():
|
|
||||||
# print("[FATAL] 至少有一个 rank 没有任何样本。请减少 WORLD_SIZE 或修正分片;本次训练不会启动。", flush=True)
|
|
||||||
# dist.barrier()
|
|
||||||
# sys.exit(2)
|
|
||||||
# else:
|
|
||||||
# if local_ok == 0:
|
|
||||||
# print("[FATAL] 本机无样本,退出。", flush=True); sys.exit(2)
|
|
||||||
|
|
||||||
# ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ----
|
# ---- Eval 构造:优先使用 --eval_data_glob;否则才用 eval_ratio 抽样 ----
|
||||||
eval_dataset: Optional[Dataset] = None
|
eval_dataset: Optional[Dataset] = None
|
||||||
|
|
||||||
|
|
@ -595,7 +576,6 @@ def main():
|
||||||
f"eval size {len(eval_dataset)} still not divisible by global_bs {global_bs}"
|
f"eval size {len(eval_dataset)} still not divisible by global_bs {global_bs}"
|
||||||
|
|
||||||
# 更稳:联调阶段不强行 pad 到 4096
|
# 更稳:联调阶段不强行 pad 到 4096
|
||||||
# data_collator = SFTDataCollator(tokenizer, pad_to_length=None)
|
|
||||||
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
|
data_collator = SFTDataCollator(tokenizer, pad_to_length=args.seq_len)
|
||||||
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
@ -661,14 +641,6 @@ def main():
|
||||||
data_collator=data_collator
|
data_collator=data_collator
|
||||||
)
|
)
|
||||||
|
|
||||||
# trainer = Trainer(
|
|
||||||
# model=model,
|
|
||||||
# args=training_args,
|
|
||||||
# train_dataset=train_stream,
|
|
||||||
# eval_dataset=eval_dataset,
|
|
||||||
# processing_class=tokenizer,
|
|
||||||
# data_collator=data_collator
|
|
||||||
# )
|
|
||||||
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
trainer.add_callback(CsvLossLogger(csv_path=os.path.join(args.output_dir, "loss.csv")))
|
||||||
|
|
||||||
# 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*
|
# 无共享盘:各节点本地 output_dir 下是否已有 checkpoint-*
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue