faiss_rag_enterprise/llama_index/vector_stores/tair.py

274 lines
8.7 KiB
Python

"""Tair Vector store index.
An index that is built on top of Alibaba Cloud's Tair database.
"""
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from llama_index.schema import (
BaseNode,
MetadataMode,
NodeRelationship,
RelatedNodeInfo,
TextNode,
)
from llama_index.vector_stores.types import (
MetadataFilters,
VectorStore,
VectorStoreQuery,
VectorStoreQueryResult,
)
from llama_index.vector_stores.utils import node_to_metadata_dict
_logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from tair import Tair
def _to_filter_expr(filters: MetadataFilters) -> str:
conditions = []
for f in filters.legacy_filters():
value = str(f.value)
if isinstance(f.value, str):
value = '"' + value + '"'
conditions.append(f"{f.key}=={value}")
return "&&".join(conditions)
class TairVectorStore(VectorStore):
stores_text = True
stores_node = True
flat_metadata = False
def __init__(
self,
tair_url: str,
index_name: str,
index_type: str = "HNSW",
index_args: Optional[Dict[str, Any]] = None,
overwrite: bool = False,
**kwargs: Any,
) -> None:
"""Initialize TairVectorStore.
Two index types are available: FLAT & HNSW.
index args for HNSW:
- ef_construct
- M
- ef_search
Detailed info for these arguments can be found here:
https://www.alibabacloud.com/help/en/tair/latest/tairvector#section-c76-ull-5mk
Args:
index_name (str): Name of the index.
index_type (str): Type of the index. Defaults to 'HNSW'.
index_args (Dict[str, Any]): Arguments for the index. Defaults to None.
tair_url (str): URL for the Tair instance.
overwrite (bool): Whether to overwrite the index if it already exists.
Defaults to False.
kwargs (Any): Additional arguments to pass to the Tair client.
Raises:
ValueError: If tair-py is not installed
ValueError: If failed to connect to Tair instance
Examples:
>>> from llama_index.vector_stores.tair import TairVectorStore
>>> # Create a TairVectorStore
>>> vector_store = TairVectorStore(
>>> tair_url="redis://{username}:{password}@r-bp****************.\
redis.rds.aliyuncs.com:{port}",
>>> index_name="my_index",
>>> index_type="HNSW",
>>> index_args={"M": 16, "ef_construct": 200},
>>> overwrite=True)
"""
try:
from tair import Tair, tairvector # noqa
except ImportError:
raise ValueError(
"Could not import tair-py python package. "
"Please install it with `pip install tair`."
)
try:
self._tair_client = Tair.from_url(tair_url, **kwargs)
except ValueError as e:
raise ValueError(f"Tair failed to connect: {e}")
# index identifiers
self._index_name = index_name
self._index_type = index_type
self._metric_type = "L2"
self._overwrite = overwrite
self._index_args = {}
self._query_args = {}
if index_type == "HNSW":
if index_args is not None:
ef_construct = index_args.get("ef_construct", 500)
M = index_args.get("M", 24)
ef_search = index_args.get("ef_search", 400)
else:
ef_construct = 500
M = 24
ef_search = 400
self._index_args = {"ef_construct": ef_construct, "M": M}
self._query_args = {"ef_search": ef_search}
@property
def client(self) -> "Tair":
"""Return the Tair client instance."""
return self._tair_client
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
"""Add nodes to the index.
Args:
nodes (List[BaseNode]): List of nodes with embeddings
Returns:
List[str]: List of ids of the documents added to the index.
"""
# check to see if empty document list was passed
if len(nodes) == 0:
return []
# set vector dim for creation if index doesn't exist
self.dim = len(nodes[0].get_embedding())
if self._index_exists():
if self._overwrite:
self.delete_index()
self._create_index()
else:
logging.info(f"Adding document to existing index {self._index_name}")
else:
self._create_index()
ids = []
for node in nodes:
attributes = {
"id": node.node_id,
"doc_id": node.ref_doc_id,
"text": node.get_content(metadata_mode=MetadataMode.NONE),
}
metadata_dict = node_to_metadata_dict(
node, remove_text=True, flat_metadata=self.flat_metadata
)
attributes.update(metadata_dict)
ids.append(node.node_id)
self._tair_client.tvs_hset(
self._index_name,
f"{node.ref_doc_id}#{node.node_id}",
vector=node.get_embedding(),
is_binary=False,
**attributes,
)
_logger.info(f"Added {len(ids)} documents to index {self._index_name}")
return ids
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""Delete a document.
Args:
doc_id (str): document id
"""
iter = self._tair_client.tvs_scan(self._index_name, "%s#*" % ref_doc_id)
for k in iter:
self._tair_client.tvs_del(self._index_name, k)
def delete_index(self) -> None:
"""Delete the index and all documents."""
_logger.info(f"Deleting index {self._index_name}")
self._tair_client.tvs_del_index(self._index_name)
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query the index.
Args:
query (VectorStoreQuery): query object
Returns:
VectorStoreQueryResult: query result
Raises:
ValueError: If query.query_embedding is None.
"""
filter_expr = None
if query.filters is not None:
filter_expr = _to_filter_expr(query.filters)
if not query.query_embedding:
raise ValueError("Query embedding is required for querying.")
_logger.info(f"Querying index {self._index_name}")
query_args = self._query_args
if self._index_type == "HNSW" and "ef_search" in kwargs:
query_args["ef_search"] = kwargs["ef_search"]
results = self._tair_client.tvs_knnsearch(
self._index_name,
query.similarity_top_k,
query.query_embedding,
False,
filter_str=filter_expr,
**query_args,
)
results = [(k.decode(), float(s)) for k, s in results]
ids = []
nodes = []
scores = []
pipe = self._tair_client.pipeline(transaction=False)
for key, score in results:
scores.append(score)
pipe.tvs_hmget(self._index_name, key, "id", "doc_id", "text")
metadatas = pipe.execute()
for i, m in enumerate(metadatas):
# TODO: properly get the _node_conent
doc_id = m[0].decode()
node = TextNode(
text=m[2].decode(),
id_=doc_id,
embedding=None,
relationships={
NodeRelationship.SOURCE: RelatedNodeInfo(node_id=m[1].decode())
},
)
ids.append(doc_id)
nodes.append(node)
_logger.info(f"Found {len(nodes)} results for query with id {ids}")
return VectorStoreQueryResult(nodes=nodes, ids=ids, similarities=scores)
def _create_index(self) -> None:
try:
from tair import tairvector
except ImportError:
raise ValueError(
"Could not import tair-py python package. "
"Please install it with `pip install tair`."
)
_logger.info(f"Creating index {self._index_name}")
self._tair_client.tvs_create_index(
self._index_name,
self.dim,
distance_type=self._metric_type,
index_type=self._index_type,
data_type=tairvector.DataType.Float32,
**self._index_args,
)
def _index_exists(self) -> bool:
index = self._tair_client.tvs_get_index(self._index_name)
return index is not None