faiss_rag_enterprise/llama_index/embeddings/text_embeddings_inference.py

149 lines
5.2 KiB
Python

from typing import Callable, List, Optional, Union
from llama_index.bridge.pydantic import Field
from llama_index.callbacks import CallbackManager
from llama_index.core.embeddings.base import (
DEFAULT_EMBED_BATCH_SIZE,
BaseEmbedding,
Embedding,
)
from llama_index.embeddings.huggingface_utils import format_query, format_text
DEFAULT_URL = "http://127.0.0.1:8080"
class TextEmbeddingsInference(BaseEmbedding):
base_url: str = Field(
default=DEFAULT_URL,
description="Base URL for the text embeddings service.",
)
query_instruction: Optional[str] = Field(
description="Instruction to prepend to query text."
)
text_instruction: Optional[str] = Field(
description="Instruction to prepend to text."
)
timeout: float = Field(
default=60.0,
description="Timeout in seconds for the request.",
)
truncate_text: bool = Field(
default=True,
description="Whether to truncate text or not when generating embeddings.",
)
auth_token: Optional[Union[str, Callable[[str], str]]] = Field(
default=None,
description="Authentication token or authentication token generating function for authenticated requests",
)
def __init__(
self,
model_name: str,
base_url: str = DEFAULT_URL,
text_instruction: Optional[str] = None,
query_instruction: Optional[str] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
timeout: float = 60.0,
truncate_text: bool = True,
callback_manager: Optional[CallbackManager] = None,
auth_token: Optional[Union[str, Callable[[str], str]]] = None,
):
try:
import httpx # noqa
except ImportError:
raise ImportError(
"TextEmbeddingsInterface requires httpx to be installed.\n"
"Please install httpx with `pip install httpx`."
)
super().__init__(
base_url=base_url,
model_name=model_name,
text_instruction=text_instruction,
query_instruction=query_instruction,
embed_batch_size=embed_batch_size,
timeout=timeout,
truncate_text=truncate_text,
callback_manager=callback_manager,
auth_token=auth_token,
)
@classmethod
def class_name(cls) -> str:
return "TextEmbeddingsInference"
def _call_api(self, texts: List[str]) -> List[List[float]]:
import httpx
headers = {"Content-Type": "application/json"}
if self.auth_token is not None:
if callable(self.auth_token):
headers["Authorization"] = self.auth_token(self.base_url)
else:
headers["Authorization"] = self.auth_token
json_data = {"inputs": texts, "truncate": self.truncate_text}
with httpx.Client() as client:
response = client.post(
f"{self.base_url}/embed",
headers=headers,
json=json_data,
timeout=self.timeout,
)
return response.json()
async def _acall_api(self, texts: List[str]) -> List[List[float]]:
import httpx
headers = {"Content-Type": "application/json"}
if self.auth_token is not None:
if callable(self.auth_token):
headers["Authorization"] = self.auth_token(self.base_url)
else:
headers["Authorization"] = self.auth_token
json_data = {"inputs": texts, "truncate": self.truncate_text}
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/embed",
headers=headers,
json=json_data,
timeout=self.timeout,
)
return response.json()
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
query = format_query(query, self.model_name, self.query_instruction)
return self._call_api([query])[0]
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
text = format_text(text, self.model_name, self.text_instruction)
return self._call_api([text])[0]
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
texts = [
format_text(text, self.model_name, self.text_instruction) for text in texts
]
return self._call_api(texts)
async def _aget_query_embedding(self, query: str) -> List[float]:
"""Get query embedding async."""
query = format_query(query, self.model_name, self.query_instruction)
return (await self._acall_api([query]))[0]
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Get text embedding async."""
text = format_text(text, self.model_name, self.text_instruction)
return (await self._acall_api([text]))[0]
async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:
texts = [
format_text(text, self.model_name, self.text_instruction) for text in texts
]
return await self._acall_api(texts)