faiss_rag_enterprise/llama_index/query_engine/multistep_query_engine.py

173 lines
6.5 KiB
Python

from typing import Any, Callable, Dict, List, Optional, Tuple, cast
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.core.base_query_engine import BaseQueryEngine
from llama_index.core.response.schema import RESPONSE_TYPE
from llama_index.indices.query.query_transform.base import StepDecomposeQueryTransform
from llama_index.prompts.mixin import PromptMixinType
from llama_index.response_synthesizers import BaseSynthesizer, get_response_synthesizer
from llama_index.schema import NodeWithScore, QueryBundle, TextNode
def default_stop_fn(stop_dict: Dict) -> bool:
"""Stop function for multi-step query combiner."""
query_bundle = cast(QueryBundle, stop_dict.get("query_bundle"))
if query_bundle is None:
raise ValueError("Response must be provided to stop function.")
return "none" in query_bundle.query_str.lower()
class MultiStepQueryEngine(BaseQueryEngine):
"""Multi-step query engine.
This query engine can operate over an existing base query engine,
along with the multi-step query transform.
Args:
query_engine (BaseQueryEngine): A BaseQueryEngine object.
query_transform (StepDecomposeQueryTransform): A StepDecomposeQueryTransform
object.
response_synthesizer (Optional[BaseSynthesizer]): A BaseSynthesizer
object.
num_steps (Optional[int]): Number of steps to run the multi-step query.
early_stopping (bool): Whether to stop early if the stop function returns True.
index_summary (str): A string summary of the index.
stop_fn (Optional[Callable[[Dict], bool]]): A stop function that takes in a
dictionary of information and returns a boolean.
"""
def __init__(
self,
query_engine: BaseQueryEngine,
query_transform: StepDecomposeQueryTransform,
response_synthesizer: Optional[BaseSynthesizer] = None,
num_steps: Optional[int] = 3,
early_stopping: bool = True,
index_summary: str = "None",
stop_fn: Optional[Callable[[Dict], bool]] = None,
) -> None:
self._query_engine = query_engine
self._query_transform = query_transform
self._response_synthesizer = response_synthesizer or get_response_synthesizer(
callback_manager=self._query_engine.callback_manager
)
self._index_summary = index_summary
self._num_steps = num_steps
self._early_stopping = early_stopping
# TODO: make interface to stop function better
self._stop_fn = stop_fn or default_stop_fn
# num_steps must be provided if early_stopping is False
if not self._early_stopping and self._num_steps is None:
raise ValueError("Must specify num_steps if early_stopping is False.")
callback_manager = self._query_engine.callback_manager
super().__init__(callback_manager)
def _get_prompt_modules(self) -> PromptMixinType:
"""Get prompt sub-modules."""
return {
"response_synthesizer": self._response_synthesizer,
"query_transform": self._query_transform,
}
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
) as query_event:
nodes, source_nodes, metadata = self._query_multistep(query_bundle)
final_response = self._response_synthesizer.synthesize(
query=query_bundle,
nodes=nodes,
additional_source_nodes=source_nodes,
)
final_response.metadata = metadata
query_event.on_end(payload={EventPayload.RESPONSE: final_response})
return final_response
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
) as query_event:
nodes, source_nodes, metadata = self._query_multistep(query_bundle)
final_response = await self._response_synthesizer.asynthesize(
query=query_bundle,
nodes=nodes,
additional_source_nodes=source_nodes,
)
final_response.metadata = metadata
query_event.on_end(payload={EventPayload.RESPONSE: final_response})
return final_response
def _combine_queries(
self, query_bundle: QueryBundle, prev_reasoning: str
) -> QueryBundle:
"""Combine queries."""
transform_metadata = {
"prev_reasoning": prev_reasoning,
"index_summary": self._index_summary,
}
return self._query_transform(query_bundle, metadata=transform_metadata)
def _query_multistep(
self, query_bundle: QueryBundle
) -> Tuple[List[NodeWithScore], List[NodeWithScore], Dict[str, Any]]:
"""Run query combiner."""
prev_reasoning = ""
cur_response = None
should_stop = False
cur_steps = 0
# use response
final_response_metadata: Dict[str, Any] = {"sub_qa": []}
text_chunks = []
source_nodes = []
while not should_stop:
if self._num_steps is not None and cur_steps >= self._num_steps:
should_stop = True
break
elif should_stop:
break
updated_query_bundle = self._combine_queries(query_bundle, prev_reasoning)
# TODO: make stop logic better
stop_dict = {"query_bundle": updated_query_bundle}
if self._stop_fn(stop_dict):
should_stop = True
break
cur_response = self._query_engine.query(updated_query_bundle)
# append to response builder
cur_qa_text = (
f"\nQuestion: {updated_query_bundle.query_str}\n"
f"Answer: {cur_response!s}"
)
text_chunks.append(cur_qa_text)
for source_node in cur_response.source_nodes:
source_nodes.append(source_node)
# update metadata
final_response_metadata["sub_qa"].append(
(updated_query_bundle.query_str, cur_response)
)
prev_reasoning += (
f"- {updated_query_bundle.query_str}\n" f"- {cur_response!s}\n"
)
cur_steps += 1
nodes = [
NodeWithScore(node=TextNode(text=text_chunk)) for text_chunk in text_chunks
]
return nodes, source_nodes, final_response_metadata