"""Knowledge Graph Index. Build a KG by extracting triplets, and leveraging the KG during query-time. """ import logging from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple from llama_index.constants import GRAPH_STORE_KEY from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.data_structs import KG from llama_index.graph_stores.simple import SimpleGraphStore from llama_index.graph_stores.types import GraphStore from llama_index.indices.base import BaseIndex from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_KG_TRIPLET_EXTRACT_PROMPT from llama_index.schema import BaseNode, IndexNode, MetadataMode from llama_index.service_context import ServiceContext from llama_index.storage.docstore.types import RefDocInfo from llama_index.storage.storage_context import StorageContext from llama_index.utils import get_tqdm_iterable logger = logging.getLogger(__name__) class KnowledgeGraphIndex(BaseIndex[KG]): """Knowledge Graph Index. Build a KG by extracting triplets, and leveraging the KG during query-time. Args: kg_triple_extract_template (BasePromptTemplate): The prompt to use for extracting triplets. max_triplets_per_chunk (int): The maximum number of triplets to extract. service_context (Optional[ServiceContext]): The service context to use. storage_context (Optional[StorageContext]): The storage context to use. graph_store (Optional[GraphStore]): The graph store to use. show_progress (bool): Whether to show tqdm progress bars. Defaults to False. include_embeddings (bool): Whether to include embeddings in the index. Defaults to False. max_object_length (int): The maximum length of the object in a triplet. Defaults to 128. kg_triplet_extract_fn (Optional[Callable]): The function to use for extracting triplets. Defaults to None. """ index_struct_cls = KG def __init__( self, nodes: Optional[Sequence[BaseNode]] = None, objects: Optional[Sequence[IndexNode]] = None, index_struct: Optional[KG] = None, service_context: Optional[ServiceContext] = None, storage_context: Optional[StorageContext] = None, kg_triple_extract_template: Optional[BasePromptTemplate] = None, max_triplets_per_chunk: int = 10, include_embeddings: bool = False, show_progress: bool = False, max_object_length: int = 128, kg_triplet_extract_fn: Optional[Callable] = None, **kwargs: Any, ) -> None: """Initialize params.""" # need to set parameters before building index in base class. self.include_embeddings = include_embeddings self.max_triplets_per_chunk = max_triplets_per_chunk self.kg_triple_extract_template = ( kg_triple_extract_template or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT ) # NOTE: Partially format keyword extract template here. self.kg_triple_extract_template = ( self.kg_triple_extract_template.partial_format( max_knowledge_triplets=self.max_triplets_per_chunk ) ) self._max_object_length = max_object_length self._kg_triplet_extract_fn = kg_triplet_extract_fn super().__init__( nodes=nodes, index_struct=index_struct, service_context=service_context, storage_context=storage_context, show_progress=show_progress, objects=objects, **kwargs, ) # TODO: legacy conversion - remove in next release if ( len(self.index_struct.table) > 0 and isinstance(self.graph_store, SimpleGraphStore) and len(self.graph_store._data.graph_dict) == 0 ): logger.warning("Upgrading previously saved KG index to new storage format.") self.graph_store._data.graph_dict = self.index_struct.rel_map @property def graph_store(self) -> GraphStore: return self._graph_store def as_retriever(self, **kwargs: Any) -> BaseRetriever: from llama_index.indices.knowledge_graph.retrievers import ( KGRetrieverMode, KGTableRetriever, ) if len(self.index_struct.embedding_dict) > 0 and "retriever_mode" not in kwargs: kwargs["retriever_mode"] = KGRetrieverMode.HYBRID return KGTableRetriever(self, object_map=self._object_map, **kwargs) def _extract_triplets(self, text: str) -> List[Tuple[str, str, str]]: if self._kg_triplet_extract_fn is not None: return self._kg_triplet_extract_fn(text) else: return self._llm_extract_triplets(text) def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]: """Extract keywords from text.""" response = self._service_context.llm.predict( self.kg_triple_extract_template, text=text, ) return self._parse_triplet_response( response, max_length=self._max_object_length ) @staticmethod def _parse_triplet_response( response: str, max_length: int = 128 ) -> List[Tuple[str, str, str]]: knowledge_strs = response.strip().split("\n") results = [] for text in knowledge_strs: if "(" not in text or ")" not in text or text.index(")") < text.index("("): # skip empty lines and non-triplets continue triplet_part = text[text.index("(") + 1 : text.index(")")] tokens = triplet_part.split(",") if len(tokens) != 3: continue if any(len(s.encode("utf-8")) > max_length for s in tokens): # We count byte-length instead of len() for UTF-8 chars, # will skip if any of the tokens are too long. # This is normally due to a poorly formatted triplet # extraction, in more serious KG building cases # we'll need NLP models to better extract triplets. continue subj, pred, obj = map(str.strip, tokens) if not subj or not pred or not obj: # skip partial triplets continue # Strip double quotes and Capitalize triplets for disambiguation subj, pred, obj = ( entity.strip('"').capitalize() for entity in [subj, pred, obj] ) results.append((subj, pred, obj)) return results def _build_index_from_nodes(self, nodes: Sequence[BaseNode]) -> KG: """Build the index from nodes.""" # do simple concatenation index_struct = self.index_struct_cls() nodes_with_progress = get_tqdm_iterable( nodes, self._show_progress, "Processing nodes" ) for n in nodes_with_progress: triplets = self._extract_triplets( n.get_content(metadata_mode=MetadataMode.LLM) ) logger.debug(f"> Extracted triplets: {triplets}") for triplet in triplets: subj, _, obj = triplet self.upsert_triplet(triplet) index_struct.add_node([subj, obj], n) if self.include_embeddings: triplet_texts = [str(t) for t in triplets] embed_model = self._service_context.embed_model embed_outputs = embed_model.get_text_embedding_batch( triplet_texts, show_progress=self._show_progress ) for rel_text, rel_embed in zip(triplet_texts, embed_outputs): index_struct.add_to_embedding_dict(rel_text, rel_embed) return index_struct def _insert(self, nodes: Sequence[BaseNode], **insert_kwargs: Any) -> None: """Insert a document.""" for n in nodes: triplets = self._extract_triplets( n.get_content(metadata_mode=MetadataMode.LLM) ) logger.debug(f"Extracted triplets: {triplets}") for triplet in triplets: subj, _, obj = triplet triplet_str = str(triplet) self.upsert_triplet(triplet) self._index_struct.add_node([subj, obj], n) if ( self.include_embeddings and triplet_str not in self._index_struct.embedding_dict ): rel_embedding = ( self._service_context.embed_model.get_text_embedding( triplet_str ) ) self._index_struct.add_to_embedding_dict(triplet_str, rel_embedding) def upsert_triplet(self, triplet: Tuple[str, str, str]) -> None: """Insert triplets. Used for manual insertion of KG triplets (in the form of (subject, relationship, object)). Args: triplet (str): Knowledge triplet """ self._graph_store.upsert_triplet(*triplet) def add_node(self, keywords: List[str], node: BaseNode) -> None: """Add node. Used for manual insertion of nodes (keyed by keywords). Args: keywords (List[str]): Keywords to index the node. node (Node): Node to be indexed. """ self._index_struct.add_node(keywords, node) self._docstore.add_documents([node], allow_update=True) def upsert_triplet_and_node( self, triplet: Tuple[str, str, str], node: BaseNode ) -> None: """Upsert KG triplet and node. Calls both upsert_triplet and add_node. Behavior is idempotent; if Node already exists, only triplet will be added. Args: keywords (List[str]): Keywords to index the node. node (Node): Node to be indexed. """ subj, _, obj = triplet self.upsert_triplet(triplet) self.add_node([subj, obj], node) def _delete_node(self, node_id: str, **delete_kwargs: Any) -> None: """Delete a node.""" raise NotImplementedError("Delete is not supported for KG index yet.") @property def ref_doc_info(self) -> Dict[str, RefDocInfo]: """Retrieve a dict mapping of ingested documents and their nodes+metadata.""" node_doc_ids_sets = list(self._index_struct.table.values()) node_doc_ids = list(set().union(*node_doc_ids_sets)) nodes = self.docstore.get_nodes(node_doc_ids) all_ref_doc_info = {} for node in nodes: ref_node = node.source_node if not ref_node: continue ref_doc_info = self.docstore.get_ref_doc_info(ref_node.node_id) if not ref_doc_info: continue all_ref_doc_info[ref_node.node_id] = ref_doc_info return all_ref_doc_info def get_networkx_graph(self, limit: int = 100) -> Any: """Get networkx representation of the graph structure. Args: limit (int): Number of starting nodes to be included in the graph. NOTE: This function requires networkx to be installed. NOTE: This is a beta feature. """ try: import networkx as nx except ImportError: raise ImportError( "Please install networkx to visualize the graph: `pip install networkx`" ) g = nx.Graph() subjs = list(self.index_struct.table.keys()) # add edges rel_map = self._graph_store.get_rel_map(subjs=subjs, depth=1, limit=limit) added_nodes = set() for keyword in rel_map: for path in rel_map[keyword]: subj = keyword for i in range(0, len(path), 2): if i + 2 >= len(path): break if subj not in added_nodes: g.add_node(subj) added_nodes.add(subj) rel = path[i + 1] obj = path[i + 2] g.add_edge(subj, obj, label=rel, title=rel) subj = obj return g @property def query_context(self) -> Dict[str, Any]: return {GRAPH_STORE_KEY: self._graph_store} # legacy GPTKnowledgeGraphIndex = KnowledgeGraphIndex