116 lines
3.9 KiB
Python
116 lines
3.9 KiB
Python
"""Jina embeddings file."""
|
|
|
|
from typing import Any, List, Optional
|
|
|
|
import requests
|
|
|
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding
|
|
from llama_index.llms.generic_utils import get_from_param_or_env
|
|
|
|
MAX_BATCH_SIZE = 2048
|
|
|
|
API_URL = "https://api.jina.ai/v1/embeddings"
|
|
|
|
|
|
class JinaEmbedding(BaseEmbedding):
|
|
"""JinaAI class for embeddings.
|
|
|
|
Args:
|
|
model (str): Model for embedding.
|
|
Defaults to `jina-embeddings-v2-base-en`
|
|
"""
|
|
|
|
api_key: str = Field(default=None, description="The JinaAI API key.")
|
|
model: str = Field(
|
|
default="jina-embeddings-v2-base-en",
|
|
description="The model to use when calling Jina AI API",
|
|
)
|
|
|
|
_session: Any = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = "jina-embeddings-v2-base-en",
|
|
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
|
|
api_key: Optional[str] = None,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
super().__init__(
|
|
embed_batch_size=embed_batch_size,
|
|
callback_manager=callback_manager,
|
|
model=model,
|
|
api_key=api_key,
|
|
**kwargs,
|
|
)
|
|
self.api_key = get_from_param_or_env("api_key", api_key, "JINAAI_API_KEY", "")
|
|
self.model = model
|
|
self._session = requests.Session()
|
|
self._session.headers.update(
|
|
{"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"}
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "JinaAIEmbedding"
|
|
|
|
def _get_query_embedding(self, query: str) -> List[float]:
|
|
"""Get query embedding."""
|
|
return self._get_text_embedding(query)
|
|
|
|
async def _aget_query_embedding(self, query: str) -> List[float]:
|
|
"""The asynchronous version of _get_query_embedding."""
|
|
return await self._aget_text_embedding(query)
|
|
|
|
def _get_text_embedding(self, text: str) -> List[float]:
|
|
"""Get text embedding."""
|
|
return self._get_text_embeddings([text])[0]
|
|
|
|
async def _aget_text_embedding(self, text: str) -> List[float]:
|
|
"""Asynchronously get text embedding."""
|
|
result = await self._aget_text_embeddings([text])
|
|
return result[0]
|
|
|
|
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
"""Get text embeddings."""
|
|
# Call Jina AI Embedding API
|
|
resp = self._session.post( # type: ignore
|
|
API_URL, json={"input": texts, "model": self.model}
|
|
).json()
|
|
if "data" not in resp:
|
|
raise RuntimeError(resp["detail"])
|
|
|
|
embeddings = resp["data"]
|
|
|
|
# Sort resulting embeddings by index
|
|
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
|
|
|
|
# Return just the embeddings
|
|
return [result["embedding"] for result in sorted_embeddings]
|
|
|
|
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
"""Asynchronously get text embeddings."""
|
|
import aiohttp
|
|
|
|
async with aiohttp.ClientSession(trust_env=True) as session:
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Accept-Encoding": "identity",
|
|
}
|
|
async with session.post(
|
|
f"{API_URL}",
|
|
json={"input": texts, "model": self.model},
|
|
headers=headers,
|
|
) as response:
|
|
resp = await response.json()
|
|
response.raise_for_status()
|
|
embeddings = resp["data"]
|
|
|
|
# Sort resulting embeddings by index
|
|
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
|
|
|
|
# Return just the embeddings
|
|
return [result["embedding"] for result in sorted_embeddings]
|