from typing import Any, List, Optional from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding from llama_index.embeddings.huggingface_utils import ( format_query, format_text, get_pooling_mode, ) from llama_index.embeddings.pooling import Pooling from llama_index.utils import infer_torch_device class OptimumEmbedding(BaseEmbedding): folder_name: str = Field(description="Folder name to load from.") max_length: int = Field(description="Maximum length of input.") pooling: str = Field(description="Pooling strategy. One of ['cls', 'mean'].") normalize: str = Field(default=True, description="Normalize embeddings or not.") query_instruction: Optional[str] = Field( description="Instruction to prepend to query text." ) text_instruction: Optional[str] = Field( description="Instruction to prepend to text." ) cache_folder: Optional[str] = Field( description="Cache folder for huggingface files." ) _model: Any = PrivateAttr() _tokenizer: Any = PrivateAttr() _device: Any = PrivateAttr() def __init__( self, folder_name: str, pooling: Optional[str] = None, max_length: Optional[int] = None, normalize: bool = True, query_instruction: Optional[str] = None, text_instruction: Optional[str] = None, model: Optional[Any] = None, tokenizer: Optional[Any] = None, embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, callback_manager: Optional[CallbackManager] = None, device: Optional[str] = None, ): try: from optimum.onnxruntime import ORTModelForFeatureExtraction from transformers import AutoTokenizer except ImportError: raise ImportError( "OptimumEmbedding requires transformers to be installed.\n" "Please install transformers with " "`pip install transformers optimum[exporters]`." ) self._model = model or ORTModelForFeatureExtraction.from_pretrained(folder_name) self._tokenizer = tokenizer or AutoTokenizer.from_pretrained(folder_name) self._device = device or infer_torch_device() if max_length is None: try: max_length = int(self._model.config.max_position_embeddings) except Exception: raise ValueError( "Unable to find max_length from model config. " "Please provide max_length." ) if not pooling: pooling = get_pooling_mode(model) try: pooling = Pooling(pooling) except ValueError as exc: raise NotImplementedError( f"Pooling {pooling} unsupported, please pick one in" f" {[p.value for p in Pooling]}." ) from exc super().__init__( embed_batch_size=embed_batch_size, callback_manager=callback_manager, folder_name=folder_name, max_length=max_length, pooling=pooling, normalize=normalize, query_instruction=query_instruction, text_instruction=text_instruction, ) @classmethod def class_name(cls) -> str: return "OptimumEmbedding" @classmethod def create_and_save_optimum_model( cls, model_name_or_path: str, output_path: str, export_kwargs: Optional[dict] = None, ) -> None: try: from optimum.onnxruntime import ORTModelForFeatureExtraction from transformers import AutoTokenizer except ImportError: raise ImportError( "OptimumEmbedding requires transformers to be installed.\n" "Please install transformers with " "`pip install transformers optimum[exporters]`." ) export_kwargs = export_kwargs or {} model = ORTModelForFeatureExtraction.from_pretrained( model_name_or_path, export=True, **export_kwargs ) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) model.save_pretrained(output_path) tokenizer.save_pretrained(output_path) print( f"Saved optimum model to {output_path}. Use it with " f"`embed_model = OptimumEmbedding(folder_name='{output_path}')`." ) def _mean_pooling(self, model_output: Any, attention_mask: Any) -> Any: """Mean Pooling - Take attention mask into account for correct averaging.""" import torch # First element of model_output contains all token embeddings token_embeddings = model_output[0] input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() ) return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 ) def _cls_pooling(self, model_output: list) -> Any: """Use the CLS token as the pooling token.""" return model_output[0][:, 0] def _embed(self, sentences: List[str]) -> List[List[float]]: """Embed sentences.""" encoded_input = self._tokenizer( sentences, padding=True, max_length=self.max_length, truncation=True, return_tensors="pt", ) # pop token_type_ids encoded_input.pop("token_type_ids", None) model_output = self._model(**encoded_input) if self.pooling == "cls": embeddings = self._cls_pooling(model_output) else: embeddings = self._mean_pooling( model_output, encoded_input["attention_mask"].to(self._device) ) if self.normalize: import torch embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) return embeddings.tolist() def _get_query_embedding(self, query: str) -> List[float]: """Get query embedding.""" query = format_query(query, self.model_name, self.query_instruction) return self._embed([query])[0] async def _aget_query_embedding(self, query: str) -> List[float]: """Get query embedding async.""" return self._get_query_embedding(query) async def _aget_text_embedding(self, text: str) -> List[float]: """Get text embedding async.""" return self._get_text_embedding(text) def _get_text_embedding(self, text: str) -> List[float]: """Get text embedding.""" text = format_text(text, self.model_name, self.text_instruction) return self._embed([text])[0] def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: """Get text embeddings.""" texts = [ format_text(text, self.model_name, self.text_instruction) for text in texts ] return self._embed(texts)