embed-bge-m3/FlagEmbedding/research/C_MTEB/MKQA/sparse_retrieval/bm25_baseline.py

128 lines
4.1 KiB
Python

"""
# 1. Output Search Results with BM25
python bm25_baseline.py
# 2. Print and Save Evaluation Results
python step2-eval_sparse_mkqa.py \
--encoder bm25 \
--languages ar fi ja ko ru es sv he th da de fr it nl pl pt hu vi ms km no tr zh_cn zh_hk zh_tw \
--search_result_save_dir ./search_results \
--qa_data_dir ../qa_data \
--eval_result_save_dir ./eval_results \
--metrics recall@20 recall@100 \
--threads 32
"""
import os
import sys
import datasets
from tqdm import tqdm
sys.path.append("..")
from utils.normalize_text import normalize
def generate_corpus(corpus_save_path: str):
if os.path.exists(corpus_save_path):
print("Corpus already exists. Skip generating ...")
return
corpus = datasets.load_dataset('BeIR/nq', 'corpus')['corpus']
corpus_list = []
for data in tqdm(corpus, desc="Generating corpus"):
_id = str(data['_id'])
content = f"{data['title']}\n{data['text']}".lower()
content = normalize(content)
corpus_list.append({"id": _id, "contents": content})
corpus = datasets.Dataset.from_list(corpus_list)
corpus.to_json(corpus_save_path, force_ascii=False)
def generate_queries(qa_data_dir: str, lang: str, queries_save_dir: str):
queries_save_path = os.path.join(queries_save_dir, f"{lang}.tsv")
if os.path.exists(queries_save_path) and os.path.getsize(queries_save_path) > 0:
return
queries_path = os.path.join(qa_data_dir, f"{lang}.jsonl")
queries = datasets.load_dataset('json', data_files=queries_path)['train']
queries_list = []
for data in queries:
_id = str(data['id'])
query = data['question']
queries_list.append({
'id': _id,
'content': query
})
with open(queries_save_path, 'w', encoding='utf-8') as f:
for query in queries_list:
line = f"{query['id']}\t{query['content']}"
f.write(line + '\n')
def index(corpus_save_dir: str, index_save_dir: str):
cmd = f"python -m pyserini.index.lucene \
--collection JsonCollection \
--input {corpus_save_dir} \
--index {index_save_dir} \
--generator DefaultLuceneDocumentGenerator \
--threads 1 \
--storePositions --storeDocvectors --storeRaw \
"
os.system(cmd)
def search(index_save_dir: str, queries_save_dir: str, lang: str, result_save_path: str):
queries_save_path = os.path.join(queries_save_dir, f"{lang}.tsv")
# Note: Use `--lang {lang}` will cause the performance degradation, since the query and corpus are in different languages.
cmd = f"python -m pyserini.search.lucene \
--index {index_save_dir} \
--topics {queries_save_path} \
--output {result_save_path} \
--bm25 \
--hits 1000 \
--batch-size 128 \
--threads 16 \
"
os.system(cmd)
def main():
bm25_dir = './bm25_baseline'
qa_data_dir = '../qa_data'
result_save_dir = os.path.join('./search_results', 'bm25')
if not os.path.exists(result_save_dir):
os.makedirs(result_save_dir)
corpus_save_dir = os.path.join(bm25_dir, 'corpus')
if not os.path.exists(corpus_save_dir):
os.makedirs(corpus_save_dir)
corpus_save_path = os.path.join(corpus_save_dir, 'corpus.jsonl')
generate_corpus(corpus_save_path)
index_save_dir = os.path.join(bm25_dir, 'index')
if not os.path.exists(index_save_dir):
os.makedirs(index_save_dir)
index(corpus_save_dir, index_save_dir)
queries_save_dir = os.path.join(bm25_dir, 'queries')
if not os.path.exists(queries_save_dir):
os.makedirs(queries_save_dir)
languages = ['ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
for lang in languages:
generate_queries(qa_data_dir, lang, queries_save_dir)
result_save_path = os.path.join(result_save_dir, f'{lang}.txt')
search(index_save_dir, queries_save_dir, lang, result_save_path)
if __name__ == '__main__':
main()