mysora/opensora/utils/misc.py

439 lines
14 KiB
Python

import os
import time
from collections import OrderedDict
from collections.abc import Sequence
from contextlib import nullcontext
import numpy as np
import psutil
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.cluster.dist_coordinator import DistCoordinator
from torch.utils.tensorboard import SummaryWriter
from opensora.acceleration.parallel_states import get_data_parallel_group
from .logger import log_message
def create_tensorboard_writer(exp_dir: str) -> SummaryWriter:
"""
Create a tensorboard writer.
Args:
exp_dir (str): The directory to save tensorboard logs.
Returns:
SummaryWriter: The tensorboard writer.
"""
tensorboard_dir = f"{exp_dir}/tensorboard"
os.makedirs(tensorboard_dir, exist_ok=True)
writer = SummaryWriter(tensorboard_dir)
return writer
# ======================================================
# Memory
# ======================================================
GIGABYTE = 1024**3
def log_cuda_memory(stage: str = None):
"""
Log the current CUDA memory usage.
Args:
stage (str): The stage of the training process.
"""
text = "CUDA memory usage"
if stage is not None:
text += f" at {stage}"
log_message(text + ": %.1f GB", torch.cuda.memory_allocated() / GIGABYTE)
def log_cuda_max_memory(stage: str = None):
"""
Log the max CUDA memory usage.
Args:
stage (str): The stage of the training process.
"""
torch.cuda.synchronize()
max_memory_allocated = torch.cuda.max_memory_allocated()
max_memory_reserved = torch.cuda.max_memory_reserved()
log_message("CUDA max memory max memory allocated at " + stage + ": %.1f GB", max_memory_allocated / GIGABYTE)
log_message("CUDA max memory max memory reserved at " + stage + ": %.1f GB", max_memory_reserved / GIGABYTE)
# ======================================================
# Number of parameters
# ======================================================
def get_model_numel(model: torch.nn.Module) -> tuple[int, int]:
"""
Get the number of parameters in a model.
Args:
model (torch.nn.Module): The model.
Returns:
tuple[int, int]: The total number of parameters and the number of trainable parameters.
"""
num_params = 0
num_params_trainable = 0
for p in model.parameters():
num_params += p.numel()
if p.requires_grad:
num_params_trainable += p.numel()
return num_params, num_params_trainable
def log_model_params(model: nn.Module):
"""
Log the number of parameters in a model.
Args:
model (torch.nn.Module): The model.
"""
num_params, num_params_trainable = get_model_numel(model)
model_name = model.__class__.__name__
log_message(f"[{model_name}] Number of parameters: {format_numel_str(num_params)}")
log_message(f"[{model_name}] Number of trainable parameters: {format_numel_str(num_params_trainable)}")
# ======================================================
# String
# ======================================================
def format_numel_str(numel: int) -> str:
"""
Format a number of elements to a human-readable string.
Args:
numel (int): The number of elements.
Returns:
str: The formatted string.
"""
B = 1024**3
M = 1024**2
K = 1024
if numel >= B:
return f"{numel / B:.2f} B"
elif numel >= M:
return f"{numel / M:.2f} M"
elif numel >= K:
return f"{numel / K:.2f} K"
else:
return f"{numel}"
def format_duration(seconds: int) -> str:
days, remainder = divmod(seconds, 86400) # Extract days
hours, remainder = divmod(remainder, 3600) # Extract hours
minutes, seconds = divmod(remainder, 60) # Extract minutes and seconds
parts = []
if days > 0:
parts.append(f"{days}d")
if hours > 0:
parts.append(f"{hours}h")
if minutes > 0:
parts.append(f"{minutes}m")
if seconds > 0 or not parts: # Always show seconds if nothing else
parts.append(f"{seconds}s")
return " ".join(parts)
# ======================================================
# PyTorch
# ======================================================
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor=tensor, group=get_data_parallel_group())
tensor.div_(dist.get_world_size(group=get_data_parallel_group()))
return tensor
def all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor=tensor, group=get_data_parallel_group())
return tensor
def to_tensor(data: torch.Tensor | np.ndarray | Sequence | int | float) -> torch.Tensor:
"""Convert objects of various python types to :obj:`torch.Tensor`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
Args:
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
be converted.
Returns:
torch.Tensor: The converted tensor.
"""
if isinstance(data, torch.Tensor):
return data
elif isinstance(data, np.ndarray):
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not isinstance(data, str):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
elif isinstance(data, float):
return torch.FloatTensor([data])
else:
raise TypeError(f"type {type(data)} cannot be converted to tensor.")
def to_ndarray(data: torch.Tensor | np.ndarray | Sequence | int | float) -> np.ndarray:
"""Convert objects of various python types to :obj:`numpy.ndarray`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
Args:
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
be converted.
Returns:
numpy.ndarray: The converted ndarray.
"""
if isinstance(data, torch.Tensor):
return data.numpy()
elif isinstance(data, np.ndarray):
return data
elif isinstance(data, Sequence):
return np.array(data)
elif isinstance(data, int):
return np.ndarray([data], dtype=int)
elif isinstance(data, float):
return np.array([data], dtype=float)
else:
raise TypeError(f"type {type(data)} cannot be converted to ndarray.")
def to_torch_dtype(dtype: str | torch.dtype) -> torch.dtype:
"""
Convert a string or a torch.dtype to a torch.dtype.
Args:
dtype (str | torch.dtype): The input dtype.
Returns:
torch.dtype: The converted dtype.
"""
if isinstance(dtype, torch.dtype):
return dtype
elif isinstance(dtype, str):
dtype_mapping = {
"float64": torch.float64,
"float32": torch.float32,
"float16": torch.float16,
"fp32": torch.float32,
"fp16": torch.float16,
"half": torch.float16,
"bf16": torch.bfloat16,
}
if dtype not in dtype_mapping:
raise ValueError(f"Unsupported dtype {dtype}")
dtype = dtype_mapping[dtype]
return dtype
else:
raise ValueError(f"Unsupported dtype {dtype}")
# ======================================================
# Profile
# ======================================================
class Timer:
def __init__(self, name, log=False, barrier=False, coordinator: DistCoordinator | None = None):
self.name = name
self.start_time = None
self.end_time = None
self.log = log
self.barrier = barrier
self.coordinator = coordinator
@property
def elapsed_time(self) -> float:
return self.end_time - self.start_time
def __enter__(self):
torch.cuda.synchronize()
if self.barrier:
dist.barrier()
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.coordinator is not None:
self.coordinator.block_all()
torch.cuda.synchronize()
if self.barrier:
dist.barrier()
self.end_time = time.time()
if self.log:
print(f"Elapsed time for {self.name}: {self.elapsed_time:.2f} s")
class Timers:
def __init__(self, record_time: bool, record_barrier: bool = False, coordinator: DistCoordinator | None = None):
self.timers = OrderedDict()
self.record_time = record_time
self.record_barrier = record_barrier
self.coordinator = coordinator
def __getitem__(self, name: str) -> Timer:
if name not in self.timers:
if self.record_time:
self.timers[name] = Timer(name, barrier=self.record_barrier, coordinator=self.coordinator)
else:
self.timers[name] = nullcontext()
return self.timers[name]
def to_dict(self):
return {f"time_debug/{name}": timer.elapsed_time for name, timer in self.timers.items()}
def to_str(self, epoch: int, step: int) -> str:
log_str = f"Rank {dist.get_rank()} | Epoch {epoch} | Step {step} | "
for name, timer in self.timers.items():
log_str += f"{name}: {timer.elapsed_time:.2f} s | "
return log_str
def is_pipeline_enabled(plugin_type: str, plugin_config: dict) -> bool:
return plugin_type == "hybrid" and plugin_config.get("pp_size", 1) > 1
def is_log_process(plugin_type: str, plugin_config: dict) -> bool:
if is_pipeline_enabled(plugin_type, plugin_config):
return dist.get_rank() == dist.get_world_size() - 1
return dist.get_rank() == 0
class NsysRange:
def __init__(self, range_name: str):
self.range_name = range_name
def __enter__(self):
torch.cuda.nvtx.range_push(self.range_name)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.nvtx.range_pop()
class NsysProfiler:
"""
Use NVIDIA Nsight Systems to profile the code.
Example (~30MB):
```bash
/home/zhengzangwei/nsight-systems-2024.7.1/bin/nsys profile -w true -t cuda,nvtx,osrt,cudnn,cublas --capture-range=cudaProfilerApi --capture-range-end=stop-shutdown -o cache/nsys/report2 \
torchrun --nproc_per_node 8 scripts/diffusion/train.py configs/diffusion/train/stage2.py --nsys True --dataset.data-path /mnt/ddn/sora/meta/train/all_till_20241115_chunk901+img7.6M.parquet
```
Example (~130MB + 2G):
```bash
/home/zhengzangwei/nsight-systems-2024.7.1/bin/nsys profile -w true -t cuda,nvtx,osrt,cudnn,cublas --capture-range=cudaProfilerApi --capture-range-end=stop-shutdown -s process-tree --cudabacktrace=all --stats=true -o cache/nsys/report5 \
torchrun --nproc_per_node 8 scripts/diffusion/train.py configs/diffusion/train/stage2.py --nsys True --dataset.data-path /mnt/ddn/sora/meta/train/all_till_20241115_chunk901+img7.6M.parquet --record_time True --record_barrier True
```
To generate summary statistics, use `--stats=true`.
To disable stack traces, use use `-s none --cudabacktrace=none`.
To use stack traces, use `-s process-tree --cudabacktrace=all`.
To enable timer, use `--record_time True --record_barrier True` for `scripts/diffusion/train.py`.
"""
def __init__(self, warmup_steps: int = 0, num_steps: int = 1, enabled: bool = True):
self.warmup_steps = warmup_steps
self.num_steps = num_steps
self.current_step = 0
self.enabled = enabled
def step(self):
if not self.enabled:
return
self.current_step += 1
if self.current_step == self.warmup_steps:
torch.cuda.cudart().cudaProfilerStart()
elif self.current_step >= self.warmup_steps + self.num_steps:
torch.cuda.cudart().cudaProfilerStop()
def range(self, range_name: str) -> NsysRange:
if not self.enabled:
return nullcontext()
return NsysRange(range_name)
class ProfilerContext:
def __init__(
self,
save_path: str = "./log",
record_shapes: bool = False,
with_stack: bool = True,
wait: int = 1,
warmup: int = 1,
active: int = 1,
repeat: int = 1,
enable: bool = True,
**kwargs,
):
self.enable = enable
self.prof = None
self.step_cnt = 0
self.total_steps = (wait + warmup + active) * repeat
if enable:
self.prof = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat),
record_shapes=record_shapes,
with_stack=with_stack,
on_trace_ready=torch.profiler.tensorboard_trace_handler(save_path),
**kwargs,
)
def step(self):
if self.enable:
if self.step_cnt == 0:
self.prof.__enter__()
self.prof.step()
self.step_cnt += 1
if self.is_profile_end():
self.prof.__exit__(None, None, None)
exit(0)
def is_profile_end(self):
return self.step_cnt >= self.total_steps
def get_process_mem():
process = psutil.Process(os.getpid())
return process.memory_info().rss / 1024**3
def get_total_mem():
return psutil.virtual_memory().used / 1024**3
def print_mem(prefix: str = ""):
rank = dist.get_rank()
print(
f"[{rank}] {prefix} process memory: {get_process_mem():.2f} GB, total memory: {get_total_mem():.2f} GB",
flush=True,
)