319 lines
12 KiB
Python
319 lines
12 KiB
Python
"""Cassandra / Astra DB Vector store index.
|
|
|
|
An index based on a DB table with vector search capabilities,
|
|
powered by the cassIO library
|
|
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any, Dict, Iterable, List, Optional, TypeVar, cast
|
|
|
|
from llama_index.indices.query.embedding_utils import (
|
|
get_top_k_mmr_embeddings,
|
|
)
|
|
from llama_index.schema import BaseNode, MetadataMode
|
|
from llama_index.vector_stores.types import (
|
|
ExactMatchFilter,
|
|
MetadataFilters,
|
|
VectorStore,
|
|
VectorStoreQuery,
|
|
VectorStoreQueryMode,
|
|
VectorStoreQueryResult,
|
|
)
|
|
from llama_index.vector_stores.utils import (
|
|
metadata_dict_to_node,
|
|
node_to_metadata_dict,
|
|
)
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_MMR_PREFETCH_FACTOR = 4.0
|
|
DEFAULT_INSERTION_BATCH_SIZE = 20
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def _batch_iterable(iterable: Iterable[T], batch_size: int) -> Iterable[Iterable[T]]:
|
|
this_batch = []
|
|
for entry in iterable:
|
|
this_batch.append(entry)
|
|
if len(this_batch) == batch_size:
|
|
yield this_batch
|
|
this_batch = []
|
|
if this_batch:
|
|
yield this_batch
|
|
|
|
|
|
class CassandraVectorStore(VectorStore):
|
|
"""
|
|
Cassandra Vector Store.
|
|
|
|
An abstraction of a Cassandra table with
|
|
vector-similarity-search. Documents, and their embeddings, are stored
|
|
in a Cassandra table and a vector-capable index is used for searches.
|
|
The table does not need to exist beforehand: if necessary it will
|
|
be created behind the scenes.
|
|
|
|
All Cassandra operations are done through the CassIO library.
|
|
|
|
Note: in recent versions, only `table` and `embedding_dimension` can be
|
|
passed positionally. Please revise your code if needed.
|
|
This is to accommodate for a leaner usage, whereby the DB connection
|
|
is set globally through a `cassio.init(...)` call: then, the DB details
|
|
are not to be specified anymore when creating a vector store, unless
|
|
desired.
|
|
|
|
Args:
|
|
table (str): table name to use. If not existing, it will be created.
|
|
embedding_dimension (int): length of the embedding vectors in use.
|
|
session (optional, cassandra.cluster.Session): the Cassandra session
|
|
to use.
|
|
Can be omitted, or equivalently set to None, to use the
|
|
DB connection set globally through cassio.init() beforehand.
|
|
keyspace (optional. str): name of the Cassandra keyspace to work in
|
|
Can be omitted, or equivalently set to None, to use the
|
|
DB connection set globally through cassio.init() beforehand.
|
|
ttl_seconds (optional, int): expiration time for inserted entries.
|
|
Default is no expiration (None).
|
|
insertion_batch_size (optional, int): how many vectors are inserted
|
|
concurrently, for use by bulk inserts. Defaults to 20.
|
|
"""
|
|
|
|
stores_text: bool = True
|
|
flat_metadata: bool = True
|
|
|
|
def __init__(
|
|
self,
|
|
table: str,
|
|
embedding_dimension: int,
|
|
*,
|
|
session: Optional[Any] = None,
|
|
keyspace: Optional[str] = None,
|
|
ttl_seconds: Optional[int] = None,
|
|
insertion_batch_size: int = DEFAULT_INSERTION_BATCH_SIZE,
|
|
) -> None:
|
|
import_err_msg = (
|
|
"`cassio` package not found, please run `pip install --upgrade cassio`"
|
|
)
|
|
try:
|
|
from cassio.table import ClusteredMetadataVectorCassandraTable
|
|
except ImportError:
|
|
raise ImportError(import_err_msg)
|
|
|
|
self._session = session
|
|
self._keyspace = keyspace
|
|
self._table = table
|
|
self._embedding_dimension = embedding_dimension
|
|
self._ttl_seconds = ttl_seconds
|
|
self._insertion_batch_size = insertion_batch_size
|
|
|
|
_logger.debug("Creating the Cassandra table")
|
|
self.vector_table = ClusteredMetadataVectorCassandraTable(
|
|
session=self._session,
|
|
keyspace=self._keyspace,
|
|
table=self._table,
|
|
vector_dimension=self._embedding_dimension,
|
|
primary_key_type=["TEXT", "TEXT"],
|
|
# a conservative choice here, to make everything searchable
|
|
# except the bulky "_node_content" key (it'd make little sense to):
|
|
metadata_indexing=("default_to_searchable", ["_node_content"]),
|
|
)
|
|
|
|
def add(
|
|
self,
|
|
nodes: List[BaseNode],
|
|
**add_kwargs: Any,
|
|
) -> List[str]:
|
|
"""Add nodes to index.
|
|
|
|
Args:
|
|
nodes: List[BaseNode]: list of node with embeddings
|
|
|
|
"""
|
|
node_ids = []
|
|
node_contents = []
|
|
node_metadatas = []
|
|
node_embeddings = []
|
|
for node in nodes:
|
|
metadata = node_to_metadata_dict(
|
|
node,
|
|
remove_text=True,
|
|
flat_metadata=self.flat_metadata,
|
|
)
|
|
node_ids.append(node.node_id)
|
|
node_contents.append(node.get_content(metadata_mode=MetadataMode.NONE))
|
|
node_metadatas.append(metadata)
|
|
node_embeddings.append(node.get_embedding())
|
|
|
|
_logger.debug(f"Adding {len(node_ids)} rows to table")
|
|
# Concurrent batching of inserts:
|
|
insertion_tuples = zip(node_ids, node_contents, node_metadatas, node_embeddings)
|
|
for insertion_batch in _batch_iterable(
|
|
insertion_tuples, batch_size=self._insertion_batch_size
|
|
):
|
|
futures = []
|
|
for (
|
|
node_id,
|
|
node_content,
|
|
node_metadata,
|
|
node_embedding,
|
|
) in insertion_batch:
|
|
node_ref_doc_id = node_metadata["ref_doc_id"]
|
|
futures.append(
|
|
self.vector_table.put_async(
|
|
row_id=node_id,
|
|
body_blob=node_content,
|
|
vector=node_embedding,
|
|
metadata=node_metadata,
|
|
partition_id=node_ref_doc_id,
|
|
ttl_seconds=self._ttl_seconds,
|
|
)
|
|
)
|
|
for future in futures:
|
|
_ = future.result()
|
|
|
|
return node_ids
|
|
|
|
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
|
|
"""
|
|
Delete nodes using with ref_doc_id.
|
|
|
|
Args:
|
|
ref_doc_id (str): The doc_id of the document to delete.
|
|
|
|
"""
|
|
_logger.debug("Deleting a document from the Cassandra table")
|
|
self.vector_table.delete_partition(
|
|
partition_id=ref_doc_id,
|
|
)
|
|
|
|
@property
|
|
def client(self) -> Any:
|
|
"""Return the underlying cassIO vector table object."""
|
|
return self.vector_table
|
|
|
|
@staticmethod
|
|
def _query_filters_to_dict(query_filters: MetadataFilters) -> Dict[str, Any]:
|
|
if any(
|
|
not isinstance(f, ExactMatchFilter) for f in query_filters.legacy_filters()
|
|
):
|
|
raise NotImplementedError("Only `ExactMatchFilter` filters are supported")
|
|
return {f.key: f.value for f in query_filters.filters}
|
|
|
|
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
|
|
"""
|
|
Query index for top k most similar nodes.
|
|
|
|
Supported query modes: 'default' (most similar vectors) and 'mmr'.
|
|
|
|
Args:
|
|
query (VectorStoreQuery): the basic query definition. Defines:
|
|
mode (VectorStoreQueryMode): one of the supported modes
|
|
query_embedding (List[float]): query embedding to search against
|
|
similarity_top_k (int): top k most similar nodes
|
|
mmr_threshold (Optional[float]): this is the 0-to-1 MMR lambda.
|
|
If present, takes precedence over the kwargs parameter.
|
|
Ignored unless for MMR queries.
|
|
|
|
Args for query.mode == 'mmr' (ignored otherwise):
|
|
mmr_threshold (Optional[float]): this is the 0-to-1 lambda for MMR.
|
|
Note that in principle mmr_threshold could come in the query
|
|
mmr_prefetch_factor (Optional[float]): factor applied to top_k
|
|
for prefetch pool size. Defaults to 4.0
|
|
mmr_prefetch_k (Optional[int]): prefetch pool size. This cannot be
|
|
passed together with mmr_prefetch_factor
|
|
"""
|
|
_available_query_modes = [
|
|
VectorStoreQueryMode.DEFAULT,
|
|
VectorStoreQueryMode.MMR,
|
|
]
|
|
if query.mode not in _available_query_modes:
|
|
raise NotImplementedError(f"Query mode {query.mode} not available.")
|
|
#
|
|
query_embedding = cast(List[float], query.query_embedding)
|
|
|
|
# metadata filtering
|
|
if query.filters is not None:
|
|
# raise NotImplementedError("No metadata filtering yet")
|
|
query_metadata = self._query_filters_to_dict(query.filters)
|
|
else:
|
|
query_metadata = {}
|
|
|
|
_logger.debug(
|
|
f"Running ANN search on the Cassandra table (query mode: {query.mode})"
|
|
)
|
|
if query.mode == VectorStoreQueryMode.DEFAULT:
|
|
matches = list(
|
|
self.vector_table.metric_ann_search(
|
|
vector=query_embedding,
|
|
n=query.similarity_top_k,
|
|
metric="cos",
|
|
metric_threshold=None,
|
|
metadata=query_metadata,
|
|
)
|
|
)
|
|
top_k_scores = [match["distance"] for match in matches]
|
|
elif query.mode == VectorStoreQueryMode.MMR:
|
|
# Querying a larger number of vectors and then doing MMR on them.
|
|
if (
|
|
kwargs.get("mmr_prefetch_factor") is not None
|
|
and kwargs.get("mmr_prefetch_k") is not None
|
|
):
|
|
raise ValueError(
|
|
"'mmr_prefetch_factor' and 'mmr_prefetch_k' "
|
|
"cannot coexist in a call to query()"
|
|
)
|
|
else:
|
|
if kwargs.get("mmr_prefetch_k") is not None:
|
|
prefetch_k0 = int(kwargs["mmr_prefetch_k"])
|
|
else:
|
|
prefetch_k0 = int(
|
|
query.similarity_top_k
|
|
* kwargs.get("mmr_prefetch_factor", DEFAULT_MMR_PREFETCH_FACTOR)
|
|
)
|
|
prefetch_k = max(prefetch_k0, query.similarity_top_k)
|
|
#
|
|
prefetch_matches = list(
|
|
self.vector_table.metric_ann_search(
|
|
vector=query_embedding,
|
|
n=prefetch_k,
|
|
metric="cos",
|
|
metric_threshold=None, # this is not `mmr_threshold`
|
|
metadata=query_metadata,
|
|
)
|
|
)
|
|
#
|
|
mmr_threshold = query.mmr_threshold or kwargs.get("mmr_threshold")
|
|
if prefetch_matches:
|
|
pf_match_indices, pf_match_embeddings = zip(
|
|
*enumerate(match["vector"] for match in prefetch_matches)
|
|
)
|
|
else:
|
|
pf_match_indices, pf_match_embeddings = [], []
|
|
pf_match_indices = list(pf_match_indices)
|
|
pf_match_embeddings = list(pf_match_embeddings)
|
|
mmr_similarities, mmr_indices = get_top_k_mmr_embeddings(
|
|
query_embedding,
|
|
pf_match_embeddings,
|
|
similarity_top_k=query.similarity_top_k,
|
|
embedding_ids=pf_match_indices,
|
|
mmr_threshold=mmr_threshold,
|
|
)
|
|
#
|
|
matches = [prefetch_matches[mmr_index] for mmr_index in mmr_indices]
|
|
top_k_scores = mmr_similarities
|
|
|
|
top_k_nodes = []
|
|
top_k_ids = []
|
|
for match in matches:
|
|
node = metadata_dict_to_node(match["metadata"])
|
|
node.set_content(match["body_blob"])
|
|
top_k_nodes.append(node)
|
|
top_k_ids.append(match["row_id"])
|
|
|
|
return VectorStoreQueryResult(
|
|
nodes=top_k_nodes,
|
|
similarities=top_k_scores,
|
|
ids=top_k_ids,
|
|
)
|