faiss_rag_enterprise/llama_index/indices/tree/base.py

180 lines
6.8 KiB
Python

"""Tree-based index."""
from enum import Enum
from typing import Any, Dict, Optional, Sequence, Union
from llama_index.core.base_retriever import BaseRetriever
# from llama_index.data_structs.data_structs import IndexGraph
from llama_index.data_structs.data_structs import IndexGraph
from llama_index.indices.base import BaseIndex
from llama_index.indices.common_tree.base import GPTTreeIndexBuilder
from llama_index.indices.tree.inserter import TreeIndexInserter
from llama_index.prompts import BasePromptTemplate
from llama_index.prompts.default_prompts import (
DEFAULT_INSERT_PROMPT,
DEFAULT_SUMMARY_PROMPT,
)
from llama_index.schema import BaseNode, IndexNode
from llama_index.service_context import ServiceContext
from llama_index.storage.docstore.types import RefDocInfo
class TreeRetrieverMode(str, Enum):
SELECT_LEAF = "select_leaf"
SELECT_LEAF_EMBEDDING = "select_leaf_embedding"
ALL_LEAF = "all_leaf"
ROOT = "root"
REQUIRE_TREE_MODES = {
TreeRetrieverMode.SELECT_LEAF,
TreeRetrieverMode.SELECT_LEAF_EMBEDDING,
TreeRetrieverMode.ROOT,
}
class TreeIndex(BaseIndex[IndexGraph]):
"""Tree Index.
The tree index is a tree-structured index, where each node is a summary of
the children nodes. During index construction, the tree is constructed
in a bottoms-up fashion until we end up with a set of root_nodes.
There are a few different options during query time (see :ref:`Ref-Query`).
The main option is to traverse down the tree from the root nodes.
A secondary answer is to directly synthesize the answer from the root nodes.
Args:
summary_template (Optional[BasePromptTemplate]): A Summarization Prompt
(see :ref:`Prompt-Templates`).
insert_prompt (Optional[BasePromptTemplate]): An Tree Insertion Prompt
(see :ref:`Prompt-Templates`).
num_children (int): The number of children each node should have.
build_tree (bool): Whether to build the tree during index construction.
show_progress (bool): Whether to show progress bars. Defaults to False.
"""
index_struct_cls = IndexGraph
def __init__(
self,
nodes: Optional[Sequence[BaseNode]] = None,
objects: Optional[Sequence[IndexNode]] = None,
index_struct: Optional[IndexGraph] = None,
service_context: Optional[ServiceContext] = None,
summary_template: Optional[BasePromptTemplate] = None,
insert_prompt: Optional[BasePromptTemplate] = None,
num_children: int = 10,
build_tree: bool = True,
use_async: bool = False,
show_progress: bool = False,
**kwargs: Any,
) -> None:
"""Initialize params."""
# need to set parameters before building index in base class.
self.num_children = num_children
self.summary_template = summary_template or DEFAULT_SUMMARY_PROMPT
self.insert_prompt: BasePromptTemplate = insert_prompt or DEFAULT_INSERT_PROMPT
self.build_tree = build_tree
self._use_async = use_async
super().__init__(
nodes=nodes,
index_struct=index_struct,
service_context=service_context,
show_progress=show_progress,
objects=objects,
**kwargs,
)
def as_retriever(
self,
retriever_mode: Union[str, TreeRetrieverMode] = TreeRetrieverMode.SELECT_LEAF,
**kwargs: Any,
) -> BaseRetriever:
# NOTE: lazy import
from llama_index.indices.tree.all_leaf_retriever import TreeAllLeafRetriever
from llama_index.indices.tree.select_leaf_embedding_retriever import (
TreeSelectLeafEmbeddingRetriever,
)
from llama_index.indices.tree.select_leaf_retriever import (
TreeSelectLeafRetriever,
)
from llama_index.indices.tree.tree_root_retriever import TreeRootRetriever
self._validate_build_tree_required(TreeRetrieverMode(retriever_mode))
if retriever_mode == TreeRetrieverMode.SELECT_LEAF:
return TreeSelectLeafRetriever(self, object_map=self._object_map, **kwargs)
elif retriever_mode == TreeRetrieverMode.SELECT_LEAF_EMBEDDING:
return TreeSelectLeafEmbeddingRetriever(
self, object_map=self._object_map, **kwargs
)
elif retriever_mode == TreeRetrieverMode.ROOT:
return TreeRootRetriever(self, object_map=self._object_map, **kwargs)
elif retriever_mode == TreeRetrieverMode.ALL_LEAF:
return TreeAllLeafRetriever(self, object_map=self._object_map, **kwargs)
else:
raise ValueError(f"Unknown retriever mode: {retriever_mode}")
def _validate_build_tree_required(self, retriever_mode: TreeRetrieverMode) -> None:
"""Check if index supports modes that require trees."""
if retriever_mode in REQUIRE_TREE_MODES and not self.build_tree:
raise ValueError(
"Index was constructed without building trees, "
f"but retriever mode {retriever_mode} requires trees."
)
def _build_index_from_nodes(self, nodes: Sequence[BaseNode]) -> IndexGraph:
"""Build the index from nodes."""
index_builder = GPTTreeIndexBuilder(
self.num_children,
self.summary_template,
service_context=self._service_context,
use_async=self._use_async,
show_progress=self._show_progress,
docstore=self._docstore,
)
return index_builder.build_from_nodes(nodes, build_tree=self.build_tree)
def _insert(self, nodes: Sequence[BaseNode], **insert_kwargs: Any) -> None:
"""Insert a document."""
# TODO: allow to customize insert prompt
inserter = TreeIndexInserter(
self.index_struct,
num_children=self.num_children,
insert_prompt=self.insert_prompt,
summary_prompt=self.summary_template,
service_context=self._service_context,
docstore=self._docstore,
)
inserter.insert(nodes)
def _delete_node(self, node_id: str, **delete_kwargs: Any) -> None:
"""Delete a node."""
raise NotImplementedError("Delete not implemented for tree index.")
@property
def ref_doc_info(self) -> Dict[str, RefDocInfo]:
"""Retrieve a dict mapping of ingested documents and their nodes+metadata."""
node_doc_ids = list(self.index_struct.all_nodes.values())
nodes = self.docstore.get_nodes(node_doc_ids)
all_ref_doc_info = {}
for node in nodes:
ref_node = node.source_node
if not ref_node:
continue
ref_doc_info = self.docstore.get_ref_doc_info(ref_node.node_id)
if not ref_doc_info:
continue
all_ref_doc_info[ref_node.node_id] = ref_doc_info
return all_ref_doc_info
# legacy
GPTTreeIndex = TreeIndex