207 lines
7.4 KiB
Python
207 lines
7.4 KiB
Python
"""Hierarchical node parser."""
|
|
|
|
from typing import Any, Dict, List, Optional, Sequence
|
|
|
|
from llama_index.bridge.pydantic import Field
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.callbacks.schema import CBEventType, EventPayload
|
|
from llama_index.node_parser.interface import NodeParser
|
|
from llama_index.node_parser.text.sentence import SentenceSplitter
|
|
from llama_index.schema import BaseNode, Document, NodeRelationship
|
|
from llama_index.utils import get_tqdm_iterable
|
|
|
|
|
|
def _add_parent_child_relationship(parent_node: BaseNode, child_node: BaseNode) -> None:
|
|
"""Add parent/child relationship between nodes."""
|
|
child_list = parent_node.relationships.get(NodeRelationship.CHILD, [])
|
|
child_list.append(child_node.as_related_node_info())
|
|
parent_node.relationships[NodeRelationship.CHILD] = child_list
|
|
|
|
child_node.relationships[
|
|
NodeRelationship.PARENT
|
|
] = parent_node.as_related_node_info()
|
|
|
|
|
|
def get_leaf_nodes(nodes: List[BaseNode]) -> List[BaseNode]:
|
|
"""Get leaf nodes."""
|
|
leaf_nodes = []
|
|
for node in nodes:
|
|
if NodeRelationship.CHILD not in node.relationships:
|
|
leaf_nodes.append(node)
|
|
return leaf_nodes
|
|
|
|
|
|
def get_root_nodes(nodes: List[BaseNode]) -> List[BaseNode]:
|
|
"""Get root nodes."""
|
|
root_nodes = []
|
|
for node in nodes:
|
|
if NodeRelationship.PARENT not in node.relationships:
|
|
root_nodes.append(node)
|
|
return root_nodes
|
|
|
|
|
|
class HierarchicalNodeParser(NodeParser):
|
|
"""Hierarchical node parser.
|
|
|
|
Splits a document into a recursive hierarchy Nodes using a NodeParser.
|
|
|
|
NOTE: this will return a hierarchy of nodes in a flat list, where there will be
|
|
overlap between parent nodes (e.g. with a bigger chunk size), and child nodes
|
|
per parent (e.g. with a smaller chunk size).
|
|
|
|
For instance, this may return a list of nodes like:
|
|
- list of top-level nodes with chunk size 2048
|
|
- list of second-level nodes, where each node is a child of a top-level node,
|
|
chunk size 512
|
|
- list of third-level nodes, where each node is a child of a second-level node,
|
|
chunk size 128
|
|
"""
|
|
|
|
chunk_sizes: Optional[List[int]] = Field(
|
|
default=None,
|
|
description=(
|
|
"The chunk sizes to use when splitting documents, in order of level."
|
|
),
|
|
)
|
|
node_parser_ids: List[str] = Field(
|
|
default_factory=list,
|
|
description=(
|
|
"List of ids for the node parsers to use when splitting documents, "
|
|
+ "in order of level (first id used for first level, etc.)."
|
|
),
|
|
)
|
|
node_parser_map: Dict[str, NodeParser] = Field(
|
|
description="Map of node parser id to node parser.",
|
|
)
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
chunk_sizes: Optional[List[int]] = None,
|
|
chunk_overlap: int = 20,
|
|
node_parser_ids: Optional[List[str]] = None,
|
|
node_parser_map: Optional[Dict[str, NodeParser]] = None,
|
|
include_metadata: bool = True,
|
|
include_prev_next_rel: bool = True,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
) -> "HierarchicalNodeParser":
|
|
callback_manager = callback_manager or CallbackManager([])
|
|
|
|
if node_parser_ids is None:
|
|
if chunk_sizes is None:
|
|
chunk_sizes = [2048, 512, 128]
|
|
|
|
node_parser_ids = [f"chunk_size_{chunk_size}" for chunk_size in chunk_sizes]
|
|
node_parser_map = {}
|
|
for chunk_size, node_parser_id in zip(chunk_sizes, node_parser_ids):
|
|
node_parser_map[node_parser_id] = SentenceSplitter(
|
|
chunk_size=chunk_size,
|
|
callback_manager=callback_manager,
|
|
chunk_overlap=chunk_overlap,
|
|
include_metadata=include_metadata,
|
|
include_prev_next_rel=include_prev_next_rel,
|
|
)
|
|
else:
|
|
if chunk_sizes is not None:
|
|
raise ValueError("Cannot specify both node_parser_ids and chunk_sizes.")
|
|
if node_parser_map is None:
|
|
raise ValueError(
|
|
"Must specify node_parser_map if using node_parser_ids."
|
|
)
|
|
|
|
return cls(
|
|
chunk_sizes=chunk_sizes,
|
|
node_parser_ids=node_parser_ids,
|
|
node_parser_map=node_parser_map,
|
|
include_metadata=include_metadata,
|
|
include_prev_next_rel=include_prev_next_rel,
|
|
callback_manager=callback_manager,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "HierarchicalNodeParser"
|
|
|
|
def _recursively_get_nodes_from_nodes(
|
|
self,
|
|
nodes: List[BaseNode],
|
|
level: int,
|
|
show_progress: bool = False,
|
|
) -> List[BaseNode]:
|
|
"""Recursively get nodes from nodes."""
|
|
if level >= len(self.node_parser_ids):
|
|
raise ValueError(
|
|
f"Level {level} is greater than number of text "
|
|
f"splitters ({len(self.node_parser_ids)})."
|
|
)
|
|
|
|
# first split current nodes into sub-nodes
|
|
nodes_with_progress = get_tqdm_iterable(
|
|
nodes, show_progress, "Parsing documents into nodes"
|
|
)
|
|
sub_nodes = []
|
|
for node in nodes_with_progress:
|
|
cur_sub_nodes = self.node_parser_map[
|
|
self.node_parser_ids[level]
|
|
].get_nodes_from_documents([node])
|
|
# add parent relationship from sub node to parent node
|
|
# add child relationship from parent node to sub node
|
|
# NOTE: Only add relationships if level > 0, since we don't want to add
|
|
# relationships for the top-level document objects that we are splitting
|
|
if level > 0:
|
|
for sub_node in cur_sub_nodes:
|
|
_add_parent_child_relationship(
|
|
parent_node=node,
|
|
child_node=sub_node,
|
|
)
|
|
|
|
sub_nodes.extend(cur_sub_nodes)
|
|
|
|
# now for each sub-node, recursively split into sub-sub-nodes, and add
|
|
if level < len(self.node_parser_ids) - 1:
|
|
sub_sub_nodes = self._recursively_get_nodes_from_nodes(
|
|
sub_nodes,
|
|
level + 1,
|
|
show_progress=show_progress,
|
|
)
|
|
else:
|
|
sub_sub_nodes = []
|
|
|
|
return sub_nodes + sub_sub_nodes
|
|
|
|
def get_nodes_from_documents(
|
|
self,
|
|
documents: Sequence[Document],
|
|
show_progress: bool = False,
|
|
**kwargs: Any,
|
|
) -> List[BaseNode]:
|
|
"""Parse document into nodes.
|
|
|
|
Args:
|
|
documents (Sequence[Document]): documents to parse
|
|
include_metadata (bool): whether to include metadata in nodes
|
|
|
|
"""
|
|
with self.callback_manager.event(
|
|
CBEventType.NODE_PARSING, payload={EventPayload.DOCUMENTS: documents}
|
|
) as event:
|
|
all_nodes: List[BaseNode] = []
|
|
documents_with_progress = get_tqdm_iterable(
|
|
documents, show_progress, "Parsing documents into nodes"
|
|
)
|
|
|
|
# TODO: a bit of a hack rn for tqdm
|
|
for doc in documents_with_progress:
|
|
nodes_from_doc = self._recursively_get_nodes_from_nodes([doc], 0)
|
|
all_nodes.extend(nodes_from_doc)
|
|
|
|
event.on_end(payload={EventPayload.NODES: all_nodes})
|
|
|
|
return all_nodes
|
|
|
|
# Unused abstract method
|
|
def _parse_nodes(
|
|
self, nodes: Sequence[BaseNode], show_progress: bool = False, **kwargs: Any
|
|
) -> List[BaseNode]:
|
|
return list(nodes)
|