faiss_rag_enterprise/app/core/embedding.py

32 lines
1.4 KiB
Python

from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
from app.core.config import settings
class BGEEmbedding:
def __init__(self):
self.device = torch.device(settings.DEVICE)
self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
self.model = AutoModel.from_pretrained(settings.MODEL_NAME).to(self.device)
self.model.eval()
def encode(self, texts, batch_size=8):
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
with torch.no_grad():
inputs = self.tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt").to(self.device)
outputs = self.model(**inputs)
embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
all_embeddings.append(embeddings.cpu().numpy())
return np.vstack(all_embeddings)
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
embedder = BGEEmbedding()