faiss_rag_enterprise/llama_index/node_parser/file/markdown.py

122 lines
3.9 KiB
Python

"""Markdown node parser."""
import re
from typing import Any, Dict, List, Optional, Sequence
from llama_index.callbacks.base import CallbackManager
from llama_index.node_parser.interface import NodeParser
from llama_index.node_parser.node_utils import build_nodes_from_splits
from llama_index.schema import BaseNode, MetadataMode, TextNode
from llama_index.utils import get_tqdm_iterable
class MarkdownNodeParser(NodeParser):
"""Markdown node parser.
Splits a document into Nodes using custom Markdown splitting logic.
Args:
include_metadata (bool): whether to include metadata in nodes
include_prev_next_rel (bool): whether to include prev/next relationships
"""
@classmethod
def from_defaults(
cls,
include_metadata: bool = True,
include_prev_next_rel: bool = True,
callback_manager: Optional[CallbackManager] = None,
) -> "MarkdownNodeParser":
callback_manager = callback_manager or CallbackManager([])
return cls(
include_metadata=include_metadata,
include_prev_next_rel=include_prev_next_rel,
callback_manager=callback_manager,
)
@classmethod
def class_name(cls) -> str:
"""Get class name."""
return "MarkdownNodeParser"
def _parse_nodes(
self,
nodes: Sequence[BaseNode],
show_progress: bool = False,
**kwargs: Any,
) -> List[BaseNode]:
all_nodes: List[BaseNode] = []
nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes")
for node in nodes_with_progress:
nodes = self.get_nodes_from_node(node)
all_nodes.extend(nodes)
return all_nodes
def get_nodes_from_node(self, node: BaseNode) -> List[TextNode]:
"""Get nodes from document."""
text = node.get_content(metadata_mode=MetadataMode.NONE)
markdown_nodes = []
lines = text.split("\n")
metadata: Dict[str, str] = {}
code_block = False
current_section = ""
for line in lines:
if line.startswith("```"):
code_block = not code_block
header_match = re.match(r"^(#+)\s(.*)", line)
if header_match and not code_block:
if current_section != "":
markdown_nodes.append(
self._build_node_from_split(
current_section.strip(), node, metadata
)
)
metadata = self._update_metadata(
metadata, header_match.group(2), len(header_match.group(1).strip())
)
current_section = f"{header_match.group(2)}\n"
else:
current_section += line + "\n"
markdown_nodes.append(
self._build_node_from_split(current_section.strip(), node, metadata)
)
return markdown_nodes
def _update_metadata(
self, headers_metadata: dict, new_header: str, new_header_level: int
) -> dict:
"""Update the markdown headers for metadata.
Removes all headers that are equal or less than the level
of the newly found header
"""
updated_headers = {}
for i in range(1, new_header_level):
key = f"Header {i}"
if key in headers_metadata:
updated_headers[key] = headers_metadata[key]
updated_headers[f"Header {new_header_level}"] = new_header
return updated_headers
def _build_node_from_split(
self,
text_split: str,
node: BaseNode,
metadata: dict,
) -> TextNode:
"""Build node from single text split."""
node = build_nodes_from_splits([text_split], node, id_func=self.id_func)[0]
if self.include_metadata:
node.metadata = {**node.metadata, **metadata}
return node