102 lines
3.5 KiB
Python
102 lines
3.5 KiB
Python
from typing import Any, List, Optional
|
|
|
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
|
from llama_index.callbacks import CallbackManager
|
|
from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding
|
|
from llama_index.embeddings.huggingface_utils import (
|
|
DEFAULT_INSTRUCT_MODEL,
|
|
get_query_instruct_for_model_name,
|
|
get_text_instruct_for_model_name,
|
|
)
|
|
|
|
|
|
class InstructorEmbedding(BaseEmbedding):
|
|
query_instruction: Optional[str] = Field(
|
|
description="Instruction to prepend to query text."
|
|
)
|
|
text_instruction: Optional[str] = Field(
|
|
description="Instruction to prepend to text."
|
|
)
|
|
cache_folder: Optional[str] = Field(
|
|
description="Cache folder for huggingface files."
|
|
)
|
|
|
|
_model: Any = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = DEFAULT_INSTRUCT_MODEL,
|
|
query_instruction: Optional[str] = None,
|
|
text_instruction: Optional[str] = None,
|
|
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
|
|
cache_folder: Optional[str] = None,
|
|
device: Optional[str] = None,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
):
|
|
try:
|
|
from InstructorEmbedding import INSTRUCTOR
|
|
except ImportError:
|
|
raise ImportError(
|
|
"InstructorEmbedding requires instructor to be installed.\n"
|
|
"Please install transformers with `pip install InstructorEmbedding`."
|
|
)
|
|
self._model = INSTRUCTOR(model_name, cache_folder=cache_folder, device=device)
|
|
|
|
super().__init__(
|
|
embed_batch_size=embed_batch_size,
|
|
callback_manager=callback_manager,
|
|
model_name=model_name,
|
|
query_instruction=query_instruction,
|
|
text_instruction=text_instruction,
|
|
cache_folder=cache_folder,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "InstructorEmbedding"
|
|
|
|
def _format_query_text(self, query_text: str) -> List[str]:
|
|
"""Format query text."""
|
|
instruction = self.text_instruction
|
|
|
|
if instruction is None:
|
|
instruction = get_query_instruct_for_model_name(self.model_name)
|
|
|
|
return [instruction, query_text]
|
|
|
|
def _format_text(self, text: str) -> List[str]:
|
|
"""Format text."""
|
|
instruction = self.text_instruction
|
|
|
|
if instruction is None:
|
|
instruction = get_text_instruct_for_model_name(self.model_name)
|
|
|
|
return [instruction, text]
|
|
|
|
def _embed(self, instruct_sentence_pairs: List[List[str]]) -> List[List[float]]:
|
|
"""Embed sentences."""
|
|
return self._model.encode(instruct_sentence_pairs).tolist()
|
|
|
|
def _get_query_embedding(self, query: str) -> List[float]:
|
|
"""Get query embedding."""
|
|
query_pair = self._format_query_text(query)
|
|
return self._embed([query_pair])[0]
|
|
|
|
async def _aget_query_embedding(self, query: str) -> List[float]:
|
|
"""Get query embedding async."""
|
|
return self._get_query_embedding(query)
|
|
|
|
async def _aget_text_embedding(self, text: str) -> List[float]:
|
|
"""Get text embedding async."""
|
|
return self._get_text_embedding(text)
|
|
|
|
def _get_text_embedding(self, text: str) -> List[float]:
|
|
"""Get text embedding."""
|
|
text_pair = self._format_text(text)
|
|
return self._embed([text_pair])[0]
|
|
|
|
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
"""Get text embeddings."""
|
|
text_pairs = [self._format_text(text) for text in texts]
|
|
return self._embed(text_pairs)
|