faiss_rag_enterprise/llama_index/vector_stores/milvus.py

341 lines
13 KiB
Python

"""Milvus vector store index.
An index that is built within Milvus.
"""
import logging
from typing import Any, Dict, List, Optional, Union
from llama_index.schema import BaseNode, TextNode
from llama_index.vector_stores.types import (
MetadataFilters,
VectorStore,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
)
from llama_index.vector_stores.utils import (
DEFAULT_DOC_ID_KEY,
DEFAULT_EMBEDDING_KEY,
metadata_dict_to_node,
node_to_metadata_dict,
)
logger = logging.getLogger(__name__)
MILVUS_ID_FIELD = "id"
def _to_milvus_filter(standard_filters: MetadataFilters) -> List[str]:
"""Translate standard metadata filters to Milvus specific spec."""
filters = []
for filter in standard_filters.legacy_filters():
if isinstance(filter.value, str):
filters.append(str(filter.key) + " == " + '"' + str(filter.value) + '"')
else:
filters.append(str(filter.key) + " == " + str(filter.value))
return filters
class MilvusVectorStore(VectorStore):
"""The Milvus Vector Store.
In this vector store we store the text, its embedding and
a its metadata in a Milvus collection. This implementation
allows the use of an already existing collection.
It also supports creating a new one if the collection doesn't
exist or if `overwrite` is set to True.
Args:
uri (str, optional): The URI to connect to, comes in the form of
"http://address:port".
token (str, optional): The token for log in. Empty if not using rbac, if
using rbac it will most likely be "username:password".
collection_name (str, optional): The name of the collection where data will be
stored. Defaults to "llamalection".
dim (int, optional): The dimension of the embedding vectors for the collection.
Required if creating a new collection.
embedding_field (str, optional): The name of the embedding field for the
collection, defaults to DEFAULT_EMBEDDING_KEY.
doc_id_field (str, optional): The name of the doc_id field for the collection,
defaults to DEFAULT_DOC_ID_KEY.
similarity_metric (str, optional): The similarity metric to use,
currently supports IP and L2.
consistency_level (str, optional): Which consistency level to use for a newly
created collection. Defaults to "Strong".
overwrite (bool, optional): Whether to overwrite existing collection with same
name. Defaults to False.
text_key (str, optional): What key text is stored in in the passed collection.
Used when bringing your own collection. Defaults to None.
index_config (dict, optional): The configuration used for building the
Milvus index. Defaults to None.
search_config (dict, optional): The configuration used for searching
the Milvus index. Note that this must be compatible with the index
type specified by `index_config`. Defaults to None.
Raises:
ImportError: Unable to import `pymilvus`.
MilvusException: Error communicating with Milvus, more can be found in logging
under Debug.
Returns:
MilvusVectorstore: Vectorstore that supports add, delete, and query.
"""
stores_text: bool = True
stores_node: bool = True
def __init__(
self,
uri: str = "http://localhost:19530",
token: str = "",
collection_name: str = "llamalection",
dim: Optional[int] = None,
embedding_field: str = DEFAULT_EMBEDDING_KEY,
doc_id_field: str = DEFAULT_DOC_ID_KEY,
similarity_metric: str = "IP",
consistency_level: str = "Strong",
overwrite: bool = False,
text_key: Optional[str] = None,
index_config: Optional[dict] = None,
search_config: Optional[dict] = None,
**kwargs: Any,
) -> None:
"""Init params."""
import_err_msg = (
"`pymilvus` package not found, please run `pip install pymilvus`"
)
try:
import pymilvus # noqa
except ImportError:
raise ImportError(import_err_msg)
from pymilvus import Collection, MilvusClient
self.collection_name = collection_name
self.dim = dim
self.embedding_field = embedding_field
self.doc_id_field = doc_id_field
self.consistency_level = consistency_level
self.overwrite = overwrite
self.text_key = text_key
self.index_config: Dict[str, Any] = index_config.copy() if index_config else {}
# Note: The search configuration is set at construction to avoid having
# to change the API for usage of the vector store (i.e. to pass the
# search config along with the rest of the query).
self.search_config: Dict[str, Any] = (
search_config.copy() if search_config else {}
)
# Select the similarity metric
if similarity_metric.lower() in ("ip"):
self.similarity_metric = "IP"
elif similarity_metric.lower() in ("l2", "euclidean"):
self.similarity_metric = "L2"
# Connect to Milvus instance
self.milvusclient = MilvusClient(
uri=uri,
token=token,
**kwargs, # pass additional arguments such as server_pem_path
)
# Delete previous collection if overwriting
if self.overwrite and self.collection_name in self.client.list_collections():
self.milvusclient.drop_collection(self.collection_name)
# Create the collection if it does not exist
if self.collection_name not in self.client.list_collections():
if self.dim is None:
raise ValueError("Dim argument required for collection creation.")
self.milvusclient.create_collection(
collection_name=self.collection_name,
dimension=self.dim,
primary_field_name=MILVUS_ID_FIELD,
vector_field_name=self.embedding_field,
id_type="string",
metric_type=self.similarity_metric,
max_length=65_535,
consistency_level=self.consistency_level,
)
self.collection = Collection(
self.collection_name, using=self.milvusclient._using
)
self._create_index_if_required()
logger.debug(f"Successfully created a new collection: {self.collection_name}")
@property
def client(self) -> Any:
"""Get client."""
return self.milvusclient
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
"""Add the embeddings and their nodes into Milvus.
Args:
nodes (List[BaseNode]): List of nodes with embeddings
to insert.
Raises:
MilvusException: Failed to insert data.
Returns:
List[str]: List of ids inserted.
"""
insert_list = []
insert_ids = []
# Process that data we are going to insert
for node in nodes:
entry = node_to_metadata_dict(node)
entry[MILVUS_ID_FIELD] = node.node_id
entry[self.embedding_field] = node.embedding
insert_ids.append(node.node_id)
insert_list.append(entry)
# Insert the data into milvus
self.collection.insert(insert_list)
self.collection.flush()
self._create_index_if_required()
logger.debug(
f"Successfully inserted embeddings into: {self.collection_name} "
f"Num Inserted: {len(insert_list)}"
)
return insert_ids
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
Delete nodes using with ref_doc_id.
Args:
ref_doc_id (str): The doc_id of the document to delete.
Raises:
MilvusException: Failed to delete the doc.
"""
# Adds ability for multiple doc delete in future.
doc_ids: List[str]
if isinstance(ref_doc_id, list):
doc_ids = ref_doc_id # type: ignore
else:
doc_ids = [ref_doc_id]
# Begin by querying for the primary keys to delete
doc_ids = ['"' + entry + '"' for entry in doc_ids]
entries = self.milvusclient.query(
collection_name=self.collection_name,
filter=f"{self.doc_id_field} in [{','.join(doc_ids)}]",
)
ids = [entry["id"] for entry in entries]
self.milvusclient.delete(collection_name=self.collection_name, pks=ids)
logger.debug(f"Successfully deleted embedding with doc_id: {doc_ids}")
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
Args:
query_embedding (List[float]): query embedding
similarity_top_k (int): top k most similar nodes
doc_ids (Optional[List[str]]): list of doc_ids to filter by
node_ids (Optional[List[str]]): list of node_ids to filter by
output_fields (Optional[List[str]]): list of fields to return
embedding_field (Optional[str]): name of embedding field
"""
if query.mode != VectorStoreQueryMode.DEFAULT:
raise ValueError(f"Milvus does not support {query.mode} yet.")
expr = []
output_fields = ["*"]
# Parse the filter
if query.filters is not None:
expr.extend(_to_milvus_filter(query.filters))
# Parse any docs we are filtering on
if query.doc_ids is not None and len(query.doc_ids) != 0:
expr_list = ['"' + entry + '"' for entry in query.doc_ids]
expr.append(f"{self.doc_id_field} in [{','.join(expr_list)}]")
# Parse any nodes we are filtering on
if query.node_ids is not None and len(query.node_ids) != 0:
expr_list = ['"' + entry + '"' for entry in query.node_ids]
expr.append(f"{MILVUS_ID_FIELD} in [{','.join(expr_list)}]")
# Limit output fields
if query.output_fields is not None:
output_fields = query.output_fields
# Convert to string expression
string_expr = ""
if len(expr) != 0:
string_expr = " and ".join(expr)
# Perform the search
res = self.milvusclient.search(
collection_name=self.collection_name,
data=[query.query_embedding],
filter=string_expr,
limit=query.similarity_top_k,
output_fields=output_fields,
search_params=self.search_config,
)
logger.debug(
f"Successfully searched embedding in collection: {self.collection_name}"
f" Num Results: {len(res[0])}"
)
nodes = []
similarities = []
ids = []
# Parse the results
for hit in res[0]:
if not self.text_key:
node = metadata_dict_to_node(
{
"_node_content": hit["entity"].get("_node_content", None),
"_node_type": hit["entity"].get("_node_type", None),
}
)
else:
try:
text = hit["entity"].get(self.text_key)
except Exception:
raise ValueError(
"The passed in text_key value does not exist "
"in the retrieved entity."
)
node = TextNode(
text=text,
)
nodes.append(node)
similarities.append(hit["distance"])
ids.append(hit["id"])
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
def _create_index_if_required(self, force: bool = False) -> None:
# This helper method is introduced to allow the index to be created
# both in the constructor and in the `add` method. The `force` flag is
# provided to ensure that the index is created in the constructor even
# if self.overwrite is false. In the `add` method, the index is
# recreated only if self.overwrite is true.
if (self.collection.has_index() and self.overwrite) or force:
self.collection.release()
self.collection.drop_index()
base_params: Dict[str, Any] = self.index_config.copy()
index_type: str = base_params.pop("index_type", "FLAT")
index_params: Dict[str, Union[str, Dict[str, Any]]] = {
"params": base_params,
"metric_type": self.similarity_metric,
"index_type": index_type,
}
self.collection.create_index(
self.embedding_field, index_params=index_params
)
self.collection.load()