601 lines
20 KiB
Python
601 lines
20 KiB
Python
"""Elasticsearch vector store."""
|
|
|
|
import asyncio
|
|
import uuid
|
|
from logging import getLogger
|
|
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
|
|
|
|
import nest_asyncio
|
|
import numpy as np
|
|
|
|
from llama_index.bridge.pydantic import PrivateAttr
|
|
from llama_index.schema import BaseNode, MetadataMode, TextNode
|
|
from llama_index.vector_stores.types import (
|
|
BasePydanticVectorStore,
|
|
MetadataFilters,
|
|
VectorStoreQuery,
|
|
VectorStoreQueryMode,
|
|
VectorStoreQueryResult,
|
|
)
|
|
from llama_index.vector_stores.utils import metadata_dict_to_node, node_to_metadata_dict
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
DISTANCE_STRATEGIES = Literal[
|
|
"COSINE",
|
|
"DOT_PRODUCT",
|
|
"EUCLIDEAN_DISTANCE",
|
|
]
|
|
|
|
|
|
def _get_elasticsearch_client(
|
|
*,
|
|
es_url: Optional[str] = None,
|
|
cloud_id: Optional[str] = None,
|
|
api_key: Optional[str] = None,
|
|
username: Optional[str] = None,
|
|
password: Optional[str] = None,
|
|
) -> Any:
|
|
"""Get AsyncElasticsearch client.
|
|
|
|
Args:
|
|
es_url: Elasticsearch URL.
|
|
cloud_id: Elasticsearch cloud ID.
|
|
api_key: Elasticsearch API key.
|
|
username: Elasticsearch username.
|
|
password: Elasticsearch password.
|
|
|
|
Returns:
|
|
AsyncElasticsearch client.
|
|
|
|
Raises:
|
|
ConnectionError: If Elasticsearch client cannot connect to Elasticsearch.
|
|
"""
|
|
try:
|
|
import elasticsearch
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import elasticsearch python package. "
|
|
"Please install it with `pip install elasticsearch`."
|
|
)
|
|
|
|
if es_url and cloud_id:
|
|
raise ValueError(
|
|
"Both es_url and cloud_id are defined. Please provide only one."
|
|
)
|
|
|
|
if es_url and cloud_id:
|
|
raise ValueError(
|
|
"Both es_url and cloud_id are defined. Please provide only one."
|
|
)
|
|
|
|
connection_params: Dict[str, Any] = {}
|
|
|
|
if es_url:
|
|
connection_params["hosts"] = [es_url]
|
|
elif cloud_id:
|
|
connection_params["cloud_id"] = cloud_id
|
|
else:
|
|
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
|
|
|
|
if api_key:
|
|
connection_params["api_key"] = api_key
|
|
elif username and password:
|
|
connection_params["basic_auth"] = (username, password)
|
|
|
|
sync_es_client = elasticsearch.Elasticsearch(
|
|
**connection_params, headers={"user-agent": ElasticsearchStore.get_user_agent()}
|
|
)
|
|
async_es_client = elasticsearch.AsyncElasticsearch(**connection_params)
|
|
try:
|
|
sync_es_client.info() # so don't have to 'await' to just get info
|
|
except Exception as e:
|
|
logger.error(f"Error connecting to Elasticsearch: {e}")
|
|
raise
|
|
|
|
return async_es_client
|
|
|
|
|
|
def _to_elasticsearch_filter(standard_filters: MetadataFilters) -> Dict[str, Any]:
|
|
"""Convert standard filters to Elasticsearch filter.
|
|
|
|
Args:
|
|
standard_filters: Standard Llama-index filters.
|
|
|
|
Returns:
|
|
Elasticsearch filter.
|
|
"""
|
|
if len(standard_filters.legacy_filters()) == 1:
|
|
filter = standard_filters.legacy_filters()[0]
|
|
return {
|
|
"term": {
|
|
f"metadata.{filter.key}.keyword": {
|
|
"value": filter.value,
|
|
}
|
|
}
|
|
}
|
|
else:
|
|
operands = []
|
|
for filter in standard_filters.legacy_filters():
|
|
operands.append(
|
|
{
|
|
"term": {
|
|
f"metadata.{filter.key}.keyword": {
|
|
"value": filter.value,
|
|
}
|
|
}
|
|
}
|
|
)
|
|
return {"bool": {"must": operands}}
|
|
|
|
|
|
def _to_llama_similarities(scores: List[float]) -> List[float]:
|
|
if scores is None or len(scores) == 0:
|
|
return []
|
|
|
|
scores_to_norm: np.ndarray = np.array(scores)
|
|
return np.exp(scores_to_norm - np.max(scores_to_norm)).tolist()
|
|
|
|
|
|
class ElasticsearchStore(BasePydanticVectorStore):
|
|
"""Elasticsearch vector store.
|
|
|
|
Args:
|
|
index_name: Name of the Elasticsearch index.
|
|
es_client: Optional. Pre-existing AsyncElasticsearch client.
|
|
es_url: Optional. Elasticsearch URL.
|
|
es_cloud_id: Optional. Elasticsearch cloud ID.
|
|
es_api_key: Optional. Elasticsearch API key.
|
|
es_user: Optional. Elasticsearch username.
|
|
es_password: Optional. Elasticsearch password.
|
|
text_field: Optional. Name of the Elasticsearch field that stores the text.
|
|
vector_field: Optional. Name of the Elasticsearch field that stores the
|
|
embedding.
|
|
batch_size: Optional. Batch size for bulk indexing. Defaults to 200.
|
|
distance_strategy: Optional. Distance strategy to use for similarity search.
|
|
Defaults to "COSINE".
|
|
|
|
Raises:
|
|
ConnectionError: If AsyncElasticsearch client cannot connect to Elasticsearch.
|
|
ValueError: If neither es_client nor es_url nor es_cloud_id is provided.
|
|
|
|
"""
|
|
|
|
stores_text: bool = True
|
|
index_name: str
|
|
es_client: Optional[Any]
|
|
es_url: Optional[str]
|
|
es_cloud_id: Optional[str]
|
|
es_api_key: Optional[str]
|
|
es_user: Optional[str]
|
|
es_password: Optional[str]
|
|
text_field: str = "content"
|
|
vector_field: str = "embedding"
|
|
batch_size: int = 200
|
|
distance_strategy: Optional[DISTANCE_STRATEGIES] = "COSINE"
|
|
|
|
_client = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
index_name: str,
|
|
es_client: Optional[Any] = None,
|
|
es_url: Optional[str] = None,
|
|
es_cloud_id: Optional[str] = None,
|
|
es_api_key: Optional[str] = None,
|
|
es_user: Optional[str] = None,
|
|
es_password: Optional[str] = None,
|
|
text_field: str = "content",
|
|
vector_field: str = "embedding",
|
|
batch_size: int = 200,
|
|
distance_strategy: Optional[DISTANCE_STRATEGIES] = "COSINE",
|
|
) -> None:
|
|
nest_asyncio.apply()
|
|
|
|
if es_client is not None:
|
|
self._client = es_client.options(
|
|
headers={"user-agent": self.get_user_agent()}
|
|
)
|
|
elif es_url is not None or es_cloud_id is not None:
|
|
self._client = _get_elasticsearch_client(
|
|
es_url=es_url,
|
|
username=es_user,
|
|
password=es_password,
|
|
cloud_id=es_cloud_id,
|
|
api_key=es_api_key,
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"""Either provide a pre-existing AsyncElasticsearch or valid \
|
|
credentials for creating a new connection."""
|
|
)
|
|
super().__init__(
|
|
index_name=index_name,
|
|
es_client=es_client,
|
|
es_url=es_url,
|
|
es_cloud_id=es_cloud_id,
|
|
es_api_key=es_api_key,
|
|
es_user=es_user,
|
|
es_password=es_password,
|
|
text_field=text_field,
|
|
vector_field=vector_field,
|
|
batch_size=batch_size,
|
|
distance_strategy=distance_strategy,
|
|
)
|
|
|
|
@property
|
|
def client(self) -> Any:
|
|
"""Get async elasticsearch client."""
|
|
return self._client
|
|
|
|
@staticmethod
|
|
def get_user_agent() -> str:
|
|
"""Get user agent for elasticsearch client."""
|
|
import llama_index
|
|
|
|
return f"llama_index-py-vs/{llama_index.__version__}"
|
|
|
|
async def _create_index_if_not_exists(
|
|
self, index_name: str, dims_length: Optional[int] = None
|
|
) -> None:
|
|
"""Create the AsyncElasticsearch index if it doesn't already exist.
|
|
|
|
Args:
|
|
index_name: Name of the AsyncElasticsearch index to create.
|
|
dims_length: Length of the embedding vectors.
|
|
"""
|
|
if self.client.indices.exists(index=index_name):
|
|
logger.debug(f"Index {index_name} already exists. Skipping creation.")
|
|
|
|
else:
|
|
if dims_length is None:
|
|
raise ValueError(
|
|
"Cannot create index without specifying dims_length "
|
|
"when the index doesn't already exist. We infer "
|
|
"dims_length from the first embedding. Check that "
|
|
"you have provided an embedding function."
|
|
)
|
|
|
|
if self.distance_strategy == "COSINE":
|
|
similarityAlgo = "cosine"
|
|
elif self.distance_strategy == "EUCLIDEAN_DISTANCE":
|
|
similarityAlgo = "l2_norm"
|
|
elif self.distance_strategy == "DOT_PRODUCT":
|
|
similarityAlgo = "dot_product"
|
|
else:
|
|
raise ValueError(f"Similarity {self.distance_strategy} not supported.")
|
|
|
|
index_settings = {
|
|
"mappings": {
|
|
"properties": {
|
|
self.vector_field: {
|
|
"type": "dense_vector",
|
|
"dims": dims_length,
|
|
"index": True,
|
|
"similarity": similarityAlgo,
|
|
},
|
|
self.text_field: {"type": "text"},
|
|
"metadata": {
|
|
"properties": {
|
|
"document_id": {"type": "keyword"},
|
|
"doc_id": {"type": "keyword"},
|
|
"ref_doc_id": {"type": "keyword"},
|
|
}
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
logger.debug(
|
|
f"Creating index {index_name} with mappings {index_settings['mappings']}"
|
|
)
|
|
await self.client.indices.create(index=index_name, **index_settings)
|
|
|
|
def add(
|
|
self,
|
|
nodes: List[BaseNode],
|
|
*,
|
|
create_index_if_not_exists: bool = True,
|
|
**add_kwargs: Any,
|
|
) -> List[str]:
|
|
"""Add nodes to Elasticsearch index.
|
|
|
|
Args:
|
|
nodes: List of nodes with embeddings.
|
|
create_index_if_not_exists: Optional. Whether to create
|
|
the Elasticsearch index if it
|
|
doesn't already exist.
|
|
Defaults to True.
|
|
|
|
Returns:
|
|
List of node IDs that were added to the index.
|
|
|
|
Raises:
|
|
ImportError: If elasticsearch['async'] python package is not installed.
|
|
BulkIndexError: If AsyncElasticsearch async_bulk indexing fails.
|
|
"""
|
|
return asyncio.get_event_loop().run_until_complete(
|
|
self.async_add(nodes, create_index_if_not_exists=create_index_if_not_exists)
|
|
)
|
|
|
|
async def async_add(
|
|
self,
|
|
nodes: List[BaseNode],
|
|
*,
|
|
create_index_if_not_exists: bool = True,
|
|
**add_kwargs: Any,
|
|
) -> List[str]:
|
|
"""Asynchronous method to add nodes to Elasticsearch index.
|
|
|
|
Args:
|
|
nodes: List of nodes with embeddings.
|
|
create_index_if_not_exists: Optional. Whether to create
|
|
the AsyncElasticsearch index if it
|
|
doesn't already exist.
|
|
Defaults to True.
|
|
|
|
Returns:
|
|
List of node IDs that were added to the index.
|
|
|
|
Raises:
|
|
ImportError: If elasticsearch python package is not installed.
|
|
BulkIndexError: If AsyncElasticsearch async_bulk indexing fails.
|
|
"""
|
|
try:
|
|
from elasticsearch.helpers import BulkIndexError, async_bulk
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import elasticsearch[async] python package. "
|
|
"Please install it with `pip install 'elasticsearch[async]'`."
|
|
)
|
|
|
|
if len(nodes) == 0:
|
|
return []
|
|
|
|
if create_index_if_not_exists:
|
|
dims_length = len(nodes[0].get_embedding())
|
|
await self._create_index_if_not_exists(
|
|
index_name=self.index_name, dims_length=dims_length
|
|
)
|
|
|
|
embeddings: List[List[float]] = []
|
|
texts: List[str] = []
|
|
metadatas: List[dict] = []
|
|
ids: List[str] = []
|
|
for node in nodes:
|
|
ids.append(node.node_id)
|
|
embeddings.append(node.get_embedding())
|
|
texts.append(node.get_content(metadata_mode=MetadataMode.NONE))
|
|
metadatas.append(node_to_metadata_dict(node, remove_text=True))
|
|
|
|
requests = []
|
|
return_ids = []
|
|
|
|
for i, text in enumerate(texts):
|
|
metadata = metadatas[i] if metadatas else {}
|
|
_id = ids[i] if ids else str(uuid.uuid4())
|
|
request = {
|
|
"_op_type": "index",
|
|
"_index": self.index_name,
|
|
self.vector_field: embeddings[i],
|
|
self.text_field: text,
|
|
"metadata": metadata,
|
|
"_id": _id,
|
|
}
|
|
requests.append(request)
|
|
return_ids.append(_id)
|
|
|
|
await async_bulk(
|
|
self.client, requests, chunk_size=self.batch_size, refresh=True
|
|
)
|
|
try:
|
|
success, failed = await async_bulk(
|
|
self.client, requests, stats_only=True, refresh=True
|
|
)
|
|
logger.debug(f"Added {success} and failed to add {failed} texts to index")
|
|
|
|
logger.debug(f"added texts {ids} to index")
|
|
return return_ids
|
|
except BulkIndexError as e:
|
|
logger.error(f"Error adding texts: {e}")
|
|
firstError = e.errors[0].get("index", {}).get("error", {})
|
|
logger.error(f"First error reason: {firstError.get('reason')}")
|
|
raise
|
|
|
|
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
|
|
"""Delete node from Elasticsearch index.
|
|
|
|
Args:
|
|
ref_doc_id: ID of the node to delete.
|
|
delete_kwargs: Optional. Additional arguments to
|
|
pass to Elasticsearch delete_by_query.
|
|
|
|
Raises:
|
|
Exception: If Elasticsearch delete_by_query fails.
|
|
"""
|
|
return asyncio.get_event_loop().run_until_complete(
|
|
self.adelete(ref_doc_id, **delete_kwargs)
|
|
)
|
|
|
|
async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
|
|
"""Async delete node from Elasticsearch index.
|
|
|
|
Args:
|
|
ref_doc_id: ID of the node to delete.
|
|
delete_kwargs: Optional. Additional arguments to
|
|
pass to AsyncElasticsearch delete_by_query.
|
|
|
|
Raises:
|
|
Exception: If AsyncElasticsearch delete_by_query fails.
|
|
"""
|
|
try:
|
|
async with self.client as client:
|
|
res = await client.delete_by_query(
|
|
index=self.index_name,
|
|
query={"term": {"metadata.ref_doc_id": ref_doc_id}},
|
|
refresh=True,
|
|
**delete_kwargs,
|
|
)
|
|
if res["deleted"] == 0:
|
|
logger.warning(f"Could not find text {ref_doc_id} to delete")
|
|
else:
|
|
logger.debug(f"Deleted text {ref_doc_id} from index")
|
|
except Exception:
|
|
logger.error(f"Error deleting text: {ref_doc_id}")
|
|
raise
|
|
|
|
def query(
|
|
self,
|
|
query: VectorStoreQuery,
|
|
custom_query: Optional[
|
|
Callable[[Dict, Union[VectorStoreQuery, None]], Dict]
|
|
] = None,
|
|
es_filter: Optional[List[Dict]] = None,
|
|
**kwargs: Any,
|
|
) -> VectorStoreQueryResult:
|
|
"""Query index for top k most similar nodes.
|
|
|
|
Args:
|
|
query_embedding (List[float]): query embedding
|
|
custom_query: Optional. custom query function that takes in the es query
|
|
body and returns a modified query body.
|
|
This can be used to add additional query
|
|
parameters to the Elasticsearch query.
|
|
es_filter: Optional. Elasticsearch filter to apply to the
|
|
query. If filter is provided in the query,
|
|
this filter will be ignored.
|
|
|
|
Returns:
|
|
VectorStoreQueryResult: Result of the query.
|
|
|
|
Raises:
|
|
Exception: If Elasticsearch query fails.
|
|
|
|
"""
|
|
return asyncio.get_event_loop().run_until_complete(
|
|
self.aquery(query, custom_query, es_filter, **kwargs)
|
|
)
|
|
|
|
async def aquery(
|
|
self,
|
|
query: VectorStoreQuery,
|
|
custom_query: Optional[
|
|
Callable[[Dict, Union[VectorStoreQuery, None]], Dict]
|
|
] = None,
|
|
es_filter: Optional[List[Dict]] = None,
|
|
**kwargs: Any,
|
|
) -> VectorStoreQueryResult:
|
|
"""Asynchronous query index for top k most similar nodes.
|
|
|
|
Args:
|
|
query_embedding (VectorStoreQuery): query embedding
|
|
custom_query: Optional. custom query function that takes in the es query
|
|
body and returns a modified query body.
|
|
This can be used to add additional query
|
|
parameters to the AsyncElasticsearch query.
|
|
es_filter: Optional. AsyncElasticsearch filter to apply to the
|
|
query. If filter is provided in the query,
|
|
this filter will be ignored.
|
|
|
|
Returns:
|
|
VectorStoreQueryResult: Result of the query.
|
|
|
|
Raises:
|
|
Exception: If AsyncElasticsearch query fails.
|
|
|
|
"""
|
|
query_embedding = cast(List[float], query.query_embedding)
|
|
|
|
es_query = {}
|
|
|
|
if query.filters is not None and len(query.filters.legacy_filters()) > 0:
|
|
filter = [_to_elasticsearch_filter(query.filters)]
|
|
else:
|
|
filter = es_filter or []
|
|
|
|
if query.mode in (
|
|
VectorStoreQueryMode.DEFAULT,
|
|
VectorStoreQueryMode.HYBRID,
|
|
):
|
|
es_query["knn"] = {
|
|
"filter": filter,
|
|
"field": self.vector_field,
|
|
"query_vector": query_embedding,
|
|
"k": query.similarity_top_k,
|
|
"num_candidates": query.similarity_top_k * 10,
|
|
}
|
|
|
|
if query.mode in (
|
|
VectorStoreQueryMode.TEXT_SEARCH,
|
|
VectorStoreQueryMode.HYBRID,
|
|
):
|
|
es_query["query"] = {
|
|
"bool": {
|
|
"must": {"match": {self.text_field: {"query": query.query_str}}},
|
|
"filter": filter,
|
|
}
|
|
}
|
|
|
|
if query.mode == VectorStoreQueryMode.HYBRID:
|
|
es_query["rank"] = {"rrf": {}}
|
|
|
|
if custom_query is not None:
|
|
es_query = custom_query(es_query, query)
|
|
logger.debug(f"Calling custom_query, Query body now: {es_query}")
|
|
|
|
async with self.client as client:
|
|
response = await client.search(
|
|
index=self.index_name,
|
|
**es_query,
|
|
size=query.similarity_top_k,
|
|
_source={"excludes": [self.vector_field]},
|
|
)
|
|
|
|
top_k_nodes = []
|
|
top_k_ids = []
|
|
top_k_scores = []
|
|
hits = response["hits"]["hits"]
|
|
for hit in hits:
|
|
source = hit["_source"]
|
|
metadata = source.get("metadata", None)
|
|
text = source.get(self.text_field, None)
|
|
node_id = hit["_id"]
|
|
|
|
try:
|
|
node = metadata_dict_to_node(metadata)
|
|
node.text = text
|
|
except Exception:
|
|
# Legacy support for old metadata format
|
|
logger.warning(
|
|
f"Could not parse metadata from hit {hit['_source']['metadata']}"
|
|
)
|
|
node_info = source.get("node_info")
|
|
relationships = source.get("relationships")
|
|
start_char_idx = None
|
|
end_char_idx = None
|
|
if isinstance(node_info, dict):
|
|
start_char_idx = node_info.get("start", None)
|
|
end_char_idx = node_info.get("end", None)
|
|
|
|
node = TextNode(
|
|
text=text,
|
|
metadata=metadata,
|
|
id_=node_id,
|
|
start_char_idx=start_char_idx,
|
|
end_char_idx=end_char_idx,
|
|
relationships=relationships,
|
|
)
|
|
top_k_nodes.append(node)
|
|
top_k_ids.append(node_id)
|
|
top_k_scores.append(hit.get("_rank", hit["_score"]))
|
|
|
|
if query.mode == VectorStoreQueryMode.HYBRID:
|
|
total_rank = sum(top_k_scores)
|
|
top_k_scores = [total_rank - rank / total_rank for rank in top_k_scores]
|
|
|
|
return VectorStoreQueryResult(
|
|
nodes=top_k_nodes,
|
|
ids=top_k_ids,
|
|
similarities=_to_llama_similarities(top_k_scores),
|
|
)
|