faiss_rag_enterprise/llama_index/vector_stores/lancedb.py

225 lines
7.2 KiB
Python

"""LanceDB vector store."""
import logging
from typing import Any, List, Optional
import numpy as np
from pandas import DataFrame
from llama_index.schema import (
BaseNode,
MetadataMode,
NodeRelationship,
RelatedNodeInfo,
TextNode,
)
from llama_index.vector_stores.types import (
MetadataFilters,
VectorStore,
VectorStoreQuery,
VectorStoreQueryResult,
)
from llama_index.vector_stores.utils import (
DEFAULT_DOC_ID_KEY,
DEFAULT_TEXT_KEY,
legacy_metadata_dict_to_node,
metadata_dict_to_node,
node_to_metadata_dict,
)
_logger = logging.getLogger(__name__)
def _to_lance_filter(standard_filters: MetadataFilters) -> Any:
"""Translate standard metadata filters to Lance specific spec."""
filters = []
for filter in standard_filters.legacy_filters():
if isinstance(filter.value, str):
filters.append(filter.key + ' = "' + filter.value + '"')
else:
filters.append(filter.key + " = " + str(filter.value))
return " AND ".join(filters)
def _to_llama_similarities(results: DataFrame) -> List[float]:
keys = results.keys()
normalized_similarities: np.ndarray
if "score" in keys:
normalized_similarities = np.exp(results["score"] - np.max(results["score"]))
elif "_distance" in keys:
normalized_similarities = np.exp(-results["_distance"])
else:
normalized_similarities = np.linspace(1, 0, len(results))
return normalized_similarities.tolist()
class LanceDBVectorStore(VectorStore):
"""
The LanceDB Vector Store.
Stores text and embeddings in LanceDB. The vector store will open an existing
LanceDB dataset or create the dataset if it does not exist.
Args:
uri (str, required): Location where LanceDB will store its files.
table_name (str, optional): The table name where the embeddings will be stored.
Defaults to "vectors".
vector_column_name (str, optional): The vector column name in the table if different from default.
Defaults to "vector", in keeping with lancedb convention.
nprobes (int, optional): The number of probes used.
A higher number makes search more accurate but also slower.
Defaults to 20.
refine_factor: (int, optional): Refine the results by reading extra elements
and re-ranking them in memory.
Defaults to None
Raises:
ImportError: Unable to import `lancedb`.
Returns:
LanceDBVectorStore: VectorStore that supports creating LanceDB datasets and
querying it.
"""
stores_text = True
flat_metadata: bool = True
def __init__(
self,
uri: str,
table_name: str = "vectors",
vector_column_name: str = "vector",
nprobes: int = 20,
refine_factor: Optional[int] = None,
text_key: str = DEFAULT_TEXT_KEY,
doc_id_key: str = DEFAULT_DOC_ID_KEY,
**kwargs: Any,
) -> None:
"""Init params."""
import_err_msg = "`lancedb` package not found, please run `pip install lancedb`"
try:
import lancedb
except ImportError:
raise ImportError(import_err_msg)
self.connection = lancedb.connect(uri)
self.uri = uri
self.table_name = table_name
self.vector_column_name = vector_column_name
self.nprobes = nprobes
self.text_key = text_key
self.doc_id_key = doc_id_key
self.refine_factor = refine_factor
@property
def client(self) -> None:
"""Get client."""
return
def add(
self,
nodes: List[BaseNode],
**add_kwargs: Any,
) -> List[str]:
data = []
ids = []
for node in nodes:
metadata = node_to_metadata_dict(
node, remove_text=False, flat_metadata=self.flat_metadata
)
append_data = {
"id": node.node_id,
"doc_id": node.ref_doc_id,
"vector": node.get_embedding(),
"text": node.get_content(metadata_mode=MetadataMode.NONE),
"metadata": metadata,
}
data.append(append_data)
ids.append(node.node_id)
if self.table_name in self.connection.table_names():
tbl = self.connection.open_table(self.table_name)
tbl.add(data)
else:
self.connection.create_table(self.table_name, data)
return ids
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
Delete nodes using with ref_doc_id.
Args:
ref_doc_id (str): The doc_id of the document to delete.
"""
table = self.connection.open_table(self.table_name)
table.delete('document_id = "' + ref_doc_id + '"')
def query(
self,
query: VectorStoreQuery,
**kwargs: Any,
) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes."""
if query.filters is not None:
if "where" in kwargs:
raise ValueError(
"Cannot specify filter via both query and kwargs. "
"Use kwargs only for lancedb specific items that are "
"not supported via the generic query interface."
)
where = _to_lance_filter(query.filters)
else:
where = kwargs.pop("where", None)
table = self.connection.open_table(self.table_name)
lance_query = (
table.search(
query=query.query_embedding,
vector_column_name=self.vector_column_name,
)
.limit(query.similarity_top_k)
.where(where)
.nprobes(self.nprobes)
)
if self.refine_factor is not None:
lance_query.refine_factor(self.refine_factor)
results = lance_query.to_pandas()
nodes = []
for _, item in results.iterrows():
try:
node = metadata_dict_to_node(item.metadata)
node.embedding = list(item[self.vector_column_name])
except Exception:
# deprecated legacy logic for backward compatibility
_logger.debug(
"Failed to parse Node metadata, fallback to legacy logic."
)
if "metadata" in item:
metadata, node_info, _relation = legacy_metadata_dict_to_node(
item.metadata, text_key=self.text_key
)
else:
metadata, node_info = {}, {}
node = TextNode(
text=item[self.text_key] or "",
id_=item.id,
metadata=metadata,
start_char_idx=node_info.get("start", None),
end_char_idx=node_info.get("end", None),
relationships={
NodeRelationship.SOURCE: RelatedNodeInfo(
node_id=item[self.doc_id_key]
),
},
)
nodes.append(node)
return VectorStoreQueryResult(
nodes=nodes,
similarities=_to_llama_similarities(results),
ids=results["id"].tolist(),
)