faiss_rag_enterprise/llama_index/indices/tree/tree_root_retriever.py

49 lines
1.7 KiB
Python

"""Retrieve query."""
import logging
from typing import Any, List, Optional
from llama_index.callbacks.base import CallbackManager
from llama_index.core.base_retriever import BaseRetriever
from llama_index.indices.query.schema import QueryBundle
from llama_index.indices.tree.base import TreeIndex
from llama_index.indices.utils import get_sorted_node_list
from llama_index.schema import NodeWithScore, QueryBundle
logger = logging.getLogger(__name__)
class TreeRootRetriever(BaseRetriever):
"""Tree root retriever.
This class directly retrieves the answer from the root nodes.
Unlike GPTTreeIndexLeafQuery, this class assumes the graph already stores
the answer (because it was constructed with a query_str), so it does not
attempt to parse information down the graph in order to synthesize an answer.
"""
def __init__(
self,
index: TreeIndex,
callback_manager: Optional[CallbackManager] = None,
object_map: Optional[dict] = None,
verbose: bool = False,
**kwargs: Any,
) -> None:
self._index = index
self._index_struct = index.index_struct
self._docstore = index.docstore
super().__init__(
callback_manager=callback_manager, object_map=object_map, verbose=verbose
)
def _retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
"""Get nodes for response."""
logger.info(f"> Starting query: {query_bundle.query_str}")
root_nodes = self._docstore.get_node_dict(self._index_struct.root_nodes)
sorted_nodes = get_sorted_node_list(root_nodes)
return [NodeWithScore(node=node) for node in sorted_nodes]