faiss_rag_enterprise/llama_index/node_parser/interface.py

182 lines
6.1 KiB
Python

"""Node parser interface."""
from abc import ABC, abstractmethod
from typing import Any, List, Sequence
from llama_index.bridge.pydantic import Field
from llama_index.callbacks import CallbackManager, CBEventType, EventPayload
from llama_index.node_parser.node_utils import (
IdFuncCallable,
build_nodes_from_splits,
default_id_func,
)
from llama_index.schema import (
BaseNode,
Document,
MetadataMode,
NodeRelationship,
TransformComponent,
)
from llama_index.utils import get_tqdm_iterable
class NodeParser(TransformComponent, ABC):
"""Base interface for node parser."""
include_metadata: bool = Field(
default=True, description="Whether or not to consider metadata when splitting."
)
include_prev_next_rel: bool = Field(
default=True, description="Include prev/next node relationships."
)
callback_manager: CallbackManager = Field(
default_factory=CallbackManager, exclude=True
)
id_func: IdFuncCallable = Field(
default=default_id_func,
description="Function to generate node IDs.",
)
class Config:
arbitrary_types_allowed = True
@abstractmethod
def _parse_nodes(
self,
nodes: Sequence[BaseNode],
show_progress: bool = False,
**kwargs: Any,
) -> List[BaseNode]:
...
def get_nodes_from_documents(
self,
documents: Sequence[Document],
show_progress: bool = False,
**kwargs: Any,
) -> List[BaseNode]:
"""Parse documents into nodes.
Args:
documents (Sequence[Document]): documents to parse
show_progress (bool): whether to show progress bar
"""
doc_id_to_document = {doc.id_: doc for doc in documents}
with self.callback_manager.event(
CBEventType.NODE_PARSING, payload={EventPayload.DOCUMENTS: documents}
) as event:
nodes = self._parse_nodes(documents, show_progress=show_progress, **kwargs)
for i, node in enumerate(nodes):
if (
node.ref_doc_id is not None
and node.ref_doc_id in doc_id_to_document
):
ref_doc = doc_id_to_document[node.ref_doc_id]
start_char_idx = ref_doc.text.find(
node.get_content(metadata_mode=MetadataMode.NONE)
)
# update start/end char idx
if start_char_idx >= 0:
node.start_char_idx = start_char_idx
node.end_char_idx = start_char_idx + len(
node.get_content(metadata_mode=MetadataMode.NONE)
)
# update metadata
if self.include_metadata:
node.metadata.update(
doc_id_to_document[node.ref_doc_id].metadata
)
if self.include_prev_next_rel:
if i > 0:
node.relationships[NodeRelationship.PREVIOUS] = nodes[
i - 1
].as_related_node_info()
if i < len(nodes) - 1:
node.relationships[NodeRelationship.NEXT] = nodes[
i + 1
].as_related_node_info()
event.on_end({EventPayload.NODES: nodes})
return nodes
def __call__(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]:
return self.get_nodes_from_documents(nodes, **kwargs)
class TextSplitter(NodeParser):
@abstractmethod
def split_text(self, text: str) -> List[str]:
...
def split_texts(self, texts: List[str]) -> List[str]:
nested_texts = [self.split_text(text) for text in texts]
return [item for sublist in nested_texts for item in sublist]
def _parse_nodes(
self, nodes: Sequence[BaseNode], show_progress: bool = False, **kwargs: Any
) -> List[BaseNode]:
all_nodes: List[BaseNode] = []
nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes")
for node in nodes_with_progress:
splits = self.split_text(node.get_content())
all_nodes.extend(
build_nodes_from_splits(splits, node, id_func=self.id_func)
)
return all_nodes
class MetadataAwareTextSplitter(TextSplitter):
@abstractmethod
def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]:
...
def split_texts_metadata_aware(
self, texts: List[str], metadata_strs: List[str]
) -> List[str]:
if len(texts) != len(metadata_strs):
raise ValueError("Texts and metadata_strs must have the same length")
nested_texts = [
self.split_text_metadata_aware(text, metadata)
for text, metadata in zip(texts, metadata_strs)
]
return [item for sublist in nested_texts for item in sublist]
def _get_metadata_str(self, node: BaseNode) -> str:
"""Helper function to get the proper metadata str for splitting."""
embed_metadata_str = node.get_metadata_str(mode=MetadataMode.EMBED)
llm_metadata_str = node.get_metadata_str(mode=MetadataMode.LLM)
# use the longest metadata str for splitting
if len(embed_metadata_str) > len(llm_metadata_str):
metadata_str = embed_metadata_str
else:
metadata_str = llm_metadata_str
return metadata_str
def _parse_nodes(
self, nodes: Sequence[BaseNode], show_progress: bool = False, **kwargs: Any
) -> List[BaseNode]:
all_nodes: List[BaseNode] = []
nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes")
for node in nodes_with_progress:
metadata_str = self._get_metadata_str(node)
splits = self.split_text_metadata_aware(
node.get_content(metadata_mode=MetadataMode.NONE),
metadata_str=metadata_str,
)
all_nodes.extend(
build_nodes_from_splits(splits, node, id_func=self.id_func)
)
return all_nodes