268 lines
10 KiB
Python
268 lines
10 KiB
Python
import re
|
|
import os
|
|
import json
|
|
import torch
|
|
import datasets
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
|
import seaborn as sns
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.colors import LinearSegmentedColormap
|
|
|
|
|
|
from typing import List, Optional
|
|
from tqdm import tqdm
|
|
from fuzzywuzzy import fuzz
|
|
from accelerate import Accelerator
|
|
from transformers import HfArgumentParser
|
|
from transformers.utils import logging
|
|
from dataclasses import dataclass, field, asdict
|
|
|
|
from src import ModelArgs, DefaultDataCollator, FileLogger, get_model_and_tokenizer, makedirs, apply_chat_template
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Args(ModelArgs):
|
|
output_dir: str = field(
|
|
default="data/results/passkey/",
|
|
metadata={'help': 'The base directory for saving results and logs.'}
|
|
)
|
|
result_dir: Optional[str] = field(
|
|
default=None,
|
|
metadata={'help': 'The directory relative to output_dir for saving results.'}
|
|
)
|
|
|
|
min_length: int = field(
|
|
default=8192,
|
|
metadata={'help': 'Minimum context length in evaluation.'}
|
|
)
|
|
max_length: int = field(
|
|
default=131072,
|
|
metadata={'help': 'Maximum context length in evaluation.'}
|
|
)
|
|
num_length_interval: int = field(
|
|
default=20,
|
|
metadata={'help': 'Number of invervals between min_length and max_length.'}
|
|
)
|
|
test_length: List[int] = field(
|
|
default=None,
|
|
metadata={'help': 'Specified evaluation lengths.'}
|
|
)
|
|
|
|
min_depth: float = field(
|
|
default=0,
|
|
metadata={'help': 'Minimum pass key depth in the context.'}
|
|
)
|
|
max_depth: float = field(
|
|
default=100,
|
|
metadata={'help': 'Maximum pass key depth in the context.'}
|
|
)
|
|
num_depth_interval: int = field(
|
|
default=10,
|
|
metadata={'help': 'Number of invervals between min_depth and max_depth.'}
|
|
)
|
|
test_depth: List[int] = field(
|
|
default=None,
|
|
metadata={'help': 'Specified evaluation depths.'}
|
|
)
|
|
|
|
passkey_length: int = field(
|
|
default=5,
|
|
metadata={'help': 'How many numbers are in the passkey?'}
|
|
)
|
|
seed: int = field(
|
|
default=123,
|
|
metadata={'help': 'Random seed.'}
|
|
)
|
|
|
|
do_sample: bool = False
|
|
max_new_tokens: int = 50
|
|
|
|
|
|
def generate_sample(tokenizer, chat_template, context_length, passkey_depth, passkey_length, rng:np.random.Generator=np.random.default_rng(42)):
|
|
passkey = str(rng.integers(10**(passkey_length - 1), 10**passkey_length))
|
|
description = "There is an important infomation hidden in the following context. Find the information and memorize it. I will quiz you about the important information there.\n"
|
|
noises = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." * (context_length // 10)
|
|
information = f"\n\nThe pass key is {passkey}. Remember it. {passkey} is the pass key.\n\n"
|
|
prompt = "\n\nWhat is the pass key?"
|
|
|
|
# these inputs are used only once
|
|
description_input_ids = tokenizer.encode(description, add_special_tokens=False)
|
|
information_input_ids = tokenizer.encode(information, add_special_tokens=False)
|
|
prompt_input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
|
|
|
description_length = len(description_input_ids)
|
|
information_length = len(information_input_ids)
|
|
prompt_length = len(prompt_input_ids)
|
|
|
|
# must leave room for information and prompt
|
|
minimum_pos = description_length
|
|
maximum_pos = context_length - prompt_length - information_length - 1
|
|
if minimum_pos > context_length or maximum_pos < 0:
|
|
raise ValueError(f"The length {context_length} is too small. Please increase interval!")
|
|
|
|
passkey_pos = minimum_pos + round((maximum_pos - minimum_pos) * passkey_depth / 100)
|
|
|
|
# DEBUG
|
|
# information_pos = description_length
|
|
# information_pos = rng.integers(minimum_pos, min(maximum_pos, 1000))
|
|
# information_pos = rng.integers(1024, min(maximum_pos, 2000))
|
|
|
|
prefix_noise = tokenizer.encode(noises, max_length=passkey_pos - description_length, truncation=True, add_special_tokens=False)
|
|
suffix_noise = tokenizer.encode(noises, max_length=context_length - passkey_pos - information_length - prompt_length, truncation=True, add_special_tokens=False)
|
|
|
|
input_ids = sum([description_input_ids, prefix_noise, information_input_ids, suffix_noise, prompt_input_ids], [])
|
|
inputs = tokenizer.decode(input_ids)
|
|
|
|
inputs = apply_chat_template(chat_template, messages=[{'role': 'user', 'content': inputs}], tokenizer=tokenizer, add_generation_prompt=True).raw
|
|
|
|
return inputs, prompt, passkey
|
|
|
|
|
|
@torch.no_grad()
|
|
def main():
|
|
parser = HfArgumentParser([Args])
|
|
args: Args = parser.parse_args_into_dataclasses()[0]
|
|
|
|
accelerator = Accelerator(cpu=args.cpu)
|
|
model, tokenizer = get_model_and_tokenizer(args, device=accelerator.device)
|
|
|
|
if args.test_length is None:
|
|
test_lengths = np.linspace(args.min_length, args.max_length, args.num_length_interval, endpoint=True).astype(int).tolist()
|
|
else:
|
|
test_lengths = args.test_length
|
|
|
|
if args.test_depth is None:
|
|
test_depths = np.linspace(args.min_depth, args.max_depth, args.num_depth_interval, endpoint=True).astype(int).tolist()
|
|
else:
|
|
test_depths = args.test_depth
|
|
|
|
rng_state = np.random.default_rng(args.seed)
|
|
|
|
all_inputs = []
|
|
for length in tqdm(test_lengths, desc="Constructing Data"):
|
|
for depth in test_depths:
|
|
inputs, prompt, passkey = generate_sample(
|
|
tokenizer=tokenizer,
|
|
chat_template=args.chat_template,
|
|
context_length=length,
|
|
passkey_depth=depth,
|
|
passkey_length=args.passkey_length,
|
|
rng=rng_state
|
|
)
|
|
all_inputs.append({'inputs': inputs, 'prompt': prompt, 'passkey': passkey, 'length': length, 'depth': depth})
|
|
|
|
dataset = datasets.Dataset.from_list(all_inputs)
|
|
dataloader = torch.utils.data.DataLoader(
|
|
# length and depth are useless in forward computation
|
|
dataset.remove_columns(['length', 'depth', 'passkey']),
|
|
batch_size=args.batch_size,
|
|
collate_fn=DefaultDataCollator(tokenizer),
|
|
pin_memory=not args.cpu,
|
|
)
|
|
|
|
# NOTE: prepare dataloader so the data moves to GPU automatically
|
|
dataloader = accelerator.prepare(dataloader)
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
all_outputs = []
|
|
|
|
for x in tqdm(dataloader, desc="Evaluating"):
|
|
prompt = x.pop("prompt")
|
|
inputs = x.pop("inputs")
|
|
# TODO: retrieval
|
|
|
|
# NOTE: important to reset memory for every batch
|
|
if hasattr(model, "memory"):
|
|
model.memory.reset()
|
|
|
|
inputs = tokenizer(inputs, return_tensors="pt").to(model.device)
|
|
|
|
output = model.generate(**inputs)
|
|
|
|
if isinstance(output, torch.Tensor):
|
|
# 1, max_new_tokens
|
|
output = output[:, inputs['input_ids'].shape[1]:]
|
|
output = tokenizer.batch_decode(output, skip_special_tokens=True)
|
|
elif isinstance(output, list):
|
|
pass
|
|
|
|
if accelerator.num_processes > 1:
|
|
output = accelerator.gather_for_metrics(output)
|
|
|
|
all_outputs.extend(output)
|
|
|
|
if accelerator.process_index == 0:
|
|
accuracy = {l: {d: [] for d in test_depths} for l in test_lengths}
|
|
fuzzy_score = {l: {d: [] for d in test_depths} for l in test_lengths}
|
|
results = {l: {d: [] for d in test_depths} for l in test_lengths}
|
|
|
|
for l, d, p, o in zip(dataset['length'], dataset['depth'], dataset['passkey'], all_outputs):
|
|
# extract numbers
|
|
o = re.search("\d+", o)
|
|
if o:
|
|
o = o.group()
|
|
else:
|
|
o = ""
|
|
results[l][d].append({'target': p, 'prediction': o})
|
|
|
|
acc = float(p == o)
|
|
score = round(fuzz.ratio(o, p) / 100, 2)
|
|
|
|
accuracy[l][d].append(acc)
|
|
fuzzy_score[l][d].append(score)
|
|
|
|
for l, lv in accuracy.items():
|
|
for d, dv in lv.items():
|
|
accuracy[l][d] = round(sum(dv) / len(dv), 2)
|
|
|
|
for l, lv in fuzzy_score.items():
|
|
for d, dv in lv.items():
|
|
fuzzy_score[l][d] = round(sum(dv) / len(dv), 2)
|
|
|
|
result_dir = os.path.join(args.output_dir, args.result_dir)
|
|
with open(makedirs(os.path.join(result_dir, "results.json")), "w", encoding='utf-8') as f:
|
|
json.dump(results, f)
|
|
# also save config
|
|
args.save(os.path.join(result_dir, "config.json"))
|
|
|
|
metrics = {'accuracy': accuracy, 'fuzz': fuzzy_score}
|
|
file_logger = FileLogger(makedirs(os.path.join(args.output_dir, "metrics.log")))
|
|
file_logger.log(metrics, Args=asdict(args))
|
|
|
|
for metric_key, metric_value in metrics.items():
|
|
# Copied from https://github.com/gkamradt/LLMTest_NeedleInAHaystack/blob/main/viz/CreateVizFromLLMTesting.ipynb
|
|
cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#F0496E", "#EBB839", "#0CD79F"])
|
|
# Create the heatmap with better aesthetics
|
|
sns.set(rc={"figure.figsize": (17.5, 8), "axes.titlesize":14, "axes.labelsize":12}, style="whitegrid", palette="colorblind")
|
|
data = pd.DataFrame(metric_value)
|
|
ax = sns.heatmap(
|
|
data,
|
|
cmap=cmap,
|
|
vmin=0,
|
|
vmax=1,
|
|
fmt="g",
|
|
linewidth=.5,
|
|
)
|
|
cbar = ax.collections[0].colorbar
|
|
cbar.set_label(metric_key, size=14)
|
|
|
|
# More aesthetics
|
|
plt.title('Passkey Retrieval') # Adds a title
|
|
plt.xlabel('Context Length', fontsize=14) # X-axis label
|
|
plt.ylabel('Depth Percent', fontsize=14) # Y-axis label
|
|
plt.xticks(rotation=45, fontsize=10) # Rotates the x-axis labels to prevent overlap
|
|
plt.yticks(rotation=0, fontsize=10) # Ensures the y-axis labels are horizontal
|
|
plt.tight_layout() # Fits everything neatly into the figure area
|
|
# save to result_dir
|
|
plt.savefig(os.path.join(result_dir, f"{metric_key}.png"), format='png', bbox_inches='tight')
|
|
plt.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|