faiss_rag_enterprise/llama_index/vector_stores/weaviate.py

356 lines
11 KiB
Python

"""Weaviate Vector store index.
An index that is built on top of an existing vector store.
"""
import logging
from typing import Any, Dict, List, Optional, cast
from uuid import uuid4
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.schema import BaseNode
from llama_index.vector_stores.types import (
BasePydanticVectorStore,
MetadataFilters,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
)
from llama_index.vector_stores.utils import DEFAULT_TEXT_KEY
from llama_index.vector_stores.weaviate_utils import (
add_node,
class_schema_exists,
create_default_schema,
get_all_properties,
get_node_similarity,
parse_get_response,
to_node,
)
logger = logging.getLogger(__name__)
import_err_msg = (
"`weaviate` package not found, please run `pip install weaviate-client`"
)
def _transform_weaviate_filter_condition(condition: str) -> str:
"""Translate standard metadata filter op to Chroma specific spec."""
if condition == "and":
return "And"
elif condition == "or":
return "Or"
else:
raise ValueError(f"Filter condition {condition} not supported")
def _transform_weaviate_filter_operator(operator: str) -> str:
"""Translate standard metadata filter operator to Chroma specific spec."""
if operator == "!=":
return "NotEqual"
elif operator == "==":
return "Equal"
elif operator == ">":
return "GreaterThan"
elif operator == "<":
return "LessThan"
elif operator == ">=":
return "GreaterThanEqual"
elif operator == "<=":
return "LessThanEqual"
else:
raise ValueError(f"Filter operator {operator} not supported")
def _to_weaviate_filter(standard_filters: MetadataFilters) -> Dict[str, Any]:
filters_list = []
condition = standard_filters.condition or "and"
condition = _transform_weaviate_filter_condition(condition)
if standard_filters.filters:
for filter in standard_filters.filters:
value_type = "valueText"
if isinstance(filter.value, float):
value_type = "valueNumber"
elif isinstance(filter.value, int):
value_type = "valueNumber"
elif isinstance(filter.value, str) and filter.value.isnumeric():
filter.value = float(filter.value)
value_type = "valueNumber"
filters_list.append(
{
"path": filter.key,
"operator": _transform_weaviate_filter_operator(filter.operator),
value_type: filter.value,
}
)
else:
return {}
if len(filters_list) == 1:
# If there is only one filter, return it directly
return filters_list[0]
return {"operands": filters_list, "operator": condition}
class WeaviateVectorStore(BasePydanticVectorStore):
"""Weaviate vector store.
In this vector store, embeddings and docs are stored within a
Weaviate collection.
During query time, the index uses Weaviate to query for the top
k most similar nodes.
Args:
weaviate_client (weaviate.Client): WeaviateClient
instance from `weaviate-client` package
index_name (Optional[str]): name for Weaviate classes
"""
stores_text: bool = True
index_name: str
url: Optional[str]
text_key: str
auth_config: Dict[str, Any] = Field(default_factory=dict)
client_kwargs: Dict[str, Any] = Field(default_factory=dict)
_client = PrivateAttr()
def __init__(
self,
weaviate_client: Optional[Any] = None,
class_prefix: Optional[str] = None,
index_name: Optional[str] = None,
text_key: str = DEFAULT_TEXT_KEY,
auth_config: Optional[Any] = None,
client_kwargs: Optional[Dict[str, Any]] = None,
url: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize params."""
try:
import weaviate # noqa
from weaviate import AuthApiKey, Client
except ImportError:
raise ImportError(import_err_msg)
if weaviate_client is None:
if isinstance(auth_config, dict):
auth_config = AuthApiKey(**auth_config)
client_kwargs = client_kwargs or {}
self._client = Client(
url=url, auth_client_secret=auth_config, **client_kwargs
)
else:
self._client = cast(Client, weaviate_client)
# validate class prefix starts with a capital letter
if class_prefix is not None:
logger.warning("class_prefix is deprecated, please use index_name")
# legacy, kept for backward compatibility
index_name = f"{class_prefix}_Node"
index_name = index_name or f"LlamaIndex_{uuid4().hex}"
if not index_name[0].isupper():
raise ValueError(
"Index name must start with a capital letter, e.g. 'LlamaIndex'"
)
# create default schema if does not exist
if not class_schema_exists(self._client, index_name):
create_default_schema(self._client, index_name)
super().__init__(
url=url,
index_name=index_name,
text_key=text_key,
auth_config=auth_config.__dict__ if auth_config else {},
client_kwargs=client_kwargs or {},
)
@classmethod
def from_params(
cls,
url: str,
auth_config: Any,
index_name: Optional[str] = None,
text_key: str = DEFAULT_TEXT_KEY,
client_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> "WeaviateVectorStore":
"""Create WeaviateVectorStore from config."""
try:
import weaviate # noqa
from weaviate import AuthApiKey, Client # noqa
except ImportError:
raise ImportError(import_err_msg)
client_kwargs = client_kwargs or {}
weaviate_client = Client(
url=url, auth_client_secret=auth_config, **client_kwargs
)
return cls(
weaviate_client=weaviate_client,
url=url,
auth_config=auth_config.__dict__,
client_kwargs=client_kwargs,
index_name=index_name,
text_key=text_key,
**kwargs,
)
@classmethod
def class_name(cls) -> str:
return "WeaviateVectorStore"
@property
def client(self) -> Any:
"""Get client."""
return self._client
def add(
self,
nodes: List[BaseNode],
**add_kwargs: Any,
) -> List[str]:
"""Add nodes to index.
Args:
nodes: List[BaseNode]: list of nodes with embeddings
"""
ids = [r.node_id for r in nodes]
with self._client.batch as batch:
for node in nodes:
add_node(
self._client,
node,
self.index_name,
batch=batch,
text_key=self.text_key,
)
return ids
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
Delete nodes using with ref_doc_id.
Args:
ref_doc_id (str): The doc_id of the document to delete.
"""
where_filter = {
"path": ["ref_doc_id"],
"operator": "Equal",
"valueText": ref_doc_id,
}
if "filter" in delete_kwargs and delete_kwargs["filter"] is not None:
where_filter = {
"operator": "And",
"operands": [where_filter, delete_kwargs["filter"]], # type: ignore
}
query = (
self._client.query.get(self.index_name)
.with_additional(["id"])
.with_where(where_filter)
.with_limit(10000) # 10,000 is the max weaviate can fetch
)
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[self.index_name]
for entry in entries:
self._client.data_object.delete(entry["_additional"]["id"], self.index_name)
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes."""
all_properties = get_all_properties(self._client, self.index_name)
# build query
query_builder = self._client.query.get(self.index_name, all_properties)
# list of documents to constrain search
if query.doc_ids:
filter_with_doc_ids = {
"operator": "Or",
"operands": [
{"path": ["doc_id"], "operator": "Equal", "valueText": doc_id}
for doc_id in query.doc_ids
],
}
query_builder = query_builder.with_where(filter_with_doc_ids)
if query.node_ids:
filter_with_node_ids = {
"operator": "Or",
"operands": [
{"path": ["id"], "operator": "Equal", "valueText": node_id}
for node_id in query.node_ids
],
}
query_builder = query_builder.with_where(filter_with_node_ids)
query_builder = query_builder.with_additional(
["id", "vector", "distance", "score"]
)
vector = query.query_embedding
similarity_key = "distance"
if query.mode == VectorStoreQueryMode.DEFAULT:
logger.debug("Using vector search")
if vector is not None:
query_builder = query_builder.with_near_vector(
{
"vector": vector,
}
)
elif query.mode == VectorStoreQueryMode.HYBRID:
logger.debug(f"Using hybrid search with alpha {query.alpha}")
similarity_key = "score"
if vector is not None and query.query_str:
query_builder = query_builder.with_hybrid(
query=query.query_str,
alpha=query.alpha,
vector=vector,
)
if query.filters is not None:
filter = _to_weaviate_filter(query.filters)
query_builder = query_builder.with_where(filter)
elif "filter" in kwargs and kwargs["filter"] is not None:
query_builder = query_builder.with_where(kwargs["filter"])
query_builder = query_builder.with_limit(query.similarity_top_k)
logger.debug(f"Using limit of {query.similarity_top_k}")
# execute query
query_result = query_builder.do()
# parse results
parsed_result = parse_get_response(query_result)
entries = parsed_result[self.index_name]
similarities = []
nodes: List[BaseNode] = []
node_ids = []
for i, entry in enumerate(entries):
if i < query.similarity_top_k:
similarities.append(get_node_similarity(entry, similarity_key))
nodes.append(to_node(entry, text_key=self.text_key))
node_ids.append(nodes[-1].node_id)
else:
break
return VectorStoreQueryResult(
nodes=nodes, ids=node_ids, similarities=similarities
)