"""Hierarchical node parser.""" from typing import Any, Dict, List, Optional, Sequence from llama_index.bridge.pydantic import Field from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.node_parser.interface import NodeParser from llama_index.node_parser.text.sentence import SentenceSplitter from llama_index.schema import BaseNode, Document, NodeRelationship from llama_index.utils import get_tqdm_iterable def _add_parent_child_relationship(parent_node: BaseNode, child_node: BaseNode) -> None: """Add parent/child relationship between nodes.""" child_list = parent_node.relationships.get(NodeRelationship.CHILD, []) child_list.append(child_node.as_related_node_info()) parent_node.relationships[NodeRelationship.CHILD] = child_list child_node.relationships[ NodeRelationship.PARENT ] = parent_node.as_related_node_info() def get_leaf_nodes(nodes: List[BaseNode]) -> List[BaseNode]: """Get leaf nodes.""" leaf_nodes = [] for node in nodes: if NodeRelationship.CHILD not in node.relationships: leaf_nodes.append(node) return leaf_nodes def get_root_nodes(nodes: List[BaseNode]) -> List[BaseNode]: """Get root nodes.""" root_nodes = [] for node in nodes: if NodeRelationship.PARENT not in node.relationships: root_nodes.append(node) return root_nodes class HierarchicalNodeParser(NodeParser): """Hierarchical node parser. Splits a document into a recursive hierarchy Nodes using a NodeParser. NOTE: this will return a hierarchy of nodes in a flat list, where there will be overlap between parent nodes (e.g. with a bigger chunk size), and child nodes per parent (e.g. with a smaller chunk size). For instance, this may return a list of nodes like: - list of top-level nodes with chunk size 2048 - list of second-level nodes, where each node is a child of a top-level node, chunk size 512 - list of third-level nodes, where each node is a child of a second-level node, chunk size 128 """ chunk_sizes: Optional[List[int]] = Field( default=None, description=( "The chunk sizes to use when splitting documents, in order of level." ), ) node_parser_ids: List[str] = Field( default_factory=list, description=( "List of ids for the node parsers to use when splitting documents, " + "in order of level (first id used for first level, etc.)." ), ) node_parser_map: Dict[str, NodeParser] = Field( description="Map of node parser id to node parser.", ) @classmethod def from_defaults( cls, chunk_sizes: Optional[List[int]] = None, chunk_overlap: int = 20, node_parser_ids: Optional[List[str]] = None, node_parser_map: Optional[Dict[str, NodeParser]] = None, include_metadata: bool = True, include_prev_next_rel: bool = True, callback_manager: Optional[CallbackManager] = None, ) -> "HierarchicalNodeParser": callback_manager = callback_manager or CallbackManager([]) if node_parser_ids is None: if chunk_sizes is None: chunk_sizes = [2048, 512, 128] node_parser_ids = [f"chunk_size_{chunk_size}" for chunk_size in chunk_sizes] node_parser_map = {} for chunk_size, node_parser_id in zip(chunk_sizes, node_parser_ids): node_parser_map[node_parser_id] = SentenceSplitter( chunk_size=chunk_size, callback_manager=callback_manager, chunk_overlap=chunk_overlap, include_metadata=include_metadata, include_prev_next_rel=include_prev_next_rel, ) else: if chunk_sizes is not None: raise ValueError("Cannot specify both node_parser_ids and chunk_sizes.") if node_parser_map is None: raise ValueError( "Must specify node_parser_map if using node_parser_ids." ) return cls( chunk_sizes=chunk_sizes, node_parser_ids=node_parser_ids, node_parser_map=node_parser_map, include_metadata=include_metadata, include_prev_next_rel=include_prev_next_rel, callback_manager=callback_manager, ) @classmethod def class_name(cls) -> str: return "HierarchicalNodeParser" def _recursively_get_nodes_from_nodes( self, nodes: List[BaseNode], level: int, show_progress: bool = False, ) -> List[BaseNode]: """Recursively get nodes from nodes.""" if level >= len(self.node_parser_ids): raise ValueError( f"Level {level} is greater than number of text " f"splitters ({len(self.node_parser_ids)})." ) # first split current nodes into sub-nodes nodes_with_progress = get_tqdm_iterable( nodes, show_progress, "Parsing documents into nodes" ) sub_nodes = [] for node in nodes_with_progress: cur_sub_nodes = self.node_parser_map[ self.node_parser_ids[level] ].get_nodes_from_documents([node]) # add parent relationship from sub node to parent node # add child relationship from parent node to sub node # NOTE: Only add relationships if level > 0, since we don't want to add # relationships for the top-level document objects that we are splitting if level > 0: for sub_node in cur_sub_nodes: _add_parent_child_relationship( parent_node=node, child_node=sub_node, ) sub_nodes.extend(cur_sub_nodes) # now for each sub-node, recursively split into sub-sub-nodes, and add if level < len(self.node_parser_ids) - 1: sub_sub_nodes = self._recursively_get_nodes_from_nodes( sub_nodes, level + 1, show_progress=show_progress, ) else: sub_sub_nodes = [] return sub_nodes + sub_sub_nodes def get_nodes_from_documents( self, documents: Sequence[Document], show_progress: bool = False, **kwargs: Any, ) -> List[BaseNode]: """Parse document into nodes. Args: documents (Sequence[Document]): documents to parse include_metadata (bool): whether to include metadata in nodes """ with self.callback_manager.event( CBEventType.NODE_PARSING, payload={EventPayload.DOCUMENTS: documents} ) as event: all_nodes: List[BaseNode] = [] documents_with_progress = get_tqdm_iterable( documents, show_progress, "Parsing documents into nodes" ) # TODO: a bit of a hack rn for tqdm for doc in documents_with_progress: nodes_from_doc = self._recursively_get_nodes_from_nodes([doc], 0) all_nodes.extend(nodes_from_doc) event.on_end(payload={EventPayload.NODES: all_nodes}) return all_nodes # Unused abstract method def _parse_nodes( self, nodes: Sequence[BaseNode], show_progress: bool = False, **kwargs: Any ) -> List[BaseNode]: return list(nodes)