32 lines
1.4 KiB
Python
32 lines
1.4 KiB
Python
from transformers import AutoTokenizer, AutoModel
|
|
import torch
|
|
import numpy as np
|
|
from app.core.config import settings
|
|
|
|
class BGEEmbedding:
|
|
def __init__(self):
|
|
self.device = torch.device(settings.DEVICE)
|
|
self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
|
|
self.model = AutoModel.from_pretrained(settings.MODEL_NAME).to(self.device)
|
|
self.model.eval()
|
|
|
|
def encode(self, texts, batch_size=8):
|
|
all_embeddings = []
|
|
for i in range(0, len(texts), batch_size):
|
|
batch_texts = texts[i:i+batch_size]
|
|
with torch.no_grad():
|
|
inputs = self.tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt").to(self.device)
|
|
outputs = self.model(**inputs)
|
|
embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
|
|
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
|
all_embeddings.append(embeddings.cpu().numpy())
|
|
return np.vstack(all_embeddings)
|
|
|
|
def mean_pooling(self, model_output, attention_mask):
|
|
token_embeddings = model_output[0]
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
|
|
|
embedder = BGEEmbedding()
|
|
|