faiss_rag_enterprise/llama_index/vector_stores/myscale.py

321 lines
11 KiB
Python

"""MyScale vector store.
An index that is built on top of an existing MyScale cluster.
"""
import json
import logging
from typing import Any, Dict, List, Optional, cast
from llama_index.readers.myscale import (
MyScaleSettings,
escape_str,
format_list_to_string,
)
from llama_index.schema import (
BaseNode,
MetadataMode,
NodeRelationship,
RelatedNodeInfo,
TextNode,
)
from llama_index.service_context import ServiceContext
from llama_index.utils import iter_batch
from llama_index.vector_stores.types import (
VectorStore,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
)
logger = logging.getLogger(__name__)
class MyScaleVectorStore(VectorStore):
"""MyScale Vector Store.
In this vector store, embeddings and docs are stored within an existing
MyScale cluster.
During query time, the index uses MyScale to query for the top
k most similar nodes.
Args:
myscale_client (httpclient): clickhouse-connect httpclient of
an existing MyScale cluster.
table (str, optional): The name of the MyScale table
where data will be stored. Defaults to "llama_index".
database (str, optional): The name of the MyScale database
where data will be stored. Defaults to "default".
index_type (str, optional): The type of the MyScale vector index.
Defaults to "IVFFLAT".
metric (str, optional): The metric type of the MyScale vector index.
Defaults to "cosine".
batch_size (int, optional): the size of documents to insert. Defaults to 32.
index_params (dict, optional): The index parameters for MyScale.
Defaults to None.
search_params (dict, optional): The search parameters for a MyScale query.
Defaults to None.
service_context (ServiceContext, optional): Vector store service context.
Defaults to None
"""
stores_text: bool = True
_index_existed: bool = False
metadata_column: str = "metadata"
AMPLIFY_RATIO_LE5 = 100
AMPLIFY_RATIO_GT5 = 20
AMPLIFY_RATIO_GT50 = 10
def __init__(
self,
myscale_client: Optional[Any] = None,
table: str = "llama_index",
database: str = "default",
index_type: str = "MSTG",
metric: str = "cosine",
batch_size: int = 32,
index_params: Optional[dict] = None,
search_params: Optional[dict] = None,
service_context: Optional[ServiceContext] = None,
**kwargs: Any,
) -> None:
"""Initialize params."""
import_err_msg = """
`clickhouse_connect` package not found,
please run `pip install clickhouse-connect`
"""
try:
from clickhouse_connect.driver.httpclient import HttpClient
except ImportError:
raise ImportError(import_err_msg)
if myscale_client is None:
raise ValueError("Missing MyScale client!")
self._client = cast(HttpClient, myscale_client)
self.config = MyScaleSettings(
table=table,
database=database,
index_type=index_type,
metric=metric,
batch_size=batch_size,
index_params=index_params,
search_params=search_params,
**kwargs,
)
# schema column name, type, and construct format method
self.column_config: Dict = {
"id": {"type": "String", "extract_func": lambda x: x.node_id},
"doc_id": {"type": "String", "extract_func": lambda x: x.ref_doc_id},
"text": {
"type": "String",
"extract_func": lambda x: escape_str(
x.get_content(metadata_mode=MetadataMode.NONE) or ""
),
},
"vector": {
"type": "Array(Float32)",
"extract_func": lambda x: format_list_to_string(x.get_embedding()),
},
"node_info": {
"type": "JSON",
"extract_func": lambda x: json.dumps(x.node_info),
},
"metadata": {
"type": "JSON",
"extract_func": lambda x: json.dumps(x.metadata),
},
}
if service_context is not None:
service_context = cast(ServiceContext, service_context)
dimension = len(
service_context.embed_model.get_query_embedding("try this out")
)
self._create_index(dimension)
@property
def client(self) -> Any:
"""Get client."""
return self._client
def _create_index(self, dimension: int) -> None:
index_params = (
", " + ",".join([f"'{k}={v}'" for k, v in self.config.index_params.items()])
if self.config.index_params
else ""
)
schema_ = f"""
CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
{",".join([f'{k} {v["type"]}' for k, v in self.column_config.items()])},
CONSTRAINT vector_length CHECK length(vector) = {dimension},
VECTOR INDEX {self.config.table}_index vector TYPE
{self.config.index_type}('metric_type={self.config.metric}'{index_params})
) ENGINE = MergeTree ORDER BY id
"""
self.dim = dimension
self._client.command("SET allow_experimental_object_type=1")
self._client.command(schema_)
self._index_existed = True
def _build_insert_statement(
self,
values: List[BaseNode],
) -> str:
_data = []
for item in values:
item_value_str = ",".join(
[
f"'{column['extract_func'](item)}'"
for column in self.column_config.values()
]
)
_data.append(f"({item_value_str})")
return f"""
INSERT INTO TABLE
{self.config.database}.{self.config.table}({",".join(self.column_config.keys())})
VALUES
{','.join(_data)}
"""
def _build_hybrid_search_statement(
self, stage_one_sql: str, query_str: str, similarity_top_k: int
) -> str:
terms_pattern = [f"(?i){x}" for x in query_str.split(" ")]
column_keys = self.column_config.keys()
return (
f"SELECT {','.join(filter(lambda k: k != 'vector', column_keys))}, "
f"dist FROM ({stage_one_sql}) tempt "
f"ORDER BY length(multiMatchAllIndices(text, {terms_pattern})) "
f"AS distance1 DESC, "
f"log(1 + countMatches(text, '(?i)({query_str.replace(' ', '|')})')) "
f"AS distance2 DESC limit {similarity_top_k}"
)
def _append_meta_filter_condition(
self, where_str: Optional[str], exact_match_filter: list
) -> str:
filter_str = " AND ".join(
f"JSONExtractString(toJSONString("
f"{self.metadata_column}), '{filter_item.key}') "
f"= '{filter_item.value}'"
for filter_item in exact_match_filter
)
if where_str is None:
where_str = filter_str
else:
where_str = " AND " + filter_str
return where_str
def add(
self,
nodes: List[BaseNode],
**add_kwargs: Any,
) -> List[str]:
"""Add nodes to index.
Args:
nodes: List[BaseNode]: list of nodes with embeddings
"""
if not nodes:
return []
if not self._index_existed:
self._create_index(len(nodes[0].get_embedding()))
for result_batch in iter_batch(nodes, self.config.batch_size):
insert_statement = self._build_insert_statement(values=result_batch)
self._client.command(insert_statement)
return [result.node_id for result in nodes]
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
Delete nodes using with ref_doc_id.
Args:
ref_doc_id (str): The doc_id of the document to delete.
"""
self._client.command(
f"DELETE FROM {self.config.database}.{self.config.table} "
f"where doc_id='{ref_doc_id}'"
)
def drop(self) -> None:
"""Drop MyScale Index and table."""
self._client.command(
f"DROP TABLE IF EXISTS {self.config.database}.{self.config.table}"
)
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
Args:
query (VectorStoreQuery): query
"""
query_embedding = cast(List[float], query.query_embedding)
where_str = (
f"doc_id in {format_list_to_string(query.doc_ids)}"
if query.doc_ids
else None
)
if query.filters is not None and len(query.filters.legacy_filters()) > 0:
where_str = self._append_meta_filter_condition(
where_str, query.filters.legacy_filters()
)
# build query sql
query_statement = self.config.build_query_statement(
query_embed=query_embedding,
where_str=where_str,
limit=query.similarity_top_k,
)
if query.mode == VectorStoreQueryMode.HYBRID and query.query_str is not None:
amplify_ratio = self.AMPLIFY_RATIO_LE5
if 5 < query.similarity_top_k < 50:
amplify_ratio = self.AMPLIFY_RATIO_GT5
if query.similarity_top_k > 50:
amplify_ratio = self.AMPLIFY_RATIO_GT50
query_statement = self._build_hybrid_search_statement(
self.config.build_query_statement(
query_embed=query_embedding,
where_str=where_str,
limit=query.similarity_top_k * amplify_ratio,
),
query.query_str,
query.similarity_top_k,
)
logger.debug(f"hybrid query_statement={query_statement}")
nodes = []
ids = []
similarities = []
for r in self._client.query(query_statement).named_results():
start_char_idx = None
end_char_idx = None
if isinstance(r["node_info"], dict):
start_char_idx = r["node_info"].get("start", None)
end_char_idx = r["node_info"].get("end", None)
node = TextNode(
id_=r["id"],
text=r["text"],
metadata=r["metadata"],
start_char_idx=start_char_idx,
end_char_idx=end_char_idx,
relationships={
NodeRelationship.SOURCE: RelatedNodeInfo(node_id=r["id"])
},
)
nodes.append(node)
similarities.append(r["dist"])
ids.append(r["id"])
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)