faiss_rag_enterprise/llama_index/indices/tree/inserter.py

179 lines
7.7 KiB
Python

"""Tree Index inserter."""
from typing import Optional, Sequence
from llama_index.data_structs.data_structs import IndexGraph
from llama_index.indices.tree.utils import get_numbered_text_from_nodes
from llama_index.indices.utils import (
extract_numbers_given_response,
get_sorted_node_list,
)
from llama_index.prompts.base import BasePromptTemplate
from llama_index.prompts.default_prompts import (
DEFAULT_INSERT_PROMPT,
DEFAULT_SUMMARY_PROMPT,
)
from llama_index.schema import BaseNode, MetadataMode, TextNode
from llama_index.service_context import ServiceContext
from llama_index.storage.docstore import BaseDocumentStore
from llama_index.storage.docstore.registry import get_default_docstore
class TreeIndexInserter:
"""LlamaIndex inserter."""
def __init__(
self,
index_graph: IndexGraph,
service_context: ServiceContext,
num_children: int = 10,
insert_prompt: BasePromptTemplate = DEFAULT_INSERT_PROMPT,
summary_prompt: BasePromptTemplate = DEFAULT_SUMMARY_PROMPT,
docstore: Optional[BaseDocumentStore] = None,
) -> None:
"""Initialize with params."""
if num_children < 2:
raise ValueError("Invalid number of children.")
self.num_children = num_children
self.summary_prompt = summary_prompt
self.insert_prompt = insert_prompt
self.index_graph = index_graph
self._service_context = service_context
self._docstore = docstore or get_default_docstore()
def _insert_under_parent_and_consolidate(
self, text_node: BaseNode, parent_node: Optional[BaseNode]
) -> None:
"""Insert node under parent and consolidate.
Consolidation will happen by dividing up child nodes, and creating a new
intermediate layer of nodes.
"""
# perform insertion
self.index_graph.insert_under_parent(text_node, parent_node)
# if under num_children limit, then we're fine
if len(self.index_graph.get_children(parent_node)) <= self.num_children:
return
else:
# perform consolidation
cur_graph_node_ids = self.index_graph.get_children(parent_node)
cur_graph_nodes = self._docstore.get_node_dict(cur_graph_node_ids)
cur_graph_node_list = get_sorted_node_list(cur_graph_nodes)
# this layer is all leaf nodes, consolidate and split leaf nodes
# consolidate and split leaf nodes in half
# TODO: do better splitting (with a GPT prompt etc.)
half1 = cur_graph_node_list[: len(cur_graph_nodes) // 2]
half2 = cur_graph_node_list[len(cur_graph_nodes) // 2 :]
truncated_chunks = self._service_context.prompt_helper.truncate(
prompt=self.summary_prompt,
text_chunks=[
node.get_content(metadata_mode=MetadataMode.LLM) for node in half1
],
)
text_chunk1 = "\n".join(truncated_chunks)
summary1 = self._service_context.llm.predict(
self.summary_prompt, context_str=text_chunk1
)
node1 = TextNode(text=summary1)
self.index_graph.insert(node1, children_nodes=half1)
truncated_chunks = self._service_context.prompt_helper.truncate(
prompt=self.summary_prompt,
text_chunks=[
node.get_content(metadata_mode=MetadataMode.LLM) for node in half2
],
)
text_chunk2 = "\n".join(truncated_chunks)
summary2 = self._service_context.llm.predict(
self.summary_prompt, context_str=text_chunk2
)
node2 = TextNode(text=summary2)
self.index_graph.insert(node2, children_nodes=half2)
# insert half1 and half2 as new children of parent_node
# first remove child indices from parent node
if parent_node is not None:
self.index_graph.node_id_to_children_ids[parent_node.node_id] = []
else:
self.index_graph.root_nodes = {}
self.index_graph.insert_under_parent(
node1, parent_node, new_index=self.index_graph.get_index(node1)
)
self._docstore.add_documents([node1], allow_update=False)
self.index_graph.insert_under_parent(
node2, parent_node, new_index=self.index_graph.get_index(node2)
)
self._docstore.add_documents([node2], allow_update=False)
def _insert_node(
self, node: BaseNode, parent_node: Optional[BaseNode] = None
) -> None:
"""Insert node."""
cur_graph_node_ids = self.index_graph.get_children(parent_node)
cur_graph_nodes = self._docstore.get_node_dict(cur_graph_node_ids)
cur_graph_node_list = get_sorted_node_list(cur_graph_nodes)
# if cur_graph_nodes is empty (start with empty graph), then insert under
# parent (insert new root node)
if len(cur_graph_nodes) == 0:
self._insert_under_parent_and_consolidate(node, parent_node)
# check if leaf nodes, then just insert under parent
elif len(self.index_graph.get_children(cur_graph_node_list[0])) == 0:
self._insert_under_parent_and_consolidate(node, parent_node)
# else try to find the right summary node to insert under
else:
text_splitter = (
self._service_context.prompt_helper.get_text_splitter_given_prompt(
prompt=self.insert_prompt,
num_chunks=len(cur_graph_node_list),
)
)
numbered_text = get_numbered_text_from_nodes(
cur_graph_node_list, text_splitter=text_splitter
)
response = self._service_context.llm.predict(
self.insert_prompt,
new_chunk_text=node.get_content(metadata_mode=MetadataMode.LLM),
num_chunks=len(cur_graph_node_list),
context_list=numbered_text,
)
numbers = extract_numbers_given_response(response)
if numbers is None or len(numbers) == 0:
# NOTE: if we can't extract a number, then we just insert under parent
self._insert_under_parent_and_consolidate(node, parent_node)
elif int(numbers[0]) > len(cur_graph_node_list):
# NOTE: if number is out of range, then we just insert under parent
self._insert_under_parent_and_consolidate(node, parent_node)
else:
selected_node = cur_graph_node_list[int(numbers[0]) - 1]
self._insert_node(node, selected_node)
# now we need to update summary for parent node, since we
# need to bubble updated summaries up the tree
if parent_node is not None:
# refetch children
cur_graph_node_ids = self.index_graph.get_children(parent_node)
cur_graph_nodes = self._docstore.get_node_dict(cur_graph_node_ids)
cur_graph_node_list = get_sorted_node_list(cur_graph_nodes)
truncated_chunks = self._service_context.prompt_helper.truncate(
prompt=self.summary_prompt,
text_chunks=[
node.get_content(metadata_mode=MetadataMode.LLM)
for node in cur_graph_node_list
],
)
text_chunk = "\n".join(truncated_chunks)
new_summary = self._service_context.llm.predict(
self.summary_prompt, context_str=text_chunk
)
parent_node.set_content(new_summary)
def insert(self, nodes: Sequence[BaseNode]) -> None:
"""Insert into index_graph."""
for node in nodes:
self._insert_node(node)