embed-bge-m3/FlagEmbedding/research/Reinforced_IR/data_generation/utils.py

474 lines
17 KiB
Python

import re
import random
import os
import faiss
import numpy as np
import pytrec_eval
import torch
import gc
from transformers import AutoModel
from tqdm import trange, tqdm
from typing import List, Dict, Tuple, Union
from agent import GPTAgent, LLMAgent, LLMInstructAgent
def extract_numbers(s):
numbers = re.findall(r'\d+', s)
numbers = [int(num) for num in numbers]
return numbers
def get_distill_data(
llm_for_rank = None,
temperature: float = 0.0,
top_p: float = 1.0,
max_tokens: int = 1024,
train_data: List = None,
prompts: List[str] = None,
):
generated_rank_results = llm_for_rank.generate(
prompts,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
)
for d, res in zip(train_data, generated_rank_results):
res = extract_numbers(res)
passages = []
passages.extend(d['pos'])
passages.extend(d['neg'])
if 0 not in res and len(passages) in res:
res = [e - 1 for e in res]
except_res = [i for i in res if i >= len(passages)]
for e in except_res:
res.remove(e)
if len(res) < len(passages):
print(res)
for i in range(len(passages)):
if i not in res:
res.append(i)
d['pos'] = []
d['neg'] = []
for i in res[:1]:
d['pos'].append(passages[i])
for i in res[1:]:
d['neg'].append(passages[i])
d['pos_scores'] = [1]
d['neg_scores'] = [1 / (i + 1) for i in range(len(res) - 1)]
return train_data
def generate_bge_train_data(
retrieval_model,
batch_size: int = 512,
max_length: int = 512,
queries_corpus: Union[List[dict], List[List[dict]]] = None,
dtype: str = 'passage',
corpus: List[str] = None,
filter_data: bool = False,
filter_num: int = 20,
emb_save_path: str = None,
ignore_prefix: bool = False,
neg_type: str = 'hard'
):
if corpus is None:
corpus = [d[dtype] for d in queries_corpus]
queries = [d['query'] for d in queries_corpus]
answers = [d['answer'] for d in queries_corpus]
queries_emb = retrieval_model.encode_queries(queries, batch_size=batch_size,
max_length=max_length)
# * 0.8 + retrieval_model.encode_corpus(answers, batch_size=batch_size, max_length=max_length) * 0.2
answers_emb = retrieval_model.encode_corpus(answers, batch_size=batch_size, max_length=max_length)
if emb_save_path is not None:
if os.path.exists(emb_save_path):
if ignore_prefix:
doc_emb = np.vstack(
(
retrieval_model.encode_corpus(corpus[: len(queries_emb)], batch_size=batch_size,
max_length=max_length),
np.load(emb_save_path)
)
)
else:
doc_emb = np.load(emb_save_path)
else:
doc_emb = retrieval_model.encode_corpus(corpus, batch_size=batch_size, max_length=max_length)
try:
os.makedirs('/'.join(emb_save_path.split('/')[:-1]), exist_ok=True)
except:
pass
if ignore_prefix:
np.save(emb_save_path, doc_emb[len(queries_emb):])
else:
np.save(emb_save_path, doc_emb)
else:
doc_emb = retrieval_model.encode_corpus(corpus, batch_size=batch_size, max_length=max_length)
print('len doc emb:', len(doc_emb))
all_scores, all_indices = search(queries_emb, doc_emb, 2000)
_, all_answers_indices = search(answers_emb, doc_emb, 2000)
train_data = []
find_idxs = []
for i in range(len(all_indices)):
if i in list(all_indices[i]):
find_idxs.append(list(all_indices[i]).index(i))
else:
find_idxs.append(-1)
print(find_idxs)
answers_find_idxs = []
for i in range(len(all_answers_indices)):
if i in list(all_answers_indices[i]):
answers_find_idxs.append(list(all_answers_indices[i]).index(i))
else:
answers_find_idxs.append(-1)
for i in trange(len(queries), desc='generate train set'):
if find_idxs[i] == -1: # remove false pairs
# continue
# neg_ids = random.sample(list(range(len(corpus))), k=50)
neg_ids = random.sample(list(all_indices[i][30:200]), k=50)
else:
uses_idx = -1
for j in range(find_idxs[i] + 1, 2000):
if all_scores[i][j] <= all_scores[i][find_idxs[i]] * 0.95:
uses_idx = j
break
if uses_idx == -1:
# continue
# neg_ids = random.sample(list(range(len(corpus))), k=50)
neg_ids = random.sample(list(all_indices[i][30:200]), k=50)
else:
neg_ids = list(all_indices[i][uses_idx: uses_idx + 50])
# neg_ids = list(all_indices[i][:50])
if neg_type == 'random':
neg_ids = random.sample(list(range(len(corpus))), k=50)
elif neg_type == 'hard':
# neg_ids = list(all_indices[i][:50])
neg_ids = random.sample(list(all_indices[i][30:200]), k=50)
tmp_ids = [(e, list(all_indices[i]).index(e)) for e in neg_ids]
tmp_ids = sorted(tmp_ids, key=lambda x: x[1])
neg_ids = [e[0] for e in tmp_ids]
else:
tmp_ids = [(e, list(all_indices[i]).index(e)) for e in neg_ids]
tmp_ids = sorted(tmp_ids, key=lambda x: x[1])
neg_ids = [e[0] for e in tmp_ids]
if answers_find_idxs[i] == -1: # remove false pairs
# continue
neg_answers_ids = random.sample(list(range(len(corpus))), k=50)
else:
uses_idx = -1
for j in range(answers_find_idxs[i] + 1, 2000):
if all_scores[i][j] <= all_scores[i][answers_find_idxs[i]] * 0.95:
uses_idx = j
break
if uses_idx == -1:
# continue
neg_answers_ids = random.sample(list(range(len(corpus))), k=50)
else:
neg_answers_ids = list(all_answers_indices[i][uses_idx: uses_idx + 50])
query = queries[i]
answer = answers[i]
pos = [corpus[i]]
negs = [corpus[j] for j in neg_ids]
while pos[0] in negs:
negs.remove(pos[0])
new_negs = []
for e in negs:
if e not in new_negs and len(new_negs) < 15:
new_negs.append(e)
negs = new_negs
negs_answer = [corpus[j] for j in neg_answers_ids]
while pos[0] in negs_answer:
negs_answer.remove(pos[0])
new_negs_answer = []
for e in negs_answer:
if e not in new_negs_answer and len(new_negs_answer) < 15:
new_negs_answer.append(e)
negs_answer = new_negs_answer
train_data.append(
{
'query': query,
'answer': answer,
'pos': pos,
'neg': negs,
'neg_answer': negs_answer
}
)
if filter_data:
print(filter_data)
new_train_data = []
for i in range(len(all_indices)):
if i in list(all_indices[i]):
seached_idx = list(all_indices[i]).index(i)
else:
seached_idx = len(all_indices) + 999
if seached_idx < filter_num:
new_train_data.append(train_data[i])
train_data = new_train_data
print(len(train_data))
return train_data
def generate_llm_dpo_train_data(
queries_corpus_list: List[List[dict]] = None,
search_dtype: str = 'answer',
result_dtype: str = 'passage',
retrieval_model: AutoModel = None,
threshold: float = 0.95,
batch_size: int = 512,
max_length: int = 1024,
use_rule1: bool = True
):
data = []
queries_list = []
corpus = []
raw_queries = []
for qc in queries_corpus_list:
raw_queries = [d['query'] for d in qc]
if 'new_query' in qc[0].keys():
queries_list.append([d['new_query'] for d in qc])
else:
queries_list.append([d[search_dtype] for d in qc])
corpus = [d[result_dtype] for d in qc]
doc_emb = retrieval_model.encode_corpus(corpus, batch_size=batch_size, max_length=max_length)
raw_queries_emb = retrieval_model.encode_queries(raw_queries, batch_size=batch_size, max_length=max_length)
raw_scores = np.einsum('ij,ij->i', raw_queries_emb, doc_emb)
all_scores_list = []
for queries in queries_list:
# queries = ['Generate the topic about this passage: ' + q for q in queries]
queries_emb = raw_queries_emb * 0.8 + retrieval_model.encode_queries(queries, batch_size=batch_size,
max_length=max_length) * 0.2
# queries_emb = raw_queries_emb
all_scores_list.append(np.einsum('ij,ij->i', queries_emb, doc_emb))
for i in range(len(all_scores_list[0])):
raw_score = raw_scores[i]
all_scores = [e[i] for e in all_scores_list]
items = [(idx, all_scores[idx]) for idx in range(len(all_scores))]
sorted_idx = [idx for idx, _ in sorted(items, key=lambda x: x[1], reverse=False)]
min_score = max(all_scores)
for idx in sorted_idx:
if abs(1 - all_scores[idx] / raw_score) < 0.1:
min_score = all_scores[idx]
break
min_score = min(all_scores)
max_score = max(all_scores)
if use_rule1:
if max_score > raw_score and (max_score - raw_score * 0.8) * threshold >= (min_score - raw_score * 0.8):
# print('use')
tmp = {
'prompt': queries_corpus_list[0][i]['query'],
'chosen': queries_corpus_list[all_scores.index(max_score)][i][search_dtype],
'rejected': queries_corpus_list[all_scores.index(min_score)][i][search_dtype],
}
tmp['chosen_score'] = float(max_score / raw_score)
tmp['rejected_score'] = float(min_score / raw_score)
data.append(tmp)
else:
if (max_score - raw_score * 0.8) * threshold >= (min_score - raw_score * 0.8):
# print('use')
tmp = {
'prompt': queries_corpus_list[0][i]['query'],
'chosen': queries_corpus_list[all_scores.index(max_score)][i][search_dtype],
'rejected': queries_corpus_list[all_scores.index(min_score)][i][search_dtype],
}
tmp['chosen_score'] = float(max_score / raw_score)
tmp['rejected_score'] = float(min_score / raw_score)
data.append(tmp)
return data
def evaluate_mrr(qrels: Dict[str, Dict[str, int]],
results: Dict[str, Dict[str, float]],
k_values: List[int]) -> Tuple[Dict[str, float]]:
MRR = {}
for k in k_values:
MRR[f"MRR@{k}"] = 0.0
k_max, top_hits = max(k_values), {}
for query_id, doc_scores in results.items():
top_hits[query_id] = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]
for query_id in top_hits:
query_relevant_docs = set([doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0])
for k in k_values:
for rank, hit in enumerate(top_hits[query_id][0:k]):
if hit[0] in query_relevant_docs:
MRR[f"MRR@{k}"] += 1.0 / (rank + 1)
break
for k in k_values:
MRR[f"MRR@{k}"] = round(MRR[f"MRR@{k}"] / len(qrels), 5)
return MRR
def search(queries_emb, doc_emb, topk: int = 100):
gc.collect()
torch.cuda.empty_cache()
faiss_index = faiss.index_factory(doc_emb.shape[1], 'Flat', faiss.METRIC_INNER_PRODUCT)
co = faiss.GpuMultipleClonerOptions()
co.shard = True
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
doc_emb = doc_emb.astype(np.float32)
faiss_index.train(doc_emb)
faiss_index.add(doc_emb)
dev_query_size = queries_emb.shape[0]
all_scores = []
all_indices = []
for i in tqdm(range(0, dev_query_size, 32), desc="Searching"):
j = min(i + 32, dev_query_size)
query_embedding = queries_emb[i: j]
score, indice = faiss_index.search(query_embedding.astype(np.float32), k=topk)
all_scores.append(score)
all_indices.append(indice)
all_scores = np.concatenate(all_scores, axis=0)
all_indices = np.concatenate(all_indices, axis=0)
return all_scores, all_indices
def evaluate(metrics: List[str] = ['recall', 'mrr', 'ndcg'],
k_values: List[int] = [1, 10],
ground_truths: List[Dict] = None,
predicts: List = None,
scores: List = None):
retrieval_results = {}
for i in range(len(predicts)):
tmp = {}
for j in range(len(predicts[0])):
tmp[str(predicts[i][j])] = float(scores[i][j])
retrieval_results[str(i)] = tmp
ndcg = {}
_map = {}
recall = {}
precision = {}
for k in k_values:
ndcg[f"NDCG@{k}"] = 0.0
_map[f"MAP@{k}"] = 0.0
recall[f"Recall@{k}"] = 0.0
precision[f"Precision@{k}"] = 0.0
map_string = "map_cut." + ",".join([str(k) for k in k_values])
ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
recall_string = "recall." + ",".join([str(k) for k in k_values])
precision_string = "P." + ",".join([str(k) for k in k_values])
evaluator = pytrec_eval.RelevanceEvaluator(ground_truths,
{map_string, ndcg_string, recall_string, precision_string})
scores = evaluator.evaluate(retrieval_results)
for query_id in scores.keys():
for k in k_values:
ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)]
_map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)]
recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)]
precision[f"Precision@{k}"] += scores[query_id]["P_" + str(k)]
for k in k_values:
ndcg[f"NDCG@{k}"] = round(ndcg[f"NDCG@{k}"] / len(scores), 5)
_map[f"MAP@{k}"] = round(_map[f"MAP@{k}"] / len(scores), 5)
recall[f"Recall@{k}"] = round(recall[f"Recall@{k}"] / len(scores), 5)
precision[f"Precision@{k}"] = round(precision[f"Precision@{k}"] / len(scores), 5)
mrr = evaluate_mrr(ground_truths, retrieval_results, k_values)
data = {}
if 'mrr' in metrics:
data['mrr'] = mrr
if 'recall' in metrics:
data['recall'] = recall
if 'ndcg' in metrics:
data['ndcg'] = ndcg
if 'map' in metrics:
data['map'] = _map
if 'precision' in metrics:
data['precision'] = precision
return data
def evaluate_better(metrics: List[str] = ['recall', 'mrr', 'ndcg'],
k_values: List[int] = [1, 10],
ground_truths: List[Dict] = None,
retrieval_results: List[Dict] = None):
ndcg = {}
_map = {}
recall = {}
precision = {}
for k in k_values:
ndcg[f"NDCG@{k}"] = 0.0
_map[f"MAP@{k}"] = 0.0
recall[f"Recall@{k}"] = 0.0
precision[f"Precision@{k}"] = 0.0
map_string = "map_cut." + ",".join([str(k) for k in k_values])
ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
recall_string = "recall." + ",".join([str(k) for k in k_values])
precision_string = "P." + ",".join([str(k) for k in k_values])
evaluator = pytrec_eval.RelevanceEvaluator(ground_truths,
{map_string, ndcg_string, recall_string, precision_string})
scores = evaluator.evaluate(retrieval_results)
for query_id in scores.keys():
for k in k_values:
ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)]
_map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)]
recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)]
precision[f"Precision@{k}"] += scores[query_id]["P_" + str(k)]
for k in k_values:
ndcg[f"NDCG@{k}"] = round(ndcg[f"NDCG@{k}"] / len(scores), 5)
_map[f"MAP@{k}"] = round(_map[f"MAP@{k}"] / len(scores), 5)
recall[f"Recall@{k}"] = round(recall[f"Recall@{k}"] / len(scores), 5)
precision[f"Precision@{k}"] = round(precision[f"Precision@{k}"] / len(scores), 5)
mrr = evaluate_mrr(ground_truths, retrieval_results, k_values)
data = {}
if 'mrr' in metrics:
data['mrr'] = mrr
if 'recall' in metrics:
data['recall'] = recall
if 'ndcg' in metrics:
data['ndcg'] = ndcg
if 'map' in metrics:
data['map'] = _map
if 'precision' in metrics:
data['precision'] = precision
return data