196 lines
7.0 KiB
Python
196 lines
7.0 KiB
Python
from typing import Any, List, Optional
|
|
|
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
|
from llama_index.callbacks import CallbackManager
|
|
from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding
|
|
from llama_index.embeddings.huggingface_utils import (
|
|
format_query,
|
|
format_text,
|
|
get_pooling_mode,
|
|
)
|
|
from llama_index.embeddings.pooling import Pooling
|
|
from llama_index.utils import infer_torch_device
|
|
|
|
|
|
class OptimumEmbedding(BaseEmbedding):
|
|
folder_name: str = Field(description="Folder name to load from.")
|
|
max_length: int = Field(description="Maximum length of input.")
|
|
pooling: str = Field(description="Pooling strategy. One of ['cls', 'mean'].")
|
|
normalize: str = Field(default=True, description="Normalize embeddings or not.")
|
|
query_instruction: Optional[str] = Field(
|
|
description="Instruction to prepend to query text."
|
|
)
|
|
text_instruction: Optional[str] = Field(
|
|
description="Instruction to prepend to text."
|
|
)
|
|
cache_folder: Optional[str] = Field(
|
|
description="Cache folder for huggingface files."
|
|
)
|
|
|
|
_model: Any = PrivateAttr()
|
|
_tokenizer: Any = PrivateAttr()
|
|
_device: Any = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
folder_name: str,
|
|
pooling: Optional[str] = None,
|
|
max_length: Optional[int] = None,
|
|
normalize: bool = True,
|
|
query_instruction: Optional[str] = None,
|
|
text_instruction: Optional[str] = None,
|
|
model: Optional[Any] = None,
|
|
tokenizer: Optional[Any] = None,
|
|
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
device: Optional[str] = None,
|
|
):
|
|
try:
|
|
from optimum.onnxruntime import ORTModelForFeatureExtraction
|
|
from transformers import AutoTokenizer
|
|
except ImportError:
|
|
raise ImportError(
|
|
"OptimumEmbedding requires transformers to be installed.\n"
|
|
"Please install transformers with "
|
|
"`pip install transformers optimum[exporters]`."
|
|
)
|
|
|
|
self._model = model or ORTModelForFeatureExtraction.from_pretrained(folder_name)
|
|
self._tokenizer = tokenizer or AutoTokenizer.from_pretrained(folder_name)
|
|
self._device = device or infer_torch_device()
|
|
|
|
if max_length is None:
|
|
try:
|
|
max_length = int(self._model.config.max_position_embeddings)
|
|
except Exception:
|
|
raise ValueError(
|
|
"Unable to find max_length from model config. "
|
|
"Please provide max_length."
|
|
)
|
|
|
|
if not pooling:
|
|
pooling = get_pooling_mode(model)
|
|
try:
|
|
pooling = Pooling(pooling)
|
|
except ValueError as exc:
|
|
raise NotImplementedError(
|
|
f"Pooling {pooling} unsupported, please pick one in"
|
|
f" {[p.value for p in Pooling]}."
|
|
) from exc
|
|
|
|
super().__init__(
|
|
embed_batch_size=embed_batch_size,
|
|
callback_manager=callback_manager,
|
|
folder_name=folder_name,
|
|
max_length=max_length,
|
|
pooling=pooling,
|
|
normalize=normalize,
|
|
query_instruction=query_instruction,
|
|
text_instruction=text_instruction,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "OptimumEmbedding"
|
|
|
|
@classmethod
|
|
def create_and_save_optimum_model(
|
|
cls,
|
|
model_name_or_path: str,
|
|
output_path: str,
|
|
export_kwargs: Optional[dict] = None,
|
|
) -> None:
|
|
try:
|
|
from optimum.onnxruntime import ORTModelForFeatureExtraction
|
|
from transformers import AutoTokenizer
|
|
except ImportError:
|
|
raise ImportError(
|
|
"OptimumEmbedding requires transformers to be installed.\n"
|
|
"Please install transformers with "
|
|
"`pip install transformers optimum[exporters]`."
|
|
)
|
|
|
|
export_kwargs = export_kwargs or {}
|
|
model = ORTModelForFeatureExtraction.from_pretrained(
|
|
model_name_or_path, export=True, **export_kwargs
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
|
|
|
model.save_pretrained(output_path)
|
|
tokenizer.save_pretrained(output_path)
|
|
print(
|
|
f"Saved optimum model to {output_path}. Use it with "
|
|
f"`embed_model = OptimumEmbedding(folder_name='{output_path}')`."
|
|
)
|
|
|
|
def _mean_pooling(self, model_output: Any, attention_mask: Any) -> Any:
|
|
"""Mean Pooling - Take attention mask into account for correct averaging."""
|
|
import torch
|
|
|
|
# First element of model_output contains all token embeddings
|
|
token_embeddings = model_output[0]
|
|
input_mask_expanded = (
|
|
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
|
)
|
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
|
|
input_mask_expanded.sum(1), min=1e-9
|
|
)
|
|
|
|
def _cls_pooling(self, model_output: list) -> Any:
|
|
"""Use the CLS token as the pooling token."""
|
|
return model_output[0][:, 0]
|
|
|
|
def _embed(self, sentences: List[str]) -> List[List[float]]:
|
|
"""Embed sentences."""
|
|
encoded_input = self._tokenizer(
|
|
sentences,
|
|
padding=True,
|
|
max_length=self.max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
# pop token_type_ids
|
|
encoded_input.pop("token_type_ids", None)
|
|
|
|
model_output = self._model(**encoded_input)
|
|
|
|
if self.pooling == "cls":
|
|
embeddings = self._cls_pooling(model_output)
|
|
else:
|
|
embeddings = self._mean_pooling(
|
|
model_output, encoded_input["attention_mask"].to(self._device)
|
|
)
|
|
|
|
if self.normalize:
|
|
import torch
|
|
|
|
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
|
|
|
return embeddings.tolist()
|
|
|
|
def _get_query_embedding(self, query: str) -> List[float]:
|
|
"""Get query embedding."""
|
|
query = format_query(query, self.model_name, self.query_instruction)
|
|
return self._embed([query])[0]
|
|
|
|
async def _aget_query_embedding(self, query: str) -> List[float]:
|
|
"""Get query embedding async."""
|
|
return self._get_query_embedding(query)
|
|
|
|
async def _aget_text_embedding(self, text: str) -> List[float]:
|
|
"""Get text embedding async."""
|
|
return self._get_text_embedding(text)
|
|
|
|
def _get_text_embedding(self, text: str) -> List[float]:
|
|
"""Get text embedding."""
|
|
text = format_text(text, self.model_name, self.text_instruction)
|
|
return self._embed([text])[0]
|
|
|
|
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
"""Get text embeddings."""
|
|
texts = [
|
|
format_text(text, self.model_name, self.text_instruction) for text in texts
|
|
]
|
|
return self._embed(texts)
|