179 lines
7.7 KiB
Python
179 lines
7.7 KiB
Python
"""Tree Index inserter."""
|
|
|
|
from typing import Optional, Sequence
|
|
|
|
from llama_index.data_structs.data_structs import IndexGraph
|
|
from llama_index.indices.tree.utils import get_numbered_text_from_nodes
|
|
from llama_index.indices.utils import (
|
|
extract_numbers_given_response,
|
|
get_sorted_node_list,
|
|
)
|
|
from llama_index.prompts.base import BasePromptTemplate
|
|
from llama_index.prompts.default_prompts import (
|
|
DEFAULT_INSERT_PROMPT,
|
|
DEFAULT_SUMMARY_PROMPT,
|
|
)
|
|
from llama_index.schema import BaseNode, MetadataMode, TextNode
|
|
from llama_index.service_context import ServiceContext
|
|
from llama_index.storage.docstore import BaseDocumentStore
|
|
from llama_index.storage.docstore.registry import get_default_docstore
|
|
|
|
|
|
class TreeIndexInserter:
|
|
"""LlamaIndex inserter."""
|
|
|
|
def __init__(
|
|
self,
|
|
index_graph: IndexGraph,
|
|
service_context: ServiceContext,
|
|
num_children: int = 10,
|
|
insert_prompt: BasePromptTemplate = DEFAULT_INSERT_PROMPT,
|
|
summary_prompt: BasePromptTemplate = DEFAULT_SUMMARY_PROMPT,
|
|
docstore: Optional[BaseDocumentStore] = None,
|
|
) -> None:
|
|
"""Initialize with params."""
|
|
if num_children < 2:
|
|
raise ValueError("Invalid number of children.")
|
|
self.num_children = num_children
|
|
self.summary_prompt = summary_prompt
|
|
self.insert_prompt = insert_prompt
|
|
self.index_graph = index_graph
|
|
self._service_context = service_context
|
|
self._docstore = docstore or get_default_docstore()
|
|
|
|
def _insert_under_parent_and_consolidate(
|
|
self, text_node: BaseNode, parent_node: Optional[BaseNode]
|
|
) -> None:
|
|
"""Insert node under parent and consolidate.
|
|
|
|
Consolidation will happen by dividing up child nodes, and creating a new
|
|
intermediate layer of nodes.
|
|
|
|
"""
|
|
# perform insertion
|
|
self.index_graph.insert_under_parent(text_node, parent_node)
|
|
|
|
# if under num_children limit, then we're fine
|
|
if len(self.index_graph.get_children(parent_node)) <= self.num_children:
|
|
return
|
|
else:
|
|
# perform consolidation
|
|
cur_graph_node_ids = self.index_graph.get_children(parent_node)
|
|
cur_graph_nodes = self._docstore.get_node_dict(cur_graph_node_ids)
|
|
cur_graph_node_list = get_sorted_node_list(cur_graph_nodes)
|
|
# this layer is all leaf nodes, consolidate and split leaf nodes
|
|
# consolidate and split leaf nodes in half
|
|
# TODO: do better splitting (with a GPT prompt etc.)
|
|
half1 = cur_graph_node_list[: len(cur_graph_nodes) // 2]
|
|
half2 = cur_graph_node_list[len(cur_graph_nodes) // 2 :]
|
|
|
|
truncated_chunks = self._service_context.prompt_helper.truncate(
|
|
prompt=self.summary_prompt,
|
|
text_chunks=[
|
|
node.get_content(metadata_mode=MetadataMode.LLM) for node in half1
|
|
],
|
|
)
|
|
text_chunk1 = "\n".join(truncated_chunks)
|
|
|
|
summary1 = self._service_context.llm.predict(
|
|
self.summary_prompt, context_str=text_chunk1
|
|
)
|
|
node1 = TextNode(text=summary1)
|
|
self.index_graph.insert(node1, children_nodes=half1)
|
|
|
|
truncated_chunks = self._service_context.prompt_helper.truncate(
|
|
prompt=self.summary_prompt,
|
|
text_chunks=[
|
|
node.get_content(metadata_mode=MetadataMode.LLM) for node in half2
|
|
],
|
|
)
|
|
text_chunk2 = "\n".join(truncated_chunks)
|
|
summary2 = self._service_context.llm.predict(
|
|
self.summary_prompt, context_str=text_chunk2
|
|
)
|
|
node2 = TextNode(text=summary2)
|
|
self.index_graph.insert(node2, children_nodes=half2)
|
|
|
|
# insert half1 and half2 as new children of parent_node
|
|
# first remove child indices from parent node
|
|
if parent_node is not None:
|
|
self.index_graph.node_id_to_children_ids[parent_node.node_id] = []
|
|
else:
|
|
self.index_graph.root_nodes = {}
|
|
self.index_graph.insert_under_parent(
|
|
node1, parent_node, new_index=self.index_graph.get_index(node1)
|
|
)
|
|
self._docstore.add_documents([node1], allow_update=False)
|
|
self.index_graph.insert_under_parent(
|
|
node2, parent_node, new_index=self.index_graph.get_index(node2)
|
|
)
|
|
self._docstore.add_documents([node2], allow_update=False)
|
|
|
|
def _insert_node(
|
|
self, node: BaseNode, parent_node: Optional[BaseNode] = None
|
|
) -> None:
|
|
"""Insert node."""
|
|
cur_graph_node_ids = self.index_graph.get_children(parent_node)
|
|
cur_graph_nodes = self._docstore.get_node_dict(cur_graph_node_ids)
|
|
cur_graph_node_list = get_sorted_node_list(cur_graph_nodes)
|
|
# if cur_graph_nodes is empty (start with empty graph), then insert under
|
|
# parent (insert new root node)
|
|
if len(cur_graph_nodes) == 0:
|
|
self._insert_under_parent_and_consolidate(node, parent_node)
|
|
# check if leaf nodes, then just insert under parent
|
|
elif len(self.index_graph.get_children(cur_graph_node_list[0])) == 0:
|
|
self._insert_under_parent_and_consolidate(node, parent_node)
|
|
# else try to find the right summary node to insert under
|
|
else:
|
|
text_splitter = (
|
|
self._service_context.prompt_helper.get_text_splitter_given_prompt(
|
|
prompt=self.insert_prompt,
|
|
num_chunks=len(cur_graph_node_list),
|
|
)
|
|
)
|
|
numbered_text = get_numbered_text_from_nodes(
|
|
cur_graph_node_list, text_splitter=text_splitter
|
|
)
|
|
response = self._service_context.llm.predict(
|
|
self.insert_prompt,
|
|
new_chunk_text=node.get_content(metadata_mode=MetadataMode.LLM),
|
|
num_chunks=len(cur_graph_node_list),
|
|
context_list=numbered_text,
|
|
)
|
|
numbers = extract_numbers_given_response(response)
|
|
if numbers is None or len(numbers) == 0:
|
|
# NOTE: if we can't extract a number, then we just insert under parent
|
|
self._insert_under_parent_and_consolidate(node, parent_node)
|
|
elif int(numbers[0]) > len(cur_graph_node_list):
|
|
# NOTE: if number is out of range, then we just insert under parent
|
|
self._insert_under_parent_and_consolidate(node, parent_node)
|
|
else:
|
|
selected_node = cur_graph_node_list[int(numbers[0]) - 1]
|
|
self._insert_node(node, selected_node)
|
|
|
|
# now we need to update summary for parent node, since we
|
|
# need to bubble updated summaries up the tree
|
|
if parent_node is not None:
|
|
# refetch children
|
|
cur_graph_node_ids = self.index_graph.get_children(parent_node)
|
|
cur_graph_nodes = self._docstore.get_node_dict(cur_graph_node_ids)
|
|
cur_graph_node_list = get_sorted_node_list(cur_graph_nodes)
|
|
truncated_chunks = self._service_context.prompt_helper.truncate(
|
|
prompt=self.summary_prompt,
|
|
text_chunks=[
|
|
node.get_content(metadata_mode=MetadataMode.LLM)
|
|
for node in cur_graph_node_list
|
|
],
|
|
)
|
|
text_chunk = "\n".join(truncated_chunks)
|
|
new_summary = self._service_context.llm.predict(
|
|
self.summary_prompt, context_str=text_chunk
|
|
)
|
|
|
|
parent_node.set_content(new_summary)
|
|
|
|
def insert(self, nodes: Sequence[BaseNode]) -> None:
|
|
"""Insert into index_graph."""
|
|
for node in nodes:
|
|
self._insert_node(node)
|