faiss_rag_enterprise/llama_index/node_parser/relational/base_element.py

336 lines
12 KiB
Python

import asyncio
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Tuple, cast
import pandas as pd
from tqdm import tqdm
from llama_index.async_utils import DEFAULT_NUM_WORKERS, run_jobs
from llama_index.bridge.pydantic import BaseModel, Field, ValidationError
from llama_index.callbacks.base import CallbackManager
from llama_index.core.response.schema import PydanticResponse
from llama_index.llms.llm import LLM
from llama_index.llms.openai import OpenAI
from llama_index.node_parser.interface import NodeParser
from llama_index.schema import BaseNode, Document, IndexNode, TextNode
from llama_index.utils import get_tqdm_iterable
DEFAULT_SUMMARY_QUERY_STR = """\
What is this table about? Give a very concise summary (imagine you are adding a new caption and summary for this table), \
and output the real/existing table title/caption if context provided.\
and output the real/existing table id if context provided.\
and also output whether or not the table should be kept.\
"""
class TableColumnOutput(BaseModel):
"""Output from analyzing a table column."""
col_name: str
col_type: str
summary: Optional[str] = None
def __str__(self) -> str:
"""Convert to string representation."""
return (
f"Column: {self.col_name}\nType: {self.col_type}\nSummary: {self.summary}"
)
class TableOutput(BaseModel):
"""Output from analyzing a table."""
summary: str
table_title: Optional[str] = None
table_id: Optional[str] = None
columns: List[TableColumnOutput]
class Element(BaseModel):
"""Element object."""
id: str
type: str
element: Any
title_level: Optional[int] = None
table_output: Optional[TableOutput] = None
table: Optional[pd.DataFrame] = None
class Config:
arbitrary_types_allowed = True
class BaseElementNodeParser(NodeParser):
"""
Splits a document into Text Nodes and Index Nodes corresponding to embedded objects.
Supports text and tables currently.
"""
callback_manager: CallbackManager = Field(
default_factory=CallbackManager, exclude=True
)
llm: Optional[LLM] = Field(
default=None, description="LLM model to use for summarization."
)
summary_query_str: str = Field(
default=DEFAULT_SUMMARY_QUERY_STR,
description="Query string to use for summarization.",
)
num_workers: int = Field(
default=DEFAULT_NUM_WORKERS,
description="Num of works for async jobs.",
)
show_progress: bool = Field(default=True, description="Whether to show progress.")
@classmethod
def class_name(cls) -> str:
return "BaseStructuredNodeParser"
@classmethod
def from_defaults(
cls,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> "BaseElementNodeParser":
callback_manager = callback_manager or CallbackManager([])
return cls(
callback_manager=callback_manager,
**kwargs,
)
def _parse_nodes(
self,
nodes: Sequence[BaseNode],
show_progress: bool = False,
**kwargs: Any,
) -> List[BaseNode]:
all_nodes: List[BaseNode] = []
nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes")
for node in nodes_with_progress:
nodes = self.get_nodes_from_node(node)
all_nodes.extend(nodes)
return all_nodes
@abstractmethod
def get_nodes_from_node(self, node: TextNode) -> List[BaseNode]:
"""Get nodes from node."""
@abstractmethod
def extract_elements(self, text: str, **kwargs: Any) -> List[Element]:
"""Extract elements from text."""
def get_table_elements(self, elements: List[Element]) -> List[Element]:
"""Get table elements."""
return [e for e in elements if e.type == "table" or e.type == "table_text"]
def get_text_elements(self, elements: List[Element]) -> List[Element]:
"""Get text elements."""
# TODO: There we should maybe do something with titles
# and other elements in the future?
return [e for e in elements if e.type != "table"]
def extract_table_summaries(self, elements: List[Element]) -> None:
"""Go through elements, extract out summaries that are tables."""
from llama_index.indices.list.base import SummaryIndex
from llama_index.service_context import ServiceContext
llm = self.llm or OpenAI()
llm = cast(LLM, llm)
service_context = ServiceContext.from_defaults(llm=llm, embed_model=None)
table_context_list = []
for idx, element in tqdm(enumerate(elements)):
if element.type not in ("table", "table_text"):
continue
table_context = str(element.element)
if idx > 0 and str(elements[idx - 1].element).lower().strip().startswith(
"table"
):
table_context = str(elements[idx - 1].element) + "\n" + table_context
if idx < len(elements) + 1 and str(
elements[idx - 1].element
).lower().strip().startswith("table"):
table_context += "\n" + str(elements[idx + 1].element)
table_context_list.append(table_context)
async def _get_table_output(table_context: str, summary_query_str: str) -> Any:
index = SummaryIndex.from_documents(
[Document(text=table_context)], service_context=service_context
)
query_engine = index.as_query_engine(output_cls=TableOutput)
try:
response = await query_engine.aquery(summary_query_str)
return cast(PydanticResponse, response).response
except ValidationError:
# There was a pydantic validation error, so we will run with text completion
# fill in the summary and leave other fields blank
query_engine = index.as_query_engine()
response_txt = await query_engine.aquery(summary_query_str)
return TableOutput(summary=str(response_txt), columns=[])
summary_jobs = [
_get_table_output(table_context, self.summary_query_str)
for table_context in table_context_list
]
summary_outputs = asyncio.run(
run_jobs(
summary_jobs, show_progress=self.show_progress, workers=self.num_workers
)
)
for element, summary_output in zip(elements, summary_outputs):
element.table_output = summary_output
def get_base_nodes_and_mappings(
self, nodes: List[BaseNode]
) -> Tuple[List[BaseNode], Dict]:
"""Get base nodes and mappings.
Given a list of nodes and IndexNode objects, return the base nodes and a mapping
from index id to child nodes (which are excluded from the base nodes).
"""
node_dict = {node.node_id: node for node in nodes}
node_mappings = {}
base_nodes = []
# first map index nodes to their child nodes
nonbase_node_ids = set()
for node in nodes:
if isinstance(node, IndexNode):
node_mappings[node.index_id] = node_dict[node.index_id]
nonbase_node_ids.add(node.index_id)
else:
pass
# then add all nodes that are not children of index nodes
for node in nodes:
if node.node_id not in nonbase_node_ids:
base_nodes.append(node)
return base_nodes, node_mappings
def get_nodes_and_objects(
self, nodes: List[BaseNode]
) -> Tuple[List[BaseNode], List[IndexNode]]:
base_nodes, node_mappings = self.get_base_nodes_and_mappings(nodes)
nodes = []
objects = []
for node in base_nodes:
if isinstance(node, IndexNode):
node.obj = node_mappings[node.index_id]
objects.append(node)
else:
nodes.append(node)
return nodes, objects
def _get_nodes_from_buffer(
self, buffer: List[str], node_parser: NodeParser
) -> List[BaseNode]:
"""Get nodes from buffer."""
doc = Document(text="\n\n".join(list(buffer)))
return node_parser.get_nodes_from_documents([doc])
def get_nodes_from_elements(self, elements: List[Element]) -> List[BaseNode]:
"""Get nodes and mappings."""
from llama_index.node_parser import SentenceSplitter
node_parser = SentenceSplitter()
nodes = []
cur_text_el_buffer: List[str] = []
for element in elements:
if element.type == "table" or element.type == "table_text":
# flush text buffer for table
if len(cur_text_el_buffer) > 0:
cur_text_nodes = self._get_nodes_from_buffer(
cur_text_el_buffer, node_parser
)
nodes.extend(cur_text_nodes)
cur_text_el_buffer = []
table_output = cast(TableOutput, element.table_output)
table_md = ""
if element.type == "table":
table_df = cast(pd.DataFrame, element.table)
# We serialize the table as markdown as it allow better accuracy
# We do not use the table_df.to_markdown() method as it generate
# a table with a token hungry format.
table_md = "|"
for col_name, col in table_df.items():
table_md += f"{col_name}|"
table_md += "\n|"
for col_name, col in table_df.items():
table_md += f"---|"
table_md += "\n"
for row in table_df.itertuples():
table_md += "|"
for col in row[1:]:
table_md += f"{col}|"
table_md += "\n"
elif element.type == "table_text":
# if the table is non-perfect table, we still want to keep the original text of table
table_md = str(element.element)
table_id = element.id + "_table"
table_ref_id = element.id + "_table_ref"
col_schema = "\n\n".join([str(col) for col in table_output.columns])
# We build a summary of the table containing the extracted summary, and a description of the columns
table_summary = str(table_output.summary)
if table_output.table_title:
table_summary += ",\nwith the following table title:\n"
table_summary += str(table_output.table_title)
table_summary += ",\nwith the following columns:\n"
for col in table_output.columns:
table_summary += f"- {col.col_name}: {col.summary}\n"
index_node = IndexNode(
text=table_summary,
metadata={"col_schema": col_schema},
excluded_embed_metadata_keys=["col_schema"],
id_=table_ref_id,
index_id=table_id,
)
table_str = table_summary + "\n" + table_md
text_node = TextNode(
text=table_str,
id_=table_id,
metadata={
# serialize the table as a dictionary string for dataframe of perfect table
"table_df": str(table_df.to_dict())
if element.type == "table"
else table_md,
# add table summary for retrieval purposes
"table_summary": table_summary,
},
excluded_embed_metadata_keys=["table_df", "table_summary"],
excluded_llm_metadata_keys=["table_df", "table_summary"],
)
nodes.extend([index_node, text_node])
else:
cur_text_el_buffer.append(str(element.element))
# flush text buffer
if len(cur_text_el_buffer) > 0:
cur_text_nodes = self._get_nodes_from_buffer(
cur_text_el_buffer, node_parser
)
nodes.extend(cur_text_nodes)
cur_text_el_buffer = []
# remove empty nodes
return [node for node in nodes if len(node.text) > 0]