149 lines
5.2 KiB
Python
149 lines
5.2 KiB
Python
from typing import Callable, List, Optional, Union
|
|
|
|
from llama_index.bridge.pydantic import Field
|
|
from llama_index.callbacks import CallbackManager
|
|
from llama_index.core.embeddings.base import (
|
|
DEFAULT_EMBED_BATCH_SIZE,
|
|
BaseEmbedding,
|
|
Embedding,
|
|
)
|
|
from llama_index.embeddings.huggingface_utils import format_query, format_text
|
|
|
|
DEFAULT_URL = "http://127.0.0.1:8080"
|
|
|
|
|
|
class TextEmbeddingsInference(BaseEmbedding):
|
|
base_url: str = Field(
|
|
default=DEFAULT_URL,
|
|
description="Base URL for the text embeddings service.",
|
|
)
|
|
query_instruction: Optional[str] = Field(
|
|
description="Instruction to prepend to query text."
|
|
)
|
|
text_instruction: Optional[str] = Field(
|
|
description="Instruction to prepend to text."
|
|
)
|
|
timeout: float = Field(
|
|
default=60.0,
|
|
description="Timeout in seconds for the request.",
|
|
)
|
|
truncate_text: bool = Field(
|
|
default=True,
|
|
description="Whether to truncate text or not when generating embeddings.",
|
|
)
|
|
auth_token: Optional[Union[str, Callable[[str], str]]] = Field(
|
|
default=None,
|
|
description="Authentication token or authentication token generating function for authenticated requests",
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str,
|
|
base_url: str = DEFAULT_URL,
|
|
text_instruction: Optional[str] = None,
|
|
query_instruction: Optional[str] = None,
|
|
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
|
|
timeout: float = 60.0,
|
|
truncate_text: bool = True,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
auth_token: Optional[Union[str, Callable[[str], str]]] = None,
|
|
):
|
|
try:
|
|
import httpx # noqa
|
|
except ImportError:
|
|
raise ImportError(
|
|
"TextEmbeddingsInterface requires httpx to be installed.\n"
|
|
"Please install httpx with `pip install httpx`."
|
|
)
|
|
|
|
super().__init__(
|
|
base_url=base_url,
|
|
model_name=model_name,
|
|
text_instruction=text_instruction,
|
|
query_instruction=query_instruction,
|
|
embed_batch_size=embed_batch_size,
|
|
timeout=timeout,
|
|
truncate_text=truncate_text,
|
|
callback_manager=callback_manager,
|
|
auth_token=auth_token,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "TextEmbeddingsInference"
|
|
|
|
def _call_api(self, texts: List[str]) -> List[List[float]]:
|
|
import httpx
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
if self.auth_token is not None:
|
|
if callable(self.auth_token):
|
|
headers["Authorization"] = self.auth_token(self.base_url)
|
|
else:
|
|
headers["Authorization"] = self.auth_token
|
|
json_data = {"inputs": texts, "truncate": self.truncate_text}
|
|
|
|
with httpx.Client() as client:
|
|
response = client.post(
|
|
f"{self.base_url}/embed",
|
|
headers=headers,
|
|
json=json_data,
|
|
timeout=self.timeout,
|
|
)
|
|
|
|
return response.json()
|
|
|
|
async def _acall_api(self, texts: List[str]) -> List[List[float]]:
|
|
import httpx
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
if self.auth_token is not None:
|
|
if callable(self.auth_token):
|
|
headers["Authorization"] = self.auth_token(self.base_url)
|
|
else:
|
|
headers["Authorization"] = self.auth_token
|
|
json_data = {"inputs": texts, "truncate": self.truncate_text}
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
f"{self.base_url}/embed",
|
|
headers=headers,
|
|
json=json_data,
|
|
timeout=self.timeout,
|
|
)
|
|
|
|
return response.json()
|
|
|
|
def _get_query_embedding(self, query: str) -> List[float]:
|
|
"""Get query embedding."""
|
|
query = format_query(query, self.model_name, self.query_instruction)
|
|
return self._call_api([query])[0]
|
|
|
|
def _get_text_embedding(self, text: str) -> List[float]:
|
|
"""Get text embedding."""
|
|
text = format_text(text, self.model_name, self.text_instruction)
|
|
return self._call_api([text])[0]
|
|
|
|
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
"""Get text embeddings."""
|
|
texts = [
|
|
format_text(text, self.model_name, self.text_instruction) for text in texts
|
|
]
|
|
return self._call_api(texts)
|
|
|
|
async def _aget_query_embedding(self, query: str) -> List[float]:
|
|
"""Get query embedding async."""
|
|
query = format_query(query, self.model_name, self.query_instruction)
|
|
return (await self._acall_api([query]))[0]
|
|
|
|
async def _aget_text_embedding(self, text: str) -> List[float]:
|
|
"""Get text embedding async."""
|
|
text = format_text(text, self.model_name, self.text_instruction)
|
|
return (await self._acall_api([text]))[0]
|
|
|
|
async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:
|
|
texts = [
|
|
format_text(text, self.model_name, self.text_instruction) for text in texts
|
|
]
|
|
return await self._acall_api(texts)
|