474 lines
17 KiB
Python
474 lines
17 KiB
Python
import re
|
|
import random
|
|
import os
|
|
|
|
import faiss
|
|
import numpy as np
|
|
import pytrec_eval
|
|
import torch
|
|
import gc
|
|
|
|
from transformers import AutoModel
|
|
from tqdm import trange, tqdm
|
|
from typing import List, Dict, Tuple, Union
|
|
|
|
from agent import GPTAgent, LLMAgent, LLMInstructAgent
|
|
|
|
def extract_numbers(s):
|
|
numbers = re.findall(r'\d+', s)
|
|
numbers = [int(num) for num in numbers]
|
|
return numbers
|
|
|
|
def get_distill_data(
|
|
llm_for_rank = None,
|
|
temperature: float = 0.0,
|
|
top_p: float = 1.0,
|
|
max_tokens: int = 1024,
|
|
train_data: List = None,
|
|
prompts: List[str] = None,
|
|
):
|
|
generated_rank_results = llm_for_rank.generate(
|
|
prompts,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
max_tokens=max_tokens,
|
|
)
|
|
|
|
for d, res in zip(train_data, generated_rank_results):
|
|
res = extract_numbers(res)
|
|
passages = []
|
|
passages.extend(d['pos'])
|
|
passages.extend(d['neg'])
|
|
if 0 not in res and len(passages) in res:
|
|
res = [e - 1 for e in res]
|
|
except_res = [i for i in res if i >= len(passages)]
|
|
for e in except_res:
|
|
res.remove(e)
|
|
if len(res) < len(passages):
|
|
print(res)
|
|
for i in range(len(passages)):
|
|
if i not in res:
|
|
res.append(i)
|
|
|
|
d['pos'] = []
|
|
d['neg'] = []
|
|
for i in res[:1]:
|
|
d['pos'].append(passages[i])
|
|
for i in res[1:]:
|
|
d['neg'].append(passages[i])
|
|
d['pos_scores'] = [1]
|
|
d['neg_scores'] = [1 / (i + 1) for i in range(len(res) - 1)]
|
|
|
|
return train_data
|
|
|
|
|
|
def generate_bge_train_data(
|
|
retrieval_model,
|
|
batch_size: int = 512,
|
|
max_length: int = 512,
|
|
queries_corpus: Union[List[dict], List[List[dict]]] = None,
|
|
dtype: str = 'passage',
|
|
corpus: List[str] = None,
|
|
filter_data: bool = False,
|
|
filter_num: int = 20,
|
|
emb_save_path: str = None,
|
|
ignore_prefix: bool = False,
|
|
neg_type: str = 'hard'
|
|
):
|
|
if corpus is None:
|
|
corpus = [d[dtype] for d in queries_corpus]
|
|
queries = [d['query'] for d in queries_corpus]
|
|
answers = [d['answer'] for d in queries_corpus]
|
|
|
|
queries_emb = retrieval_model.encode_queries(queries, batch_size=batch_size,
|
|
max_length=max_length)
|
|
# * 0.8 + retrieval_model.encode_corpus(answers, batch_size=batch_size, max_length=max_length) * 0.2
|
|
answers_emb = retrieval_model.encode_corpus(answers, batch_size=batch_size, max_length=max_length)
|
|
if emb_save_path is not None:
|
|
if os.path.exists(emb_save_path):
|
|
if ignore_prefix:
|
|
doc_emb = np.vstack(
|
|
(
|
|
retrieval_model.encode_corpus(corpus[: len(queries_emb)], batch_size=batch_size,
|
|
max_length=max_length),
|
|
np.load(emb_save_path)
|
|
)
|
|
)
|
|
else:
|
|
doc_emb = np.load(emb_save_path)
|
|
else:
|
|
doc_emb = retrieval_model.encode_corpus(corpus, batch_size=batch_size, max_length=max_length)
|
|
try:
|
|
os.makedirs('/'.join(emb_save_path.split('/')[:-1]), exist_ok=True)
|
|
except:
|
|
pass
|
|
if ignore_prefix:
|
|
np.save(emb_save_path, doc_emb[len(queries_emb):])
|
|
else:
|
|
np.save(emb_save_path, doc_emb)
|
|
else:
|
|
doc_emb = retrieval_model.encode_corpus(corpus, batch_size=batch_size, max_length=max_length)
|
|
|
|
print('len doc emb:', len(doc_emb))
|
|
|
|
all_scores, all_indices = search(queries_emb, doc_emb, 2000)
|
|
_, all_answers_indices = search(answers_emb, doc_emb, 2000)
|
|
|
|
train_data = []
|
|
|
|
find_idxs = []
|
|
for i in range(len(all_indices)):
|
|
if i in list(all_indices[i]):
|
|
find_idxs.append(list(all_indices[i]).index(i))
|
|
else:
|
|
find_idxs.append(-1)
|
|
print(find_idxs)
|
|
|
|
answers_find_idxs = []
|
|
for i in range(len(all_answers_indices)):
|
|
if i in list(all_answers_indices[i]):
|
|
answers_find_idxs.append(list(all_answers_indices[i]).index(i))
|
|
else:
|
|
answers_find_idxs.append(-1)
|
|
|
|
for i in trange(len(queries), desc='generate train set'):
|
|
|
|
if find_idxs[i] == -1: # remove false pairs
|
|
# continue
|
|
# neg_ids = random.sample(list(range(len(corpus))), k=50)
|
|
neg_ids = random.sample(list(all_indices[i][30:200]), k=50)
|
|
else:
|
|
uses_idx = -1
|
|
for j in range(find_idxs[i] + 1, 2000):
|
|
if all_scores[i][j] <= all_scores[i][find_idxs[i]] * 0.95:
|
|
uses_idx = j
|
|
break
|
|
if uses_idx == -1:
|
|
# continue
|
|
# neg_ids = random.sample(list(range(len(corpus))), k=50)
|
|
neg_ids = random.sample(list(all_indices[i][30:200]), k=50)
|
|
else:
|
|
neg_ids = list(all_indices[i][uses_idx: uses_idx + 50])
|
|
# neg_ids = list(all_indices[i][:50])
|
|
if neg_type == 'random':
|
|
neg_ids = random.sample(list(range(len(corpus))), k=50)
|
|
elif neg_type == 'hard':
|
|
# neg_ids = list(all_indices[i][:50])
|
|
neg_ids = random.sample(list(all_indices[i][30:200]), k=50)
|
|
tmp_ids = [(e, list(all_indices[i]).index(e)) for e in neg_ids]
|
|
tmp_ids = sorted(tmp_ids, key=lambda x: x[1])
|
|
neg_ids = [e[0] for e in tmp_ids]
|
|
else:
|
|
tmp_ids = [(e, list(all_indices[i]).index(e)) for e in neg_ids]
|
|
tmp_ids = sorted(tmp_ids, key=lambda x: x[1])
|
|
neg_ids = [e[0] for e in tmp_ids]
|
|
|
|
if answers_find_idxs[i] == -1: # remove false pairs
|
|
# continue
|
|
neg_answers_ids = random.sample(list(range(len(corpus))), k=50)
|
|
else:
|
|
uses_idx = -1
|
|
for j in range(answers_find_idxs[i] + 1, 2000):
|
|
if all_scores[i][j] <= all_scores[i][answers_find_idxs[i]] * 0.95:
|
|
uses_idx = j
|
|
break
|
|
if uses_idx == -1:
|
|
# continue
|
|
neg_answers_ids = random.sample(list(range(len(corpus))), k=50)
|
|
else:
|
|
neg_answers_ids = list(all_answers_indices[i][uses_idx: uses_idx + 50])
|
|
|
|
query = queries[i]
|
|
answer = answers[i]
|
|
pos = [corpus[i]]
|
|
negs = [corpus[j] for j in neg_ids]
|
|
while pos[0] in negs:
|
|
negs.remove(pos[0])
|
|
new_negs = []
|
|
for e in negs:
|
|
if e not in new_negs and len(new_negs) < 15:
|
|
new_negs.append(e)
|
|
negs = new_negs
|
|
|
|
negs_answer = [corpus[j] for j in neg_answers_ids]
|
|
while pos[0] in negs_answer:
|
|
negs_answer.remove(pos[0])
|
|
new_negs_answer = []
|
|
for e in negs_answer:
|
|
if e not in new_negs_answer and len(new_negs_answer) < 15:
|
|
new_negs_answer.append(e)
|
|
negs_answer = new_negs_answer
|
|
|
|
train_data.append(
|
|
{
|
|
'query': query,
|
|
'answer': answer,
|
|
'pos': pos,
|
|
'neg': negs,
|
|
'neg_answer': negs_answer
|
|
}
|
|
)
|
|
|
|
if filter_data:
|
|
print(filter_data)
|
|
new_train_data = []
|
|
for i in range(len(all_indices)):
|
|
if i in list(all_indices[i]):
|
|
seached_idx = list(all_indices[i]).index(i)
|
|
else:
|
|
seached_idx = len(all_indices) + 999
|
|
if seached_idx < filter_num:
|
|
new_train_data.append(train_data[i])
|
|
train_data = new_train_data
|
|
|
|
print(len(train_data))
|
|
|
|
return train_data
|
|
|
|
|
|
def generate_llm_dpo_train_data(
|
|
queries_corpus_list: List[List[dict]] = None,
|
|
search_dtype: str = 'answer',
|
|
result_dtype: str = 'passage',
|
|
retrieval_model: AutoModel = None,
|
|
threshold: float = 0.95,
|
|
batch_size: int = 512,
|
|
max_length: int = 1024,
|
|
use_rule1: bool = True
|
|
):
|
|
data = []
|
|
|
|
queries_list = []
|
|
corpus = []
|
|
raw_queries = []
|
|
for qc in queries_corpus_list:
|
|
raw_queries = [d['query'] for d in qc]
|
|
if 'new_query' in qc[0].keys():
|
|
queries_list.append([d['new_query'] for d in qc])
|
|
else:
|
|
queries_list.append([d[search_dtype] for d in qc])
|
|
corpus = [d[result_dtype] for d in qc]
|
|
|
|
doc_emb = retrieval_model.encode_corpus(corpus, batch_size=batch_size, max_length=max_length)
|
|
raw_queries_emb = retrieval_model.encode_queries(raw_queries, batch_size=batch_size, max_length=max_length)
|
|
raw_scores = np.einsum('ij,ij->i', raw_queries_emb, doc_emb)
|
|
all_scores_list = []
|
|
for queries in queries_list:
|
|
# queries = ['Generate the topic about this passage: ' + q for q in queries]
|
|
queries_emb = raw_queries_emb * 0.8 + retrieval_model.encode_queries(queries, batch_size=batch_size,
|
|
max_length=max_length) * 0.2
|
|
# queries_emb = raw_queries_emb
|
|
all_scores_list.append(np.einsum('ij,ij->i', queries_emb, doc_emb))
|
|
|
|
for i in range(len(all_scores_list[0])):
|
|
raw_score = raw_scores[i]
|
|
all_scores = [e[i] for e in all_scores_list]
|
|
items = [(idx, all_scores[idx]) for idx in range(len(all_scores))]
|
|
sorted_idx = [idx for idx, _ in sorted(items, key=lambda x: x[1], reverse=False)]
|
|
min_score = max(all_scores)
|
|
for idx in sorted_idx:
|
|
if abs(1 - all_scores[idx] / raw_score) < 0.1:
|
|
min_score = all_scores[idx]
|
|
break
|
|
min_score = min(all_scores)
|
|
max_score = max(all_scores)
|
|
|
|
if use_rule1:
|
|
if max_score > raw_score and (max_score - raw_score * 0.8) * threshold >= (min_score - raw_score * 0.8):
|
|
# print('use')
|
|
tmp = {
|
|
'prompt': queries_corpus_list[0][i]['query'],
|
|
'chosen': queries_corpus_list[all_scores.index(max_score)][i][search_dtype],
|
|
'rejected': queries_corpus_list[all_scores.index(min_score)][i][search_dtype],
|
|
}
|
|
tmp['chosen_score'] = float(max_score / raw_score)
|
|
tmp['rejected_score'] = float(min_score / raw_score)
|
|
data.append(tmp)
|
|
else:
|
|
if (max_score - raw_score * 0.8) * threshold >= (min_score - raw_score * 0.8):
|
|
# print('use')
|
|
tmp = {
|
|
'prompt': queries_corpus_list[0][i]['query'],
|
|
'chosen': queries_corpus_list[all_scores.index(max_score)][i][search_dtype],
|
|
'rejected': queries_corpus_list[all_scores.index(min_score)][i][search_dtype],
|
|
}
|
|
tmp['chosen_score'] = float(max_score / raw_score)
|
|
tmp['rejected_score'] = float(min_score / raw_score)
|
|
data.append(tmp)
|
|
|
|
return data
|
|
|
|
|
|
def evaluate_mrr(qrels: Dict[str, Dict[str, int]],
|
|
results: Dict[str, Dict[str, float]],
|
|
k_values: List[int]) -> Tuple[Dict[str, float]]:
|
|
MRR = {}
|
|
|
|
for k in k_values:
|
|
MRR[f"MRR@{k}"] = 0.0
|
|
|
|
k_max, top_hits = max(k_values), {}
|
|
|
|
for query_id, doc_scores in results.items():
|
|
top_hits[query_id] = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]
|
|
|
|
for query_id in top_hits:
|
|
query_relevant_docs = set([doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0])
|
|
for k in k_values:
|
|
for rank, hit in enumerate(top_hits[query_id][0:k]):
|
|
if hit[0] in query_relevant_docs:
|
|
MRR[f"MRR@{k}"] += 1.0 / (rank + 1)
|
|
break
|
|
|
|
for k in k_values:
|
|
MRR[f"MRR@{k}"] = round(MRR[f"MRR@{k}"] / len(qrels), 5)
|
|
|
|
return MRR
|
|
|
|
|
|
def search(queries_emb, doc_emb, topk: int = 100):
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
faiss_index = faiss.index_factory(doc_emb.shape[1], 'Flat', faiss.METRIC_INNER_PRODUCT)
|
|
co = faiss.GpuMultipleClonerOptions()
|
|
co.shard = True
|
|
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
|
|
|
|
doc_emb = doc_emb.astype(np.float32)
|
|
faiss_index.train(doc_emb)
|
|
faiss_index.add(doc_emb)
|
|
|
|
dev_query_size = queries_emb.shape[0]
|
|
all_scores = []
|
|
all_indices = []
|
|
for i in tqdm(range(0, dev_query_size, 32), desc="Searching"):
|
|
j = min(i + 32, dev_query_size)
|
|
query_embedding = queries_emb[i: j]
|
|
score, indice = faiss_index.search(query_embedding.astype(np.float32), k=topk)
|
|
all_scores.append(score)
|
|
all_indices.append(indice)
|
|
|
|
all_scores = np.concatenate(all_scores, axis=0)
|
|
all_indices = np.concatenate(all_indices, axis=0)
|
|
|
|
return all_scores, all_indices
|
|
|
|
|
|
def evaluate(metrics: List[str] = ['recall', 'mrr', 'ndcg'],
|
|
k_values: List[int] = [1, 10],
|
|
ground_truths: List[Dict] = None,
|
|
predicts: List = None,
|
|
scores: List = None):
|
|
|
|
retrieval_results = {}
|
|
for i in range(len(predicts)):
|
|
tmp = {}
|
|
for j in range(len(predicts[0])):
|
|
tmp[str(predicts[i][j])] = float(scores[i][j])
|
|
retrieval_results[str(i)] = tmp
|
|
|
|
ndcg = {}
|
|
_map = {}
|
|
recall = {}
|
|
precision = {}
|
|
|
|
for k in k_values:
|
|
ndcg[f"NDCG@{k}"] = 0.0
|
|
_map[f"MAP@{k}"] = 0.0
|
|
recall[f"Recall@{k}"] = 0.0
|
|
precision[f"Precision@{k}"] = 0.0
|
|
|
|
map_string = "map_cut." + ",".join([str(k) for k in k_values])
|
|
ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
|
|
recall_string = "recall." + ",".join([str(k) for k in k_values])
|
|
precision_string = "P." + ",".join([str(k) for k in k_values])
|
|
evaluator = pytrec_eval.RelevanceEvaluator(ground_truths,
|
|
{map_string, ndcg_string, recall_string, precision_string})
|
|
|
|
scores = evaluator.evaluate(retrieval_results)
|
|
|
|
for query_id in scores.keys():
|
|
for k in k_values:
|
|
ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)]
|
|
_map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)]
|
|
recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)]
|
|
precision[f"Precision@{k}"] += scores[query_id]["P_" + str(k)]
|
|
|
|
for k in k_values:
|
|
ndcg[f"NDCG@{k}"] = round(ndcg[f"NDCG@{k}"] / len(scores), 5)
|
|
_map[f"MAP@{k}"] = round(_map[f"MAP@{k}"] / len(scores), 5)
|
|
recall[f"Recall@{k}"] = round(recall[f"Recall@{k}"] / len(scores), 5)
|
|
precision[f"Precision@{k}"] = round(precision[f"Precision@{k}"] / len(scores), 5)
|
|
|
|
mrr = evaluate_mrr(ground_truths, retrieval_results, k_values)
|
|
|
|
data = {}
|
|
|
|
if 'mrr' in metrics:
|
|
data['mrr'] = mrr
|
|
if 'recall' in metrics:
|
|
data['recall'] = recall
|
|
if 'ndcg' in metrics:
|
|
data['ndcg'] = ndcg
|
|
if 'map' in metrics:
|
|
data['map'] = _map
|
|
if 'precision' in metrics:
|
|
data['precision'] = precision
|
|
|
|
return data
|
|
|
|
|
|
def evaluate_better(metrics: List[str] = ['recall', 'mrr', 'ndcg'],
|
|
k_values: List[int] = [1, 10],
|
|
ground_truths: List[Dict] = None,
|
|
retrieval_results: List[Dict] = None):
|
|
ndcg = {}
|
|
_map = {}
|
|
recall = {}
|
|
precision = {}
|
|
|
|
for k in k_values:
|
|
ndcg[f"NDCG@{k}"] = 0.0
|
|
_map[f"MAP@{k}"] = 0.0
|
|
recall[f"Recall@{k}"] = 0.0
|
|
precision[f"Precision@{k}"] = 0.0
|
|
|
|
map_string = "map_cut." + ",".join([str(k) for k in k_values])
|
|
ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
|
|
recall_string = "recall." + ",".join([str(k) for k in k_values])
|
|
precision_string = "P." + ",".join([str(k) for k in k_values])
|
|
evaluator = pytrec_eval.RelevanceEvaluator(ground_truths,
|
|
{map_string, ndcg_string, recall_string, precision_string})
|
|
|
|
scores = evaluator.evaluate(retrieval_results)
|
|
|
|
for query_id in scores.keys():
|
|
for k in k_values:
|
|
ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)]
|
|
_map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)]
|
|
recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)]
|
|
precision[f"Precision@{k}"] += scores[query_id]["P_" + str(k)]
|
|
|
|
for k in k_values:
|
|
ndcg[f"NDCG@{k}"] = round(ndcg[f"NDCG@{k}"] / len(scores), 5)
|
|
_map[f"MAP@{k}"] = round(_map[f"MAP@{k}"] / len(scores), 5)
|
|
recall[f"Recall@{k}"] = round(recall[f"Recall@{k}"] / len(scores), 5)
|
|
precision[f"Precision@{k}"] = round(precision[f"Precision@{k}"] / len(scores), 5)
|
|
|
|
mrr = evaluate_mrr(ground_truths, retrieval_results, k_values)
|
|
|
|
data = {}
|
|
|
|
if 'mrr' in metrics:
|
|
data['mrr'] = mrr
|
|
if 'recall' in metrics:
|
|
data['recall'] = recall
|
|
if 'ndcg' in metrics:
|
|
data['ndcg'] = ndcg
|
|
if 'map' in metrics:
|
|
data['map'] = _map
|
|
if 'precision' in metrics:
|
|
data['precision'] = precision
|
|
|
|
return data |