85 lines
3.1 KiB
Python
85 lines
3.1 KiB
Python
"""Langchain Embedding Wrapper Module."""
|
|
|
|
from typing import TYPE_CHECKING, List, Optional
|
|
|
|
from llama_index.bridge.pydantic import PrivateAttr
|
|
from llama_index.callbacks import CallbackManager
|
|
from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding
|
|
|
|
if TYPE_CHECKING:
|
|
from llama_index.bridge.langchain import Embeddings as LCEmbeddings
|
|
|
|
|
|
class LangchainEmbedding(BaseEmbedding):
|
|
"""External embeddings (taken from Langchain).
|
|
|
|
Args:
|
|
langchain_embedding (langchain.embeddings.Embeddings): Langchain
|
|
embeddings class.
|
|
"""
|
|
|
|
_langchain_embedding: "LCEmbeddings" = PrivateAttr()
|
|
_async_not_implemented_warned: bool = PrivateAttr(default=False)
|
|
|
|
def __init__(
|
|
self,
|
|
langchain_embeddings: "LCEmbeddings",
|
|
model_name: Optional[str] = None,
|
|
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
):
|
|
# attempt to get a useful model name
|
|
if model_name is not None:
|
|
model_name = model_name
|
|
elif hasattr(langchain_embeddings, "model_name"):
|
|
model_name = langchain_embeddings.model_name
|
|
elif hasattr(langchain_embeddings, "model"):
|
|
model_name = langchain_embeddings.model
|
|
else:
|
|
model_name = type(langchain_embeddings).__name__
|
|
|
|
self._langchain_embedding = langchain_embeddings
|
|
super().__init__(
|
|
embed_batch_size=embed_batch_size,
|
|
callback_manager=callback_manager,
|
|
model_name=model_name,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "LangchainEmbedding"
|
|
|
|
def _async_not_implemented_warn_once(self) -> None:
|
|
if not self._async_not_implemented_warned:
|
|
print("Async embedding not available, falling back to sync method.")
|
|
self._async_not_implemented_warned = True
|
|
|
|
def _get_query_embedding(self, query: str) -> List[float]:
|
|
"""Get query embedding."""
|
|
return self._langchain_embedding.embed_query(query)
|
|
|
|
async def _aget_query_embedding(self, query: str) -> List[float]:
|
|
try:
|
|
return await self._langchain_embedding.aembed_query(query)
|
|
except NotImplementedError:
|
|
# Warn the user that sync is being used
|
|
self._async_not_implemented_warn_once()
|
|
return self._get_query_embedding(query)
|
|
|
|
async def _aget_text_embedding(self, text: str) -> List[float]:
|
|
try:
|
|
embeds = await self._langchain_embedding.aembed_documents([text])
|
|
return embeds[0]
|
|
except NotImplementedError:
|
|
# Warn the user that sync is being used
|
|
self._async_not_implemented_warn_once()
|
|
return self._get_text_embedding(text)
|
|
|
|
def _get_text_embedding(self, text: str) -> List[float]:
|
|
"""Get text embedding."""
|
|
return self._langchain_embedding.embed_documents([text])[0]
|
|
|
|
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
"""Get text embeddings."""
|
|
return self._langchain_embedding.embed_documents(texts)
|