153 lines
4.6 KiB
Python
153 lines
4.6 KiB
Python
"""
|
|
python step0-generate_embedding.py \
|
|
--encoder BAAI/bge-m3 \
|
|
--index_save_dir ./corpus-index \
|
|
--max_passage_length 512 \
|
|
--batch_size 256 \
|
|
--fp16 \
|
|
--pooling_method cls \
|
|
--normalize_embeddings True
|
|
"""
|
|
import os
|
|
import sys
|
|
import faiss
|
|
import datasets
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from pprint import pprint
|
|
from FlagEmbedding import FlagModel
|
|
from dataclasses import dataclass, field
|
|
from transformers import HfArgumentParser
|
|
|
|
sys.path.append("..")
|
|
|
|
from utils.normalize_text import normalize
|
|
|
|
|
|
@dataclass
|
|
class ModelArgs:
|
|
encoder: str = field(
|
|
default="BAAI/bge-m3",
|
|
metadata={'help': 'Name or path of encoder'}
|
|
)
|
|
fp16: bool = field(
|
|
default=True,
|
|
metadata={'help': 'Use fp16 in inference?'}
|
|
)
|
|
pooling_method: str = field(
|
|
default='cls',
|
|
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
|
|
)
|
|
normalize_embeddings: bool = field(
|
|
default=True,
|
|
metadata={'help': "Normalize embeddings or not"}
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class EvalArgs:
|
|
index_save_dir: str = field(
|
|
default='./corpus-index',
|
|
metadata={'help': 'Dir to save index. Corpus index will be saved to `index_save_dir/{encoder_name}/index`. Corpus ids will be saved to `index_save_dir/{encoder_name}/docid` .'}
|
|
)
|
|
max_passage_length: int = field(
|
|
default=512,
|
|
metadata={'help': 'Max passage length.'}
|
|
)
|
|
batch_size: int = field(
|
|
default=256,
|
|
metadata={'help': 'Inference batch size.'}
|
|
)
|
|
overwrite: bool = field(
|
|
default=False,
|
|
metadata={'help': 'Whether to overwrite embedding'}
|
|
)
|
|
|
|
|
|
def get_model(model_args: ModelArgs):
|
|
model = FlagModel(
|
|
model_args.encoder,
|
|
pooling_method=model_args.pooling_method,
|
|
normalize_embeddings=model_args.normalize_embeddings,
|
|
use_fp16=model_args.fp16
|
|
)
|
|
return model
|
|
|
|
|
|
def parse_corpus(corpus: datasets.Dataset):
|
|
corpus_list = []
|
|
for data in tqdm(corpus, desc="Generating corpus"):
|
|
_id = str(data['_id'])
|
|
content = f"{data['title']}\n{data['text']}".lower()
|
|
content = normalize(content)
|
|
corpus_list.append({"id": _id, "content": content})
|
|
|
|
corpus = datasets.Dataset.from_list(corpus_list)
|
|
return corpus
|
|
|
|
|
|
def generate_index(model: FlagModel, corpus: datasets.Dataset, max_passage_length: int=512, batch_size: int=256):
|
|
corpus_embeddings = model.encode_corpus(corpus["content"], batch_size=batch_size, max_length=max_passage_length)
|
|
dim = corpus_embeddings.shape[-1]
|
|
|
|
faiss_index = faiss.index_factory(dim, "Flat", faiss.METRIC_INNER_PRODUCT)
|
|
corpus_embeddings = corpus_embeddings.astype(np.float32)
|
|
faiss_index.train(corpus_embeddings)
|
|
faiss_index.add(corpus_embeddings)
|
|
return faiss_index, list(corpus["id"])
|
|
|
|
|
|
def save_result(index: faiss.Index, docid: list, index_save_dir: str):
|
|
docid_save_path = os.path.join(index_save_dir, 'docid')
|
|
index_save_path = os.path.join(index_save_dir, 'index')
|
|
with open(docid_save_path, 'w', encoding='utf-8') as f:
|
|
for _id in docid:
|
|
f.write(str(_id) + '\n')
|
|
faiss.write_index(index, index_save_path)
|
|
|
|
|
|
def main():
|
|
parser = HfArgumentParser([ModelArgs, EvalArgs])
|
|
model_args, eval_args = parser.parse_args_into_dataclasses()
|
|
model_args: ModelArgs
|
|
eval_args: EvalArgs
|
|
|
|
if model_args.encoder[-1] == '/':
|
|
model_args.encoder = model_args.encoder[:-1]
|
|
|
|
model = get_model(model_args=model_args)
|
|
|
|
encoder = model_args.encoder
|
|
if os.path.basename(encoder).startswith('checkpoint-'):
|
|
encoder = os.path.dirname(encoder) + '_' + os.path.basename(encoder)
|
|
|
|
print("==================================================")
|
|
print("Start generating embedding with model:")
|
|
print(model_args.encoder)
|
|
|
|
index_save_dir = os.path.join(eval_args.index_save_dir, os.path.basename(encoder))
|
|
if not os.path.exists(index_save_dir):
|
|
os.makedirs(index_save_dir)
|
|
if os.path.exists(os.path.join(index_save_dir, 'index')) and not eval_args.overwrite:
|
|
print(f'Embedding already exists. Skip...')
|
|
return
|
|
|
|
corpus = datasets.load_dataset("BeIR/nq", 'corpus')['corpus']
|
|
corpus = parse_corpus(corpus=corpus)
|
|
|
|
index, docid = generate_index(
|
|
model=model,
|
|
corpus=corpus,
|
|
max_passage_length=eval_args.max_passage_length,
|
|
batch_size=eval_args.batch_size
|
|
)
|
|
save_result(index, docid, index_save_dir)
|
|
|
|
print("==================================================")
|
|
print("Finish generating embeddings with following model:")
|
|
pprint(model_args.encoder)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|