faiss_rag_enterprise/llama_index/vector_stores/rocksetdb.py

315 lines
12 KiB
Python

from __future__ import annotations
from enum import Enum
from os import getenv
from time import sleep
from types import ModuleType
from typing import Any, List, Type, TypeVar
from llama_index.schema import BaseNode
from llama_index.vector_stores.types import (
VectorStore,
VectorStoreQuery,
VectorStoreQueryResult,
)
from llama_index.vector_stores.utils import (
DEFAULT_EMBEDDING_KEY,
DEFAULT_TEXT_KEY,
metadata_dict_to_node,
node_to_metadata_dict,
)
T = TypeVar("T", bound="RocksetVectorStore")
def _get_rockset() -> ModuleType:
"""Gets the rockset module and raises an ImportError if
the rockset package hasn't been installed.
Returns:
rockset module (ModuleType)
"""
try:
import rockset
except ImportError:
raise ImportError("Please install rockset with `pip install rockset`")
return rockset
def _get_client(api_key: str | None, api_server: str | None, client: Any | None) -> Any:
"""Returns the passed in client object if valid, else
constructs and returns one.
Returns:
The rockset client object (rockset.RocksetClient)
"""
rockset = _get_rockset()
if client:
if type(client) is not rockset.RocksetClient:
raise ValueError("Parameter `client` must be of type rockset.RocksetClient")
elif not api_key and not getenv("ROCKSET_API_KEY"):
raise ValueError(
"Parameter `client`, `api_key` or env var `ROCKSET_API_KEY` must be set"
)
else:
client = rockset.RocksetClient(
api_key=api_key or getenv("ROCKSET_API_KEY"),
host=api_server or getenv("ROCKSET_API_SERVER"),
)
return client
class RocksetVectorStore(VectorStore):
stores_text: bool = True
is_embedding_query: bool = True
flat_metadata: bool = False
class DistanceFunc(Enum):
COSINE_SIM = "COSINE_SIM"
EUCLIDEAN_DIST = "EUCLIDEAN_DIST"
DOT_PRODUCT = "DOT_PRODUCT"
def __init__(
self,
collection: str,
client: Any | None = None,
text_key: str = DEFAULT_TEXT_KEY,
embedding_col: str = DEFAULT_EMBEDDING_KEY,
metadata_col: str = "metadata",
workspace: str = "commons",
api_server: str | None = None,
api_key: str | None = None,
distance_func: DistanceFunc = DistanceFunc.COSINE_SIM,
) -> None:
"""Rockset Vector Store Data container.
Args:
collection (str): The name of the collection of vectors
client (Optional[Any]): Rockset client object
text_key (str): The key to the text of nodes
(default: llama_index.vector_stores.utils.DEFAULT_TEXT_KEY)
embedding_col (str): The DB column containing embeddings
(default: llama_index.vector_stores.utils.DEFAULT_EMBEDDING_KEY))
metadata_col (str): The DB column containing node metadata
(default: "metadata")
workspace (str): The workspace containing the collection of vectors
(default: "commons")
api_server (Optional[str]): The Rockset API server to use
api_key (Optional[str]): The Rockset API key to use
distance_func (RocksetVectorStore.DistanceFunc): The metric to measure
vector relationship
(default: RocksetVectorStore.DistanceFunc.COSINE_SIM)
"""
self.rockset = _get_rockset()
self.rs = _get_client(api_key, api_server, client)
self.workspace = workspace
self.collection = collection
self.text_key = text_key
self.embedding_col = embedding_col
self.metadata_col = metadata_col
self.distance_func = distance_func
self.distance_order = (
"ASC" if distance_func is distance_func.EUCLIDEAN_DIST else "DESC"
)
try:
self.rs.set_application("llama_index")
except AttributeError:
# set_application method does not exist.
# rockset version < 2.1.0
pass
@property
def client(self) -> Any:
return self.rs
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
"""Stores vectors in the collection.
Args:
nodes (List[BaseNode]): List of nodes with embeddings
Returns:
Stored node IDs (List[str])
"""
return [
row["_id"]
for row in self.rs.Documents.add_documents(
collection=self.collection,
workspace=self.workspace,
data=[
{
self.embedding_col: node.get_embedding(),
"_id": node.node_id,
self.metadata_col: node_to_metadata_dict(
node, text_field=self.text_key
),
}
for node in nodes
],
).data
]
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""Deletes nodes stored in the collection by their ref_doc_id.
Args:
ref_doc_id (str): The ref_doc_id of the document
whose nodes are to be deleted
"""
self.rs.Documents.delete_documents(
collection=self.collection,
workspace=self.workspace,
data=[
self.rockset.models.DeleteDocumentsRequestData(id=row["_id"])
for row in self.rs.sql(
f"""
SELECT
_id
FROM
"{self.workspace}"."{self.collection}" x
WHERE
x.{self.metadata_col}.ref_doc_id=:ref_doc_id
""",
params={"ref_doc_id": ref_doc_id},
).results
],
)
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Gets nodes relevant to a query.
Args:
query (llama_index.vector_stores.types.VectorStoreQuery): The query
similarity_col (Optional[str]): The column to select the cosine
similarity as (default: "_similarity")
Returns:
query results (llama_index.vector_stores.types.VectorStoreQueryResult)
"""
similarity_col = kwargs.get("similarity_col", "_similarity")
res = self.rs.sql(
f"""
SELECT
_id,
{self.metadata_col}
{
f''', {self.distance_func.value}(
{query.query_embedding},
{self.embedding_col}
)
AS {similarity_col}'''
if query.query_embedding
else ''
}
FROM
"{self.workspace}"."{self.collection}" x
{"WHERE" if query.node_ids or (query.filters and len(query.filters.legacy_filters()) > 0) else ""} {
f'''({
' OR '.join([
f"_id='{node_id}'" for node_id in query.node_ids
])
})''' if query.node_ids else ""
} {
f''' {'AND' if query.node_ids else ''} ({
' AND '.join([
f"x.{self.metadata_col}.{filter.key}=:{filter.key}"
for filter
in query.filters.legacy_filters()
])
})''' if query.filters else ""
}
ORDER BY
{similarity_col} {self.distance_order}
LIMIT
{query.similarity_top_k}
""",
params={
filter.key: filter.value for filter in query.filters.legacy_filters()
}
if query.filters
else {},
)
similarities: List[float] | None = [] if query.query_embedding else None
nodes, ids = [], []
for row in res.results:
if similarities is not None:
similarities.append(row[similarity_col])
nodes.append(metadata_dict_to_node(row[self.metadata_col]))
ids.append(row["_id"])
return VectorStoreQueryResult(similarities=similarities, nodes=nodes, ids=ids)
@classmethod
def with_new_collection(
cls: Type[T], dimensions: int | None = None, **rockset_vector_store_args: Any
) -> RocksetVectorStore:
"""Creates a new collection and returns its RocksetVectorStore.
Args:
dimensions (Optional[int]): The length of the vectors to enforce
in the collection's ingest transformation. By default, the
collection will do no vector enforcement.
collection (str): The name of the collection to be created
client (Optional[Any]): Rockset client object
workspace (str): The workspace containing the collection to be
created (default: "commons")
text_key (str): The key to the text of nodes
(default: llama_index.vector_stores.utils.DEFAULT_TEXT_KEY)
embedding_col (str): The DB column containing embeddings
(default: llama_index.vector_stores.utils.DEFAULT_EMBEDDING_KEY))
metadata_col (str): The DB column containing node metadata
(default: "metadata")
api_server (Optional[str]): The Rockset API server to use
api_key (Optional[str]): The Rockset API key to use
distance_func (RocksetVectorStore.DistanceFunc): The metric to measure
vector relationship
(default: RocksetVectorStore.DistanceFunc.COSINE_SIM)
"""
client = rockset_vector_store_args["client"] = _get_client(
api_key=rockset_vector_store_args.get("api_key"),
api_server=rockset_vector_store_args.get("api_server"),
client=rockset_vector_store_args.get("client"),
)
collection_args = {
"workspace": rockset_vector_store_args.get("workspace", "commons"),
"name": rockset_vector_store_args.get("collection"),
}
embeddings_col = rockset_vector_store_args.get(
"embeddings_col", DEFAULT_EMBEDDING_KEY
)
if dimensions:
collection_args[
"field_mapping_query"
] = _get_rockset().model.field_mapping_query.FieldMappingQuery(
sql=f"""
SELECT
*, VECTOR_ENFORCE(
{embeddings_col},
{dimensions},
'float'
) AS {embeddings_col}
FROM
_input
"""
)
client.Collections.create_s3_collection(**collection_args) # create collection
while (
client.Collections.get(
collection=rockset_vector_store_args.get("collection")
).data.status
!= "READY"
): # wait until collection is ready
sleep(0.1)
# TODO: add async, non-blocking method collection creation
return cls(
**dict(
filter( # filter out None args
lambda arg: arg[1] is not None, rockset_vector_store_args.items()
)
)
)