faiss_rag_enterprise/llama_index/vector_stores/dynamodb.py

149 lines
5.1 KiB
Python

"""DynamoDB vector store index."""
from __future__ import annotations
from logging import getLogger
from typing import Any, Dict, List, cast
from llama_index.indices.query.embedding_utils import (
get_top_k_embeddings,
get_top_k_embeddings_learner,
)
from llama_index.schema import BaseNode
from llama_index.storage.kvstore.dynamodb_kvstore import DynamoDBKVStore
from llama_index.vector_stores.types import (
VectorStore,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
)
logger = getLogger(__name__)
DEFAULT_NAMESPACE = "vector_store"
LEARNER_MODES = {
VectorStoreQueryMode.SVM,
VectorStoreQueryMode.LINEAR_REGRESSION,
VectorStoreQueryMode.LOGISTIC_REGRESSION,
}
class DynamoDBVectorStore(VectorStore):
"""DynamoDB Vector Store.
In this vector store, embeddings are stored within dynamodb table.
This class was implemented with reference to SimpleVectorStore.
Args:
dynamodb_kvstore (DynamoDBKVStore): data store
namespace (Optional[str]): namespace
"""
stores_text: bool = False
def __init__(
self, dynamodb_kvstore: DynamoDBKVStore, namespace: str | None = None
) -> None:
"""Initialize params."""
self._kvstore = dynamodb_kvstore
namespace = namespace or DEFAULT_NAMESPACE
self._collection_embedding = f"{namespace}/embedding"
self._collection_text_id_to_doc_id = f"{namespace}/text_id_to_doc_id"
self._key_value = "value"
@classmethod
def from_table_name(
cls, table_name: str, namespace: str | None = None
) -> DynamoDBVectorStore:
"""Load from DynamoDB table name."""
dynamodb_kvstore = DynamoDBKVStore.from_table_name(table_name=table_name)
return cls(dynamodb_kvstore=dynamodb_kvstore, namespace=namespace)
@property
def client(self) -> None:
"""Get client."""
return
def get(self, text_id: str) -> List[float]:
"""Get embedding."""
item = self._kvstore.get(key=text_id, collection=self._collection_embedding)
item = cast(Dict[str, List[float]], item)
return item[self._key_value]
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
"""Add nodes to index."""
response = []
for node in nodes:
self._kvstore.put(
key=node.node_id,
val={self._key_value: node.get_embedding()},
collection=self._collection_embedding,
)
self._kvstore.put(
key=node.node_id,
val={self._key_value: node.ref_doc_id},
collection=self._collection_text_id_to_doc_id,
)
response.append(node.node_id)
return response
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
Delete nodes using with ref_doc_id.
Args:
ref_doc_id (str): The doc_id of the document to delete.
"""
text_ids_to_delete = set()
for text_id, item in self._kvstore.get_all(
collection=self._collection_text_id_to_doc_id
).items():
if ref_doc_id == item[self._key_value]:
text_ids_to_delete.add(text_id)
for text_id in text_ids_to_delete:
self._kvstore.delete(key=text_id, collection=self._collection_embedding)
self._kvstore.delete(
key=text_id, collection=self._collection_text_id_to_doc_id
)
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Get nodes for response."""
if query.filters is not None:
raise ValueError(
"Metadata filters not implemented for SimpleVectorStore yet."
)
# TODO: consolidate with get_query_text_embedding_similarities
items = self._kvstore.get_all(collection=self._collection_embedding).items()
if query.node_ids:
available_ids = set(query.node_ids)
node_ids = [k for k, _ in items if k in available_ids]
embeddings = [v[self._key_value] for k, v in items if k in available_ids]
else:
node_ids = [k for k, _ in items]
embeddings = [v[self._key_value] for k, v in items]
query_embedding = cast(List[float], query.query_embedding)
if query.mode in LEARNER_MODES:
top_similarities, top_ids = get_top_k_embeddings_learner(
query_embedding=query_embedding,
embeddings=embeddings,
similarity_top_k=query.similarity_top_k,
embedding_ids=node_ids,
)
elif query.mode == VectorStoreQueryMode.DEFAULT:
top_similarities, top_ids = get_top_k_embeddings(
query_embedding=query_embedding,
embeddings=embeddings,
similarity_top_k=query.similarity_top_k,
embedding_ids=node_ids,
)
else:
raise ValueError(f"Invalid query mode: {query.mode}")
return VectorStoreQueryResult(similarities=top_similarities, ids=top_ids)