from io import StringIO from typing import Any, Callable, List, Optional import pandas as pd from llama_index.node_parser.relational.base_element import ( BaseElementNodeParser, Element, ) from llama_index.schema import BaseNode, TextNode def md_to_df(md_str: str) -> pd.DataFrame: """Convert Markdown to dataframe.""" # Replace " by "" in md_str md_str = md_str.replace('"', '""') # Replace markdown pipe tables with commas md_str = md_str.replace("|", '","') # Remove the second line (table header separator) lines = md_str.split("\n") md_str = "\n".join(lines[:1] + lines[2:]) # Remove the first and last second char of the line (the pipes, transformed to ",") lines = md_str.split("\n") md_str = "\n".join([line[2:-2] for line in lines]) # Check if the table is empty if len(md_str) == 0: return None # Use pandas to read the CSV string into a DataFrame return pd.read_csv(StringIO(md_str)) class MarkdownElementNodeParser(BaseElementNodeParser): """Markdown element node parser. Splits a markdown document into Text Nodes and Index Nodes corresponding to embedded objects (e.g. tables). """ @classmethod def class_name(cls) -> str: return "MarkdownElementNodeParser" def get_nodes_from_node(self, node: TextNode) -> List[BaseNode]: """Get nodes from node.""" elements = self.extract_elements( node.get_content(), table_filters=[self.filter_table], node_id=node.id_, ) table_elements = self.get_table_elements(elements) # extract summaries over table elements self.extract_table_summaries(table_elements) # convert into nodes # will return a list of Nodes and Index Nodes return self.get_nodes_from_elements(elements) def extract_elements( self, text: str, node_id: Optional[str] = None, table_filters: Optional[List[Callable]] = None, **kwargs: Any, ) -> List[Element]: # get node id for each node so that we can avoid using the same id for different nodes """Extract elements from text.""" lines = text.split("\n") currentElement = None elements: List[Element] = [] # Then parse the lines for line in lines: if line.startswith("```"): # check if this is the end of a code block if currentElement is not None and currentElement.type == "code": elements.append(currentElement) currentElement = None # if there is some text after the ``` create a text element with it if len(line) > 3: elements.append( Element( id=f"id_{len(elements)}", type="text", element=line.lstrip("```"), ) ) elif line.count("```") == 2 and line[-3] != "`": # check if inline code block (aka have a second ``` in line but not at the end) if currentElement is not None: elements.append(currentElement) currentElement = Element( id=f"id_{len(elements)}", type="code", element=line.lstrip("```"), ) elif currentElement is not None and currentElement.type == "text": currentElement.element += "\n" + line else: if currentElement is not None: elements.append(currentElement) currentElement = Element( id=f"id_{len(elements)}", type="text", element=line ) elif currentElement is not None and currentElement.type == "code": currentElement.element += "\n" + line elif line.startswith("|"): if currentElement is not None and currentElement.type != "table": if currentElement is not None: elements.append(currentElement) currentElement = Element( id=f"id_{len(elements)}", type="table", element=line ) elif currentElement is not None: currentElement.element += "\n" + line else: currentElement = Element( id=f"id_{len(elements)}", type="table", element=line ) elif line.startswith("#"): if currentElement is not None: elements.append(currentElement) currentElement = Element( id=f"id_{len(elements)}", type="title", element=line.lstrip("#"), title_level=len(line) - len(line.lstrip("#")), ) else: if currentElement is not None and currentElement.type != "text": elements.append(currentElement) currentElement = Element( id=f"id_{len(elements)}", type="text", element=line ) elif currentElement is not None: currentElement.element += "\n" + line else: currentElement = Element( id=f"id_{len(elements)}", type="text", element=line ) if currentElement is not None: elements.append(currentElement) for idx, element in enumerate(elements): if element.type == "table": should_keep = True perfect_table = True # verify that the table (markdown) have the same number of columns on each rows table_lines = element.element.split("\n") table_columns = [len(line.split("|")) for line in table_lines] if len(set(table_columns)) > 1: # if the table have different number of columns on each rows, it's not a perfect table # we will store the raw text for such tables instead of converting them to a dataframe perfect_table = False # verify that the table (markdown) have at least 2 rows if len(table_lines) < 2: should_keep = False # apply the table filter, now only filter empty tables if should_keep and perfect_table and table_filters is not None: should_keep = all(tf(element) for tf in table_filters) # if the element is a table, convert it to a dataframe if should_keep: if perfect_table: table = md_to_df(element.element) elements[idx] = Element( id=f"id_{node_id}_{idx}" if node_id else f"id_{idx}", type="table", element=element, table=table, ) else: # for non-perfect tables, we will store the raw text # and give it a different type to differentiate it from perfect tables elements[idx] = Element( id=f"id_{node_id}_{idx}" if node_id else f"id_{idx}", type="table_text", element=element.element, # table=table ) else: elements[idx] = Element( id=f"id_{node_id}_{idx}" if node_id else f"id_{idx}", type="text", element=element.element, ) else: # if the element is not a table, keep it as to text elements[idx] = Element( id=f"id_{node_id}_{idx}" if node_id else f"id_{idx}", type="text", element=element.element, ) # merge consecutive text elements together for now merged_elements: List[Element] = [] for element in elements: if ( len(merged_elements) > 0 and element.type == "text" and merged_elements[-1].type == "text" ): merged_elements[-1].element += "\n" + element.element else: merged_elements.append(element) elements = merged_elements return merged_elements def filter_table(self, table_element: Any) -> bool: """Filter tables.""" table_df = md_to_df(table_element.element) # check if table_df is not None, has more than one row, and more than one column return table_df is not None and not table_df.empty and len(table_df.columns) > 1