faiss_rag_enterprise/llama_index/embeddings/llm_rails.py

119 lines
3.7 KiB
Python

import logging
from typing import Any, List
import requests
from requests.adapters import HTTPAdapter, Retry
from llama_index.embeddings.base import BaseEmbedding
logger = logging.getLogger(__name__)
class LLMRailsEmbedding(BaseEmbedding):
"""LLMRails embedding models.
This class provides an interface to generate embeddings using a model deployed
in an LLMRails cluster. It requires a model_id of the model deployed in the cluster and api key you can obtain
from https://console.llmrails.com/api-keys.
"""
model_id: str
api_key: str
session: requests.Session
@classmethod
def class_name(self) -> str:
return "LLMRailsEmbedding"
def __init__(
self,
api_key: str,
model_id: str = "embedding-english-v1", # or embedding-multi-v1
**kwargs: Any,
):
retry = Retry(
total=3,
connect=3,
read=2,
allowed_methods=["POST"],
backoff_factor=2,
status_forcelist=[502, 503, 504],
)
session = requests.Session()
session.mount("https://api.llmrails.com", HTTPAdapter(max_retries=retry))
session.headers = {"X-API-KEY": api_key}
super().__init__(model_id=model_id, api_key=api_key, session=session, **kwargs)
def _get_embedding(self, text: str) -> List[float]:
"""
Generate an embedding for a single query text.
Args:
text (str): The query text to generate an embedding for.
Returns:
List[float]: The embedding for the input query text.
"""
try:
response = self.session.post(
"https://api.llmrails.com/v1/embeddings",
json={"input": [text], "model": self.model_id},
)
response.raise_for_status()
return response.json()["data"][0]["embedding"]
except requests.exceptions.HTTPError as e:
logger.error(f"Error while embedding text {e}.")
raise ValueError(f"Unable to embed given text {e}")
async def _aget_embedding(self, text: str) -> List[float]:
"""
Generate an embedding for a single query text.
Args:
text (str): The query text to generate an embedding for.
Returns:
List[float]: The embedding for the input query text.
"""
try:
import httpx
except ImportError:
raise ImportError(
"The httpx library is required to use the async version of "
"this function. Install it with `pip install httpx`."
)
try:
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.llmrails.com/v1/embeddings",
headers={"X-API-KEY": self.api_key},
json={"input": [text], "model": self.model_id},
)
response.raise_for_status()
return response.json()["data"][0]["embedding"]
except httpx._exceptions.HTTPError as e:
logger.error(f"Error while embedding text {e}.")
raise ValueError(f"Unable to embed given text {e}")
def _get_text_embedding(self, text: str) -> List[float]:
return self._get_embedding(text)
def _get_query_embedding(self, query: str) -> List[float]:
return self._get_embedding(query)
async def _aget_query_embedding(self, query: str) -> List[float]:
return await self._aget_embedding(query)
async def _aget_text_embedding(self, query: str) -> List[float]:
return await self._aget_embedding(query)
LLMRailsEmbeddings = LLMRailsEmbedding