386 lines
15 KiB
Python
386 lines
15 KiB
Python
import logging
|
|
from typing import Callable, List, Optional, Sequence
|
|
|
|
from llama_index.async_utils import run_async_tasks
|
|
from llama_index.bridge.pydantic import BaseModel
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.callbacks.schema import CBEventType, EventPayload
|
|
from llama_index.core.base_query_engine import BaseQueryEngine
|
|
from llama_index.core.base_retriever import BaseRetriever
|
|
from llama_index.core.base_selector import BaseSelector
|
|
from llama_index.core.response.schema import (
|
|
RESPONSE_TYPE,
|
|
PydanticResponse,
|
|
Response,
|
|
StreamingResponse,
|
|
)
|
|
from llama_index.objects.base import ObjectRetriever
|
|
from llama_index.prompts.default_prompt_selectors import (
|
|
DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
|
|
)
|
|
from llama_index.prompts.mixin import PromptMixinType
|
|
from llama_index.response_synthesizers import TreeSummarize
|
|
from llama_index.schema import BaseNode, QueryBundle
|
|
from llama_index.selectors.utils import get_selector_from_context
|
|
from llama_index.service_context import ServiceContext
|
|
from llama_index.tools.query_engine import QueryEngineTool
|
|
from llama_index.tools.types import ToolMetadata
|
|
from llama_index.utils import print_text
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def combine_responses(
|
|
summarizer: TreeSummarize, responses: List[RESPONSE_TYPE], query_bundle: QueryBundle
|
|
) -> RESPONSE_TYPE:
|
|
"""Combine multiple response from sub-engines."""
|
|
logger.info("Combining responses from multiple query engines.")
|
|
|
|
response_strs = []
|
|
source_nodes = []
|
|
for response in responses:
|
|
if isinstance(response, (StreamingResponse, PydanticResponse)):
|
|
response_obj = response.get_response()
|
|
else:
|
|
response_obj = response
|
|
source_nodes.extend(response_obj.source_nodes)
|
|
response_strs.append(str(response))
|
|
|
|
summary = summarizer.get_response(query_bundle.query_str, response_strs)
|
|
|
|
if isinstance(summary, str):
|
|
return Response(response=summary, source_nodes=source_nodes)
|
|
elif isinstance(summary, BaseModel):
|
|
return PydanticResponse(response=summary, source_nodes=source_nodes)
|
|
else:
|
|
return StreamingResponse(response_gen=summary, source_nodes=source_nodes)
|
|
|
|
|
|
async def acombine_responses(
|
|
summarizer: TreeSummarize, responses: List[RESPONSE_TYPE], query_bundle: QueryBundle
|
|
) -> RESPONSE_TYPE:
|
|
"""Async combine multiple response from sub-engines."""
|
|
logger.info("Combining responses from multiple query engines.")
|
|
|
|
response_strs = []
|
|
source_nodes = []
|
|
for response in responses:
|
|
if isinstance(response, (StreamingResponse, PydanticResponse)):
|
|
response_obj = response.get_response()
|
|
else:
|
|
response_obj = response
|
|
source_nodes.extend(response_obj.source_nodes)
|
|
response_strs.append(str(response))
|
|
|
|
summary = await summarizer.aget_response(query_bundle.query_str, response_strs)
|
|
|
|
if isinstance(summary, str):
|
|
return Response(response=summary, source_nodes=source_nodes)
|
|
elif isinstance(summary, BaseModel):
|
|
return PydanticResponse(response=summary, source_nodes=source_nodes)
|
|
else:
|
|
return StreamingResponse(response_gen=summary, source_nodes=source_nodes)
|
|
|
|
|
|
class RouterQueryEngine(BaseQueryEngine):
|
|
"""Router query engine.
|
|
|
|
Selects one out of several candidate query engines to execute a query.
|
|
|
|
Args:
|
|
selector (BaseSelector): A selector that chooses one out of many options based
|
|
on each candidate's metadata and query.
|
|
query_engine_tools (Sequence[QueryEngineTool]): A sequence of candidate
|
|
query engines. They must be wrapped as tools to expose metadata to
|
|
the selector.
|
|
service_context (Optional[ServiceContext]): A service context.
|
|
summarizer (Optional[TreeSummarize]): Tree summarizer to summarize sub-results.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
selector: BaseSelector,
|
|
query_engine_tools: Sequence[QueryEngineTool],
|
|
service_context: Optional[ServiceContext] = None,
|
|
summarizer: Optional[TreeSummarize] = None,
|
|
verbose: bool = False,
|
|
) -> None:
|
|
self.service_context = service_context or ServiceContext.from_defaults()
|
|
self._selector = selector
|
|
self._query_engines = [x.query_engine for x in query_engine_tools]
|
|
self._metadatas = [x.metadata for x in query_engine_tools]
|
|
self._summarizer = summarizer or TreeSummarize(
|
|
service_context=self.service_context,
|
|
summary_template=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
|
|
)
|
|
self._verbose = verbose
|
|
|
|
super().__init__(self.service_context.callback_manager)
|
|
|
|
def _get_prompt_modules(self) -> PromptMixinType:
|
|
"""Get prompt sub-modules."""
|
|
# NOTE: don't include tools for now
|
|
return {"summarizer": self._summarizer, "selector": self._selector}
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
query_engine_tools: Sequence[QueryEngineTool],
|
|
service_context: Optional[ServiceContext] = None,
|
|
selector: Optional[BaseSelector] = None,
|
|
summarizer: Optional[TreeSummarize] = None,
|
|
select_multi: bool = False,
|
|
) -> "RouterQueryEngine":
|
|
service_context = service_context or ServiceContext.from_defaults()
|
|
|
|
selector = selector or get_selector_from_context(
|
|
service_context, is_multi=select_multi
|
|
)
|
|
|
|
assert selector is not None
|
|
|
|
return cls(
|
|
selector,
|
|
query_engine_tools,
|
|
service_context=service_context,
|
|
summarizer=summarizer,
|
|
)
|
|
|
|
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
|
with self.callback_manager.event(
|
|
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
|
|
) as query_event:
|
|
result = self._selector.select(self._metadatas, query_bundle)
|
|
|
|
if len(result.inds) > 1:
|
|
responses = []
|
|
for i, engine_ind in enumerate(result.inds):
|
|
log_str = (
|
|
f"Selecting query engine {engine_ind}: " f"{result.reasons[i]}."
|
|
)
|
|
logger.info(log_str)
|
|
if self._verbose:
|
|
print_text(log_str + "\n", color="pink")
|
|
|
|
selected_query_engine = self._query_engines[engine_ind]
|
|
responses.append(selected_query_engine.query(query_bundle))
|
|
|
|
if len(responses) > 1:
|
|
final_response = combine_responses(
|
|
self._summarizer, responses, query_bundle
|
|
)
|
|
else:
|
|
final_response = responses[0]
|
|
else:
|
|
try:
|
|
selected_query_engine = self._query_engines[result.ind]
|
|
log_str = f"Selecting query engine {result.ind}: {result.reason}."
|
|
logger.info(log_str)
|
|
if self._verbose:
|
|
print_text(log_str + "\n", color="pink")
|
|
except ValueError as e:
|
|
raise ValueError("Failed to select query engine") from e
|
|
|
|
final_response = selected_query_engine.query(query_bundle)
|
|
|
|
# add selected result
|
|
final_response.metadata = final_response.metadata or {}
|
|
final_response.metadata["selector_result"] = result
|
|
|
|
query_event.on_end(payload={EventPayload.RESPONSE: final_response})
|
|
|
|
return final_response
|
|
|
|
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
|
with self.callback_manager.event(
|
|
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
|
|
) as query_event:
|
|
result = await self._selector.aselect(self._metadatas, query_bundle)
|
|
|
|
if len(result.inds) > 1:
|
|
tasks = []
|
|
for i, engine_ind in enumerate(result.inds):
|
|
log_str = (
|
|
f"Selecting query engine {engine_ind}: " f"{result.reasons[i]}."
|
|
)
|
|
logger.info(log_str)
|
|
if self._verbose:
|
|
print_text(log_str + "\n", color="pink")
|
|
selected_query_engine = self._query_engines[engine_ind]
|
|
tasks.append(selected_query_engine.aquery(query_bundle))
|
|
|
|
responses = run_async_tasks(tasks)
|
|
if len(responses) > 1:
|
|
final_response = await acombine_responses(
|
|
self._summarizer, responses, query_bundle
|
|
)
|
|
else:
|
|
final_response = responses[0]
|
|
else:
|
|
try:
|
|
selected_query_engine = self._query_engines[result.ind]
|
|
log_str = f"Selecting query engine {result.ind}: {result.reason}."
|
|
logger.info(log_str)
|
|
if self._verbose:
|
|
print_text(log_str + "\n", color="pink")
|
|
except ValueError as e:
|
|
raise ValueError("Failed to select query engine") from e
|
|
|
|
final_response = await selected_query_engine.aquery(query_bundle)
|
|
|
|
# add selected result
|
|
final_response.metadata = final_response.metadata or {}
|
|
final_response.metadata["selector_result"] = result
|
|
|
|
query_event.on_end(payload={EventPayload.RESPONSE: final_response})
|
|
|
|
return final_response
|
|
|
|
|
|
def default_node_to_metadata_fn(node: BaseNode) -> ToolMetadata:
|
|
"""Default node to metadata function.
|
|
|
|
We use the node's text as the Tool description.
|
|
|
|
"""
|
|
metadata = node.metadata or {}
|
|
if "tool_name" not in metadata:
|
|
raise ValueError("Node must have a tool_name in metadata.")
|
|
return ToolMetadata(name=metadata["tool_name"], description=node.get_content())
|
|
|
|
|
|
class RetrieverRouterQueryEngine(BaseQueryEngine):
|
|
"""Retriever-based router query engine.
|
|
|
|
NOTE: this is deprecated, please use our new ToolRetrieverRouterQueryEngine
|
|
|
|
Use a retriever to select a set of Nodes. Each node will be converted
|
|
into a ToolMetadata object, and also used to retrieve a query engine, to form
|
|
a QueryEngineTool.
|
|
|
|
NOTE: this is a beta feature. We are figuring out the right interface
|
|
between the retriever and query engine.
|
|
|
|
Args:
|
|
selector (BaseSelector): A selector that chooses one out of many options based
|
|
on each candidate's metadata and query.
|
|
query_engine_tools (Sequence[QueryEngineTool]): A sequence of candidate
|
|
query engines. They must be wrapped as tools to expose metadata to
|
|
the selector.
|
|
callback_manager (Optional[CallbackManager]): A callback manager.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
retriever: BaseRetriever,
|
|
node_to_query_engine_fn: Callable,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
) -> None:
|
|
self._retriever = retriever
|
|
self._node_to_query_engine_fn = node_to_query_engine_fn
|
|
super().__init__(callback_manager)
|
|
|
|
def _get_prompt_modules(self) -> PromptMixinType:
|
|
"""Get prompt sub-modules."""
|
|
# NOTE: don't include tools for now
|
|
return {"retriever": self._retriever}
|
|
|
|
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
|
nodes_with_score = self._retriever.retrieve(query_bundle)
|
|
# TODO: for now we only support retrieving one node
|
|
if len(nodes_with_score) > 1:
|
|
raise ValueError("Retrieved more than one node.")
|
|
|
|
node = nodes_with_score[0].node
|
|
query_engine = self._node_to_query_engine_fn(node)
|
|
return query_engine.query(query_bundle)
|
|
|
|
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
|
return self._query(query_bundle)
|
|
|
|
|
|
class ToolRetrieverRouterQueryEngine(BaseQueryEngine):
|
|
"""Tool Retriever router query engine.
|
|
|
|
Selects a set of candidate query engines to execute a query.
|
|
|
|
Args:
|
|
retriever (ObjectRetriever): A retriever that retrieves a set of
|
|
query engine tools.
|
|
service_context (Optional[ServiceContext]): A service context.
|
|
summarizer (Optional[TreeSummarize]): Tree summarizer to summarize sub-results.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
retriever: ObjectRetriever[QueryEngineTool],
|
|
service_context: Optional[ServiceContext] = None,
|
|
summarizer: Optional[TreeSummarize] = None,
|
|
) -> None:
|
|
self.service_context = service_context or ServiceContext.from_defaults()
|
|
self._summarizer = summarizer or TreeSummarize(
|
|
service_context=self.service_context,
|
|
summary_template=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
|
|
)
|
|
self._retriever = retriever
|
|
|
|
super().__init__(self.service_context.callback_manager)
|
|
|
|
def _get_prompt_modules(self) -> PromptMixinType:
|
|
"""Get prompt sub-modules."""
|
|
# NOTE: don't include tools for now
|
|
return {"summarizer": self._summarizer}
|
|
|
|
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
|
with self.callback_manager.event(
|
|
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
|
|
) as query_event:
|
|
query_engine_tools = self._retriever.retrieve(query_bundle)
|
|
responses = []
|
|
for query_engine_tool in query_engine_tools:
|
|
query_engine = query_engine_tool.query_engine
|
|
responses.append(query_engine.query(query_bundle))
|
|
|
|
if len(responses) > 1:
|
|
final_response = combine_responses(
|
|
self._summarizer, responses, query_bundle
|
|
)
|
|
else:
|
|
final_response = responses[0]
|
|
|
|
# add selected result
|
|
final_response.metadata = final_response.metadata or {}
|
|
final_response.metadata["retrieved_tools"] = query_engine_tools
|
|
|
|
query_event.on_end(payload={EventPayload.RESPONSE: final_response})
|
|
|
|
return final_response
|
|
|
|
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
|
with self.callback_manager.event(
|
|
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
|
|
) as query_event:
|
|
query_engine_tools = self._retriever.retrieve(query_bundle)
|
|
tasks = []
|
|
for query_engine_tool in query_engine_tools:
|
|
query_engine = query_engine_tool.query_engine
|
|
tasks.append(query_engine.aquery(query_bundle))
|
|
responses = run_async_tasks(tasks)
|
|
if len(responses) > 1:
|
|
final_response = await acombine_responses(
|
|
self._summarizer, responses, query_bundle
|
|
)
|
|
else:
|
|
final_response = responses[0]
|
|
|
|
# add selected result
|
|
final_response.metadata = final_response.metadata or {}
|
|
final_response.metadata["retrieved_tools"] = query_engine_tools
|
|
|
|
query_event.on_end(payload={EventPayload.RESPONSE: final_response})
|
|
|
|
return final_response
|