faiss_rag_enterprise/llama_index/embeddings/gradient.py

138 lines
5.0 KiB
Python

import logging
from typing import Any, List, Optional
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.core.embeddings.base import (
DEFAULT_EMBED_BATCH_SIZE,
BaseEmbedding,
Embedding,
)
logger = logging.getLogger(__name__)
# For bge models that Gradient AI provides, it is suggested to add the instruction for retrieval.
# Reference: https://huggingface.co/BAAI/bge-large-en-v1.5#model-list
QUERY_INSTRUCTION_FOR_RETRIEVAL = (
"Represent this sentence for searching relevant passages:"
)
GRADIENT_EMBED_BATCH_SIZE: int = 32_768
class GradientEmbedding(BaseEmbedding):
"""GradientAI embedding models.
This class provides an interface to generate embeddings using a model
deployed in Gradient AI. At the initialization it requires a model_id
of the model deployed in the cluster.
Note:
Requires `gradientai` package to be available in the PYTHONPATH. It can be installed with
`pip install gradientai`.
"""
embed_batch_size: int = Field(default=GRADIENT_EMBED_BATCH_SIZE, gt=0)
_gradient: Any = PrivateAttr()
_model: Any = PrivateAttr()
@classmethod
def class_name(cls) -> str:
return "GradientEmbedding"
def __init__(
self,
*,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
gradient_model_slug: str,
gradient_access_token: Optional[str] = None,
gradient_workspace_id: Optional[str] = None,
gradient_host: Optional[str] = None,
**kwargs: Any,
):
"""Initializes the GradientEmbedding class.
During the initialization the `gradientai` package is imported. Using the access token,
workspace id and the slug of the model, the model is fetched from Gradient AI and prepared to use.
Args:
embed_batch_size (int, optional): The batch size for embedding generation. Defaults to 10,
must be > 0 and <= 100.
gradient_model_slug (str): The model slug of the model in the Gradient AI account.
gradient_access_token (str, optional): The access token of the Gradient AI account, if
`None` read from the environment variable `GRADIENT_ACCESS_TOKEN`.
gradient_workspace_id (str, optional): The workspace ID of the Gradient AI account, if `None`
read from the environment variable `GRADIENT_WORKSPACE_ID`.
gradient_host (str, optional): The host of the Gradient AI API. Defaults to None, which
means the default host is used.
Raises:
ImportError: If the `gradientai` package is not available in the PYTHONPATH.
ValueError: If the model cannot be fetched from Gradient AI.
"""
if embed_batch_size <= 0:
raise ValueError(f"Embed batch size {embed_batch_size} must be > 0.")
try:
import gradientai
except ImportError:
raise ImportError("GradientEmbedding requires `pip install gradientai`.")
self._gradient = gradientai.Gradient(
access_token=gradient_access_token,
workspace_id=gradient_workspace_id,
host=gradient_host,
)
try:
self._model = self._gradient.get_embeddings_model(slug=gradient_model_slug)
except gradientai.openapi.client.exceptions.UnauthorizedException as e:
logger.error(f"Error while loading model {gradient_model_slug}.")
self._gradient.close()
raise ValueError("Unable to fetch the requested embeddings model") from e
super().__init__(
embed_batch_size=embed_batch_size, model_name=gradient_model_slug, **kwargs
)
async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:
"""
Embed the input sequence of text asynchronously.
"""
inputs = [{"input": text} for text in texts]
result = await self._model.aembed(inputs=inputs).embeddings
return [e.embedding for e in result]
def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:
"""
Embed the input sequence of text.
"""
inputs = [{"input": text} for text in texts]
result = self._model.embed(inputs=inputs).embeddings
return [e.embedding for e in result]
def _get_text_embedding(self, text: str) -> Embedding:
"""Alias for _get_text_embeddings() with single text input."""
return self._get_text_embeddings([text])[0]
async def _aget_text_embedding(self, text: str) -> Embedding:
"""Alias for _aget_text_embeddings() with single text input."""
embedding = await self._aget_text_embeddings([text])
return embedding[0]
async def _aget_query_embedding(self, query: str) -> Embedding:
embedding = await self._aget_text_embeddings(
[f"{QUERY_INSTRUCTION_FOR_RETRIEVAL} {query}"]
)
return embedding[0]
def _get_query_embedding(self, query: str) -> Embedding:
return self._get_text_embeddings(
[f"{QUERY_INSTRUCTION_FOR_RETRIEVAL} {query}"]
)[0]