""" Generate line data for line retrieval task. Usage: python3 gen_data.py --number 1000 """ import argparse import json from collections import defaultdict import numpy as np from tqdm import tqdm def generate_lines(random_words, num_lines, redirect_ratio): prefix = "Here is a list of lines, each with its corresponding REGISTER_CONTENT value. Please memorize them. Be prepared to provide the REGISTER_CONTENT value for a specific line index when I ask." suffix = "The list has ended. Please give the final REGISTER_CONTENT value for a specific line after resolving the redirections and references. For example, the REGISTER_CONTENT of Line __idx0__ is __val0__. The REGISTER_CONTENT of Line __idx1__ is __val1__. The REGISTER_CONTENT of Line __idx2__ is __val2__. The REGISTER_CONTENT of Line ??? is" # Raw lines visited_indices = set([None]) visited_values = set([None]) lines = [] redirects = [] indices = [] values = [] for i in tqdm(range(num_lines)): line_index = None while line_index in visited_indices: line_index = "-".join(np.random.choice(random_words, size=(2,))) visited_indices.add(line_index) line_value = np.random.randint(low=0, high=999999) line_value = f"{line_value:06}" line = f"Line {line_index}: The REGISTER_CONTENT is {line_value}." lines.append(line) redirects.append(None) indices.append(line_index) values.append(line_value) # Add redirect if redirect_ratio > 0: num_redirect_lines = int(len(lines) * redirect_ratio) redirect_indices = np.random.choice( np.arange(len(lines)), size=(num_redirect_lines,), replace=False ) for i in redirect_indices: target_idx = np.random.choice(min(i * 2 + 100, num_lines)) lines[i] = ( f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}." ) redirects[i] = target_idx # Build links and find sources links = [[] for _ in range(num_lines)] contains_ring = set() for i in range(num_lines): if redirects[i] is None: continue tmp_link = [] cur = i visited = set() while redirects[cur] is not None: visited.add(cur) tmp_link.append(redirects[cur]) cur = redirects[cur] if cur in visited: contains_ring.add(i) tmp_link = None break values[i] = values[cur] links[i] = tmp_link # Group by num_links group_by_num_hoops = defaultdict(list) for i in range(num_lines): if i in contains_ring: continue group_by_num_hoops[len(links[i]) + 1].append(i) keys = sorted(list(group_by_num_hoops.keys())) for num_links in keys: print(f"#links: {num_links}, #lines: {len(group_by_num_hoops[num_links])}") # Append few-shot examples hoop1_candidates = list(group_by_num_hoops[1]) hoop1_candidate_keys = {c: max([c] + links[c]) for c in hoop1_candidates} hoop1_candidates.sort(key=lambda c: hoop1_candidate_keys[c]) hoop2_candidates = list(group_by_num_hoops[2]) hoop2_candidate_keys = {c: max([c] + links[c]) for c in hoop2_candidates} hoop2_candidates.sort(key=lambda c: hoop2_candidate_keys[c]) i = hoop1_candidates[5] suffix = suffix.replace("__idx0__", indices[i]).replace("__val0__", values[i]) if len(hoop2_candidates): i = hoop2_candidates[0] suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i]) i = hoop2_candidates[1] suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i]) else: i = hoop1_candidates[1] suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i]) i = hoop1_candidates[10] suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i]) obj = { "prefix": prefix, "suffix": suffix, "lines": lines, "indices": indices, "values": values, "links": links, "group_by_num_hoops": group_by_num_hoops, "contains_ring": sorted(list(contains_ring)), } return obj if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--number", type=int) parser.add_argument("--redirect-ratio", type=float, default=0.0) args = parser.parse_args() num_lines = args.number random_words_filename = "random_words.json" random_words = json.load(open(random_words_filename, "r")) np.random.seed(42) obj = generate_lines(random_words, num_lines, args.redirect_ratio) fout = f"lines_{num_lines}_{args.redirect_ratio:.1f}.json" with open(fout, "w") as fout: json.dump(obj, fout, indent=2)