faiss_rag_enterprise/llama_index/indices/common_tree/base.py

246 lines
8.5 KiB
Python

"""Common classes/functions for tree index operations."""
import asyncio
import logging
from typing import Dict, List, Optional, Sequence, Tuple
from llama_index.async_utils import run_async_tasks
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.data_structs.data_structs import IndexGraph
from llama_index.indices.utils import get_sorted_node_list, truncate_text
from llama_index.prompts import BasePromptTemplate
from llama_index.schema import BaseNode, MetadataMode, TextNode
from llama_index.service_context import ServiceContext
from llama_index.storage.docstore import BaseDocumentStore
from llama_index.storage.docstore.registry import get_default_docstore
from llama_index.utils import get_tqdm_iterable
logger = logging.getLogger(__name__)
class GPTTreeIndexBuilder:
"""GPT tree index builder.
Helper class to build the tree-structured index,
or to synthesize an answer.
"""
def __init__(
self,
num_children: int,
summary_prompt: BasePromptTemplate,
service_context: ServiceContext,
docstore: Optional[BaseDocumentStore] = None,
show_progress: bool = False,
use_async: bool = False,
) -> None:
"""Initialize with params."""
if num_children < 2:
raise ValueError("Invalid number of children.")
self.num_children = num_children
self.summary_prompt = summary_prompt
self._service_context = service_context
self._use_async = use_async
self._show_progress = show_progress
self._docstore = docstore or get_default_docstore()
@property
def docstore(self) -> BaseDocumentStore:
"""Return docstore."""
return self._docstore
def build_from_nodes(
self,
nodes: Sequence[BaseNode],
build_tree: bool = True,
) -> IndexGraph:
"""Build from text.
Returns:
IndexGraph: graph object consisting of all_nodes, root_nodes
"""
index_graph = IndexGraph()
for node in nodes:
index_graph.insert(node)
if build_tree:
return self.build_index_from_nodes(
index_graph, index_graph.all_nodes, index_graph.all_nodes, level=0
)
else:
return index_graph
def _prepare_node_and_text_chunks(
self, cur_node_ids: Dict[int, str]
) -> Tuple[List[int], List[List[BaseNode]], List[str]]:
"""Prepare node and text chunks."""
cur_nodes = {
index: self._docstore.get_node(node_id)
for index, node_id in cur_node_ids.items()
}
cur_node_list = get_sorted_node_list(cur_nodes)
logger.info(
f"> Building index from nodes: {len(cur_nodes) // self.num_children} chunks"
)
indices, cur_nodes_chunks, text_chunks = [], [], []
for i in range(0, len(cur_node_list), self.num_children):
cur_nodes_chunk = cur_node_list[i : i + self.num_children]
truncated_chunks = self._service_context.prompt_helper.truncate(
prompt=self.summary_prompt,
text_chunks=[
node.get_content(metadata_mode=MetadataMode.LLM)
for node in cur_nodes_chunk
],
)
text_chunk = "\n".join(truncated_chunks)
indices.append(i)
cur_nodes_chunks.append(cur_nodes_chunk)
text_chunks.append(text_chunk)
return indices, cur_nodes_chunks, text_chunks
def _construct_parent_nodes(
self,
index_graph: IndexGraph,
indices: List[int],
cur_nodes_chunks: List[List[BaseNode]],
summaries: List[str],
) -> Dict[int, str]:
"""Construct parent nodes.
Save nodes to docstore.
"""
new_node_dict = {}
for i, cur_nodes_chunk, new_summary in zip(
indices, cur_nodes_chunks, summaries
):
logger.debug(
f"> {i}/{len(cur_nodes_chunk)}, "
f"summary: {truncate_text(new_summary, 50)}"
)
new_node = TextNode(text=new_summary)
index_graph.insert(new_node, children_nodes=cur_nodes_chunk)
index = index_graph.get_index(new_node)
new_node_dict[index] = new_node.node_id
self._docstore.add_documents([new_node], allow_update=False)
return new_node_dict
def build_index_from_nodes(
self,
index_graph: IndexGraph,
cur_node_ids: Dict[int, str],
all_node_ids: Dict[int, str],
level: int = 0,
) -> IndexGraph:
"""Consolidates chunks recursively, in a bottoms-up fashion."""
if len(cur_node_ids) <= self.num_children:
index_graph.root_nodes = cur_node_ids
return index_graph
indices, cur_nodes_chunks, text_chunks = self._prepare_node_and_text_chunks(
cur_node_ids
)
with self._service_context.callback_manager.event(
CBEventType.TREE, payload={EventPayload.CHUNKS: text_chunks}
) as event:
if self._use_async:
tasks = [
self._service_context.llm.apredict(
self.summary_prompt, context_str=text_chunk
)
for text_chunk in text_chunks
]
outputs: List[Tuple[str, str]] = run_async_tasks(
tasks,
show_progress=self._show_progress,
progress_bar_desc="Generating summaries",
)
summaries = [output[0] for output in outputs]
else:
text_chunks_progress = get_tqdm_iterable(
text_chunks,
show_progress=self._show_progress,
desc="Generating summaries",
)
summaries = [
self._service_context.llm.predict(
self.summary_prompt, context_str=text_chunk
)
for text_chunk in text_chunks_progress
]
self._service_context.llama_logger.add_log(
{"summaries": summaries, "level": level}
)
event.on_end(payload={"summaries": summaries, "level": level})
new_node_dict = self._construct_parent_nodes(
index_graph, indices, cur_nodes_chunks, summaries
)
all_node_ids.update(new_node_dict)
index_graph.root_nodes = new_node_dict
if len(new_node_dict) <= self.num_children:
return index_graph
else:
return self.build_index_from_nodes(
index_graph, new_node_dict, all_node_ids, level=level + 1
)
async def abuild_index_from_nodes(
self,
index_graph: IndexGraph,
cur_node_ids: Dict[int, str],
all_node_ids: Dict[int, str],
level: int = 0,
) -> IndexGraph:
"""Consolidates chunks recursively, in a bottoms-up fashion."""
if len(cur_node_ids) <= self.num_children:
index_graph.root_nodes = cur_node_ids
return index_graph
indices, cur_nodes_chunks, text_chunks = self._prepare_node_and_text_chunks(
cur_node_ids
)
with self._service_context.callback_manager.event(
CBEventType.TREE, payload={EventPayload.CHUNKS: text_chunks}
) as event:
text_chunks_progress = get_tqdm_iterable(
text_chunks,
show_progress=self._show_progress,
desc="Generating summaries",
)
tasks = [
self._service_context.llm.apredict(
self.summary_prompt, context_str=text_chunk
)
for text_chunk in text_chunks_progress
]
outputs: List[Tuple[str, str]] = await asyncio.gather(*tasks)
summaries = [output[0] for output in outputs]
self._service_context.llama_logger.add_log(
{"summaries": summaries, "level": level}
)
event.on_end(payload={"summaries": summaries, "level": level})
new_node_dict = self._construct_parent_nodes(
index_graph, indices, cur_nodes_chunks, summaries
)
all_node_ids.update(new_node_dict)
index_graph.root_nodes = new_node_dict
if len(new_node_dict) <= self.num_children:
return index_graph
else:
return await self.abuild_index_from_nodes(
index_graph, new_node_dict, all_node_ids, level=level + 1
)