This commit is contained in:
parent
552caf31f1
commit
9225b4b327
|
|
@ -253,21 +253,73 @@ def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
|
|
||||||
|
# -------- 调试打印工具(每个 rank 都打)--------
|
||||||
|
host = socket.gethostname()
|
||||||
|
def dbg(msg):
|
||||||
|
print(
|
||||||
|
f"[dbg][host={host} RANK={os.environ.get('RANK','0')} "
|
||||||
|
f"LOCAL_RANK={os.environ.get('LOCAL_RANK', str(args.local_rank))}] {msg}",
|
||||||
|
flush=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 版本 & 启动参数 & 关键环境变量
|
||||||
|
import transformers as hf
|
||||||
|
try:
|
||||||
|
import deepspeed as ds
|
||||||
|
ds_ver = ds.__version__
|
||||||
|
except Exception:
|
||||||
|
ds_ver = "n/a"
|
||||||
|
dbg(f"torch={torch.__version__}, transformers={hf.__version__}, deepspeed={ds_ver}")
|
||||||
|
dbg(f"args={args}")
|
||||||
|
dbg("ENV: WORLD_SIZE=%s RANK=%s LOCAL_RANK=%s MASTER_ADDR=%s MASTER_PORT=%s CUDA_VISIBLE_DEVICES=%s" % (
|
||||||
|
os.environ.get("WORLD_SIZE"),
|
||||||
|
os.environ.get("RANK"),
|
||||||
|
os.environ.get("LOCAL_RANK", str(args.local_rank)),
|
||||||
|
os.environ.get("MASTER_ADDR"),
|
||||||
|
os.environ.get("MASTER_PORT"),
|
||||||
|
os.environ.get("CUDA_VISIBLE_DEVICES"),
|
||||||
|
))
|
||||||
|
dbg(f"cuda_available={torch.cuda.is_available()} device_count={torch.cuda.device_count()}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ---- 初始化分布式(供一致性探针使用)----
|
# ---- 初始化分布式(供一致性探针使用)----
|
||||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||||
rank = int(os.environ.get("RANK", "0"))
|
rank = int(os.environ.get("RANK", "0"))
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank)))
|
local_rank = int(os.environ.get("LOCAL_RANK", str(args.local_rank)))
|
||||||
|
dbg(f"pre-init: world_size={world_size}, rank={rank}, local_rank={local_rank}")
|
||||||
|
|
||||||
if world_size > 1 and dist.is_available() and not dist.is_initialized():
|
if world_size > 1 and dist.is_available() and not dist.is_initialized():
|
||||||
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
||||||
|
dbg(f"init_process_group backend={backend} via env://")
|
||||||
dist.init_process_group(backend=backend, init_method="env://")
|
dist.init_process_group(backend=backend, init_method="env://")
|
||||||
|
else:
|
||||||
|
dbg(f"skip init_process_group: world_size>1? {world_size>1}, dist_available={dist.is_available()}, already_init={dist.is_initialized()}")
|
||||||
|
|
||||||
if torch.cuda.is_available() and local_rank >= 0:
|
if torch.cuda.is_available() and local_rank >= 0:
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
|
dbg(f"set_device({local_rank}); current_device={torch.cuda.current_device()} "
|
||||||
|
f"name={torch.cuda.get_device_name(torch.cuda.current_device())}")
|
||||||
|
else:
|
||||||
|
dbg("no cuda or invalid local_rank; not calling set_device")
|
||||||
|
|
||||||
|
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
try:
|
||||||
|
dbg(f"dist.get_backend()={dist.get_backend()} "
|
||||||
|
f"dist.get_world_size()={dist.get_world_size()} dist.get_rank()={dist.get_rank()}")
|
||||||
|
except Exception as e:
|
||||||
|
dbg(f"dist query error: {e}")
|
||||||
|
|
||||||
# 1) 先补 tokenizer 的 pad
|
# 1) 先补 tokenizer 的 pad
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True, trust_remote_code=True)
|
||||||
if tokenizer.pad_token is None:
|
if tokenizer.pad_token is None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
tokenizer.model_max_length = args.seq_len
|
tokenizer.model_max_length = args.seq_len
|
||||||
|
dbg(f"tokenizer.pad_token_id={tokenizer.pad_token_id} "
|
||||||
|
f"pad_token={repr(tokenizer.pad_token)} model_max_length={tokenizer.model_max_length}")
|
||||||
|
|
||||||
# 2) 再加载模型
|
# 2) 再加载模型
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
|
@ -277,6 +329,10 @@ def main():
|
||||||
trust_remote_code=True
|
trust_remote_code=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dbg(f"model loaded: dtype={next(model.parameters()).dtype} "
|
||||||
|
f"use_cache={getattr(model.config,'use_cache',None)} "
|
||||||
|
f"pad_token_id={getattr(model.config,'pad_token_id',None)}")
|
||||||
|
|
||||||
# 3) pad/alibi 等配置
|
# 3) pad/alibi 等配置
|
||||||
model.config.pad_token_id = tokenizer.pad_token_id
|
model.config.pad_token_id = tokenizer.pad_token_id
|
||||||
model.config.use_cache = False
|
model.config.use_cache = False
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue