79 lines
2.9 KiB
Python
79 lines
2.9 KiB
Python
import json
|
|
import os
|
|
import pathlib
|
|
import shutil
|
|
from argparse import ArgumentParser
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = ArgumentParser()
|
|
parser.add_argument("model_folder", type=str)
|
|
args = parser.parse_args()
|
|
|
|
folder = args.model_folder.rstrip(os.sep)
|
|
path = pathlib.Path(folder)
|
|
parent = path.parent
|
|
name = path.name
|
|
folder_extrapolation = os.path.join(parent, f"extrapolation-{name}")
|
|
folder_yarn_4 = os.path.join(parent, f"yarn-4-{name}")
|
|
folder_yarn_8 = os.path.join(parent, f"yarn-8-{name}")
|
|
|
|
if os.path.exists(folder_extrapolation):
|
|
shutil.rmtree(folder_extrapolation)
|
|
if os.path.exists(folder_yarn_4):
|
|
shutil.rmtree(folder_yarn_4)
|
|
if os.path.exists(folder_yarn_8):
|
|
shutil.rmtree(folder_yarn_8)
|
|
|
|
os.makedirs(folder_extrapolation)
|
|
os.makedirs(folder_yarn_4)
|
|
os.makedirs(folder_yarn_8)
|
|
|
|
for name in os.listdir(folder):
|
|
if name == "config.json":
|
|
with open(os.path.join(folder, name), "r", encoding="utf-8") as f:
|
|
config = json.load(f)
|
|
|
|
extrapolation_config = config.copy()
|
|
extrapolation_config["max_position_embeddings"] = extrapolation_config["max_position_embeddings"] * 8
|
|
if "sliding_window" in extrapolation_config and extrapolation_config["sliding_window"] is not None:
|
|
extrapolation_config["sliding_window"] = extrapolation_config["max_position_embeddings"]
|
|
with open(os.path.join(folder_extrapolation, name), "w", encoding="utf-8") as f:
|
|
json.dump(extrapolation_config, f)
|
|
|
|
yarn_4_config = config.copy()
|
|
yarn_4_config["rope_scaling"] = {
|
|
"type": "yarn",
|
|
"factor": 4,
|
|
"original_max_position_embeddings": yarn_4_config["max_position_embeddings"]
|
|
}
|
|
with open(os.path.join(folder_yarn_4, name), "w", encoding="utf-8") as f:
|
|
json.dump(yarn_4_config, f)
|
|
|
|
yarn_8_config = config.copy()
|
|
yarn_8_config["rope_scaling"] = {
|
|
"type": "yarn",
|
|
"factor": 8,
|
|
"original_max_position_embeddings": yarn_8_config["max_position_embeddings"]
|
|
}
|
|
with open(os.path.join(folder_yarn_8, name), "w", encoding="utf-8") as f:
|
|
json.dump(yarn_8_config, f)
|
|
|
|
else:
|
|
src = os.path.join(folder, name)
|
|
|
|
dest = os.path.join(folder_extrapolation, name)
|
|
if os.path.exists(dest):
|
|
os.remove(dest)
|
|
os.symlink(src, dest)
|
|
|
|
dest = os.path.join(folder_yarn_4, name)
|
|
if os.path.exists(dest):
|
|
os.remove(dest)
|
|
os.symlink(src, dest)
|
|
|
|
dest = os.path.join(folder_yarn_8, name)
|
|
if os.path.exists(dest):
|
|
os.remove(dest)
|
|
os.symlink(src, dest)
|