"""Milvus vector store index. An index that is built within Milvus. """ import logging from typing import Any, Dict, List, Optional, Union from llama_index.schema import BaseNode, TextNode from llama_index.vector_stores.types import ( MetadataFilters, VectorStore, VectorStoreQuery, VectorStoreQueryMode, VectorStoreQueryResult, ) from llama_index.vector_stores.utils import ( DEFAULT_DOC_ID_KEY, DEFAULT_EMBEDDING_KEY, metadata_dict_to_node, node_to_metadata_dict, ) logger = logging.getLogger(__name__) MILVUS_ID_FIELD = "id" def _to_milvus_filter(standard_filters: MetadataFilters) -> List[str]: """Translate standard metadata filters to Milvus specific spec.""" filters = [] for filter in standard_filters.legacy_filters(): if isinstance(filter.value, str): filters.append(str(filter.key) + " == " + '"' + str(filter.value) + '"') else: filters.append(str(filter.key) + " == " + str(filter.value)) return filters class MilvusVectorStore(VectorStore): """The Milvus Vector Store. In this vector store we store the text, its embedding and a its metadata in a Milvus collection. This implementation allows the use of an already existing collection. It also supports creating a new one if the collection doesn't exist or if `overwrite` is set to True. Args: uri (str, optional): The URI to connect to, comes in the form of "http://address:port". token (str, optional): The token for log in. Empty if not using rbac, if using rbac it will most likely be "username:password". collection_name (str, optional): The name of the collection where data will be stored. Defaults to "llamalection". dim (int, optional): The dimension of the embedding vectors for the collection. Required if creating a new collection. embedding_field (str, optional): The name of the embedding field for the collection, defaults to DEFAULT_EMBEDDING_KEY. doc_id_field (str, optional): The name of the doc_id field for the collection, defaults to DEFAULT_DOC_ID_KEY. similarity_metric (str, optional): The similarity metric to use, currently supports IP and L2. consistency_level (str, optional): Which consistency level to use for a newly created collection. Defaults to "Strong". overwrite (bool, optional): Whether to overwrite existing collection with same name. Defaults to False. text_key (str, optional): What key text is stored in in the passed collection. Used when bringing your own collection. Defaults to None. index_config (dict, optional): The configuration used for building the Milvus index. Defaults to None. search_config (dict, optional): The configuration used for searching the Milvus index. Note that this must be compatible with the index type specified by `index_config`. Defaults to None. Raises: ImportError: Unable to import `pymilvus`. MilvusException: Error communicating with Milvus, more can be found in logging under Debug. Returns: MilvusVectorstore: Vectorstore that supports add, delete, and query. """ stores_text: bool = True stores_node: bool = True def __init__( self, uri: str = "http://localhost:19530", token: str = "", collection_name: str = "llamalection", dim: Optional[int] = None, embedding_field: str = DEFAULT_EMBEDDING_KEY, doc_id_field: str = DEFAULT_DOC_ID_KEY, similarity_metric: str = "IP", consistency_level: str = "Strong", overwrite: bool = False, text_key: Optional[str] = None, index_config: Optional[dict] = None, search_config: Optional[dict] = None, **kwargs: Any, ) -> None: """Init params.""" import_err_msg = ( "`pymilvus` package not found, please run `pip install pymilvus`" ) try: import pymilvus # noqa except ImportError: raise ImportError(import_err_msg) from pymilvus import Collection, MilvusClient self.collection_name = collection_name self.dim = dim self.embedding_field = embedding_field self.doc_id_field = doc_id_field self.consistency_level = consistency_level self.overwrite = overwrite self.text_key = text_key self.index_config: Dict[str, Any] = index_config.copy() if index_config else {} # Note: The search configuration is set at construction to avoid having # to change the API for usage of the vector store (i.e. to pass the # search config along with the rest of the query). self.search_config: Dict[str, Any] = ( search_config.copy() if search_config else {} ) # Select the similarity metric if similarity_metric.lower() in ("ip"): self.similarity_metric = "IP" elif similarity_metric.lower() in ("l2", "euclidean"): self.similarity_metric = "L2" # Connect to Milvus instance self.milvusclient = MilvusClient( uri=uri, token=token, **kwargs, # pass additional arguments such as server_pem_path ) # Delete previous collection if overwriting if self.overwrite and self.collection_name in self.client.list_collections(): self.milvusclient.drop_collection(self.collection_name) # Create the collection if it does not exist if self.collection_name not in self.client.list_collections(): if self.dim is None: raise ValueError("Dim argument required for collection creation.") self.milvusclient.create_collection( collection_name=self.collection_name, dimension=self.dim, primary_field_name=MILVUS_ID_FIELD, vector_field_name=self.embedding_field, id_type="string", metric_type=self.similarity_metric, max_length=65_535, consistency_level=self.consistency_level, ) self.collection = Collection( self.collection_name, using=self.milvusclient._using ) self._create_index_if_required() logger.debug(f"Successfully created a new collection: {self.collection_name}") @property def client(self) -> Any: """Get client.""" return self.milvusclient def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]: """Add the embeddings and their nodes into Milvus. Args: nodes (List[BaseNode]): List of nodes with embeddings to insert. Raises: MilvusException: Failed to insert data. Returns: List[str]: List of ids inserted. """ insert_list = [] insert_ids = [] # Process that data we are going to insert for node in nodes: entry = node_to_metadata_dict(node) entry[MILVUS_ID_FIELD] = node.node_id entry[self.embedding_field] = node.embedding insert_ids.append(node.node_id) insert_list.append(entry) # Insert the data into milvus self.collection.insert(insert_list) self.collection.flush() self._create_index_if_required() logger.debug( f"Successfully inserted embeddings into: {self.collection_name} " f"Num Inserted: {len(insert_list)}" ) return insert_ids def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: """ Delete nodes using with ref_doc_id. Args: ref_doc_id (str): The doc_id of the document to delete. Raises: MilvusException: Failed to delete the doc. """ # Adds ability for multiple doc delete in future. doc_ids: List[str] if isinstance(ref_doc_id, list): doc_ids = ref_doc_id # type: ignore else: doc_ids = [ref_doc_id] # Begin by querying for the primary keys to delete doc_ids = ['"' + entry + '"' for entry in doc_ids] entries = self.milvusclient.query( collection_name=self.collection_name, filter=f"{self.doc_id_field} in [{','.join(doc_ids)}]", ) ids = [entry["id"] for entry in entries] self.milvusclient.delete(collection_name=self.collection_name, pks=ids) logger.debug(f"Successfully deleted embedding with doc_id: {doc_ids}") def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: """Query index for top k most similar nodes. Args: query_embedding (List[float]): query embedding similarity_top_k (int): top k most similar nodes doc_ids (Optional[List[str]]): list of doc_ids to filter by node_ids (Optional[List[str]]): list of node_ids to filter by output_fields (Optional[List[str]]): list of fields to return embedding_field (Optional[str]): name of embedding field """ if query.mode != VectorStoreQueryMode.DEFAULT: raise ValueError(f"Milvus does not support {query.mode} yet.") expr = [] output_fields = ["*"] # Parse the filter if query.filters is not None: expr.extend(_to_milvus_filter(query.filters)) # Parse any docs we are filtering on if query.doc_ids is not None and len(query.doc_ids) != 0: expr_list = ['"' + entry + '"' for entry in query.doc_ids] expr.append(f"{self.doc_id_field} in [{','.join(expr_list)}]") # Parse any nodes we are filtering on if query.node_ids is not None and len(query.node_ids) != 0: expr_list = ['"' + entry + '"' for entry in query.node_ids] expr.append(f"{MILVUS_ID_FIELD} in [{','.join(expr_list)}]") # Limit output fields if query.output_fields is not None: output_fields = query.output_fields # Convert to string expression string_expr = "" if len(expr) != 0: string_expr = " and ".join(expr) # Perform the search res = self.milvusclient.search( collection_name=self.collection_name, data=[query.query_embedding], filter=string_expr, limit=query.similarity_top_k, output_fields=output_fields, search_params=self.search_config, ) logger.debug( f"Successfully searched embedding in collection: {self.collection_name}" f" Num Results: {len(res[0])}" ) nodes = [] similarities = [] ids = [] # Parse the results for hit in res[0]: if not self.text_key: node = metadata_dict_to_node( { "_node_content": hit["entity"].get("_node_content", None), "_node_type": hit["entity"].get("_node_type", None), } ) else: try: text = hit["entity"].get(self.text_key) except Exception: raise ValueError( "The passed in text_key value does not exist " "in the retrieved entity." ) node = TextNode( text=text, ) nodes.append(node) similarities.append(hit["distance"]) ids.append(hit["id"]) return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) def _create_index_if_required(self, force: bool = False) -> None: # This helper method is introduced to allow the index to be created # both in the constructor and in the `add` method. The `force` flag is # provided to ensure that the index is created in the constructor even # if self.overwrite is false. In the `add` method, the index is # recreated only if self.overwrite is true. if (self.collection.has_index() and self.overwrite) or force: self.collection.release() self.collection.drop_index() base_params: Dict[str, Any] = self.index_config.copy() index_type: str = base_params.pop("index_type", "FLAT") index_params: Dict[str, Union[str, Dict[str, Any]]] = { "params": base_params, "metric_type": self.similarity_metric, "index_type": index_type, } self.collection.create_index( self.embedding_field, index_params=index_params ) self.collection.load()