182 lines
6.1 KiB
Python
182 lines
6.1 KiB
Python
"""Node parser interface."""
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, List, Sequence
|
|
|
|
from llama_index.bridge.pydantic import Field
|
|
from llama_index.callbacks import CallbackManager, CBEventType, EventPayload
|
|
from llama_index.node_parser.node_utils import (
|
|
IdFuncCallable,
|
|
build_nodes_from_splits,
|
|
default_id_func,
|
|
)
|
|
from llama_index.schema import (
|
|
BaseNode,
|
|
Document,
|
|
MetadataMode,
|
|
NodeRelationship,
|
|
TransformComponent,
|
|
)
|
|
from llama_index.utils import get_tqdm_iterable
|
|
|
|
|
|
class NodeParser(TransformComponent, ABC):
|
|
"""Base interface for node parser."""
|
|
|
|
include_metadata: bool = Field(
|
|
default=True, description="Whether or not to consider metadata when splitting."
|
|
)
|
|
include_prev_next_rel: bool = Field(
|
|
default=True, description="Include prev/next node relationships."
|
|
)
|
|
callback_manager: CallbackManager = Field(
|
|
default_factory=CallbackManager, exclude=True
|
|
)
|
|
id_func: IdFuncCallable = Field(
|
|
default=default_id_func,
|
|
description="Function to generate node IDs.",
|
|
)
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
@abstractmethod
|
|
def _parse_nodes(
|
|
self,
|
|
nodes: Sequence[BaseNode],
|
|
show_progress: bool = False,
|
|
**kwargs: Any,
|
|
) -> List[BaseNode]:
|
|
...
|
|
|
|
def get_nodes_from_documents(
|
|
self,
|
|
documents: Sequence[Document],
|
|
show_progress: bool = False,
|
|
**kwargs: Any,
|
|
) -> List[BaseNode]:
|
|
"""Parse documents into nodes.
|
|
|
|
Args:
|
|
documents (Sequence[Document]): documents to parse
|
|
show_progress (bool): whether to show progress bar
|
|
|
|
"""
|
|
doc_id_to_document = {doc.id_: doc for doc in documents}
|
|
|
|
with self.callback_manager.event(
|
|
CBEventType.NODE_PARSING, payload={EventPayload.DOCUMENTS: documents}
|
|
) as event:
|
|
nodes = self._parse_nodes(documents, show_progress=show_progress, **kwargs)
|
|
|
|
for i, node in enumerate(nodes):
|
|
if (
|
|
node.ref_doc_id is not None
|
|
and node.ref_doc_id in doc_id_to_document
|
|
):
|
|
ref_doc = doc_id_to_document[node.ref_doc_id]
|
|
start_char_idx = ref_doc.text.find(
|
|
node.get_content(metadata_mode=MetadataMode.NONE)
|
|
)
|
|
|
|
# update start/end char idx
|
|
if start_char_idx >= 0:
|
|
node.start_char_idx = start_char_idx
|
|
node.end_char_idx = start_char_idx + len(
|
|
node.get_content(metadata_mode=MetadataMode.NONE)
|
|
)
|
|
|
|
# update metadata
|
|
if self.include_metadata:
|
|
node.metadata.update(
|
|
doc_id_to_document[node.ref_doc_id].metadata
|
|
)
|
|
|
|
if self.include_prev_next_rel:
|
|
if i > 0:
|
|
node.relationships[NodeRelationship.PREVIOUS] = nodes[
|
|
i - 1
|
|
].as_related_node_info()
|
|
if i < len(nodes) - 1:
|
|
node.relationships[NodeRelationship.NEXT] = nodes[
|
|
i + 1
|
|
].as_related_node_info()
|
|
|
|
event.on_end({EventPayload.NODES: nodes})
|
|
|
|
return nodes
|
|
|
|
def __call__(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]:
|
|
return self.get_nodes_from_documents(nodes, **kwargs)
|
|
|
|
|
|
class TextSplitter(NodeParser):
|
|
@abstractmethod
|
|
def split_text(self, text: str) -> List[str]:
|
|
...
|
|
|
|
def split_texts(self, texts: List[str]) -> List[str]:
|
|
nested_texts = [self.split_text(text) for text in texts]
|
|
return [item for sublist in nested_texts for item in sublist]
|
|
|
|
def _parse_nodes(
|
|
self, nodes: Sequence[BaseNode], show_progress: bool = False, **kwargs: Any
|
|
) -> List[BaseNode]:
|
|
all_nodes: List[BaseNode] = []
|
|
nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes")
|
|
for node in nodes_with_progress:
|
|
splits = self.split_text(node.get_content())
|
|
|
|
all_nodes.extend(
|
|
build_nodes_from_splits(splits, node, id_func=self.id_func)
|
|
)
|
|
|
|
return all_nodes
|
|
|
|
|
|
class MetadataAwareTextSplitter(TextSplitter):
|
|
@abstractmethod
|
|
def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]:
|
|
...
|
|
|
|
def split_texts_metadata_aware(
|
|
self, texts: List[str], metadata_strs: List[str]
|
|
) -> List[str]:
|
|
if len(texts) != len(metadata_strs):
|
|
raise ValueError("Texts and metadata_strs must have the same length")
|
|
nested_texts = [
|
|
self.split_text_metadata_aware(text, metadata)
|
|
for text, metadata in zip(texts, metadata_strs)
|
|
]
|
|
return [item for sublist in nested_texts for item in sublist]
|
|
|
|
def _get_metadata_str(self, node: BaseNode) -> str:
|
|
"""Helper function to get the proper metadata str for splitting."""
|
|
embed_metadata_str = node.get_metadata_str(mode=MetadataMode.EMBED)
|
|
llm_metadata_str = node.get_metadata_str(mode=MetadataMode.LLM)
|
|
|
|
# use the longest metadata str for splitting
|
|
if len(embed_metadata_str) > len(llm_metadata_str):
|
|
metadata_str = embed_metadata_str
|
|
else:
|
|
metadata_str = llm_metadata_str
|
|
|
|
return metadata_str
|
|
|
|
def _parse_nodes(
|
|
self, nodes: Sequence[BaseNode], show_progress: bool = False, **kwargs: Any
|
|
) -> List[BaseNode]:
|
|
all_nodes: List[BaseNode] = []
|
|
nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes")
|
|
|
|
for node in nodes_with_progress:
|
|
metadata_str = self._get_metadata_str(node)
|
|
splits = self.split_text_metadata_aware(
|
|
node.get_content(metadata_mode=MetadataMode.NONE),
|
|
metadata_str=metadata_str,
|
|
)
|
|
all_nodes.extend(
|
|
build_nodes_from_splits(splits, node, id_func=self.id_func)
|
|
)
|
|
|
|
return all_nodes
|