faiss_rag_enterprise/llama_index/retrievers/recursive_retriever.py

193 lines
7.3 KiB
Python

from typing import Dict, List, Optional, Tuple, Union
from llama_index.callbacks.base import CallbackManager
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.core.base_query_engine import BaseQueryEngine
from llama_index.core.base_retriever import BaseRetriever
from llama_index.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle, TextNode
from llama_index.utils import print_text
DEFAULT_QUERY_RESPONSE_TMPL = "Query: {query_str}\nResponse: {response}"
RQN_TYPE = Union[BaseRetriever, BaseQueryEngine, BaseNode]
class RecursiveRetriever(BaseRetriever):
"""Recursive retriever.
This retriever will recursively explore links from nodes to other
retrievers/query engines.
For any retrieved nodes, if any of the nodes are IndexNodes,
then it will explore the linked retriever/query engine, and query that.
Args:
root_id (str): The root id of the query graph.
retriever_dict (Optional[Dict[str, BaseRetriever]]): A dictionary
of id to retrievers.
query_engine_dict (Optional[Dict[str, BaseQueryEngine]]): A dictionary of
id to query engines.
"""
def __init__(
self,
root_id: str,
retriever_dict: Dict[str, BaseRetriever],
query_engine_dict: Optional[Dict[str, BaseQueryEngine]] = None,
node_dict: Optional[Dict[str, BaseNode]] = None,
callback_manager: Optional[CallbackManager] = None,
query_response_tmpl: Optional[str] = None,
verbose: bool = False,
) -> None:
"""Init params."""
self._root_id = root_id
if root_id not in retriever_dict:
raise ValueError(
f"Root id {root_id} not in retriever_dict, it must be a retriever."
)
self._retriever_dict = retriever_dict
self._query_engine_dict = query_engine_dict or {}
self._node_dict = node_dict or {}
# make sure keys don't overlap
if set(self._retriever_dict.keys()) & set(self._query_engine_dict.keys()):
raise ValueError("Retriever and query engine ids must not overlap.")
self._query_response_tmpl = query_response_tmpl or DEFAULT_QUERY_RESPONSE_TMPL
super().__init__(callback_manager, verbose=verbose)
def _query_retrieved_nodes(
self, query_bundle: QueryBundle, nodes_with_score: List[NodeWithScore]
) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:
"""Query for retrieved nodes.
If node is an IndexNode, then recursively query the retriever/query engine.
If node is a TextNode, then simply return the node.
"""
nodes_to_add = []
additional_nodes = []
visited_ids = set()
# dedup index nodes that reference same index id
new_nodes_with_score = []
for node_with_score in nodes_with_score:
node = node_with_score.node
if isinstance(node, IndexNode):
if node.index_id not in visited_ids:
visited_ids.add(node.index_id)
new_nodes_with_score.append(node_with_score)
else:
new_nodes_with_score.append(node_with_score)
nodes_with_score = new_nodes_with_score
# recursively retrieve
for node_with_score in nodes_with_score:
node = node_with_score.node
if isinstance(node, IndexNode):
if self._verbose:
print_text(
"Retrieved node with id, entering: " f"{node.index_id}\n",
color="pink",
)
cur_retrieved_nodes, cur_additional_nodes = self._retrieve_rec(
query_bundle,
query_id=node.index_id,
cur_similarity=node_with_score.score,
)
else:
assert isinstance(node, TextNode)
if self._verbose:
print_text(
"Retrieving text node: " f"{node.get_content()}\n",
color="pink",
)
cur_retrieved_nodes = [node_with_score]
cur_additional_nodes = []
nodes_to_add.extend(cur_retrieved_nodes)
additional_nodes.extend(cur_additional_nodes)
return nodes_to_add, additional_nodes
def _get_object(self, query_id: str) -> RQN_TYPE:
"""Fetch retriever or query engine."""
node = self._node_dict.get(query_id, None)
if node is not None:
return node
retriever = self._retriever_dict.get(query_id, None)
if retriever is not None:
return retriever
query_engine = self._query_engine_dict.get(query_id, None)
if query_engine is not None:
return query_engine
raise ValueError(
f"Query id {query_id} not found in either `retriever_dict` "
"or `query_engine_dict`."
)
def _retrieve_rec(
self,
query_bundle: QueryBundle,
query_id: Optional[str] = None,
cur_similarity: Optional[float] = None,
) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:
"""Query recursively."""
if self._verbose:
print_text(
f"Retrieving with query id {query_id}: {query_bundle.query_str}\n",
color="blue",
)
query_id = query_id or self._root_id
cur_similarity = cur_similarity or 1.0
obj = self._get_object(query_id)
if isinstance(obj, BaseNode):
nodes_to_add = [NodeWithScore(node=obj, score=cur_similarity)]
additional_nodes: List[NodeWithScore] = []
elif isinstance(obj, BaseRetriever):
with self.callback_manager.event(
CBEventType.RETRIEVE,
payload={EventPayload.QUERY_STR: query_bundle.query_str},
) as event:
nodes = obj.retrieve(query_bundle)
event.on_end(payload={EventPayload.NODES: nodes})
nodes_to_add, additional_nodes = self._query_retrieved_nodes(
query_bundle, nodes
)
elif isinstance(obj, BaseQueryEngine):
sub_resp = obj.query(query_bundle)
if self._verbose:
print_text(
f"Got response: {sub_resp!s}\n",
color="green",
)
# format with both the query and the response
node_text = self._query_response_tmpl.format(
query_str=query_bundle.query_str, response=str(sub_resp)
)
node = TextNode(text=node_text)
nodes_to_add = [NodeWithScore(node=node, score=cur_similarity)]
additional_nodes = sub_resp.source_nodes
else:
raise ValueError("Must be a retriever or query engine.")
return nodes_to_add, additional_nodes
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
retrieved_nodes, _ = self._retrieve_rec(query_bundle, query_id=None)
return retrieved_nodes
def retrieve_all(
self, query_bundle: QueryBundle
) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:
"""Retrieve all nodes.
Unlike default `retrieve` method, this also fetches additional sources.
"""
return self._retrieve_rec(query_bundle, query_id=None)