122 lines
3.9 KiB
Python
122 lines
3.9 KiB
Python
"""Markdown node parser."""
|
|
import re
|
|
from typing import Any, Dict, List, Optional, Sequence
|
|
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.node_parser.interface import NodeParser
|
|
from llama_index.node_parser.node_utils import build_nodes_from_splits
|
|
from llama_index.schema import BaseNode, MetadataMode, TextNode
|
|
from llama_index.utils import get_tqdm_iterable
|
|
|
|
|
|
class MarkdownNodeParser(NodeParser):
|
|
"""Markdown node parser.
|
|
|
|
Splits a document into Nodes using custom Markdown splitting logic.
|
|
|
|
Args:
|
|
include_metadata (bool): whether to include metadata in nodes
|
|
include_prev_next_rel (bool): whether to include prev/next relationships
|
|
|
|
"""
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
include_metadata: bool = True,
|
|
include_prev_next_rel: bool = True,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
) -> "MarkdownNodeParser":
|
|
callback_manager = callback_manager or CallbackManager([])
|
|
|
|
return cls(
|
|
include_metadata=include_metadata,
|
|
include_prev_next_rel=include_prev_next_rel,
|
|
callback_manager=callback_manager,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
"""Get class name."""
|
|
return "MarkdownNodeParser"
|
|
|
|
def _parse_nodes(
|
|
self,
|
|
nodes: Sequence[BaseNode],
|
|
show_progress: bool = False,
|
|
**kwargs: Any,
|
|
) -> List[BaseNode]:
|
|
all_nodes: List[BaseNode] = []
|
|
nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes")
|
|
|
|
for node in nodes_with_progress:
|
|
nodes = self.get_nodes_from_node(node)
|
|
all_nodes.extend(nodes)
|
|
|
|
return all_nodes
|
|
|
|
def get_nodes_from_node(self, node: BaseNode) -> List[TextNode]:
|
|
"""Get nodes from document."""
|
|
text = node.get_content(metadata_mode=MetadataMode.NONE)
|
|
markdown_nodes = []
|
|
lines = text.split("\n")
|
|
metadata: Dict[str, str] = {}
|
|
code_block = False
|
|
current_section = ""
|
|
|
|
for line in lines:
|
|
if line.startswith("```"):
|
|
code_block = not code_block
|
|
header_match = re.match(r"^(#+)\s(.*)", line)
|
|
if header_match and not code_block:
|
|
if current_section != "":
|
|
markdown_nodes.append(
|
|
self._build_node_from_split(
|
|
current_section.strip(), node, metadata
|
|
)
|
|
)
|
|
metadata = self._update_metadata(
|
|
metadata, header_match.group(2), len(header_match.group(1).strip())
|
|
)
|
|
current_section = f"{header_match.group(2)}\n"
|
|
else:
|
|
current_section += line + "\n"
|
|
|
|
markdown_nodes.append(
|
|
self._build_node_from_split(current_section.strip(), node, metadata)
|
|
)
|
|
|
|
return markdown_nodes
|
|
|
|
def _update_metadata(
|
|
self, headers_metadata: dict, new_header: str, new_header_level: int
|
|
) -> dict:
|
|
"""Update the markdown headers for metadata.
|
|
|
|
Removes all headers that are equal or less than the level
|
|
of the newly found header
|
|
"""
|
|
updated_headers = {}
|
|
|
|
for i in range(1, new_header_level):
|
|
key = f"Header {i}"
|
|
if key in headers_metadata:
|
|
updated_headers[key] = headers_metadata[key]
|
|
|
|
updated_headers[f"Header {new_header_level}"] = new_header
|
|
return updated_headers
|
|
|
|
def _build_node_from_split(
|
|
self,
|
|
text_split: str,
|
|
node: BaseNode,
|
|
metadata: dict,
|
|
) -> TextNode:
|
|
"""Build node from single text split."""
|
|
node = build_nodes_from_splits([text_split], node, id_func=self.id_func)[0]
|
|
|
|
if self.include_metadata:
|
|
node.metadata = {**node.metadata, **metadata}
|
|
|
|
return node
|