faiss_rag_enterprise/llama_index/embeddings/huggingface_optimum.py

196 lines
7.0 KiB
Python

from typing import Any, List, Optional
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks import CallbackManager
from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding
from llama_index.embeddings.huggingface_utils import (
format_query,
format_text,
get_pooling_mode,
)
from llama_index.embeddings.pooling import Pooling
from llama_index.utils import infer_torch_device
class OptimumEmbedding(BaseEmbedding):
folder_name: str = Field(description="Folder name to load from.")
max_length: int = Field(description="Maximum length of input.")
pooling: str = Field(description="Pooling strategy. One of ['cls', 'mean'].")
normalize: str = Field(default=True, description="Normalize embeddings or not.")
query_instruction: Optional[str] = Field(
description="Instruction to prepend to query text."
)
text_instruction: Optional[str] = Field(
description="Instruction to prepend to text."
)
cache_folder: Optional[str] = Field(
description="Cache folder for huggingface files."
)
_model: Any = PrivateAttr()
_tokenizer: Any = PrivateAttr()
_device: Any = PrivateAttr()
def __init__(
self,
folder_name: str,
pooling: Optional[str] = None,
max_length: Optional[int] = None,
normalize: bool = True,
query_instruction: Optional[str] = None,
text_instruction: Optional[str] = None,
model: Optional[Any] = None,
tokenizer: Optional[Any] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
device: Optional[str] = None,
):
try:
from optimum.onnxruntime import ORTModelForFeatureExtraction
from transformers import AutoTokenizer
except ImportError:
raise ImportError(
"OptimumEmbedding requires transformers to be installed.\n"
"Please install transformers with "
"`pip install transformers optimum[exporters]`."
)
self._model = model or ORTModelForFeatureExtraction.from_pretrained(folder_name)
self._tokenizer = tokenizer or AutoTokenizer.from_pretrained(folder_name)
self._device = device or infer_torch_device()
if max_length is None:
try:
max_length = int(self._model.config.max_position_embeddings)
except Exception:
raise ValueError(
"Unable to find max_length from model config. "
"Please provide max_length."
)
if not pooling:
pooling = get_pooling_mode(model)
try:
pooling = Pooling(pooling)
except ValueError as exc:
raise NotImplementedError(
f"Pooling {pooling} unsupported, please pick one in"
f" {[p.value for p in Pooling]}."
) from exc
super().__init__(
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
folder_name=folder_name,
max_length=max_length,
pooling=pooling,
normalize=normalize,
query_instruction=query_instruction,
text_instruction=text_instruction,
)
@classmethod
def class_name(cls) -> str:
return "OptimumEmbedding"
@classmethod
def create_and_save_optimum_model(
cls,
model_name_or_path: str,
output_path: str,
export_kwargs: Optional[dict] = None,
) -> None:
try:
from optimum.onnxruntime import ORTModelForFeatureExtraction
from transformers import AutoTokenizer
except ImportError:
raise ImportError(
"OptimumEmbedding requires transformers to be installed.\n"
"Please install transformers with "
"`pip install transformers optimum[exporters]`."
)
export_kwargs = export_kwargs or {}
model = ORTModelForFeatureExtraction.from_pretrained(
model_name_or_path, export=True, **export_kwargs
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
print(
f"Saved optimum model to {output_path}. Use it with "
f"`embed_model = OptimumEmbedding(folder_name='{output_path}')`."
)
def _mean_pooling(self, model_output: Any, attention_mask: Any) -> Any:
"""Mean Pooling - Take attention mask into account for correct averaging."""
import torch
# First element of model_output contains all token embeddings
token_embeddings = model_output[0]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def _cls_pooling(self, model_output: list) -> Any:
"""Use the CLS token as the pooling token."""
return model_output[0][:, 0]
def _embed(self, sentences: List[str]) -> List[List[float]]:
"""Embed sentences."""
encoded_input = self._tokenizer(
sentences,
padding=True,
max_length=self.max_length,
truncation=True,
return_tensors="pt",
)
# pop token_type_ids
encoded_input.pop("token_type_ids", None)
model_output = self._model(**encoded_input)
if self.pooling == "cls":
embeddings = self._cls_pooling(model_output)
else:
embeddings = self._mean_pooling(
model_output, encoded_input["attention_mask"].to(self._device)
)
if self.normalize:
import torch
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings.tolist()
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
query = format_query(query, self.model_name, self.query_instruction)
return self._embed([query])[0]
async def _aget_query_embedding(self, query: str) -> List[float]:
"""Get query embedding async."""
return self._get_query_embedding(query)
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Get text embedding async."""
return self._get_text_embedding(text)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
text = format_text(text, self.model_name, self.text_instruction)
return self._embed([text])[0]
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
texts = [
format_text(text, self.model_name, self.text_instruction) for text in texts
]
return self._embed(texts)