336 lines
12 KiB
Python
336 lines
12 KiB
Python
import asyncio
|
|
from abc import abstractmethod
|
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, cast
|
|
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
|
|
from llama_index.async_utils import DEFAULT_NUM_WORKERS, run_jobs
|
|
from llama_index.bridge.pydantic import BaseModel, Field, ValidationError
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.core.response.schema import PydanticResponse
|
|
from llama_index.llms.llm import LLM
|
|
from llama_index.llms.openai import OpenAI
|
|
from llama_index.node_parser.interface import NodeParser
|
|
from llama_index.schema import BaseNode, Document, IndexNode, TextNode
|
|
from llama_index.utils import get_tqdm_iterable
|
|
|
|
DEFAULT_SUMMARY_QUERY_STR = """\
|
|
What is this table about? Give a very concise summary (imagine you are adding a new caption and summary for this table), \
|
|
and output the real/existing table title/caption if context provided.\
|
|
and output the real/existing table id if context provided.\
|
|
and also output whether or not the table should be kept.\
|
|
"""
|
|
|
|
|
|
class TableColumnOutput(BaseModel):
|
|
"""Output from analyzing a table column."""
|
|
|
|
col_name: str
|
|
col_type: str
|
|
summary: Optional[str] = None
|
|
|
|
def __str__(self) -> str:
|
|
"""Convert to string representation."""
|
|
return (
|
|
f"Column: {self.col_name}\nType: {self.col_type}\nSummary: {self.summary}"
|
|
)
|
|
|
|
|
|
class TableOutput(BaseModel):
|
|
"""Output from analyzing a table."""
|
|
|
|
summary: str
|
|
table_title: Optional[str] = None
|
|
table_id: Optional[str] = None
|
|
columns: List[TableColumnOutput]
|
|
|
|
|
|
class Element(BaseModel):
|
|
"""Element object."""
|
|
|
|
id: str
|
|
type: str
|
|
element: Any
|
|
title_level: Optional[int] = None
|
|
table_output: Optional[TableOutput] = None
|
|
table: Optional[pd.DataFrame] = None
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
|
|
class BaseElementNodeParser(NodeParser):
|
|
"""
|
|
Splits a document into Text Nodes and Index Nodes corresponding to embedded objects.
|
|
|
|
Supports text and tables currently.
|
|
"""
|
|
|
|
callback_manager: CallbackManager = Field(
|
|
default_factory=CallbackManager, exclude=True
|
|
)
|
|
llm: Optional[LLM] = Field(
|
|
default=None, description="LLM model to use for summarization."
|
|
)
|
|
summary_query_str: str = Field(
|
|
default=DEFAULT_SUMMARY_QUERY_STR,
|
|
description="Query string to use for summarization.",
|
|
)
|
|
num_workers: int = Field(
|
|
default=DEFAULT_NUM_WORKERS,
|
|
description="Num of works for async jobs.",
|
|
)
|
|
|
|
show_progress: bool = Field(default=True, description="Whether to show progress.")
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "BaseStructuredNodeParser"
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
**kwargs: Any,
|
|
) -> "BaseElementNodeParser":
|
|
callback_manager = callback_manager or CallbackManager([])
|
|
|
|
return cls(
|
|
callback_manager=callback_manager,
|
|
**kwargs,
|
|
)
|
|
|
|
def _parse_nodes(
|
|
self,
|
|
nodes: Sequence[BaseNode],
|
|
show_progress: bool = False,
|
|
**kwargs: Any,
|
|
) -> List[BaseNode]:
|
|
all_nodes: List[BaseNode] = []
|
|
nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes")
|
|
|
|
for node in nodes_with_progress:
|
|
nodes = self.get_nodes_from_node(node)
|
|
all_nodes.extend(nodes)
|
|
|
|
return all_nodes
|
|
|
|
@abstractmethod
|
|
def get_nodes_from_node(self, node: TextNode) -> List[BaseNode]:
|
|
"""Get nodes from node."""
|
|
|
|
@abstractmethod
|
|
def extract_elements(self, text: str, **kwargs: Any) -> List[Element]:
|
|
"""Extract elements from text."""
|
|
|
|
def get_table_elements(self, elements: List[Element]) -> List[Element]:
|
|
"""Get table elements."""
|
|
return [e for e in elements if e.type == "table" or e.type == "table_text"]
|
|
|
|
def get_text_elements(self, elements: List[Element]) -> List[Element]:
|
|
"""Get text elements."""
|
|
# TODO: There we should maybe do something with titles
|
|
# and other elements in the future?
|
|
return [e for e in elements if e.type != "table"]
|
|
|
|
def extract_table_summaries(self, elements: List[Element]) -> None:
|
|
"""Go through elements, extract out summaries that are tables."""
|
|
from llama_index.indices.list.base import SummaryIndex
|
|
from llama_index.service_context import ServiceContext
|
|
|
|
llm = self.llm or OpenAI()
|
|
llm = cast(LLM, llm)
|
|
|
|
service_context = ServiceContext.from_defaults(llm=llm, embed_model=None)
|
|
|
|
table_context_list = []
|
|
for idx, element in tqdm(enumerate(elements)):
|
|
if element.type not in ("table", "table_text"):
|
|
continue
|
|
table_context = str(element.element)
|
|
if idx > 0 and str(elements[idx - 1].element).lower().strip().startswith(
|
|
"table"
|
|
):
|
|
table_context = str(elements[idx - 1].element) + "\n" + table_context
|
|
if idx < len(elements) + 1 and str(
|
|
elements[idx - 1].element
|
|
).lower().strip().startswith("table"):
|
|
table_context += "\n" + str(elements[idx + 1].element)
|
|
|
|
table_context_list.append(table_context)
|
|
|
|
async def _get_table_output(table_context: str, summary_query_str: str) -> Any:
|
|
index = SummaryIndex.from_documents(
|
|
[Document(text=table_context)], service_context=service_context
|
|
)
|
|
query_engine = index.as_query_engine(output_cls=TableOutput)
|
|
try:
|
|
response = await query_engine.aquery(summary_query_str)
|
|
return cast(PydanticResponse, response).response
|
|
except ValidationError:
|
|
# There was a pydantic validation error, so we will run with text completion
|
|
# fill in the summary and leave other fields blank
|
|
query_engine = index.as_query_engine()
|
|
response_txt = await query_engine.aquery(summary_query_str)
|
|
return TableOutput(summary=str(response_txt), columns=[])
|
|
|
|
summary_jobs = [
|
|
_get_table_output(table_context, self.summary_query_str)
|
|
for table_context in table_context_list
|
|
]
|
|
summary_outputs = asyncio.run(
|
|
run_jobs(
|
|
summary_jobs, show_progress=self.show_progress, workers=self.num_workers
|
|
)
|
|
)
|
|
for element, summary_output in zip(elements, summary_outputs):
|
|
element.table_output = summary_output
|
|
|
|
def get_base_nodes_and_mappings(
|
|
self, nodes: List[BaseNode]
|
|
) -> Tuple[List[BaseNode], Dict]:
|
|
"""Get base nodes and mappings.
|
|
|
|
Given a list of nodes and IndexNode objects, return the base nodes and a mapping
|
|
from index id to child nodes (which are excluded from the base nodes).
|
|
|
|
"""
|
|
node_dict = {node.node_id: node for node in nodes}
|
|
|
|
node_mappings = {}
|
|
base_nodes = []
|
|
|
|
# first map index nodes to their child nodes
|
|
nonbase_node_ids = set()
|
|
for node in nodes:
|
|
if isinstance(node, IndexNode):
|
|
node_mappings[node.index_id] = node_dict[node.index_id]
|
|
nonbase_node_ids.add(node.index_id)
|
|
else:
|
|
pass
|
|
|
|
# then add all nodes that are not children of index nodes
|
|
for node in nodes:
|
|
if node.node_id not in nonbase_node_ids:
|
|
base_nodes.append(node)
|
|
|
|
return base_nodes, node_mappings
|
|
|
|
def get_nodes_and_objects(
|
|
self, nodes: List[BaseNode]
|
|
) -> Tuple[List[BaseNode], List[IndexNode]]:
|
|
base_nodes, node_mappings = self.get_base_nodes_and_mappings(nodes)
|
|
|
|
nodes = []
|
|
objects = []
|
|
for node in base_nodes:
|
|
if isinstance(node, IndexNode):
|
|
node.obj = node_mappings[node.index_id]
|
|
objects.append(node)
|
|
else:
|
|
nodes.append(node)
|
|
|
|
return nodes, objects
|
|
|
|
def _get_nodes_from_buffer(
|
|
self, buffer: List[str], node_parser: NodeParser
|
|
) -> List[BaseNode]:
|
|
"""Get nodes from buffer."""
|
|
doc = Document(text="\n\n".join(list(buffer)))
|
|
return node_parser.get_nodes_from_documents([doc])
|
|
|
|
def get_nodes_from_elements(self, elements: List[Element]) -> List[BaseNode]:
|
|
"""Get nodes and mappings."""
|
|
from llama_index.node_parser import SentenceSplitter
|
|
|
|
node_parser = SentenceSplitter()
|
|
|
|
nodes = []
|
|
cur_text_el_buffer: List[str] = []
|
|
for element in elements:
|
|
if element.type == "table" or element.type == "table_text":
|
|
# flush text buffer for table
|
|
if len(cur_text_el_buffer) > 0:
|
|
cur_text_nodes = self._get_nodes_from_buffer(
|
|
cur_text_el_buffer, node_parser
|
|
)
|
|
nodes.extend(cur_text_nodes)
|
|
cur_text_el_buffer = []
|
|
|
|
table_output = cast(TableOutput, element.table_output)
|
|
table_md = ""
|
|
if element.type == "table":
|
|
table_df = cast(pd.DataFrame, element.table)
|
|
# We serialize the table as markdown as it allow better accuracy
|
|
# We do not use the table_df.to_markdown() method as it generate
|
|
# a table with a token hungry format.
|
|
table_md = "|"
|
|
for col_name, col in table_df.items():
|
|
table_md += f"{col_name}|"
|
|
table_md += "\n|"
|
|
for col_name, col in table_df.items():
|
|
table_md += f"---|"
|
|
table_md += "\n"
|
|
for row in table_df.itertuples():
|
|
table_md += "|"
|
|
for col in row[1:]:
|
|
table_md += f"{col}|"
|
|
table_md += "\n"
|
|
elif element.type == "table_text":
|
|
# if the table is non-perfect table, we still want to keep the original text of table
|
|
table_md = str(element.element)
|
|
table_id = element.id + "_table"
|
|
table_ref_id = element.id + "_table_ref"
|
|
|
|
col_schema = "\n\n".join([str(col) for col in table_output.columns])
|
|
|
|
# We build a summary of the table containing the extracted summary, and a description of the columns
|
|
table_summary = str(table_output.summary)
|
|
if table_output.table_title:
|
|
table_summary += ",\nwith the following table title:\n"
|
|
table_summary += str(table_output.table_title)
|
|
|
|
table_summary += ",\nwith the following columns:\n"
|
|
|
|
for col in table_output.columns:
|
|
table_summary += f"- {col.col_name}: {col.summary}\n"
|
|
|
|
index_node = IndexNode(
|
|
text=table_summary,
|
|
metadata={"col_schema": col_schema},
|
|
excluded_embed_metadata_keys=["col_schema"],
|
|
id_=table_ref_id,
|
|
index_id=table_id,
|
|
)
|
|
|
|
table_str = table_summary + "\n" + table_md
|
|
|
|
text_node = TextNode(
|
|
text=table_str,
|
|
id_=table_id,
|
|
metadata={
|
|
# serialize the table as a dictionary string for dataframe of perfect table
|
|
"table_df": str(table_df.to_dict())
|
|
if element.type == "table"
|
|
else table_md,
|
|
# add table summary for retrieval purposes
|
|
"table_summary": table_summary,
|
|
},
|
|
excluded_embed_metadata_keys=["table_df", "table_summary"],
|
|
excluded_llm_metadata_keys=["table_df", "table_summary"],
|
|
)
|
|
nodes.extend([index_node, text_node])
|
|
else:
|
|
cur_text_el_buffer.append(str(element.element))
|
|
# flush text buffer
|
|
if len(cur_text_el_buffer) > 0:
|
|
cur_text_nodes = self._get_nodes_from_buffer(
|
|
cur_text_el_buffer, node_parser
|
|
)
|
|
nodes.extend(cur_text_nodes)
|
|
cur_text_el_buffer = []
|
|
|
|
# remove empty nodes
|
|
return [node for node in nodes if len(node.text) > 0]
|