336 lines
12 KiB
Python
336 lines
12 KiB
Python
"""Knowledge Graph Index.
|
|
|
|
Build a KG by extracting triplets, and leveraging the KG during query-time.
|
|
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
|
|
|
|
from llama_index.constants import GRAPH_STORE_KEY
|
|
from llama_index.core.base_retriever import BaseRetriever
|
|
from llama_index.data_structs.data_structs import KG
|
|
from llama_index.graph_stores.simple import SimpleGraphStore
|
|
from llama_index.graph_stores.types import GraphStore
|
|
from llama_index.indices.base import BaseIndex
|
|
from llama_index.prompts import BasePromptTemplate
|
|
from llama_index.prompts.default_prompts import DEFAULT_KG_TRIPLET_EXTRACT_PROMPT
|
|
from llama_index.schema import BaseNode, IndexNode, MetadataMode
|
|
from llama_index.service_context import ServiceContext
|
|
from llama_index.storage.docstore.types import RefDocInfo
|
|
from llama_index.storage.storage_context import StorageContext
|
|
from llama_index.utils import get_tqdm_iterable
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class KnowledgeGraphIndex(BaseIndex[KG]):
|
|
"""Knowledge Graph Index.
|
|
|
|
Build a KG by extracting triplets, and leveraging the KG during query-time.
|
|
|
|
Args:
|
|
kg_triple_extract_template (BasePromptTemplate): The prompt to use for
|
|
extracting triplets.
|
|
max_triplets_per_chunk (int): The maximum number of triplets to extract.
|
|
service_context (Optional[ServiceContext]): The service context to use.
|
|
storage_context (Optional[StorageContext]): The storage context to use.
|
|
graph_store (Optional[GraphStore]): The graph store to use.
|
|
show_progress (bool): Whether to show tqdm progress bars. Defaults to False.
|
|
include_embeddings (bool): Whether to include embeddings in the index.
|
|
Defaults to False.
|
|
max_object_length (int): The maximum length of the object in a triplet.
|
|
Defaults to 128.
|
|
kg_triplet_extract_fn (Optional[Callable]): The function to use for
|
|
extracting triplets. Defaults to None.
|
|
|
|
"""
|
|
|
|
index_struct_cls = KG
|
|
|
|
def __init__(
|
|
self,
|
|
nodes: Optional[Sequence[BaseNode]] = None,
|
|
objects: Optional[Sequence[IndexNode]] = None,
|
|
index_struct: Optional[KG] = None,
|
|
service_context: Optional[ServiceContext] = None,
|
|
storage_context: Optional[StorageContext] = None,
|
|
kg_triple_extract_template: Optional[BasePromptTemplate] = None,
|
|
max_triplets_per_chunk: int = 10,
|
|
include_embeddings: bool = False,
|
|
show_progress: bool = False,
|
|
max_object_length: int = 128,
|
|
kg_triplet_extract_fn: Optional[Callable] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize params."""
|
|
# need to set parameters before building index in base class.
|
|
self.include_embeddings = include_embeddings
|
|
self.max_triplets_per_chunk = max_triplets_per_chunk
|
|
self.kg_triple_extract_template = (
|
|
kg_triple_extract_template or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT
|
|
)
|
|
# NOTE: Partially format keyword extract template here.
|
|
self.kg_triple_extract_template = (
|
|
self.kg_triple_extract_template.partial_format(
|
|
max_knowledge_triplets=self.max_triplets_per_chunk
|
|
)
|
|
)
|
|
self._max_object_length = max_object_length
|
|
self._kg_triplet_extract_fn = kg_triplet_extract_fn
|
|
|
|
super().__init__(
|
|
nodes=nodes,
|
|
index_struct=index_struct,
|
|
service_context=service_context,
|
|
storage_context=storage_context,
|
|
show_progress=show_progress,
|
|
objects=objects,
|
|
**kwargs,
|
|
)
|
|
|
|
# TODO: legacy conversion - remove in next release
|
|
if (
|
|
len(self.index_struct.table) > 0
|
|
and isinstance(self.graph_store, SimpleGraphStore)
|
|
and len(self.graph_store._data.graph_dict) == 0
|
|
):
|
|
logger.warning("Upgrading previously saved KG index to new storage format.")
|
|
self.graph_store._data.graph_dict = self.index_struct.rel_map
|
|
|
|
@property
|
|
def graph_store(self) -> GraphStore:
|
|
return self._graph_store
|
|
|
|
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
|
|
from llama_index.indices.knowledge_graph.retrievers import (
|
|
KGRetrieverMode,
|
|
KGTableRetriever,
|
|
)
|
|
|
|
if len(self.index_struct.embedding_dict) > 0 and "retriever_mode" not in kwargs:
|
|
kwargs["retriever_mode"] = KGRetrieverMode.HYBRID
|
|
|
|
return KGTableRetriever(self, object_map=self._object_map, **kwargs)
|
|
|
|
def _extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
|
|
if self._kg_triplet_extract_fn is not None:
|
|
return self._kg_triplet_extract_fn(text)
|
|
else:
|
|
return self._llm_extract_triplets(text)
|
|
|
|
def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
|
|
"""Extract keywords from text."""
|
|
response = self._service_context.llm.predict(
|
|
self.kg_triple_extract_template,
|
|
text=text,
|
|
)
|
|
return self._parse_triplet_response(
|
|
response, max_length=self._max_object_length
|
|
)
|
|
|
|
@staticmethod
|
|
def _parse_triplet_response(
|
|
response: str, max_length: int = 128
|
|
) -> List[Tuple[str, str, str]]:
|
|
knowledge_strs = response.strip().split("\n")
|
|
results = []
|
|
for text in knowledge_strs:
|
|
if "(" not in text or ")" not in text or text.index(")") < text.index("("):
|
|
# skip empty lines and non-triplets
|
|
continue
|
|
triplet_part = text[text.index("(") + 1 : text.index(")")]
|
|
tokens = triplet_part.split(",")
|
|
if len(tokens) != 3:
|
|
continue
|
|
|
|
if any(len(s.encode("utf-8")) > max_length for s in tokens):
|
|
# We count byte-length instead of len() for UTF-8 chars,
|
|
# will skip if any of the tokens are too long.
|
|
# This is normally due to a poorly formatted triplet
|
|
# extraction, in more serious KG building cases
|
|
# we'll need NLP models to better extract triplets.
|
|
continue
|
|
|
|
subj, pred, obj = map(str.strip, tokens)
|
|
if not subj or not pred or not obj:
|
|
# skip partial triplets
|
|
continue
|
|
|
|
# Strip double quotes and Capitalize triplets for disambiguation
|
|
subj, pred, obj = (
|
|
entity.strip('"').capitalize() for entity in [subj, pred, obj]
|
|
)
|
|
|
|
results.append((subj, pred, obj))
|
|
return results
|
|
|
|
def _build_index_from_nodes(self, nodes: Sequence[BaseNode]) -> KG:
|
|
"""Build the index from nodes."""
|
|
# do simple concatenation
|
|
index_struct = self.index_struct_cls()
|
|
nodes_with_progress = get_tqdm_iterable(
|
|
nodes, self._show_progress, "Processing nodes"
|
|
)
|
|
for n in nodes_with_progress:
|
|
triplets = self._extract_triplets(
|
|
n.get_content(metadata_mode=MetadataMode.LLM)
|
|
)
|
|
logger.debug(f"> Extracted triplets: {triplets}")
|
|
for triplet in triplets:
|
|
subj, _, obj = triplet
|
|
self.upsert_triplet(triplet)
|
|
index_struct.add_node([subj, obj], n)
|
|
|
|
if self.include_embeddings:
|
|
triplet_texts = [str(t) for t in triplets]
|
|
|
|
embed_model = self._service_context.embed_model
|
|
embed_outputs = embed_model.get_text_embedding_batch(
|
|
triplet_texts, show_progress=self._show_progress
|
|
)
|
|
for rel_text, rel_embed in zip(triplet_texts, embed_outputs):
|
|
index_struct.add_to_embedding_dict(rel_text, rel_embed)
|
|
|
|
return index_struct
|
|
|
|
def _insert(self, nodes: Sequence[BaseNode], **insert_kwargs: Any) -> None:
|
|
"""Insert a document."""
|
|
for n in nodes:
|
|
triplets = self._extract_triplets(
|
|
n.get_content(metadata_mode=MetadataMode.LLM)
|
|
)
|
|
logger.debug(f"Extracted triplets: {triplets}")
|
|
for triplet in triplets:
|
|
subj, _, obj = triplet
|
|
triplet_str = str(triplet)
|
|
self.upsert_triplet(triplet)
|
|
self._index_struct.add_node([subj, obj], n)
|
|
if (
|
|
self.include_embeddings
|
|
and triplet_str not in self._index_struct.embedding_dict
|
|
):
|
|
rel_embedding = (
|
|
self._service_context.embed_model.get_text_embedding(
|
|
triplet_str
|
|
)
|
|
)
|
|
self._index_struct.add_to_embedding_dict(triplet_str, rel_embedding)
|
|
|
|
def upsert_triplet(self, triplet: Tuple[str, str, str]) -> None:
|
|
"""Insert triplets.
|
|
|
|
Used for manual insertion of KG triplets (in the form
|
|
of (subject, relationship, object)).
|
|
|
|
Args:
|
|
triplet (str): Knowledge triplet
|
|
|
|
"""
|
|
self._graph_store.upsert_triplet(*triplet)
|
|
|
|
def add_node(self, keywords: List[str], node: BaseNode) -> None:
|
|
"""Add node.
|
|
|
|
Used for manual insertion of nodes (keyed by keywords).
|
|
|
|
Args:
|
|
keywords (List[str]): Keywords to index the node.
|
|
node (Node): Node to be indexed.
|
|
|
|
"""
|
|
self._index_struct.add_node(keywords, node)
|
|
self._docstore.add_documents([node], allow_update=True)
|
|
|
|
def upsert_triplet_and_node(
|
|
self, triplet: Tuple[str, str, str], node: BaseNode
|
|
) -> None:
|
|
"""Upsert KG triplet and node.
|
|
|
|
Calls both upsert_triplet and add_node.
|
|
Behavior is idempotent; if Node already exists,
|
|
only triplet will be added.
|
|
|
|
Args:
|
|
keywords (List[str]): Keywords to index the node.
|
|
node (Node): Node to be indexed.
|
|
|
|
"""
|
|
subj, _, obj = triplet
|
|
self.upsert_triplet(triplet)
|
|
self.add_node([subj, obj], node)
|
|
|
|
def _delete_node(self, node_id: str, **delete_kwargs: Any) -> None:
|
|
"""Delete a node."""
|
|
raise NotImplementedError("Delete is not supported for KG index yet.")
|
|
|
|
@property
|
|
def ref_doc_info(self) -> Dict[str, RefDocInfo]:
|
|
"""Retrieve a dict mapping of ingested documents and their nodes+metadata."""
|
|
node_doc_ids_sets = list(self._index_struct.table.values())
|
|
node_doc_ids = list(set().union(*node_doc_ids_sets))
|
|
nodes = self.docstore.get_nodes(node_doc_ids)
|
|
|
|
all_ref_doc_info = {}
|
|
for node in nodes:
|
|
ref_node = node.source_node
|
|
if not ref_node:
|
|
continue
|
|
|
|
ref_doc_info = self.docstore.get_ref_doc_info(ref_node.node_id)
|
|
if not ref_doc_info:
|
|
continue
|
|
|
|
all_ref_doc_info[ref_node.node_id] = ref_doc_info
|
|
return all_ref_doc_info
|
|
|
|
def get_networkx_graph(self, limit: int = 100) -> Any:
|
|
"""Get networkx representation of the graph structure.
|
|
|
|
Args:
|
|
limit (int): Number of starting nodes to be included in the graph.
|
|
|
|
NOTE: This function requires networkx to be installed.
|
|
NOTE: This is a beta feature.
|
|
|
|
"""
|
|
try:
|
|
import networkx as nx
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please install networkx to visualize the graph: `pip install networkx`"
|
|
)
|
|
|
|
g = nx.Graph()
|
|
subjs = list(self.index_struct.table.keys())
|
|
|
|
# add edges
|
|
rel_map = self._graph_store.get_rel_map(subjs=subjs, depth=1, limit=limit)
|
|
|
|
added_nodes = set()
|
|
for keyword in rel_map:
|
|
for path in rel_map[keyword]:
|
|
subj = keyword
|
|
for i in range(0, len(path), 2):
|
|
if i + 2 >= len(path):
|
|
break
|
|
|
|
if subj not in added_nodes:
|
|
g.add_node(subj)
|
|
added_nodes.add(subj)
|
|
|
|
rel = path[i + 1]
|
|
obj = path[i + 2]
|
|
|
|
g.add_edge(subj, obj, label=rel, title=rel)
|
|
subj = obj
|
|
return g
|
|
|
|
@property
|
|
def query_context(self) -> Dict[str, Any]:
|
|
return {GRAPH_STORE_KEY: self._graph_store}
|
|
|
|
|
|
# legacy
|
|
GPTKnowledgeGraphIndex = KnowledgeGraphIndex
|