embed-bge-m3/FlagEmbedding/research/C_MTEB/MLDR/sparse_retrieval/step1-search_results.py

159 lines
5.7 KiB
Python

"""
python step1-search_results.py \
--encoder BAAI/bge-m3 \
--languages ar de en es fr hi it ja ko pt ru th zh \
--encoded_query_and_corpus_save_dir ./encoded_query-and-corpus \
--result_save_dir ./search_results \
--threads 16 \
--hits 1000
"""
import os
from dataclasses import dataclass, field
from transformers import HfArgumentParser
@dataclass
class ModelArgs:
encoder: str = field(
default="BAAI/bge-m3",
metadata={'help': 'Name or path of encoder'}
)
@dataclass
class EvalArgs:
languages: str = field(
default="en",
metadata={'help': 'Languages to evaluate. Avaliable languages: ar de en es fr hi it ja ko pt ru th zh',
"nargs": "+"}
)
encoded_query_and_corpus_save_dir: str = field(
default='./encoded_query-and-corpus',
metadata={'help': 'Dir to save encoded queries and corpus. Encoded queries and corpus are saved in `save_dir/{encoder_name}/{lang}/query_embd.tsv` and `save_dir/{encoder_name}/{lang}/corpus/corpus_embd.jsonl`, individually.'}
)
result_save_dir: str = field(
default='./search_results',
metadata={'help': 'Dir to saving results. Search results will be saved to `result_save_dir/{encoder_name}/{lang}.txt`'}
)
batch_size: int = field(
default=32,
metadata={'help': 'Batch size to use during search'}
)
threads: int = field(
default=1,
metadata={'help': 'Maximum threads to use during search'}
)
hits: int = field(
default=1000,
metadata={'help': 'Number of hits'}
)
overwrite: bool = field(
default=False,
metadata={'help': 'Whether to overwrite embedding'}
)
def check_languages(languages):
if isinstance(languages, str):
languages = [languages]
avaliable_languages = ['ar', 'de', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'pt', 'ru', 'th', 'zh']
for lang in languages:
if lang not in avaliable_languages:
raise ValueError(f"Language `{lang}` is not supported. Avaliable languages: {avaliable_languages}")
return languages
def generate_index(lang: str, corpus_embd_dir: str, index_save_dir: str, threads: int=12):
cmd = f"python -m pyserini.index.lucene \
--language {lang} \
--collection JsonVectorCollection \
--input {corpus_embd_dir} \
--index {index_save_dir} \
--generator DefaultLuceneDocumentGenerator \
--threads {threads} \
--impact --pretokenized --optimize \
"
os.system(cmd)
def search_and_save_results(index_save_dir: str, query_embd_path: str, result_save_path: str, batch_size: int = 32, threads: int = 12, hits: int = 1000):
cmd = f"python -m pyserini.search.lucene \
--index {index_save_dir} \
--topics {query_embd_path} \
--output {result_save_path} \
--output-format trec \
--batch {batch_size} \
--threads {threads} \
--hits {hits} \
--impact \
"
os.system(cmd)
def main():
parser = HfArgumentParser([ModelArgs, EvalArgs])
model_args, eval_args = parser.parse_args_into_dataclasses()
model_args: ModelArgs
eval_args: EvalArgs
languages = check_languages(eval_args.languages)
if model_args.encoder[-1] == '/':
model_args.encoder = model_args.encoder[:-1]
encoder = model_args.encoder
if os.path.basename(encoder).startswith('checkpoint-'):
encoder = os.path.dirname(encoder) + '_' + os.path.basename(encoder)
print("==================================================")
print("Start generating search results with model:")
print(model_args.encoder)
print('Generate search results of following languages: ', languages)
for lang in languages:
print("**************************************************")
print(f"Start searching results of {lang} ...")
result_save_path = os.path.join(eval_args.result_save_dir, os.path.basename(encoder), f"{lang}.txt")
if not os.path.exists(os.path.dirname(result_save_path)):
os.makedirs(os.path.dirname(result_save_path))
if os.path.exists(result_save_path) and not eval_args.overwrite:
print(f'Search results of {lang} already exists. Skip...')
continue
encoded_query_and_corpus_save_dir = os.path.join(eval_args.encoded_query_and_corpus_save_dir, os.path.basename(encoder), lang)
if not os.path.exists(encoded_query_and_corpus_save_dir):
raise FileNotFoundError(f"{encoded_query_and_corpus_save_dir} not found")
corpus_embd_dir = os.path.join(encoded_query_and_corpus_save_dir, 'corpus')
index_save_dir = os.path.join(eval_args.encoded_query_and_corpus_save_dir, os.path.basename(encoder), lang, 'index')
if os.path.exists(index_save_dir) and not eval_args.overwrite:
print(f'Index of {lang} already exists')
else:
generate_index(
lang=lang,
corpus_embd_dir=corpus_embd_dir,
index_save_dir=index_save_dir,
threads=eval_args.threads
)
query_embd_path = os.path.join(encoded_query_and_corpus_save_dir, 'query_embd.tsv')
search_and_save_results(
index_save_dir=index_save_dir,
query_embd_path=query_embd_path,
result_save_path=result_save_path,
batch_size=eval_args.batch_size,
threads=eval_args.threads,
hits=eval_args.hits
)
print("==================================================")
print("Finish generating search results with model:")
print(model_args.encoder)
if __name__ == "__main__":
main()