98 lines
3.6 KiB
Python
98 lines
3.6 KiB
Python
from typing import cast, List, Dict, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from tqdm import tqdm
|
|
from transformers import AutoModel, AutoTokenizer, is_torch_npu_available
|
|
|
|
|
|
class FlagDRESModel:
|
|
def __init__(
|
|
self,
|
|
model_name_or_path: str = None,
|
|
pooling_method: str = 'cls',
|
|
normalize_embeddings: bool = True,
|
|
query_instruction_for_retrieval: str = None,
|
|
batch_size: int = 256,
|
|
) -> None:
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
|
self.model = AutoModel.from_pretrained(model_name_or_path)
|
|
self.query_instruction_for_retrieval = query_instruction_for_retrieval
|
|
self.normalize_embeddings = normalize_embeddings
|
|
self.pooling_method = pooling_method
|
|
self.batch_size = batch_size
|
|
|
|
if torch.cuda.is_available():
|
|
self.device = torch.device("cuda")
|
|
elif is_torch_npu_available():
|
|
self.device = torch.device("npu")
|
|
else:
|
|
self.device = torch.device("cpu")
|
|
self.model = self.model.to(self.device)
|
|
|
|
num_gpus = torch.cuda.device_count()
|
|
if num_gpus > 1:
|
|
self.model = torch.nn.DataParallel(self.model)
|
|
self.batch_size = self.batch_size * num_gpus
|
|
|
|
|
|
def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
|
|
'''
|
|
This function will be used for retrieval task
|
|
if there is a instruction for queries, we will add it to the query text
|
|
'''
|
|
if self.query_instruction_for_retrieval is not None:
|
|
input_texts = ['{}{}'.format(self.query_instruction_for_retrieval, q) for q in queries]
|
|
else:
|
|
input_texts = queries
|
|
return self.encode(input_texts)
|
|
|
|
|
|
def encode_corpus(self, corpus: List[Union[Dict[str, str], str]], **kwargs) -> np.ndarray:
|
|
'''
|
|
This function will be used for retrieval task
|
|
encode corpus for retrieval task
|
|
'''
|
|
if isinstance(corpus[0], dict):
|
|
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
|
|
else:
|
|
input_texts = corpus
|
|
return self.encode(input_texts)
|
|
|
|
|
|
@torch.no_grad()
|
|
def encode(self, sentences: List[str], **kwargs) -> np.ndarray:
|
|
self.model.eval()
|
|
|
|
all_embeddings = []
|
|
for start_index in tqdm(range(0, len(sentences), self.batch_size), desc="Batches", disable=len(sentences)<256):
|
|
sentences_batch = sentences[start_index:start_index + self.batch_size]
|
|
inputs = self.tokenizer(
|
|
sentences_batch,
|
|
padding=True,
|
|
truncation=True,
|
|
return_tensors='pt',
|
|
max_length=512,
|
|
).to(self.device)
|
|
last_hidden_state = self.model(**inputs, return_dict=True).last_hidden_state
|
|
embeddings = self.pooling(last_hidden_state, inputs['attention_mask'])
|
|
if self.normalize_embeddings:
|
|
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
|
embeddings = cast(torch.Tensor, embeddings)
|
|
all_embeddings.append(embeddings.cpu().numpy())
|
|
|
|
return np.concatenate(all_embeddings, axis=0)
|
|
|
|
def pooling(self,
|
|
last_hidden_state: torch.Tensor,
|
|
attention_mask: torch.Tensor=None):
|
|
if self.pooling_method == 'cls':
|
|
return last_hidden_state[:, 0]
|
|
elif self.pooling_method == 'mean':
|
|
s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
|
|
d = attention_mask.sum(dim=1, keepdim=True).float()
|
|
return s / d
|
|
|
|
|