525 lines
20 KiB
Python
525 lines
20 KiB
Python
import functools
|
|
import json
|
|
import operator
|
|
import os
|
|
import re
|
|
import shutil
|
|
from glob import glob
|
|
from typing import Dict, Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from colossalai.booster import Booster
|
|
from colossalai.checkpoint_io import GeneralCheckpointIO
|
|
from colossalai.utils.safetensors import save as async_save
|
|
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
|
from huggingface_hub import hf_hub_download
|
|
from safetensors.torch import load_file
|
|
from tensornvme.async_file_io import AsyncFileWriter
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
|
|
from opensora.acceleration.parallel_states import get_data_parallel_group
|
|
|
|
from .logger import log_message
|
|
|
|
hf_endpoint = os.environ.get("HF_ENDPOINT")
|
|
if hf_endpoint is None:
|
|
hf_endpoint = "https://huggingface.co"
|
|
os.environ["TENSORNVME_DEBUG"] = "1"
|
|
|
|
|
|
def load_from_hf_hub(repo_path: str, cache_dir: str = None) -> str:
|
|
"""
|
|
Loads a checkpoint from the Hugging Face Hub.
|
|
|
|
Args:
|
|
repo_path (str): The path to the checkpoint on the Hugging Face Hub.
|
|
cache_dir (str): The directory to cache the downloaded checkpoint.
|
|
|
|
Returns:
|
|
str: The path to the downloaded checkpoint.
|
|
"""
|
|
repo_id = "/".join(repo_path.split("/")[:-1])
|
|
repo_file = repo_path.split("/")[-1]
|
|
ckpt_path = hf_hub_download(repo_id=repo_id, filename=repo_file, cache_dir=cache_dir)
|
|
return ckpt_path
|
|
|
|
|
|
def load_from_sharded_state_dict(model: nn.Module, ckpt_path: str, model_name: str = "model", strict=False):
|
|
"""
|
|
Loads a model from a sharded checkpoint.
|
|
|
|
Args:
|
|
model (nn.Module): The model to load the checkpoint into.
|
|
ckpt_path (str): The path to the checkpoint.
|
|
model_name (str): The name of the model in the checkpoint.
|
|
strict (bool): Whether to strictly enforce that the keys in the checkpoint match the keys in the model.
|
|
"""
|
|
ckpt_io = GeneralCheckpointIO()
|
|
ckpt_io.load_model(model, os.path.join(ckpt_path, model_name), strict=strict)
|
|
|
|
|
|
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
|
|
"""
|
|
Prints a warning if there are missing or unexpected keys when loading a model.
|
|
|
|
Args:
|
|
missing (list[str]): The missing keys.
|
|
unexpected (list[str]): The unexpected keys.
|
|
"""
|
|
if len(missing) > 0 and len(unexpected) > 0:
|
|
log_message(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
|
log_message("\n" + "-" * 79 + "\n")
|
|
log_message(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
|
elif len(missing) > 0:
|
|
log_message(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
|
elif len(unexpected) > 0:
|
|
log_message(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
|
else:
|
|
log_message("Model loaded successfully")
|
|
|
|
|
|
def load_checkpoint(
|
|
model: nn.Module,
|
|
path: str,
|
|
cache_dir: str = None,
|
|
device_map: torch.device | str = "cpu",
|
|
cai_model_name: str = "model",
|
|
strict: bool = False,
|
|
rename_keys: dict = None, # rename keys in the checkpoint to support fine-tuning with a different model architecture; map old_key_prefix to new_key_prefix
|
|
) -> nn.Module:
|
|
"""
|
|
Loads a checkpoint into model from a path. Support three types of checkpoints:
|
|
1. huggingface safetensors
|
|
2. local .pt or .pth
|
|
3. colossalai sharded checkpoint
|
|
|
|
Args:
|
|
model (nn.Module): The model to load the checkpoint into.
|
|
path (str): The path to the checkpoint.
|
|
cache_dir (str): The directory to cache the downloaded checkpoint.
|
|
device_map (torch.device | str): The device to map the checkpoint to.
|
|
cai_model_name (str): The name of the model in the checkpoint.
|
|
|
|
Returns:
|
|
nn.Module: The model with the loaded checkpoint.
|
|
"""
|
|
if not os.path.exists(path):
|
|
log_message(f"Checkpoint not found at {path}, trying to download from Hugging Face Hub")
|
|
path = load_from_hf_hub(path, cache_dir)
|
|
assert os.path.exists(path), f"Could not find checkpoint at {path}"
|
|
|
|
log_message(f"Loading checkpoint from {path}")
|
|
if path.endswith(".safetensors"):
|
|
ckpt = load_file(path, device='cpu')
|
|
|
|
if rename_keys is not None:
|
|
# rename keys in the loaded state_dict with old_key_prefix to with new_key_prefix.
|
|
renamed_ckpt = {}
|
|
for old_key, v in ckpt.items():
|
|
new_key = old_key
|
|
for old_key_prefix, new_key_prefix in rename_keys.items():
|
|
if old_key_prefix in old_key:
|
|
new_key = old_key.replace(old_key_prefix, new_key_prefix)
|
|
print(f"Renamed {old_key} to {new_key} in the loaded state_dict")
|
|
break
|
|
renamed_ckpt[new_key] = v
|
|
ckpt = renamed_ckpt
|
|
|
|
missing, unexpected = model.load_state_dict(ckpt, strict=strict)
|
|
print_load_warning(missing, unexpected)
|
|
elif path.endswith(".pt") or path.endswith(".pth"):
|
|
ckpt = torch.load(path, map_location=device_map)
|
|
missing, unexpected = model.load_state_dict(ckpt, strict=strict)
|
|
print_load_warning(missing, unexpected)
|
|
else:
|
|
assert os.path.isdir(path), f"Invalid checkpoint path: {path}"
|
|
load_from_sharded_state_dict(model, path, model_name=cai_model_name, strict=strict)
|
|
return model
|
|
|
|
|
|
def rm_checkpoints(
|
|
save_dir: str,
|
|
keep_n_latest: int = 0,
|
|
):
|
|
"""
|
|
Remove old checkpoints.
|
|
|
|
Args:
|
|
save_dir (str): The directory to save the checkpoints.
|
|
keep_n_latest (int): The number of latest checkpoints to keep.
|
|
"""
|
|
if keep_n_latest <= 0 or dist.get_rank() != 0:
|
|
return
|
|
files = glob(os.path.join(save_dir, "epoch*-global_step*"))
|
|
files = sorted(
|
|
files, key=lambda s: tuple(map(int, re.search(r"epoch(\d+)-global_step(\d+)", s).groups())), reverse=True
|
|
)
|
|
to_remove = files[keep_n_latest:]
|
|
for f in to_remove:
|
|
# shutil.rmtree(f)
|
|
for item in glob(os.path.join(f, "*")):
|
|
if os.path.isdir(item):
|
|
dir_name = os.path.basename(item)
|
|
if dir_name != "eval":
|
|
shutil.rmtree(item)
|
|
else:
|
|
os.remove(item)
|
|
|
|
|
|
def model_sharding(model: torch.nn.Module, device: torch.device = None):
|
|
"""
|
|
Sharding the model parameters across multiple GPUs.
|
|
|
|
Args:
|
|
model (torch.nn.Module): The model to shard.
|
|
device (torch.device): The device to shard the model to.
|
|
"""
|
|
global_rank = dist.get_rank()
|
|
world_size = dist.get_world_size()
|
|
for _, param in model.named_parameters():
|
|
if device is None:
|
|
device = param.device
|
|
padding_size = (world_size - param.numel() % world_size) % world_size
|
|
if padding_size > 0:
|
|
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
|
|
else:
|
|
padding_param = param.data.view(-1)
|
|
splited_params = padding_param.split(padding_param.numel() // world_size)
|
|
splited_params = splited_params[global_rank]
|
|
param.data = splited_params.to(device)
|
|
|
|
|
|
def model_gathering(model: torch.nn.Module, model_shape_dict: dict, pinned_state_dict: dict) -> None:
|
|
"""
|
|
Gather the model parameters from multiple GPUs.
|
|
|
|
Args:
|
|
model (torch.nn.Module): The model to gather.
|
|
model_shape_dict (dict): The shape of the model parameters.
|
|
device (torch.device): The device to gather the model to.
|
|
"""
|
|
global_rank = dist.get_rank()
|
|
global_size = dist.get_world_size()
|
|
params = set()
|
|
for name, param in model.named_parameters():
|
|
params.add(name)
|
|
all_params = [torch.empty_like(param.data) for _ in range(global_size)]
|
|
dist.all_gather(all_params, param.data, group=dist.group.WORLD)
|
|
if int(global_rank) == 0:
|
|
all_params = torch.cat(all_params)
|
|
gathered_param = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name])
|
|
pinned_state_dict[name].copy_(gathered_param)
|
|
if int(global_rank) == 0:
|
|
for k, v in model.state_dict(keep_vars=True).items():
|
|
if k not in params:
|
|
pinned_state_dict[k].copy_(v)
|
|
|
|
dist.barrier()
|
|
|
|
|
|
def remove_padding(tensor: torch.Tensor, original_shape: tuple) -> torch.Tensor:
|
|
"""
|
|
Remove padding from a tensor.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): The tensor to remove padding from.
|
|
original_shape (tuple): The original shape of the tensor.
|
|
"""
|
|
return tensor[: functools.reduce(operator.mul, original_shape)]
|
|
|
|
|
|
def record_model_param_shape(model: torch.nn.Module) -> dict:
|
|
"""
|
|
Record the shape of the model parameters.
|
|
|
|
Args:
|
|
model (torch.nn.Module): The model to record the parameter shape of.
|
|
|
|
Returns:
|
|
dict: The shape of the model parameters.
|
|
"""
|
|
param_shape = {}
|
|
for name, param in model.named_parameters():
|
|
param_shape[name] = param.shape
|
|
return param_shape
|
|
|
|
|
|
def load_json(file_path: str) -> dict:
|
|
"""
|
|
Load a JSON file.
|
|
|
|
Args:
|
|
file_path (str): The path to the JSON file.
|
|
|
|
Returns:
|
|
dict: The loaded JSON file.
|
|
"""
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def save_json(data, file_path: str):
|
|
"""
|
|
Save a dictionary to a JSON file.
|
|
|
|
Args:
|
|
data: The dictionary to save.
|
|
file_path (str): The path to save the JSON file.
|
|
"""
|
|
with open(file_path, "w", encoding="utf-8") as f:
|
|
json.dump(data, f, indent=4)
|
|
|
|
|
|
def _prepare_ema_pinned_state_dict(model: nn.Module, ema_shape_dict: dict):
|
|
ema_pinned_state_dict = dict()
|
|
for name, p in model.named_parameters():
|
|
ema_pinned_state_dict[name] = torch.empty(ema_shape_dict[name], pin_memory=True, device="cpu", dtype=p.dtype)
|
|
sd = model.state_dict(keep_vars=True)
|
|
# handle buffers
|
|
for k, v in sd.items():
|
|
if k not in ema_pinned_state_dict:
|
|
ema_pinned_state_dict[k] = torch.empty(v.shape, pin_memory=True, device="cpu", dtype=v.dtype)
|
|
|
|
return ema_pinned_state_dict
|
|
|
|
|
|
def _search_valid_path(path: str) -> str:
|
|
if os.path.exists(f"{path}.safetensors"):
|
|
return f"{path}.safetensors"
|
|
elif os.path.exists(f"{path}.pt"):
|
|
return f"{path}.pt"
|
|
return path
|
|
|
|
|
|
def master_weights_gathering(model: torch.nn.Module, optimizer: LowLevelZeroOptimizer, pinned_state_dict: dict) -> None:
|
|
"""
|
|
Gather the model parameters from multiple GPUs.
|
|
|
|
Args:
|
|
model (torch.nn.Module): The model to gather.
|
|
model_shape_dict (dict): The shape of the model parameters.
|
|
device (torch.device): The device to gather the model to.
|
|
"""
|
|
w2m = optimizer.get_working_to_master_map()
|
|
for name, param in model.named_parameters():
|
|
master_p = w2m[id(param)]
|
|
zero_pg = optimizer.param_to_pg[param]
|
|
world_size = dist.get_world_size(zero_pg)
|
|
all_params = [torch.empty_like(master_p) for _ in range(world_size)]
|
|
dist.all_gather(all_params, master_p, group=zero_pg)
|
|
if dist.get_rank() == 0:
|
|
all_params = torch.cat(all_params)
|
|
gathered_param = remove_padding(all_params, param.shape).view(param.shape)
|
|
pinned_state_dict[name].copy_(gathered_param)
|
|
|
|
dist.barrier()
|
|
|
|
|
|
def load_master_weights(model: torch.nn.Module, optimizer: LowLevelZeroOptimizer, state_dict: dict) -> None:
|
|
pg = get_data_parallel_group(get_mixed_dp_pg=True)
|
|
world_size = dist.get_world_size(pg)
|
|
rank = dist.get_rank(pg)
|
|
w2m = optimizer.get_working_to_master_map()
|
|
for name, param in model.named_parameters():
|
|
master_p = w2m[id(param)]
|
|
state = state_dict[name].view(-1)
|
|
padding_size = len(master_p) * world_size - len(state)
|
|
state = torch.nn.functional.pad(state, [0, padding_size])
|
|
target_chunk = state.chunk(world_size)[rank].to(master_p.dtype)
|
|
master_p[: len(target_chunk)].copy_(target_chunk)
|
|
|
|
|
|
class CheckpointIO:
|
|
def __init__(self, n_write_entries: int = 32):
|
|
self.n_write_entries = n_write_entries
|
|
self.writer: Optional[AsyncFileWriter] = None
|
|
self.pinned_state_dict: Optional[Dict[str, torch.Tensor]] = None
|
|
self.master_pinned_state_dict: Optional[Dict[str, torch.Tensor]] = None
|
|
self.master_writer: Optional[AsyncFileWriter] = None
|
|
|
|
def _sync_io(self):
|
|
if self.writer is not None:
|
|
self.writer.synchronize()
|
|
self.writer = None
|
|
if self.master_writer is not None:
|
|
self.master_writer.synchronize()
|
|
self.master_writer = None
|
|
|
|
def __del__(self):
|
|
self._sync_io()
|
|
|
|
def _prepare_pinned_state_dict(self, ema: nn.Module, ema_shape_dict: dict):
|
|
if self.pinned_state_dict is None and dist.get_rank() == 0:
|
|
self.pinned_state_dict = _prepare_ema_pinned_state_dict(ema, ema_shape_dict)
|
|
|
|
def _prepare_master_pinned_state_dict(self, model: nn.Module, optimizer: LowLevelZeroOptimizer):
|
|
if self.master_pinned_state_dict is None and dist.get_rank() == 0:
|
|
sd = {}
|
|
w2m = optimizer.get_working_to_master_map()
|
|
for n, p in model.named_parameters():
|
|
master_p = w2m[id(p)]
|
|
sd[n] = torch.empty(p.shape, dtype=master_p.dtype, pin_memory=True, device="cpu")
|
|
self.master_pinned_state_dict = sd
|
|
|
|
def save(
|
|
self,
|
|
booster: Booster,
|
|
save_dir: str,
|
|
model: nn.Module = None,
|
|
ema: nn.Module = None,
|
|
optimizer: Optimizer = None,
|
|
lr_scheduler: _LRScheduler = None,
|
|
sampler=None,
|
|
epoch: int = None,
|
|
step: int = None,
|
|
global_step: int = None,
|
|
batch_size: int = None,
|
|
lora: bool = False,
|
|
actual_update_step: int = None,
|
|
ema_shape_dict: dict = None,
|
|
async_io: bool = True,
|
|
include_master_weights: bool = False,
|
|
) -> str:
|
|
"""
|
|
Save a checkpoint.
|
|
|
|
Args:
|
|
booster (Booster): The Booster object.
|
|
save_dir (str): The directory to save the checkpoint to.
|
|
model (nn.Module): The model to save the checkpoint from.
|
|
ema (nn.Module): The EMA model to save the checkpoint from.
|
|
optimizer (Optimizer): The optimizer to save the checkpoint from.
|
|
lr_scheduler (_LRScheduler): The learning rate scheduler to save the checkpoint from.
|
|
sampler: The sampler to save the checkpoint from.
|
|
epoch (int): The epoch of the checkpoint.
|
|
step (int): The step of the checkpoint.
|
|
global_step (int): The global step of the checkpoint.
|
|
batch_size (int): The batch size of the checkpoint.
|
|
lora (bool): Whether the model is trained with LoRA.
|
|
|
|
Returns:
|
|
str: The path to the saved checkpoint
|
|
"""
|
|
self._sync_io()
|
|
save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{actual_update_step}")
|
|
os.environ["TENSORNVME_DEBUG_LOG"] = os.path.join(save_dir, "async_file_io.log")
|
|
if model is not None:
|
|
if not lora:
|
|
os.makedirs(os.path.join(save_dir, "model"), exist_ok=True)
|
|
booster.save_model(
|
|
model,
|
|
os.path.join(save_dir, "model"),
|
|
shard=True,
|
|
use_safetensors=True,
|
|
size_per_shard=4096,
|
|
use_async=async_io,
|
|
)
|
|
else:
|
|
os.makedirs(os.path.join(save_dir, "lora"), exist_ok=True)
|
|
booster.save_lora_as_pretrained(model, os.path.join(save_dir, "lora"))
|
|
if optimizer is not None:
|
|
booster.save_optimizer(
|
|
optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096, use_async=async_io
|
|
)
|
|
if include_master_weights:
|
|
self._prepare_master_pinned_state_dict(model, optimizer)
|
|
master_weights_gathering(model, optimizer, self.master_pinned_state_dict)
|
|
if lr_scheduler is not None:
|
|
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
|
if ema is not None:
|
|
self._prepare_pinned_state_dict(ema, ema_shape_dict)
|
|
model_gathering(ema, ema_shape_dict, self.pinned_state_dict)
|
|
if dist.get_rank() == 0:
|
|
running_states = {
|
|
"epoch": epoch,
|
|
"step": step,
|
|
"global_step": global_step,
|
|
"batch_size": batch_size,
|
|
"actual_update_step": actual_update_step,
|
|
}
|
|
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
|
|
|
if ema is not None:
|
|
if async_io:
|
|
self.writer = async_save(os.path.join(save_dir, "ema.safetensors"), self.pinned_state_dict)
|
|
else:
|
|
torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt"))
|
|
|
|
if sampler is not None:
|
|
# only for VariableVideoBatchSampler
|
|
torch.save(sampler.state_dict(step), os.path.join(save_dir, "sampler"))
|
|
|
|
if optimizer is not None and include_master_weights:
|
|
self.master_writer = async_save(
|
|
os.path.join(save_dir, "master.safetensors"), self.master_pinned_state_dict
|
|
)
|
|
|
|
dist.barrier()
|
|
return save_dir
|
|
|
|
def load(
|
|
self,
|
|
booster: Booster,
|
|
load_dir: str,
|
|
model: nn.Module = None,
|
|
ema: nn.Module = None,
|
|
optimizer: Optimizer = None,
|
|
lr_scheduler: _LRScheduler = None,
|
|
sampler=None,
|
|
strict: bool = False,
|
|
include_master_weights: bool = False,
|
|
) -> tuple[int, int]:
|
|
"""
|
|
Load a checkpoint.
|
|
|
|
Args:
|
|
booster (Booster): The Booster object.
|
|
load_dir (str): The directory to load the checkpoint from.
|
|
model (nn.Module): The model to load the checkpoint into.
|
|
ema (nn.Module): The EMA model to load the checkpoint into.
|
|
optimizer (Optimizer): The optimizer to load the checkpoint into.
|
|
lr_scheduler (_LRScheduler): The learning rate scheduler to load the checkpoint into.
|
|
sampler: The sampler to load the checkpoint into.
|
|
|
|
Returns:
|
|
tuple[int, int]: The epoch and step of the checkpoint.
|
|
"""
|
|
assert os.path.exists(load_dir), f"Checkpoint directory {load_dir} does not exist"
|
|
assert os.path.exists(os.path.join(load_dir, "running_states.json")), "running_states.json does not exist"
|
|
|
|
running_states = load_json(os.path.join(load_dir, "running_states.json"))
|
|
if model is not None:
|
|
booster.load_model(
|
|
model,
|
|
_search_valid_path(os.path.join(load_dir, "model")),
|
|
strict=strict,
|
|
low_cpu_mem_mode=False,
|
|
num_threads=32,
|
|
)
|
|
if ema is not None:
|
|
if os.path.exists(os.path.join(load_dir, "ema.safetensors")):
|
|
ema_state_dict = load_file(os.path.join(load_dir, "ema.safetensors"))
|
|
else:
|
|
ema_state_dict = torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu"))
|
|
# ema is not boosted, so we don't use booster.load_model
|
|
ema.load_state_dict(ema_state_dict, strict=strict, assign=True)
|
|
|
|
if optimizer is not None:
|
|
booster.load_optimizer(
|
|
optimizer, os.path.join(load_dir, "optimizer"), low_cpu_mem_mode=False, num_threads=32
|
|
)
|
|
if include_master_weights:
|
|
master_state_dict = load_file(os.path.join(load_dir, "master.safetensors"))
|
|
load_master_weights(model, optimizer, master_state_dict)
|
|
if lr_scheduler is not None:
|
|
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler"))
|
|
if sampler is not None:
|
|
sampler.load_state_dict(torch.load(os.path.join(load_dir, "sampler")))
|
|
|
|
dist.barrier()
|
|
|
|
return (running_states["epoch"], running_states["step"])
|