embed-bge-m3/FlagEmbedding/research/Long_LLM/activation_beacon/main/vllm_symlink.py

79 lines
2.9 KiB
Python

import json
import os
import pathlib
import shutil
from argparse import ArgumentParser
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("model_folder", type=str)
args = parser.parse_args()
folder = args.model_folder.rstrip(os.sep)
path = pathlib.Path(folder)
parent = path.parent
name = path.name
folder_extrapolation = os.path.join(parent, f"extrapolation-{name}")
folder_yarn_4 = os.path.join(parent, f"yarn-4-{name}")
folder_yarn_8 = os.path.join(parent, f"yarn-8-{name}")
if os.path.exists(folder_extrapolation):
shutil.rmtree(folder_extrapolation)
if os.path.exists(folder_yarn_4):
shutil.rmtree(folder_yarn_4)
if os.path.exists(folder_yarn_8):
shutil.rmtree(folder_yarn_8)
os.makedirs(folder_extrapolation)
os.makedirs(folder_yarn_4)
os.makedirs(folder_yarn_8)
for name in os.listdir(folder):
if name == "config.json":
with open(os.path.join(folder, name), "r", encoding="utf-8") as f:
config = json.load(f)
extrapolation_config = config.copy()
extrapolation_config["max_position_embeddings"] = extrapolation_config["max_position_embeddings"] * 8
if "sliding_window" in extrapolation_config and extrapolation_config["sliding_window"] is not None:
extrapolation_config["sliding_window"] = extrapolation_config["max_position_embeddings"]
with open(os.path.join(folder_extrapolation, name), "w", encoding="utf-8") as f:
json.dump(extrapolation_config, f)
yarn_4_config = config.copy()
yarn_4_config["rope_scaling"] = {
"type": "yarn",
"factor": 4,
"original_max_position_embeddings": yarn_4_config["max_position_embeddings"]
}
with open(os.path.join(folder_yarn_4, name), "w", encoding="utf-8") as f:
json.dump(yarn_4_config, f)
yarn_8_config = config.copy()
yarn_8_config["rope_scaling"] = {
"type": "yarn",
"factor": 8,
"original_max_position_embeddings": yarn_8_config["max_position_embeddings"]
}
with open(os.path.join(folder_yarn_8, name), "w", encoding="utf-8") as f:
json.dump(yarn_8_config, f)
else:
src = os.path.join(folder, name)
dest = os.path.join(folder_extrapolation, name)
if os.path.exists(dest):
os.remove(dest)
os.symlink(src, dest)
dest = os.path.join(folder_yarn_4, name)
if os.path.exists(dest):
os.remove(dest)
os.symlink(src, dest)
dest = os.path.join(folder_yarn_8, name)
if os.path.exists(dest):
os.remove(dest)
os.symlink(src, dest)