193 lines
7.3 KiB
Python
193 lines
7.3 KiB
Python
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.callbacks.schema import CBEventType, EventPayload
|
|
from llama_index.core.base_query_engine import BaseQueryEngine
|
|
from llama_index.core.base_retriever import BaseRetriever
|
|
from llama_index.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle, TextNode
|
|
from llama_index.utils import print_text
|
|
|
|
DEFAULT_QUERY_RESPONSE_TMPL = "Query: {query_str}\nResponse: {response}"
|
|
|
|
|
|
RQN_TYPE = Union[BaseRetriever, BaseQueryEngine, BaseNode]
|
|
|
|
|
|
class RecursiveRetriever(BaseRetriever):
|
|
"""Recursive retriever.
|
|
|
|
This retriever will recursively explore links from nodes to other
|
|
retrievers/query engines.
|
|
|
|
For any retrieved nodes, if any of the nodes are IndexNodes,
|
|
then it will explore the linked retriever/query engine, and query that.
|
|
|
|
Args:
|
|
root_id (str): The root id of the query graph.
|
|
retriever_dict (Optional[Dict[str, BaseRetriever]]): A dictionary
|
|
of id to retrievers.
|
|
query_engine_dict (Optional[Dict[str, BaseQueryEngine]]): A dictionary of
|
|
id to query engines.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
root_id: str,
|
|
retriever_dict: Dict[str, BaseRetriever],
|
|
query_engine_dict: Optional[Dict[str, BaseQueryEngine]] = None,
|
|
node_dict: Optional[Dict[str, BaseNode]] = None,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
query_response_tmpl: Optional[str] = None,
|
|
verbose: bool = False,
|
|
) -> None:
|
|
"""Init params."""
|
|
self._root_id = root_id
|
|
if root_id not in retriever_dict:
|
|
raise ValueError(
|
|
f"Root id {root_id} not in retriever_dict, it must be a retriever."
|
|
)
|
|
self._retriever_dict = retriever_dict
|
|
self._query_engine_dict = query_engine_dict or {}
|
|
self._node_dict = node_dict or {}
|
|
# make sure keys don't overlap
|
|
if set(self._retriever_dict.keys()) & set(self._query_engine_dict.keys()):
|
|
raise ValueError("Retriever and query engine ids must not overlap.")
|
|
|
|
self._query_response_tmpl = query_response_tmpl or DEFAULT_QUERY_RESPONSE_TMPL
|
|
super().__init__(callback_manager, verbose=verbose)
|
|
|
|
def _query_retrieved_nodes(
|
|
self, query_bundle: QueryBundle, nodes_with_score: List[NodeWithScore]
|
|
) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:
|
|
"""Query for retrieved nodes.
|
|
|
|
If node is an IndexNode, then recursively query the retriever/query engine.
|
|
If node is a TextNode, then simply return the node.
|
|
|
|
"""
|
|
nodes_to_add = []
|
|
additional_nodes = []
|
|
visited_ids = set()
|
|
|
|
# dedup index nodes that reference same index id
|
|
new_nodes_with_score = []
|
|
for node_with_score in nodes_with_score:
|
|
node = node_with_score.node
|
|
if isinstance(node, IndexNode):
|
|
if node.index_id not in visited_ids:
|
|
visited_ids.add(node.index_id)
|
|
new_nodes_with_score.append(node_with_score)
|
|
else:
|
|
new_nodes_with_score.append(node_with_score)
|
|
|
|
nodes_with_score = new_nodes_with_score
|
|
|
|
# recursively retrieve
|
|
for node_with_score in nodes_with_score:
|
|
node = node_with_score.node
|
|
if isinstance(node, IndexNode):
|
|
if self._verbose:
|
|
print_text(
|
|
"Retrieved node with id, entering: " f"{node.index_id}\n",
|
|
color="pink",
|
|
)
|
|
cur_retrieved_nodes, cur_additional_nodes = self._retrieve_rec(
|
|
query_bundle,
|
|
query_id=node.index_id,
|
|
cur_similarity=node_with_score.score,
|
|
)
|
|
else:
|
|
assert isinstance(node, TextNode)
|
|
if self._verbose:
|
|
print_text(
|
|
"Retrieving text node: " f"{node.get_content()}\n",
|
|
color="pink",
|
|
)
|
|
cur_retrieved_nodes = [node_with_score]
|
|
cur_additional_nodes = []
|
|
nodes_to_add.extend(cur_retrieved_nodes)
|
|
additional_nodes.extend(cur_additional_nodes)
|
|
|
|
return nodes_to_add, additional_nodes
|
|
|
|
def _get_object(self, query_id: str) -> RQN_TYPE:
|
|
"""Fetch retriever or query engine."""
|
|
node = self._node_dict.get(query_id, None)
|
|
if node is not None:
|
|
return node
|
|
retriever = self._retriever_dict.get(query_id, None)
|
|
if retriever is not None:
|
|
return retriever
|
|
query_engine = self._query_engine_dict.get(query_id, None)
|
|
if query_engine is not None:
|
|
return query_engine
|
|
raise ValueError(
|
|
f"Query id {query_id} not found in either `retriever_dict` "
|
|
"or `query_engine_dict`."
|
|
)
|
|
|
|
def _retrieve_rec(
|
|
self,
|
|
query_bundle: QueryBundle,
|
|
query_id: Optional[str] = None,
|
|
cur_similarity: Optional[float] = None,
|
|
) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:
|
|
"""Query recursively."""
|
|
if self._verbose:
|
|
print_text(
|
|
f"Retrieving with query id {query_id}: {query_bundle.query_str}\n",
|
|
color="blue",
|
|
)
|
|
query_id = query_id or self._root_id
|
|
cur_similarity = cur_similarity or 1.0
|
|
|
|
obj = self._get_object(query_id)
|
|
if isinstance(obj, BaseNode):
|
|
nodes_to_add = [NodeWithScore(node=obj, score=cur_similarity)]
|
|
additional_nodes: List[NodeWithScore] = []
|
|
elif isinstance(obj, BaseRetriever):
|
|
with self.callback_manager.event(
|
|
CBEventType.RETRIEVE,
|
|
payload={EventPayload.QUERY_STR: query_bundle.query_str},
|
|
) as event:
|
|
nodes = obj.retrieve(query_bundle)
|
|
event.on_end(payload={EventPayload.NODES: nodes})
|
|
|
|
nodes_to_add, additional_nodes = self._query_retrieved_nodes(
|
|
query_bundle, nodes
|
|
)
|
|
|
|
elif isinstance(obj, BaseQueryEngine):
|
|
sub_resp = obj.query(query_bundle)
|
|
if self._verbose:
|
|
print_text(
|
|
f"Got response: {sub_resp!s}\n",
|
|
color="green",
|
|
)
|
|
# format with both the query and the response
|
|
node_text = self._query_response_tmpl.format(
|
|
query_str=query_bundle.query_str, response=str(sub_resp)
|
|
)
|
|
node = TextNode(text=node_text)
|
|
nodes_to_add = [NodeWithScore(node=node, score=cur_similarity)]
|
|
additional_nodes = sub_resp.source_nodes
|
|
else:
|
|
raise ValueError("Must be a retriever or query engine.")
|
|
|
|
return nodes_to_add, additional_nodes
|
|
|
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
retrieved_nodes, _ = self._retrieve_rec(query_bundle, query_id=None)
|
|
return retrieved_nodes
|
|
|
|
def retrieve_all(
|
|
self, query_bundle: QueryBundle
|
|
) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:
|
|
"""Retrieve all nodes.
|
|
|
|
Unlike default `retrieve` method, this also fetches additional sources.
|
|
|
|
"""
|
|
return self._retrieve_rec(query_bundle, query_id=None)
|