faiss_rag_enterprise/llama_index/vector_stores/bagel.py

184 lines
5.2 KiB
Python

import logging
import math
from typing import Any, List
from llama_index.schema import BaseNode, MetadataMode, TextNode
from llama_index.vector_stores.types import (
MetadataFilters,
VectorStore,
VectorStoreQuery,
VectorStoreQueryResult,
)
from llama_index.vector_stores.utils import (
legacy_metadata_dict_to_node,
metadata_dict_to_node,
node_to_metadata_dict,
)
logger = logging.getLogger(__name__)
def _to_bagel_filter(standard_filters: MetadataFilters) -> dict:
"""
Translate standard metadata filters to Bagel specific spec.
"""
filters = {}
for filter in standard_filters.legacy_filters():
filters[filter.key] = filter.value
return filters
class BagelVectorStore(VectorStore):
"""
Vector store for Bagel.
"""
# support for Bagel specific parameters
stores_text: bool = True
flat_metadata: bool = True
def __init__(self, collection: Any, **kwargs: Any) -> None:
"""
Initialize BagelVectorStore.
Args:
collection: Bagel collection.
**kwargs: Additional arguments.
"""
try:
from bagel.api.Cluster import Cluster
except ImportError:
raise ImportError("Bagel is not installed. Please install bagel.")
if not isinstance(collection, Cluster):
raise ValueError("Collection must be a bagel Cluster.")
self._collection = collection
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
"""
Add a list of nodes with embeddings to the vector store.
Args:
nodes: List of nodes with embeddings.
kwargs: Additional arguments.
Returns:
List of document ids.
"""
if not self._collection:
raise ValueError("collection not set")
ids = []
embeddings = []
metadatas = []
documents = []
for node in nodes:
ids.append(node.node_id)
embeddings.append(node.get_embedding())
metadatas.append(
node_to_metadata_dict(
node,
remove_text=True,
flat_metadata=self.flat_metadata,
)
)
documents.append(node.get_content(metadata_mode=MetadataMode.NONE) or "")
self._collection.add(
ids=ids, embeddings=embeddings, metadatas=metadatas, documents=documents
)
return ids
def delete(self, ref_doc_id: str, **kwargs: Any) -> None:
"""
Delete a document from the vector store.
Args:
ref_doc_id: Reference document id.
kwargs: Additional arguments.
"""
if not self._collection:
raise ValueError("collection not set")
results = self._collection.get(where={"doc_id": ref_doc_id})
if results and "ids" in results:
self._collection.delete(ids=results["ids"])
@property
def client(self) -> Any:
"""
Get the Bagel cluster.
"""
return self._collection
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""
Query the vector store.
Args:
query: Query to run.
kwargs: Additional arguments.
Returns:
Query result.
"""
if not self._collection:
raise ValueError("collection not set")
if query.filters is not None:
if "where" in kwargs:
raise ValueError("Cannot specify both filters and where")
where = _to_bagel_filter(query.filters)
else:
where = kwargs.get("where", {})
results = self._collection.find(
query_embeddings=query.query_embedding,
where=where,
n_results=query.similarity_top_k,
**kwargs,
)
logger.debug(f"query results: {results}")
nodes = []
similarities = []
ids = []
for node_id, text, metadata, distance in zip(
results["ids"][0],
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
):
try:
node = metadata_dict_to_node(metadata)
node.set_content(text)
except Exception:
# NOTE: deprecated legacy logic for backward compatibility
metadata, node_info, relationships = legacy_metadata_dict_to_node(
metadata
)
node = TextNode(
text=text,
id_=node_id,
metadata=metadata,
start_char_idx=node_info.get("start", None),
end_char_idx=node_info.get("end", None),
relationships=relationships,
)
nodes.append(node)
similarities.append(1.0 - math.exp(-distance))
ids.append(node_id)
logger.debug(f"node: {node}")
logger.debug(f"similarity: {1.0 - math.exp(-distance)}")
logger.debug(f"id: {node_id}")
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)