140 lines
4.8 KiB
Python
140 lines
4.8 KiB
Python
"""
|
|
Generate line data for line retrieval task.
|
|
|
|
Usage:
|
|
python3 gen_data.py --number 1000
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
from collections import defaultdict
|
|
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
|
|
def generate_lines(random_words, num_lines, redirect_ratio):
|
|
prefix = "Here is a list of lines, each with its corresponding REGISTER_CONTENT value. Please memorize them. Be prepared to provide the REGISTER_CONTENT value for a specific line index when I ask."
|
|
suffix = "The list has ended. Please give the final REGISTER_CONTENT value for a specific line after resolving the redirections and references. For example, the REGISTER_CONTENT of Line __idx0__ is __val0__. The REGISTER_CONTENT of Line __idx1__ is __val1__. The REGISTER_CONTENT of Line __idx2__ is __val2__. The REGISTER_CONTENT of Line ??? is"
|
|
|
|
# Raw lines
|
|
visited_indices = set([None])
|
|
visited_values = set([None])
|
|
|
|
lines = []
|
|
redirects = []
|
|
indices = []
|
|
values = []
|
|
for i in tqdm(range(num_lines)):
|
|
line_index = None
|
|
while line_index in visited_indices:
|
|
line_index = "-".join(np.random.choice(random_words, size=(2,)))
|
|
visited_indices.add(line_index)
|
|
|
|
line_value = np.random.randint(low=0, high=999999)
|
|
line_value = f"{line_value:06}"
|
|
|
|
line = f"Line {line_index}: The REGISTER_CONTENT is {line_value}."
|
|
lines.append(line)
|
|
redirects.append(None)
|
|
indices.append(line_index)
|
|
values.append(line_value)
|
|
|
|
# Add redirect
|
|
if redirect_ratio > 0:
|
|
num_redirect_lines = int(len(lines) * redirect_ratio)
|
|
redirect_indices = np.random.choice(
|
|
np.arange(len(lines)), size=(num_redirect_lines,), replace=False
|
|
)
|
|
for i in redirect_indices:
|
|
target_idx = np.random.choice(min(i * 2 + 100, num_lines))
|
|
lines[i] = (
|
|
f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
|
|
)
|
|
redirects[i] = target_idx
|
|
|
|
# Build links and find sources
|
|
links = [[] for _ in range(num_lines)]
|
|
contains_ring = set()
|
|
for i in range(num_lines):
|
|
if redirects[i] is None:
|
|
continue
|
|
|
|
tmp_link = []
|
|
cur = i
|
|
visited = set()
|
|
while redirects[cur] is not None:
|
|
visited.add(cur)
|
|
tmp_link.append(redirects[cur])
|
|
cur = redirects[cur]
|
|
|
|
if cur in visited:
|
|
contains_ring.add(i)
|
|
tmp_link = None
|
|
break
|
|
values[i] = values[cur]
|
|
links[i] = tmp_link
|
|
|
|
# Group by num_links
|
|
group_by_num_hoops = defaultdict(list)
|
|
for i in range(num_lines):
|
|
if i in contains_ring:
|
|
continue
|
|
group_by_num_hoops[len(links[i]) + 1].append(i)
|
|
|
|
keys = sorted(list(group_by_num_hoops.keys()))
|
|
for num_links in keys:
|
|
print(f"#links: {num_links}, #lines: {len(group_by_num_hoops[num_links])}")
|
|
|
|
# Append few-shot examples
|
|
hoop1_candidates = list(group_by_num_hoops[1])
|
|
hoop1_candidate_keys = {c: max([c] + links[c]) for c in hoop1_candidates}
|
|
hoop1_candidates.sort(key=lambda c: hoop1_candidate_keys[c])
|
|
hoop2_candidates = list(group_by_num_hoops[2])
|
|
hoop2_candidate_keys = {c: max([c] + links[c]) for c in hoop2_candidates}
|
|
hoop2_candidates.sort(key=lambda c: hoop2_candidate_keys[c])
|
|
|
|
i = hoop1_candidates[5]
|
|
suffix = suffix.replace("__idx0__", indices[i]).replace("__val0__", values[i])
|
|
if len(hoop2_candidates):
|
|
i = hoop2_candidates[0]
|
|
suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
|
|
i = hoop2_candidates[1]
|
|
suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
|
|
else:
|
|
i = hoop1_candidates[1]
|
|
suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
|
|
i = hoop1_candidates[10]
|
|
suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
|
|
|
|
obj = {
|
|
"prefix": prefix,
|
|
"suffix": suffix,
|
|
"lines": lines,
|
|
"indices": indices,
|
|
"values": values,
|
|
"links": links,
|
|
"group_by_num_hoops": group_by_num_hoops,
|
|
"contains_ring": sorted(list(contains_ring)),
|
|
}
|
|
return obj
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--number", type=int)
|
|
parser.add_argument("--redirect-ratio", type=float, default=0.0)
|
|
args = parser.parse_args()
|
|
|
|
num_lines = args.number
|
|
|
|
random_words_filename = "random_words.json"
|
|
random_words = json.load(open(random_words_filename, "r"))
|
|
|
|
np.random.seed(42)
|
|
obj = generate_lines(random_words, num_lines, args.redirect_ratio)
|
|
|
|
fout = f"lines_{num_lines}_{args.redirect_ratio:.1f}.json"
|
|
with open(fout, "w") as fout:
|
|
json.dump(obj, fout, indent=2)
|