mysora/opensora/utils/config.py

214 lines
6.2 KiB
Python

import argparse
import ast
import json
import os
from datetime import datetime
import torch
from mmengine.config import Config
from .logger import is_distributed, is_main_process
def parse_args() -> tuple[str, argparse.Namespace]:
"""
This function parses the command line arguments.
Returns:
tuple[str, argparse.Namespace]: The path to the configuration file and the command line arguments.
"""
parser = argparse.ArgumentParser()
parser.add_argument("config", type=str, help="model config file path")
args, unknown_args = parser.parse_known_args()
return args.config, unknown_args
def read_config(config_path: str) -> Config:
"""
This function reads the configuration file.
Args:
config_path (str): The path to the configuration file.
Returns:
Config: The configuration object.
"""
cfg = Config.fromfile(config_path)
return cfg
def parse_configs() -> Config:
"""
This function parses the configuration file and command line arguments.
Returns:
Config: The configuration object.
"""
config, args = parse_args()
cfg = read_config(config)
cfg = merge_args(cfg, args)
cfg.config_path = config
# hard-coded for spatial compression
if cfg.get("ae_spatial_compression", None) is not None:
os.environ["AE_SPATIAL_COMPRESSION"] = str(cfg.ae_spatial_compression)
return cfg
def merge_args(cfg: Config, args: argparse.Namespace) -> Config:
"""
This function merges the configuration file and command line arguments.
Args:
cfg (Config): The configuration object.
args (argparse.Namespace): The command line arguments.
Returns:
Config: The configuration object.
"""
for k, v in zip(args[::2], args[1::2]):
assert k.startswith("--"), f"Invalid argument: {k}"
k = k[2:].replace("-", "_")
k_split = k.split(".")
target = cfg
for key in k_split[:-1]:
assert key in cfg, f"Key {key} not found in config"
target = target[key]
if v.lower() == "none":
v = None
elif k in target:
v_type = type(target[k])
if v_type == bool:
v = auto_convert(v)
else:
v = type(target[k])(v)
else:
v = auto_convert(v)
target[k_split[-1]] = v
return cfg
def auto_convert(value: str) -> int | float | bool | list | dict | None:
"""
Automatically convert a string to the appropriate Python data type,
including int, float, bool, list, dict, etc.
Args:
value (str): The string to convert.
Returns:
int, float, bool, list | dict: The converted value.
"""
# Handle empty string
if value == "":
return value
# Handle None
if value.lower() == "none":
return None
# Handle boolean values
lower_value = value.lower()
if lower_value == "true":
return True
elif lower_value == "false":
return False
# Try to convert the string to an integer or float
try:
# Try converting to an integer
return int(value)
except ValueError:
pass
try:
# Try converting to a float
return float(value)
except ValueError:
pass
# Try to convert the string to a list, dict, tuple, etc.
try:
return ast.literal_eval(value)
except (ValueError, SyntaxError):
pass
# If all attempts fail, return the original string
return value
def sync_string(value: str):
"""
This function synchronizes a string across all processes.
"""
if not is_distributed():
return value
bytes_value = value.encode("utf-8")
max_len = 256
bytes_tensor = torch.zeros(max_len, dtype=torch.uint8).cuda()
bytes_tensor[: len(bytes_value)] = torch.tensor(
list(bytes_value), dtype=torch.uint8
)
torch.distributed.broadcast(bytes_tensor, 0)
synced_value = bytes_tensor.cpu().numpy().tobytes().decode("utf-8").rstrip("\x00")
return synced_value
def create_experiment_workspace(
output_dir: str, model_name: str = None, config: dict = None, exp_name: str = None
) -> tuple[str, str]:
"""
This function creates a folder for experiment tracking.
Args:
output_dir: The path to the output directory.
model_name: The name of the model.
exp_name: The given name of the experiment, if None will use default.
Returns:
tuple[str, str]: The experiment name and the experiment directory.
"""
if exp_name is None:
# Make outputs folder (holds all experiment subfolders)
experiment_index = datetime.now().strftime("%y%m%d_%H%M%S")
experiment_index = sync_string(experiment_index)
# Create an experiment folder
model_name = (
"-" + model_name.replace("/", "-") if model_name is not None else ""
)
exp_name = f"{experiment_index}{model_name}"
exp_dir = f"{output_dir}/{exp_name}"
if is_main_process():
os.makedirs(exp_dir, exist_ok=True)
# Save the config
with open(f"{exp_dir}/config.txt", "w", encoding="utf-8") as f:
json.dump(config, f, indent=4)
return exp_name, exp_dir
def config_to_name(cfg: Config) -> str:
filename = cfg._filename
filename = filename.replace("configs/", "")
filename = filename.replace(".py", "")
filename = filename.replace("/", "_")
return filename
def parse_alias(cfg: Config) -> Config:
if cfg.get("resolution", None) is not None:
cfg.sampling_option.resolution = cfg.resolution
if cfg.get("guidance", None) is not None:
cfg.sampling_option.guidance = float(cfg.guidance)
if cfg.get("guidance_img", None) is not None:
cfg.sampling_option.guidance_img = float(cfg.guidance_img)
if cfg.get("num_steps", None) is not None:
cfg.sampling_option.num_steps = int(cfg.num_steps)
if cfg.get("num_frames", None) is not None:
cfg.sampling_option.num_frames = int(cfg.num_frames)
if cfg.get("aspect_ratio", None) is not None:
cfg.sampling_option.aspect_ratio = cfg.aspect_ratio
if cfg.get("ckpt_path", None) is not None:
cfg.model.from_pretrained = cfg.ckpt_path
return cfg