faiss_rag_enterprise/llama_index/embeddings/jinaai.py

116 lines
3.9 KiB
Python

"""Jina embeddings file."""
from typing import Any, List, Optional
import requests
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks.base import CallbackManager
from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding
from llama_index.llms.generic_utils import get_from_param_or_env
MAX_BATCH_SIZE = 2048
API_URL = "https://api.jina.ai/v1/embeddings"
class JinaEmbedding(BaseEmbedding):
"""JinaAI class for embeddings.
Args:
model (str): Model for embedding.
Defaults to `jina-embeddings-v2-base-en`
"""
api_key: str = Field(default=None, description="The JinaAI API key.")
model: str = Field(
default="jina-embeddings-v2-base-en",
description="The model to use when calling Jina AI API",
)
_session: Any = PrivateAttr()
def __init__(
self,
model: str = "jina-embeddings-v2-base-en",
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
api_key: Optional[str] = None,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> None:
super().__init__(
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
model=model,
api_key=api_key,
**kwargs,
)
self.api_key = get_from_param_or_env("api_key", api_key, "JINAAI_API_KEY", "")
self.model = model
self._session = requests.Session()
self._session.headers.update(
{"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"}
)
@classmethod
def class_name(cls) -> str:
return "JinaAIEmbedding"
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self._get_text_embedding(query)
async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
return await self._aget_text_embedding(query)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self._get_text_embeddings([text])[0]
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
result = await self._aget_text_embeddings([text])
return result[0]
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
# Call Jina AI Embedding API
resp = self._session.post( # type: ignore
API_URL, json={"input": texts, "model": self.model}
).json()
if "data" not in resp:
raise RuntimeError(resp["detail"])
embeddings = resp["data"]
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
import aiohttp
async with aiohttp.ClientSession(trust_env=True) as session:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept-Encoding": "identity",
}
async with session.post(
f"{API_URL}",
json={"input": texts, "model": self.model},
headers=headers,
) as response:
resp = await response.json()
response.raise_for_status()
embeddings = resp["data"]
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
# Return just the embeddings
return [result["embedding"] for result in sorted_embeddings]