faiss_rag_enterprise/llama_index/tools/query_plan.py

215 lines
7.5 KiB
Python

"""Query plan tool."""
from typing import Any, Dict, List, Optional
from llama_index.bridge.pydantic import BaseModel, Field
from llama_index.response_synthesizers import BaseSynthesizer, get_response_synthesizer
from llama_index.schema import NodeWithScore, TextNode
from llama_index.tools.types import BaseTool, ToolMetadata, ToolOutput
from llama_index.utils import print_text
DEFAULT_NAME = "query_plan_tool"
QUERYNODE_QUERY_STR_DESC = """\
Question we are asking. This is the query string that will be executed. \
"""
QUERYNODE_TOOL_NAME_DESC = """\
Name of the tool to execute the `query_str`. \
Should NOT be specified if there are subquestions to be specified, in which \
case child_nodes should be nonempty instead.\
"""
QUERYNODE_DEPENDENCIES_DESC = """\
List of sub-questions that need to be answered in order \
to answer the question given by `query_str`.\
Should be blank if there are no sub-questions to be specified, in which case \
`tool_name` is specified.\
"""
class QueryNode(BaseModel):
"""Query node.
A query node represents a query (query_str) that must be answered.
It can either be answered by a tool (tool_name), or by a list of child nodes
(child_nodes).
The tool_name and child_nodes fields are mutually exclusive.
"""
# NOTE: inspired from https://github.com/jxnl/openai_function_call/pull/3/files
id: int = Field(..., description="ID of the query node.")
query_str: str = Field(..., description=QUERYNODE_QUERY_STR_DESC)
tool_name: Optional[str] = Field(
default=None, description="Name of the tool to execute the `query_str`."
)
dependencies: List[int] = Field(
default_factory=list, description=QUERYNODE_DEPENDENCIES_DESC
)
class QueryPlan(BaseModel):
"""Query plan.
Contains a list of QueryNode objects (which is a recursive object).
Out of the list of QueryNode objects, one of them must be the root node.
The root node is the one that isn't a dependency of any other node.
"""
nodes: List[QueryNode] = Field(
...,
description="The original question we are asking.",
)
DEFAULT_DESCRIPTION_PREFIX = """\
This is a query plan tool that takes in a list of tools and executes a \
query plan over these tools to answer a query. The query plan is a DAG of query nodes.
Given a list of tool names and the query plan schema, you \
can choose to generate a query plan to answer a question.
The tool names and descriptions are as follows:
"""
class QueryPlanTool(BaseTool):
"""Query plan tool.
A tool that takes in a list of tools and executes a query plan.
"""
def __init__(
self,
query_engine_tools: List[BaseTool],
response_synthesizer: BaseSynthesizer,
name: str,
description_prefix: str,
) -> None:
"""Initialize."""
self._query_tools_dict = {t.metadata.name: t for t in query_engine_tools}
self._response_synthesizer = response_synthesizer
self._name = name
self._description_prefix = description_prefix
@classmethod
def from_defaults(
cls,
query_engine_tools: List[BaseTool],
response_synthesizer: Optional[BaseSynthesizer] = None,
name: Optional[str] = None,
description_prefix: Optional[str] = None,
) -> "QueryPlanTool":
"""Initialize from defaults."""
name = name or DEFAULT_NAME
description_prefix = description_prefix or DEFAULT_DESCRIPTION_PREFIX
response_synthesizer = response_synthesizer or get_response_synthesizer()
return cls(
query_engine_tools=query_engine_tools,
response_synthesizer=response_synthesizer,
name=name,
description_prefix=description_prefix,
)
@property
def metadata(self) -> ToolMetadata:
"""Metadata."""
tools_description = "\n\n".join(
[
f"Tool Name: {tool.metadata.name}\n"
+ f"Tool Description: {tool.metadata.description} "
for tool in self._query_tools_dict.values()
]
)
# TODO: fill in description with query engine tools.
description = f"""\
{self._description_prefix}\n\n
{tools_description}
"""
return ToolMetadata(description, self._name, fn_schema=QueryPlan)
def _execute_node(
self, node: QueryNode, nodes_dict: Dict[int, QueryNode]
) -> ToolOutput:
"""Execute node."""
print_text(f"Executing node {node.json()}\n", color="blue")
if len(node.dependencies) > 0:
print_text(
f"Executing {len(node.dependencies)} child nodes\n", color="pink"
)
child_query_nodes: List[QueryNode] = [
nodes_dict[dep] for dep in node.dependencies
]
# execute the child nodes first
child_responses: List[ToolOutput] = [
self._execute_node(child, nodes_dict) for child in child_query_nodes
]
# form the child Node/NodeWithScore objects
child_nodes = []
for child_query_node, child_response in zip(
child_query_nodes, child_responses
):
node_text = (
f"Query: {child_query_node.query_str}\n"
f"Response: {child_response!s}\n"
)
child_node = TextNode(text=node_text)
child_nodes.append(child_node)
# use response synthesizer to combine results
child_nodes_with_scores = [
NodeWithScore(node=n, score=1.0) for n in child_nodes
]
response_obj = self._response_synthesizer.synthesize(
query=node.query_str,
nodes=child_nodes_with_scores,
)
response = ToolOutput(
content=str(response_obj),
tool_name=node.query_str,
raw_input={"query": node.query_str},
raw_output=response_obj,
)
else:
# this is a leaf request, execute the query string using the specified tool
tool = self._query_tools_dict[node.tool_name]
print_text(f"Selected Tool: {tool.metadata}\n", color="pink")
response = tool(node.query_str)
print_text(
"Executed query, got response.\n"
f"Query: {node.query_str}\n"
f"Response: {response!s}\n",
color="blue",
)
return response
def _find_root_nodes(self, nodes_dict: Dict[int, QueryNode]) -> List[QueryNode]:
"""Find root node."""
# the root node is the one that isn't a dependency of any other node
node_counts = {node_id: 0 for node_id in nodes_dict}
for node in nodes_dict.values():
for dep in node.dependencies:
node_counts[dep] += 1
root_node_ids = [
node_id for node_id, count in node_counts.items() if count == 0
]
return [nodes_dict[node_id] for node_id in root_node_ids]
def __call__(self, *args: Any, **kwargs: Any) -> ToolOutput:
"""Call."""
# the kwargs represented as a JSON object
# should be a QueryPlan object
query_plan = QueryPlan(**kwargs)
nodes_dict = {node.id: node for node in query_plan.nodes}
root_nodes = self._find_root_nodes(nodes_dict)
if len(root_nodes) > 1:
raise ValueError("Query plan should have exactly one root node.")
return self._execute_node(root_nodes[0], nodes_dict)