faiss_rag_enterprise/llama_index/node_parser/relational/markdown_element.py

226 lines
9.0 KiB
Python

from io import StringIO
from typing import Any, Callable, List, Optional
import pandas as pd
from llama_index.node_parser.relational.base_element import (
BaseElementNodeParser,
Element,
)
from llama_index.schema import BaseNode, TextNode
def md_to_df(md_str: str) -> pd.DataFrame:
"""Convert Markdown to dataframe."""
# Replace " by "" in md_str
md_str = md_str.replace('"', '""')
# Replace markdown pipe tables with commas
md_str = md_str.replace("|", '","')
# Remove the second line (table header separator)
lines = md_str.split("\n")
md_str = "\n".join(lines[:1] + lines[2:])
# Remove the first and last second char of the line (the pipes, transformed to ",")
lines = md_str.split("\n")
md_str = "\n".join([line[2:-2] for line in lines])
# Check if the table is empty
if len(md_str) == 0:
return None
# Use pandas to read the CSV string into a DataFrame
return pd.read_csv(StringIO(md_str))
class MarkdownElementNodeParser(BaseElementNodeParser):
"""Markdown element node parser.
Splits a markdown document into Text Nodes and Index Nodes corresponding to embedded objects
(e.g. tables).
"""
@classmethod
def class_name(cls) -> str:
return "MarkdownElementNodeParser"
def get_nodes_from_node(self, node: TextNode) -> List[BaseNode]:
"""Get nodes from node."""
elements = self.extract_elements(
node.get_content(),
table_filters=[self.filter_table],
node_id=node.id_,
)
table_elements = self.get_table_elements(elements)
# extract summaries over table elements
self.extract_table_summaries(table_elements)
# convert into nodes
# will return a list of Nodes and Index Nodes
return self.get_nodes_from_elements(elements)
def extract_elements(
self,
text: str,
node_id: Optional[str] = None,
table_filters: Optional[List[Callable]] = None,
**kwargs: Any,
) -> List[Element]:
# get node id for each node so that we can avoid using the same id for different nodes
"""Extract elements from text."""
lines = text.split("\n")
currentElement = None
elements: List[Element] = []
# Then parse the lines
for line in lines:
if line.startswith("```"):
# check if this is the end of a code block
if currentElement is not None and currentElement.type == "code":
elements.append(currentElement)
currentElement = None
# if there is some text after the ``` create a text element with it
if len(line) > 3:
elements.append(
Element(
id=f"id_{len(elements)}",
type="text",
element=line.lstrip("```"),
)
)
elif line.count("```") == 2 and line[-3] != "`":
# check if inline code block (aka have a second ``` in line but not at the end)
if currentElement is not None:
elements.append(currentElement)
currentElement = Element(
id=f"id_{len(elements)}",
type="code",
element=line.lstrip("```"),
)
elif currentElement is not None and currentElement.type == "text":
currentElement.element += "\n" + line
else:
if currentElement is not None:
elements.append(currentElement)
currentElement = Element(
id=f"id_{len(elements)}", type="text", element=line
)
elif currentElement is not None and currentElement.type == "code":
currentElement.element += "\n" + line
elif line.startswith("|"):
if currentElement is not None and currentElement.type != "table":
if currentElement is not None:
elements.append(currentElement)
currentElement = Element(
id=f"id_{len(elements)}", type="table", element=line
)
elif currentElement is not None:
currentElement.element += "\n" + line
else:
currentElement = Element(
id=f"id_{len(elements)}", type="table", element=line
)
elif line.startswith("#"):
if currentElement is not None:
elements.append(currentElement)
currentElement = Element(
id=f"id_{len(elements)}",
type="title",
element=line.lstrip("#"),
title_level=len(line) - len(line.lstrip("#")),
)
else:
if currentElement is not None and currentElement.type != "text":
elements.append(currentElement)
currentElement = Element(
id=f"id_{len(elements)}", type="text", element=line
)
elif currentElement is not None:
currentElement.element += "\n" + line
else:
currentElement = Element(
id=f"id_{len(elements)}", type="text", element=line
)
if currentElement is not None:
elements.append(currentElement)
for idx, element in enumerate(elements):
if element.type == "table":
should_keep = True
perfect_table = True
# verify that the table (markdown) have the same number of columns on each rows
table_lines = element.element.split("\n")
table_columns = [len(line.split("|")) for line in table_lines]
if len(set(table_columns)) > 1:
# if the table have different number of columns on each rows, it's not a perfect table
# we will store the raw text for such tables instead of converting them to a dataframe
perfect_table = False
# verify that the table (markdown) have at least 2 rows
if len(table_lines) < 2:
should_keep = False
# apply the table filter, now only filter empty tables
if should_keep and perfect_table and table_filters is not None:
should_keep = all(tf(element) for tf in table_filters)
# if the element is a table, convert it to a dataframe
if should_keep:
if perfect_table:
table = md_to_df(element.element)
elements[idx] = Element(
id=f"id_{node_id}_{idx}" if node_id else f"id_{idx}",
type="table",
element=element,
table=table,
)
else:
# for non-perfect tables, we will store the raw text
# and give it a different type to differentiate it from perfect tables
elements[idx] = Element(
id=f"id_{node_id}_{idx}" if node_id else f"id_{idx}",
type="table_text",
element=element.element,
# table=table
)
else:
elements[idx] = Element(
id=f"id_{node_id}_{idx}" if node_id else f"id_{idx}",
type="text",
element=element.element,
)
else:
# if the element is not a table, keep it as to text
elements[idx] = Element(
id=f"id_{node_id}_{idx}" if node_id else f"id_{idx}",
type="text",
element=element.element,
)
# merge consecutive text elements together for now
merged_elements: List[Element] = []
for element in elements:
if (
len(merged_elements) > 0
and element.type == "text"
and merged_elements[-1].type == "text"
):
merged_elements[-1].element += "\n" + element.element
else:
merged_elements.append(element)
elements = merged_elements
return merged_elements
def filter_table(self, table_element: Any) -> bool:
"""Filter tables."""
table_df = md_to_df(table_element.element)
# check if table_df is not None, has more than one row, and more than one column
return table_df is not None and not table_df.empty and len(table_df.columns) > 1