"""Chroma vector store.""" import logging import math from typing import Any, Dict, Generator, List, Optional, cast from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.schema import BaseNode, MetadataMode, TextNode from llama_index.utils import truncate_text from llama_index.vector_stores.types import ( BasePydanticVectorStore, MetadataFilters, VectorStoreQuery, VectorStoreQueryResult, ) from llama_index.vector_stores.utils import ( legacy_metadata_dict_to_node, metadata_dict_to_node, node_to_metadata_dict, ) logger = logging.getLogger(__name__) def _transform_chroma_filter_condition(condition: str) -> str: """Translate standard metadata filter op to Chroma specific spec.""" if condition == "and": return "$and" elif condition == "or": return "$or" else: raise ValueError(f"Filter condition {condition} not supported") def _transform_chroma_filter_operator(operator: str) -> str: """Translate standard metadata filter operator to Chroma specific spec.""" if operator == "!=": return "$ne" elif operator == "==": return "$eq" elif operator == ">": return "$gt" elif operator == "<": return "$lt" elif operator == ">=": return "$gte" elif operator == "<=": return "$lte" else: raise ValueError(f"Filter operator {operator} not supported") def _to_chroma_filter( standard_filters: MetadataFilters, ) -> dict: """Translate standard metadata filters to Chroma specific spec.""" filters = {} filters_list = [] condition = standard_filters.condition or "and" condition = _transform_chroma_filter_condition(condition) if standard_filters.filters: for filter in standard_filters.filters: if filter.operator: filters_list.append( { filter.key: { _transform_chroma_filter_operator( filter.operator ): filter.value } } ) else: filters_list.append({filter.key: filter.value}) if len(filters_list) == 1: # If there is only one filter, return it directly return filters_list[0] elif len(filters_list) > 1: filters[condition] = filters_list return filters import_err_msg = "`chromadb` package not found, please run `pip install chromadb`" MAX_CHUNK_SIZE = 41665 # One less than the max chunk size for ChromaDB def chunk_list( lst: List[BaseNode], max_chunk_size: int ) -> Generator[List[BaseNode], None, None]: """Yield successive max_chunk_size-sized chunks from lst. Args: lst (List[BaseNode]): list of nodes with embeddings max_chunk_size (int): max chunk size Yields: Generator[List[BaseNode], None, None]: list of nodes with embeddings """ for i in range(0, len(lst), max_chunk_size): yield lst[i : i + max_chunk_size] class ChromaVectorStore(BasePydanticVectorStore): """Chroma vector store. In this vector store, embeddings are stored within a ChromaDB collection. During query time, the index uses ChromaDB to query for the top k most similar nodes. Args: chroma_collection (chromadb.api.models.Collection.Collection): ChromaDB collection instance """ stores_text: bool = True flat_metadata: bool = True collection_name: Optional[str] host: Optional[str] port: Optional[str] ssl: bool headers: Optional[Dict[str, str]] persist_dir: Optional[str] collection_kwargs: Dict[str, Any] = Field(default_factory=dict) _collection: Any = PrivateAttr() def __init__( self, chroma_collection: Optional[Any] = None, collection_name: Optional[str] = None, host: Optional[str] = None, port: Optional[str] = None, ssl: bool = False, headers: Optional[Dict[str, str]] = None, persist_dir: Optional[str] = None, collection_kwargs: Optional[dict] = None, **kwargs: Any, ) -> None: """Init params.""" try: import chromadb except ImportError: raise ImportError(import_err_msg) from chromadb.api.models.Collection import Collection if chroma_collection is None: client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers) self._collection = client.get_or_create_collection( name=collection_name, **collection_kwargs ) else: self._collection = cast(Collection, chroma_collection) super().__init__( host=host, port=port, ssl=ssl, headers=headers, collection_name=collection_name, persist_dir=persist_dir, collection_kwargs=collection_kwargs or {}, ) @classmethod def from_collection(cls, collection: Any) -> "ChromaVectorStore": try: from chromadb import Collection except ImportError: raise ImportError(import_err_msg) if not isinstance(collection, Collection): raise Exception("argument is not chromadb collection instance") return cls(chroma_collection=collection) @classmethod def from_params( cls, collection_name: str, host: Optional[str] = None, port: Optional[str] = None, ssl: bool = False, headers: Optional[Dict[str, str]] = None, persist_dir: Optional[str] = None, collection_kwargs: dict = {}, **kwargs: Any, ) -> "ChromaVectorStore": try: import chromadb except ImportError: raise ImportError(import_err_msg) if persist_dir: client = chromadb.PersistentClient(path=persist_dir) collection = client.get_or_create_collection( name=collection_name, **collection_kwargs ) elif host and port: client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers) collection = client.get_or_create_collection( name=collection_name, **collection_kwargs ) else: raise ValueError( "Either `persist_dir` or (`host`,`port`) must be specified" ) return cls( chroma_collection=collection, host=host, port=port, ssl=ssl, headers=headers, persist_dir=persist_dir, collection_kwargs=collection_kwargs, **kwargs, ) @classmethod def class_name(cls) -> str: return "ChromaVectorStore" def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]: """Add nodes to index. Args: nodes: List[BaseNode]: list of nodes with embeddings """ if not self._collection: raise ValueError("Collection not initialized") max_chunk_size = MAX_CHUNK_SIZE node_chunks = chunk_list(nodes, max_chunk_size) all_ids = [] for node_chunk in node_chunks: embeddings = [] metadatas = [] ids = [] documents = [] for node in node_chunk: embeddings.append(node.get_embedding()) metadata_dict = node_to_metadata_dict( node, remove_text=True, flat_metadata=self.flat_metadata ) for key in metadata_dict: if metadata_dict[key] is None: metadata_dict[key] = "" metadatas.append(metadata_dict) ids.append(node.node_id) documents.append(node.get_content(metadata_mode=MetadataMode.NONE)) self._collection.add( embeddings=embeddings, ids=ids, metadatas=metadatas, documents=documents, ) all_ids.extend(ids) return all_ids def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: """ Delete nodes using with ref_doc_id. Args: ref_doc_id (str): The doc_id of the document to delete. """ self._collection.delete(where={"document_id": ref_doc_id}) @property def client(self) -> Any: """Return client.""" return self._collection def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: """Query index for top k most similar nodes. Args: query_embedding (List[float]): query embedding similarity_top_k (int): top k most similar nodes """ if query.filters is not None: if "where" in kwargs: raise ValueError( "Cannot specify metadata filters via both query and kwargs. " "Use kwargs only for chroma specific items that are " "not supported via the generic query interface." ) where = _to_chroma_filter(query.filters) else: where = kwargs.pop("where", {}) results = self._collection.query( query_embeddings=query.query_embedding, n_results=query.similarity_top_k, where=where, **kwargs, ) logger.debug(f"> Top {len(results['documents'])} nodes:") nodes = [] similarities = [] ids = [] for node_id, text, metadata, distance in zip( results["ids"][0], results["documents"][0], results["metadatas"][0], results["distances"][0], ): try: node = metadata_dict_to_node(metadata) node.set_content(text) except Exception: # NOTE: deprecated legacy logic for backward compatibility metadata, node_info, relationships = legacy_metadata_dict_to_node( metadata ) node = TextNode( text=text, id_=node_id, metadata=metadata, start_char_idx=node_info.get("start", None), end_char_idx=node_info.get("end", None), relationships=relationships, ) nodes.append(node) similarity_score = math.exp(-distance) similarities.append(similarity_score) logger.debug( f"> [Node {node_id}] [Similarity score: {similarity_score}] " f"{truncate_text(str(text), 100)}" ) ids.append(node_id) return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)