embed-bge-m3/FlagEmbedding/research/Long_LLM/activation_beacon/main/eval_passkey.py

268 lines
10 KiB
Python

import re
import os
import json
import torch
import datasets
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from typing import List, Optional
from tqdm import tqdm
from fuzzywuzzy import fuzz
from accelerate import Accelerator
from transformers import HfArgumentParser
from transformers.utils import logging
from dataclasses import dataclass, field, asdict
from src import ModelArgs, DefaultDataCollator, FileLogger, get_model_and_tokenizer, makedirs, apply_chat_template
logger = logging.get_logger(__name__)
@dataclass
class Args(ModelArgs):
output_dir: str = field(
default="data/results/passkey/",
metadata={'help': 'The base directory for saving results and logs.'}
)
result_dir: Optional[str] = field(
default=None,
metadata={'help': 'The directory relative to output_dir for saving results.'}
)
min_length: int = field(
default=8192,
metadata={'help': 'Minimum context length in evaluation.'}
)
max_length: int = field(
default=131072,
metadata={'help': 'Maximum context length in evaluation.'}
)
num_length_interval: int = field(
default=20,
metadata={'help': 'Number of invervals between min_length and max_length.'}
)
test_length: List[int] = field(
default=None,
metadata={'help': 'Specified evaluation lengths.'}
)
min_depth: float = field(
default=0,
metadata={'help': 'Minimum pass key depth in the context.'}
)
max_depth: float = field(
default=100,
metadata={'help': 'Maximum pass key depth in the context.'}
)
num_depth_interval: int = field(
default=10,
metadata={'help': 'Number of invervals between min_depth and max_depth.'}
)
test_depth: List[int] = field(
default=None,
metadata={'help': 'Specified evaluation depths.'}
)
passkey_length: int = field(
default=5,
metadata={'help': 'How many numbers are in the passkey?'}
)
seed: int = field(
default=123,
metadata={'help': 'Random seed.'}
)
do_sample: bool = False
max_new_tokens: int = 50
def generate_sample(tokenizer, chat_template, context_length, passkey_depth, passkey_length, rng:np.random.Generator=np.random.default_rng(42)):
passkey = str(rng.integers(10**(passkey_length - 1), 10**passkey_length))
description = "There is an important infomation hidden in the following context. Find the information and memorize it. I will quiz you about the important information there.\n"
noises = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." * (context_length // 10)
information = f"\n\nThe pass key is {passkey}. Remember it. {passkey} is the pass key.\n\n"
prompt = "\n\nWhat is the pass key?"
# these inputs are used only once
description_input_ids = tokenizer.encode(description, add_special_tokens=False)
information_input_ids = tokenizer.encode(information, add_special_tokens=False)
prompt_input_ids = tokenizer.encode(prompt, add_special_tokens=False)
description_length = len(description_input_ids)
information_length = len(information_input_ids)
prompt_length = len(prompt_input_ids)
# must leave room for information and prompt
minimum_pos = description_length
maximum_pos = context_length - prompt_length - information_length - 1
if minimum_pos > context_length or maximum_pos < 0:
raise ValueError(f"The length {context_length} is too small. Please increase interval!")
passkey_pos = minimum_pos + round((maximum_pos - minimum_pos) * passkey_depth / 100)
# DEBUG
# information_pos = description_length
# information_pos = rng.integers(minimum_pos, min(maximum_pos, 1000))
# information_pos = rng.integers(1024, min(maximum_pos, 2000))
prefix_noise = tokenizer.encode(noises, max_length=passkey_pos - description_length, truncation=True, add_special_tokens=False)
suffix_noise = tokenizer.encode(noises, max_length=context_length - passkey_pos - information_length - prompt_length, truncation=True, add_special_tokens=False)
input_ids = sum([description_input_ids, prefix_noise, information_input_ids, suffix_noise, prompt_input_ids], [])
inputs = tokenizer.decode(input_ids)
inputs = apply_chat_template(chat_template, messages=[{'role': 'user', 'content': inputs}], tokenizer=tokenizer, add_generation_prompt=True).raw
return inputs, prompt, passkey
@torch.no_grad()
def main():
parser = HfArgumentParser([Args])
args: Args = parser.parse_args_into_dataclasses()[0]
accelerator = Accelerator(cpu=args.cpu)
model, tokenizer = get_model_and_tokenizer(args, device=accelerator.device)
if args.test_length is None:
test_lengths = np.linspace(args.min_length, args.max_length, args.num_length_interval, endpoint=True).astype(int).tolist()
else:
test_lengths = args.test_length
if args.test_depth is None:
test_depths = np.linspace(args.min_depth, args.max_depth, args.num_depth_interval, endpoint=True).astype(int).tolist()
else:
test_depths = args.test_depth
rng_state = np.random.default_rng(args.seed)
all_inputs = []
for length in tqdm(test_lengths, desc="Constructing Data"):
for depth in test_depths:
inputs, prompt, passkey = generate_sample(
tokenizer=tokenizer,
chat_template=args.chat_template,
context_length=length,
passkey_depth=depth,
passkey_length=args.passkey_length,
rng=rng_state
)
all_inputs.append({'inputs': inputs, 'prompt': prompt, 'passkey': passkey, 'length': length, 'depth': depth})
dataset = datasets.Dataset.from_list(all_inputs)
dataloader = torch.utils.data.DataLoader(
# length and depth are useless in forward computation
dataset.remove_columns(['length', 'depth', 'passkey']),
batch_size=args.batch_size,
collate_fn=DefaultDataCollator(tokenizer),
pin_memory=not args.cpu,
)
# NOTE: prepare dataloader so the data moves to GPU automatically
dataloader = accelerator.prepare(dataloader)
accelerator.wait_for_everyone()
all_outputs = []
for x in tqdm(dataloader, desc="Evaluating"):
prompt = x.pop("prompt")
inputs = x.pop("inputs")
# TODO: retrieval
# NOTE: important to reset memory for every batch
if hasattr(model, "memory"):
model.memory.reset()
inputs = tokenizer(inputs, return_tensors="pt").to(model.device)
output = model.generate(**inputs)
if isinstance(output, torch.Tensor):
# 1, max_new_tokens
output = output[:, inputs['input_ids'].shape[1]:]
output = tokenizer.batch_decode(output, skip_special_tokens=True)
elif isinstance(output, list):
pass
if accelerator.num_processes > 1:
output = accelerator.gather_for_metrics(output)
all_outputs.extend(output)
if accelerator.process_index == 0:
accuracy = {l: {d: [] for d in test_depths} for l in test_lengths}
fuzzy_score = {l: {d: [] for d in test_depths} for l in test_lengths}
results = {l: {d: [] for d in test_depths} for l in test_lengths}
for l, d, p, o in zip(dataset['length'], dataset['depth'], dataset['passkey'], all_outputs):
# extract numbers
o = re.search("\d+", o)
if o:
o = o.group()
else:
o = ""
results[l][d].append({'target': p, 'prediction': o})
acc = float(p == o)
score = round(fuzz.ratio(o, p) / 100, 2)
accuracy[l][d].append(acc)
fuzzy_score[l][d].append(score)
for l, lv in accuracy.items():
for d, dv in lv.items():
accuracy[l][d] = round(sum(dv) / len(dv), 2)
for l, lv in fuzzy_score.items():
for d, dv in lv.items():
fuzzy_score[l][d] = round(sum(dv) / len(dv), 2)
result_dir = os.path.join(args.output_dir, args.result_dir)
with open(makedirs(os.path.join(result_dir, "results.json")), "w", encoding='utf-8') as f:
json.dump(results, f)
# also save config
args.save(os.path.join(result_dir, "config.json"))
metrics = {'accuracy': accuracy, 'fuzz': fuzzy_score}
file_logger = FileLogger(makedirs(os.path.join(args.output_dir, "metrics.log")))
file_logger.log(metrics, Args=asdict(args))
for metric_key, metric_value in metrics.items():
# Copied from https://github.com/gkamradt/LLMTest_NeedleInAHaystack/blob/main/viz/CreateVizFromLLMTesting.ipynb
cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#F0496E", "#EBB839", "#0CD79F"])
# Create the heatmap with better aesthetics
sns.set(rc={"figure.figsize": (17.5, 8), "axes.titlesize":14, "axes.labelsize":12}, style="whitegrid", palette="colorblind")
data = pd.DataFrame(metric_value)
ax = sns.heatmap(
data,
cmap=cmap,
vmin=0,
vmax=1,
fmt="g",
linewidth=.5,
)
cbar = ax.collections[0].colorbar
cbar.set_label(metric_key, size=14)
# More aesthetics
plt.title('Passkey Retrieval') # Adds a title
plt.xlabel('Context Length', fontsize=14) # X-axis label
plt.ylabel('Depth Percent', fontsize=14) # Y-axis label
plt.xticks(rotation=45, fontsize=10) # Rotates the x-axis labels to prevent overlap
plt.yticks(rotation=0, fontsize=10) # Ensures the y-axis labels are horizontal
plt.tight_layout() # Fits everything neatly into the figure area
# save to result_dir
plt.savefig(os.path.join(result_dir, f"{metric_key}.png"), format='png', bbox_inches='tight')
plt.close()
if __name__ == "__main__":
main()