109 lines
3.5 KiB
Python
109 lines
3.5 KiB
Python
"""
|
|
# 1. Output Search Results with BM25
|
|
python bm25_baseline.py
|
|
|
|
# 2. Print and Save Evaluation Results
|
|
python step2-eval_sparse_mldr.py \
|
|
--encoder bm25 \
|
|
--languages ar de es fr hi it ja ko pt ru th en zh \
|
|
--search_result_save_dir ./search_results \
|
|
--qrels_dir ../qrels \
|
|
--eval_result_save_dir ./eval_results \
|
|
--metrics ndcg@10
|
|
"""
|
|
import os
|
|
import datasets
|
|
from tqdm import tqdm
|
|
|
|
|
|
def generate_corpus(lang: str, corpus_save_dir: str):
|
|
corpus_save_path = os.path.join(corpus_save_dir, 'corpus.jsonl')
|
|
if os.path.exists(corpus_save_path):
|
|
return
|
|
|
|
corpus = datasets.load_dataset('Shitao/MLDR', f'corpus-{lang}', split='corpus')
|
|
|
|
corpus_list = [{'id': e['docid'], 'contents': e['text']} for e in tqdm(corpus, desc="Generating corpus")]
|
|
corpus = datasets.Dataset.from_list(corpus_list)
|
|
|
|
corpus.to_json(corpus_save_path, force_ascii=False)
|
|
|
|
|
|
def generate_queries(lang: str, queries_save_dir: str, split: str='test'):
|
|
queries_save_path = os.path.join(queries_save_dir, f"{lang}.tsv")
|
|
if os.path.exists(queries_save_path):
|
|
return
|
|
|
|
dataset = datasets.load_dataset('Shitao/MLDR', lang, split=split)
|
|
|
|
queries_list = []
|
|
for data in dataset:
|
|
queries_list.append({
|
|
'id': data['query_id'],
|
|
'content': data['query'].replace('\n', ' ').replace('\t', ' ')
|
|
})
|
|
with open(queries_save_path, 'w', encoding='utf-8') as f:
|
|
for query in queries_list:
|
|
assert '\n' not in query['content'] and '\t' not in query['content']
|
|
line = f"{query['id']}\t{query['content']}"
|
|
f.write(line + '\n')
|
|
|
|
|
|
def index(lang: str, corpus_save_dir: str, index_save_dir: str):
|
|
cmd = f"python -m pyserini.index.lucene \
|
|
--language {lang} \
|
|
--collection JsonCollection \
|
|
--input {corpus_save_dir} \
|
|
--index {index_save_dir} \
|
|
--generator DefaultLuceneDocumentGenerator \
|
|
--threads 1 --optimize \
|
|
"
|
|
os.system(cmd)
|
|
|
|
|
|
def search(index_save_dir: str, queries_save_dir: str, lang: str, result_save_path: str):
|
|
queries_save_path = os.path.join(queries_save_dir, f"{lang}.tsv")
|
|
cmd = f"python -m pyserini.search.lucene \
|
|
--language {lang} \
|
|
--index {index_save_dir} \
|
|
--topics {queries_save_path} \
|
|
--output {result_save_path} \
|
|
--bm25 \
|
|
--hits 1000 \
|
|
--batch-size 128 \
|
|
--threads 16 \
|
|
"
|
|
os.system(cmd)
|
|
|
|
|
|
def main():
|
|
bm25_dir = './bm25_baseline'
|
|
|
|
result_save_dir = os.path.join('./search_results', 'bm25')
|
|
if not os.path.exists(result_save_dir):
|
|
os.makedirs(result_save_dir)
|
|
|
|
for lang in ['ar', 'de', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'pt', 'ru', 'th', 'zh']:
|
|
save_dir = os.path.join(bm25_dir, lang)
|
|
if not os.path.exists(save_dir):
|
|
os.makedirs(save_dir)
|
|
|
|
corpus_save_dir = os.path.join(save_dir, 'corpus')
|
|
if not os.path.exists(corpus_save_dir):
|
|
os.makedirs(corpus_save_dir)
|
|
generate_corpus(lang, corpus_save_dir)
|
|
|
|
index_save_dir = os.path.join(save_dir, 'index')
|
|
if not os.path.exists(index_save_dir):
|
|
os.makedirs(index_save_dir)
|
|
index(lang, corpus_save_dir, index_save_dir)
|
|
|
|
generate_queries(lang, save_dir, split='test')
|
|
|
|
result_save_path = os.path.join(result_save_dir, f'{lang}.txt')
|
|
search(index_save_dir, save_dir, lang, result_save_path)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|