faiss_rag_enterprise/llama_index/vector_stores/elasticsearch.py

601 lines
20 KiB
Python

"""Elasticsearch vector store."""
import asyncio
import uuid
from logging import getLogger
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
import nest_asyncio
import numpy as np
from llama_index.bridge.pydantic import PrivateAttr
from llama_index.schema import BaseNode, MetadataMode, TextNode
from llama_index.vector_stores.types import (
BasePydanticVectorStore,
MetadataFilters,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
)
from llama_index.vector_stores.utils import metadata_dict_to_node, node_to_metadata_dict
logger = getLogger(__name__)
DISTANCE_STRATEGIES = Literal[
"COSINE",
"DOT_PRODUCT",
"EUCLIDEAN_DISTANCE",
]
def _get_elasticsearch_client(
*,
es_url: Optional[str] = None,
cloud_id: Optional[str] = None,
api_key: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
) -> Any:
"""Get AsyncElasticsearch client.
Args:
es_url: Elasticsearch URL.
cloud_id: Elasticsearch cloud ID.
api_key: Elasticsearch API key.
username: Elasticsearch username.
password: Elasticsearch password.
Returns:
AsyncElasticsearch client.
Raises:
ConnectionError: If Elasticsearch client cannot connect to Elasticsearch.
"""
try:
import elasticsearch
except ImportError:
raise ImportError(
"Could not import elasticsearch python package. "
"Please install it with `pip install elasticsearch`."
)
if es_url and cloud_id:
raise ValueError(
"Both es_url and cloud_id are defined. Please provide only one."
)
if es_url and cloud_id:
raise ValueError(
"Both es_url and cloud_id are defined. Please provide only one."
)
connection_params: Dict[str, Any] = {}
if es_url:
connection_params["hosts"] = [es_url]
elif cloud_id:
connection_params["cloud_id"] = cloud_id
else:
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
if api_key:
connection_params["api_key"] = api_key
elif username and password:
connection_params["basic_auth"] = (username, password)
sync_es_client = elasticsearch.Elasticsearch(
**connection_params, headers={"user-agent": ElasticsearchStore.get_user_agent()}
)
async_es_client = elasticsearch.AsyncElasticsearch(**connection_params)
try:
sync_es_client.info() # so don't have to 'await' to just get info
except Exception as e:
logger.error(f"Error connecting to Elasticsearch: {e}")
raise
return async_es_client
def _to_elasticsearch_filter(standard_filters: MetadataFilters) -> Dict[str, Any]:
"""Convert standard filters to Elasticsearch filter.
Args:
standard_filters: Standard Llama-index filters.
Returns:
Elasticsearch filter.
"""
if len(standard_filters.legacy_filters()) == 1:
filter = standard_filters.legacy_filters()[0]
return {
"term": {
f"metadata.{filter.key}.keyword": {
"value": filter.value,
}
}
}
else:
operands = []
for filter in standard_filters.legacy_filters():
operands.append(
{
"term": {
f"metadata.{filter.key}.keyword": {
"value": filter.value,
}
}
}
)
return {"bool": {"must": operands}}
def _to_llama_similarities(scores: List[float]) -> List[float]:
if scores is None or len(scores) == 0:
return []
scores_to_norm: np.ndarray = np.array(scores)
return np.exp(scores_to_norm - np.max(scores_to_norm)).tolist()
class ElasticsearchStore(BasePydanticVectorStore):
"""Elasticsearch vector store.
Args:
index_name: Name of the Elasticsearch index.
es_client: Optional. Pre-existing AsyncElasticsearch client.
es_url: Optional. Elasticsearch URL.
es_cloud_id: Optional. Elasticsearch cloud ID.
es_api_key: Optional. Elasticsearch API key.
es_user: Optional. Elasticsearch username.
es_password: Optional. Elasticsearch password.
text_field: Optional. Name of the Elasticsearch field that stores the text.
vector_field: Optional. Name of the Elasticsearch field that stores the
embedding.
batch_size: Optional. Batch size for bulk indexing. Defaults to 200.
distance_strategy: Optional. Distance strategy to use for similarity search.
Defaults to "COSINE".
Raises:
ConnectionError: If AsyncElasticsearch client cannot connect to Elasticsearch.
ValueError: If neither es_client nor es_url nor es_cloud_id is provided.
"""
stores_text: bool = True
index_name: str
es_client: Optional[Any]
es_url: Optional[str]
es_cloud_id: Optional[str]
es_api_key: Optional[str]
es_user: Optional[str]
es_password: Optional[str]
text_field: str = "content"
vector_field: str = "embedding"
batch_size: int = 200
distance_strategy: Optional[DISTANCE_STRATEGIES] = "COSINE"
_client = PrivateAttr()
def __init__(
self,
index_name: str,
es_client: Optional[Any] = None,
es_url: Optional[str] = None,
es_cloud_id: Optional[str] = None,
es_api_key: Optional[str] = None,
es_user: Optional[str] = None,
es_password: Optional[str] = None,
text_field: str = "content",
vector_field: str = "embedding",
batch_size: int = 200,
distance_strategy: Optional[DISTANCE_STRATEGIES] = "COSINE",
) -> None:
nest_asyncio.apply()
if es_client is not None:
self._client = es_client.options(
headers={"user-agent": self.get_user_agent()}
)
elif es_url is not None or es_cloud_id is not None:
self._client = _get_elasticsearch_client(
es_url=es_url,
username=es_user,
password=es_password,
cloud_id=es_cloud_id,
api_key=es_api_key,
)
else:
raise ValueError(
"""Either provide a pre-existing AsyncElasticsearch or valid \
credentials for creating a new connection."""
)
super().__init__(
index_name=index_name,
es_client=es_client,
es_url=es_url,
es_cloud_id=es_cloud_id,
es_api_key=es_api_key,
es_user=es_user,
es_password=es_password,
text_field=text_field,
vector_field=vector_field,
batch_size=batch_size,
distance_strategy=distance_strategy,
)
@property
def client(self) -> Any:
"""Get async elasticsearch client."""
return self._client
@staticmethod
def get_user_agent() -> str:
"""Get user agent for elasticsearch client."""
import llama_index
return f"llama_index-py-vs/{llama_index.__version__}"
async def _create_index_if_not_exists(
self, index_name: str, dims_length: Optional[int] = None
) -> None:
"""Create the AsyncElasticsearch index if it doesn't already exist.
Args:
index_name: Name of the AsyncElasticsearch index to create.
dims_length: Length of the embedding vectors.
"""
if self.client.indices.exists(index=index_name):
logger.debug(f"Index {index_name} already exists. Skipping creation.")
else:
if dims_length is None:
raise ValueError(
"Cannot create index without specifying dims_length "
"when the index doesn't already exist. We infer "
"dims_length from the first embedding. Check that "
"you have provided an embedding function."
)
if self.distance_strategy == "COSINE":
similarityAlgo = "cosine"
elif self.distance_strategy == "EUCLIDEAN_DISTANCE":
similarityAlgo = "l2_norm"
elif self.distance_strategy == "DOT_PRODUCT":
similarityAlgo = "dot_product"
else:
raise ValueError(f"Similarity {self.distance_strategy} not supported.")
index_settings = {
"mappings": {
"properties": {
self.vector_field: {
"type": "dense_vector",
"dims": dims_length,
"index": True,
"similarity": similarityAlgo,
},
self.text_field: {"type": "text"},
"metadata": {
"properties": {
"document_id": {"type": "keyword"},
"doc_id": {"type": "keyword"},
"ref_doc_id": {"type": "keyword"},
}
},
}
}
}
logger.debug(
f"Creating index {index_name} with mappings {index_settings['mappings']}"
)
await self.client.indices.create(index=index_name, **index_settings)
def add(
self,
nodes: List[BaseNode],
*,
create_index_if_not_exists: bool = True,
**add_kwargs: Any,
) -> List[str]:
"""Add nodes to Elasticsearch index.
Args:
nodes: List of nodes with embeddings.
create_index_if_not_exists: Optional. Whether to create
the Elasticsearch index if it
doesn't already exist.
Defaults to True.
Returns:
List of node IDs that were added to the index.
Raises:
ImportError: If elasticsearch['async'] python package is not installed.
BulkIndexError: If AsyncElasticsearch async_bulk indexing fails.
"""
return asyncio.get_event_loop().run_until_complete(
self.async_add(nodes, create_index_if_not_exists=create_index_if_not_exists)
)
async def async_add(
self,
nodes: List[BaseNode],
*,
create_index_if_not_exists: bool = True,
**add_kwargs: Any,
) -> List[str]:
"""Asynchronous method to add nodes to Elasticsearch index.
Args:
nodes: List of nodes with embeddings.
create_index_if_not_exists: Optional. Whether to create
the AsyncElasticsearch index if it
doesn't already exist.
Defaults to True.
Returns:
List of node IDs that were added to the index.
Raises:
ImportError: If elasticsearch python package is not installed.
BulkIndexError: If AsyncElasticsearch async_bulk indexing fails.
"""
try:
from elasticsearch.helpers import BulkIndexError, async_bulk
except ImportError:
raise ImportError(
"Could not import elasticsearch[async] python package. "
"Please install it with `pip install 'elasticsearch[async]'`."
)
if len(nodes) == 0:
return []
if create_index_if_not_exists:
dims_length = len(nodes[0].get_embedding())
await self._create_index_if_not_exists(
index_name=self.index_name, dims_length=dims_length
)
embeddings: List[List[float]] = []
texts: List[str] = []
metadatas: List[dict] = []
ids: List[str] = []
for node in nodes:
ids.append(node.node_id)
embeddings.append(node.get_embedding())
texts.append(node.get_content(metadata_mode=MetadataMode.NONE))
metadatas.append(node_to_metadata_dict(node, remove_text=True))
requests = []
return_ids = []
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
_id = ids[i] if ids else str(uuid.uuid4())
request = {
"_op_type": "index",
"_index": self.index_name,
self.vector_field: embeddings[i],
self.text_field: text,
"metadata": metadata,
"_id": _id,
}
requests.append(request)
return_ids.append(_id)
await async_bulk(
self.client, requests, chunk_size=self.batch_size, refresh=True
)
try:
success, failed = await async_bulk(
self.client, requests, stats_only=True, refresh=True
)
logger.debug(f"Added {success} and failed to add {failed} texts to index")
logger.debug(f"added texts {ids} to index")
return return_ids
except BulkIndexError as e:
logger.error(f"Error adding texts: {e}")
firstError = e.errors[0].get("index", {}).get("error", {})
logger.error(f"First error reason: {firstError.get('reason')}")
raise
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""Delete node from Elasticsearch index.
Args:
ref_doc_id: ID of the node to delete.
delete_kwargs: Optional. Additional arguments to
pass to Elasticsearch delete_by_query.
Raises:
Exception: If Elasticsearch delete_by_query fails.
"""
return asyncio.get_event_loop().run_until_complete(
self.adelete(ref_doc_id, **delete_kwargs)
)
async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""Async delete node from Elasticsearch index.
Args:
ref_doc_id: ID of the node to delete.
delete_kwargs: Optional. Additional arguments to
pass to AsyncElasticsearch delete_by_query.
Raises:
Exception: If AsyncElasticsearch delete_by_query fails.
"""
try:
async with self.client as client:
res = await client.delete_by_query(
index=self.index_name,
query={"term": {"metadata.ref_doc_id": ref_doc_id}},
refresh=True,
**delete_kwargs,
)
if res["deleted"] == 0:
logger.warning(f"Could not find text {ref_doc_id} to delete")
else:
logger.debug(f"Deleted text {ref_doc_id} from index")
except Exception:
logger.error(f"Error deleting text: {ref_doc_id}")
raise
def query(
self,
query: VectorStoreQuery,
custom_query: Optional[
Callable[[Dict, Union[VectorStoreQuery, None]], Dict]
] = None,
es_filter: Optional[List[Dict]] = None,
**kwargs: Any,
) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
Args:
query_embedding (List[float]): query embedding
custom_query: Optional. custom query function that takes in the es query
body and returns a modified query body.
This can be used to add additional query
parameters to the Elasticsearch query.
es_filter: Optional. Elasticsearch filter to apply to the
query. If filter is provided in the query,
this filter will be ignored.
Returns:
VectorStoreQueryResult: Result of the query.
Raises:
Exception: If Elasticsearch query fails.
"""
return asyncio.get_event_loop().run_until_complete(
self.aquery(query, custom_query, es_filter, **kwargs)
)
async def aquery(
self,
query: VectorStoreQuery,
custom_query: Optional[
Callable[[Dict, Union[VectorStoreQuery, None]], Dict]
] = None,
es_filter: Optional[List[Dict]] = None,
**kwargs: Any,
) -> VectorStoreQueryResult:
"""Asynchronous query index for top k most similar nodes.
Args:
query_embedding (VectorStoreQuery): query embedding
custom_query: Optional. custom query function that takes in the es query
body and returns a modified query body.
This can be used to add additional query
parameters to the AsyncElasticsearch query.
es_filter: Optional. AsyncElasticsearch filter to apply to the
query. If filter is provided in the query,
this filter will be ignored.
Returns:
VectorStoreQueryResult: Result of the query.
Raises:
Exception: If AsyncElasticsearch query fails.
"""
query_embedding = cast(List[float], query.query_embedding)
es_query = {}
if query.filters is not None and len(query.filters.legacy_filters()) > 0:
filter = [_to_elasticsearch_filter(query.filters)]
else:
filter = es_filter or []
if query.mode in (
VectorStoreQueryMode.DEFAULT,
VectorStoreQueryMode.HYBRID,
):
es_query["knn"] = {
"filter": filter,
"field": self.vector_field,
"query_vector": query_embedding,
"k": query.similarity_top_k,
"num_candidates": query.similarity_top_k * 10,
}
if query.mode in (
VectorStoreQueryMode.TEXT_SEARCH,
VectorStoreQueryMode.HYBRID,
):
es_query["query"] = {
"bool": {
"must": {"match": {self.text_field: {"query": query.query_str}}},
"filter": filter,
}
}
if query.mode == VectorStoreQueryMode.HYBRID:
es_query["rank"] = {"rrf": {}}
if custom_query is not None:
es_query = custom_query(es_query, query)
logger.debug(f"Calling custom_query, Query body now: {es_query}")
async with self.client as client:
response = await client.search(
index=self.index_name,
**es_query,
size=query.similarity_top_k,
_source={"excludes": [self.vector_field]},
)
top_k_nodes = []
top_k_ids = []
top_k_scores = []
hits = response["hits"]["hits"]
for hit in hits:
source = hit["_source"]
metadata = source.get("metadata", None)
text = source.get(self.text_field, None)
node_id = hit["_id"]
try:
node = metadata_dict_to_node(metadata)
node.text = text
except Exception:
# Legacy support for old metadata format
logger.warning(
f"Could not parse metadata from hit {hit['_source']['metadata']}"
)
node_info = source.get("node_info")
relationships = source.get("relationships")
start_char_idx = None
end_char_idx = None
if isinstance(node_info, dict):
start_char_idx = node_info.get("start", None)
end_char_idx = node_info.get("end", None)
node = TextNode(
text=text,
metadata=metadata,
id_=node_id,
start_char_idx=start_char_idx,
end_char_idx=end_char_idx,
relationships=relationships,
)
top_k_nodes.append(node)
top_k_ids.append(node_id)
top_k_scores.append(hit.get("_rank", hit["_score"]))
if query.mode == VectorStoreQueryMode.HYBRID:
total_rank = sum(top_k_scores)
top_k_scores = [total_rank - rank / total_rank for rank in top_k_scores]
return VectorStoreQueryResult(
nodes=top_k_nodes,
ids=top_k_ids,
similarities=_to_llama_similarities(top_k_scores),
)