91 lines
2.3 KiB
Python
91 lines
2.3 KiB
Python
import logging
|
|
import os
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
def is_distributed() -> bool:
|
|
"""
|
|
Check if the code is running in a distributed setting.
|
|
|
|
Returns:
|
|
bool: True if running in a distributed setting, False otherwise
|
|
"""
|
|
return os.environ.get("WORLD_SIZE", None) is not None
|
|
|
|
|
|
def is_main_process() -> bool:
|
|
"""
|
|
Check if the current process is the main process.
|
|
|
|
Returns:
|
|
bool: True if the current process is the main process, False otherwise.
|
|
"""
|
|
return not is_distributed() or dist.get_rank() == 0
|
|
|
|
|
|
def get_world_size() -> int:
|
|
"""
|
|
Get the number of processes in the distributed setting.
|
|
|
|
Returns:
|
|
int: The number of processes.
|
|
"""
|
|
if is_distributed():
|
|
return dist.get_world_size()
|
|
else:
|
|
return 1
|
|
|
|
|
|
def create_logger(logging_dir: str = None) -> logging.Logger:
|
|
"""
|
|
Create a logger that writes to a log file and stdout. Only the main process logs.
|
|
|
|
Args:
|
|
logging_dir (str): The directory to save the log file.
|
|
|
|
Returns:
|
|
logging.Logger: The logger.
|
|
"""
|
|
if is_main_process():
|
|
additional_args = dict()
|
|
if logging_dir is not None:
|
|
additional_args["handlers"] = [
|
|
logging.StreamHandler(),
|
|
logging.FileHandler(f"{logging_dir}/log.txt"),
|
|
]
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="[\033[34m%(asctime)s\033[0m] %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
**additional_args,
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
if logging_dir is not None:
|
|
logger.info("Experiment directory created at %s", logging_dir)
|
|
else:
|
|
logger = logging.getLogger(__name__)
|
|
logger.addHandler(logging.NullHandler())
|
|
return logger
|
|
|
|
|
|
def log_message(*args, level: str = "info"):
|
|
"""
|
|
Log a message to the logger.
|
|
|
|
Args:
|
|
*args: The message to log.
|
|
level (str): The logging level.
|
|
"""
|
|
logger = logging.getLogger(__name__)
|
|
if level == "info":
|
|
logger.info(*args)
|
|
elif level == "warning":
|
|
logger.warning(*args)
|
|
elif level == "error":
|
|
logger.error(*args)
|
|
elif level == "print":
|
|
print(*args)
|
|
else:
|
|
raise ValueError(f"Invalid logging level: {level}")
|