faiss_rag_enterprise/llama_index/vector_stores/redis.py

467 lines
17 KiB
Python

"""Redis Vector store index.
An index that is built on top of an existing vector store.
"""
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import fsspec
from llama_index.bridge.pydantic import PrivateAttr
from llama_index.readers.redis.utils import (
TokenEscaper,
array_to_buffer,
check_redis_modules_exist,
convert_bytes,
get_redis_query,
)
from llama_index.schema import (
BaseNode,
MetadataMode,
NodeRelationship,
RelatedNodeInfo,
TextNode,
)
from llama_index.vector_stores.types import (
BasePydanticVectorStore,
MetadataFilters,
VectorStoreQuery,
VectorStoreQueryResult,
)
from llama_index.vector_stores.utils import metadata_dict_to_node, node_to_metadata_dict
_logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from redis.client import Redis as RedisType
from redis.commands.search.field import VectorField
class RedisVectorStore(BasePydanticVectorStore):
stores_text = True
stores_node = True
flat_metadata = False
_tokenizer: Any = PrivateAttr()
_redis_client: Any = PrivateAttr()
_prefix: str = PrivateAttr()
_index_name: str = PrivateAttr()
_index_args: Dict[str, Any] = PrivateAttr()
_metadata_fields: List[str] = PrivateAttr()
_overwrite: bool = PrivateAttr()
_vector_field: str = PrivateAttr()
_vector_key: str = PrivateAttr()
def __init__(
self,
index_name: str,
index_prefix: str = "llama_index",
prefix_ending: str = "/vector",
index_args: Optional[Dict[str, Any]] = None,
metadata_fields: Optional[List[str]] = None,
redis_url: str = "redis://localhost:6379",
overwrite: bool = False,
**kwargs: Any,
) -> None:
"""Initialize RedisVectorStore.
For index arguments that can be passed to RediSearch, see
https://redis.io/docs/stack/search/reference/vectors/
The index arguments will depend on the index type chosen. There
are two available index types
- FLAT: a flat index that uses brute force search
- HNSW: a hierarchical navigable small world graph index
Args:
index_name (str): Name of the index.
index_prefix (str): Prefix for the index. Defaults to "llama_index".
The actual prefix used by Redis will be
"{index_prefix}{prefix_ending}".
prefix_ending (str): Prefix ending for the index. Be careful when
changing this: https://github.com/jerryjliu/llama_index/pull/6665.
Defaults to "/vector".
index_args (Dict[str, Any]): Arguments for the index. Defaults to None.
metadata_fields (List[str]): List of metadata fields to store in the index
(only supports TAG fields).
redis_url (str): URL for the redis instance.
Defaults to "redis://localhost:6379".
overwrite (bool): Whether to overwrite the index if it already exists.
Defaults to False.
kwargs (Any): Additional arguments to pass to the redis client.
Raises:
ValueError: If redis-py is not installed
ValueError: If RediSearch is not installed
Examples:
>>> from llama_index.vector_stores.redis import RedisVectorStore
>>> # Create a RedisVectorStore
>>> vector_store = RedisVectorStore(
>>> index_name="my_index",
>>> index_prefix="llama_index",
>>> index_args={"algorithm": "HNSW", "m": 16, "ef_construction": 200,
"distance_metric": "cosine"},
>>> redis_url="redis://localhost:6379/",
>>> overwrite=True)
"""
try:
import redis
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
try:
# connect to redis from url
self._redis_client = redis.from_url(redis_url, **kwargs)
# check if redis has redisearch module installed
check_redis_modules_exist(self._redis_client)
except ValueError as e:
raise ValueError(f"Redis failed to connect: {e}")
# index identifiers
self._prefix = index_prefix + prefix_ending
self._index_name = index_name
self._index_args = index_args if index_args is not None else {}
self._metadata_fields = metadata_fields if metadata_fields is not None else []
self._overwrite = overwrite
self._vector_field = str(self._index_args.get("vector_field", "vector"))
self._vector_key = str(self._index_args.get("vector_key", "vector"))
self._tokenizer = TokenEscaper()
super().__init__()
@property
def client(self) -> "RedisType":
"""Return the redis client instance."""
return self._redis_client
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
"""Add nodes to the index.
Args:
nodes (List[BaseNode]): List of nodes with embeddings
Returns:
List[str]: List of ids of the documents added to the index.
Raises:
ValueError: If the index already exists and overwrite is False.
"""
# check to see if empty document list was passed
if len(nodes) == 0:
return []
# set vector dim for creation if index doesn't exist
self._index_args["dims"] = len(nodes[0].get_embedding())
if self._index_exists():
if self._overwrite:
self.delete_index()
self._create_index()
else:
logging.info(f"Adding document to existing index {self._index_name}")
else:
self._create_index()
ids = []
for node in nodes:
mapping = {
"id": node.node_id,
"doc_id": node.ref_doc_id,
"text": node.get_content(metadata_mode=MetadataMode.NONE),
self._vector_key: array_to_buffer(node.get_embedding()),
}
additional_metadata = node_to_metadata_dict(
node, remove_text=True, flat_metadata=self.flat_metadata
)
mapping.update(additional_metadata)
ids.append(node.node_id)
key = "_".join([self._prefix, str(node.node_id)])
self._redis_client.hset(key, mapping=mapping) # type: ignore
_logger.info(f"Added {len(ids)} documents to index {self._index_name}")
return ids
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
Delete nodes using with ref_doc_id.
Args:
ref_doc_id (str): The doc_id of the document to delete.
"""
# use tokenizer to escape dashes in query
query_str = "@doc_id:{%s}" % self._tokenizer.escape(ref_doc_id)
# find all documents that match a doc_id
results = self._redis_client.ft(self._index_name).search(query_str)
if len(results.docs) == 0:
# don't raise an error but warn the user that document wasn't found
# could be a result of eviction policy
_logger.warning(
f"Document with doc_id {ref_doc_id} not found "
f"in index {self._index_name}"
)
return
for doc in results.docs:
self._redis_client.delete(doc.id)
_logger.info(
f"Deleted {len(results.docs)} documents from index {self._index_name}"
)
def delete_index(self) -> None:
"""Delete the index and all documents."""
_logger.info(f"Deleting index {self._index_name}")
self._redis_client.ft(self._index_name).dropindex(delete_documents=True)
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query the index.
Args:
query (VectorStoreQuery): query object
Returns:
VectorStoreQueryResult: query result
Raises:
ValueError: If query.query_embedding is None.
redis.exceptions.RedisError: If there is an error querying the index.
redis.exceptions.TimeoutError: If there is a timeout querying the index.
ValueError: If no documents are found when querying the index.
"""
from redis.exceptions import RedisError
from redis.exceptions import TimeoutError as RedisTimeoutError
return_fields = [
"id",
"doc_id",
"text",
self._vector_key,
"vector_score",
"_node_content",
]
filters = _to_redis_filters(query.filters) if query.filters is not None else "*"
_logger.info(f"Using filters: {filters}")
redis_query = get_redis_query(
return_fields=return_fields,
top_k=query.similarity_top_k,
vector_field=self._vector_field,
filters=filters,
)
if not query.query_embedding:
raise ValueError("Query embedding is required for querying.")
query_params = {
"vector": array_to_buffer(query.query_embedding),
}
_logger.info(f"Querying index {self._index_name}")
try:
results = self._redis_client.ft(self._index_name).search(
redis_query, query_params=query_params # type: ignore
)
except RedisTimeoutError as e:
_logger.error(f"Query timed out on {self._index_name}: {e}")
raise
except RedisError as e:
_logger.error(f"Error querying {self._index_name}: {e}")
raise
if len(results.docs) == 0:
raise ValueError(
f"No docs found on index '{self._index_name}' with "
f"prefix '{self._prefix}' and filters '{filters}'. "
"* Did you originally create the index with a different prefix? "
"* Did you index your metadata fields when you created the index?"
)
ids = []
nodes = []
scores = []
for doc in results.docs:
try:
node = metadata_dict_to_node({"_node_content": doc._node_content})
node.text = doc.text
except Exception:
# TODO: Legacy support for old metadata format
node = TextNode(
text=doc.text,
id_=doc.id,
embedding=None,
relationships={
NodeRelationship.SOURCE: RelatedNodeInfo(node_id=doc.doc_id)
},
)
ids.append(doc.id.replace(self._prefix + "_", ""))
nodes.append(node)
scores.append(1 - float(doc.vector_score))
_logger.info(f"Found {len(nodes)} results for query with id {ids}")
return VectorStoreQueryResult(nodes=nodes, ids=ids, similarities=scores)
def persist(
self,
persist_path: str,
fs: Optional[fsspec.AbstractFileSystem] = None,
in_background: bool = True,
) -> None:
"""Persist the vector store to disk.
Args:
persist_path (str): Path to persist the vector store to. (doesn't apply)
in_background (bool, optional): Persist in background. Defaults to True.
fs (fsspec.AbstractFileSystem, optional): Filesystem to persist to.
(doesn't apply)
Raises:
redis.exceptions.RedisError: If there is an error
persisting the index to disk.
"""
from redis.exceptions import RedisError
try:
if in_background:
_logger.info("Saving index to disk in background")
self._redis_client.bgsave()
else:
_logger.info("Saving index to disk")
self._redis_client.save()
except RedisError as e:
_logger.error(f"Error saving index to disk: {e}")
raise
def _create_index(self) -> None:
# should never be called outside class and hence should not raise importerror
from redis.commands.search.field import TagField, TextField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
# Create Index
default_fields = [
TextField("text", weight=1.0),
TagField("doc_id", sortable=False),
TagField("id", sortable=False),
]
# add vector field to list of index fields. Create lazily to allow user
# to specify index and search attributes in creation.
fields = [
*default_fields,
self._create_vector_field(self._vector_field, **self._index_args),
]
# add metadata fields to list of index fields or we won't be able to search them
for metadata_field in self._metadata_fields:
# TODO: allow addition of text fields as metadata
# TODO: make sure we're preventing overwriting other keys (e.g. text,
# doc_id, id, and other vector fields)
fields.append(TagField(metadata_field, sortable=False))
_logger.info(f"Creating index {self._index_name}")
self._redis_client.ft(self._index_name).create_index(
fields=fields,
definition=IndexDefinition(
prefix=[self._prefix], index_type=IndexType.HASH
), # TODO support JSON
)
def _index_exists(self) -> bool:
# use FT._LIST to check if index exists
indices = convert_bytes(self._redis_client.execute_command("FT._LIST"))
return self._index_name in indices
def _create_vector_field(
self,
name: str,
dims: int = 1536,
algorithm: str = "FLAT",
datatype: str = "FLOAT32",
distance_metric: str = "COSINE",
initial_cap: int = 20000,
block_size: int = 1000,
m: int = 16,
ef_construction: int = 200,
ef_runtime: int = 10,
epsilon: float = 0.8,
**kwargs: Any,
) -> "VectorField":
"""Create a RediSearch VectorField.
Args:
name (str): The name of the field.
algorithm (str): The algorithm used to index the vector.
dims (int): The dimensionality of the vector.
datatype (str): The type of the vector. default: FLOAT32
distance_metric (str): The distance metric used to compare vectors.
initial_cap (int): The initial capacity of the index.
block_size (int): The block size of the index.
m (int): The number of outgoing edges in the HNSW graph.
ef_construction (int): Number of maximum allowed potential outgoing edges
candidates for each node in the graph,
during the graph building.
ef_runtime (int): The umber of maximum top candidates to hold during the
KNN search
Returns:
A RediSearch VectorField.
"""
from redis import DataError
from redis.commands.search.field import VectorField
try:
if algorithm.upper() == "HNSW":
return VectorField(
name,
"HNSW",
{
"TYPE": datatype.upper(),
"DIM": dims,
"DISTANCE_METRIC": distance_metric.upper(),
"INITIAL_CAP": initial_cap,
"M": m,
"EF_CONSTRUCTION": ef_construction,
"EF_RUNTIME": ef_runtime,
"EPSILON": epsilon,
},
)
else:
return VectorField(
name,
"FLAT",
{
"TYPE": datatype.upper(),
"DIM": dims,
"DISTANCE_METRIC": distance_metric.upper(),
"INITIAL_CAP": initial_cap,
"BLOCK_SIZE": block_size,
},
)
except DataError as e:
raise ValueError(
f"Failed to create Redis index vector field with error: {e}"
)
# currently only supports exact tag match - {} denotes a tag
# must create the index with the correct metadata field before using a field as a
# filter, or it will return no results
def _to_redis_filters(metadata_filters: MetadataFilters) -> str:
tokenizer = TokenEscaper()
filter_strings = []
for filter in metadata_filters.legacy_filters():
# adds quotes around the value to ensure that the filter is treated as an
# exact match
filter_string = f"@{filter.key}:{{{tokenizer.escape(str(filter.value))}}}"
filter_strings.append(filter_string)
joined_filter_strings = " & ".join(filter_strings)
return f"({joined_filter_strings})"