49 lines
1.7 KiB
Python
49 lines
1.7 KiB
Python
"""Retrieve query."""
|
|
import logging
|
|
from typing import Any, List, Optional
|
|
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.core.base_retriever import BaseRetriever
|
|
from llama_index.indices.query.schema import QueryBundle
|
|
from llama_index.indices.tree.base import TreeIndex
|
|
from llama_index.indices.utils import get_sorted_node_list
|
|
from llama_index.schema import NodeWithScore, QueryBundle
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TreeRootRetriever(BaseRetriever):
|
|
"""Tree root retriever.
|
|
|
|
This class directly retrieves the answer from the root nodes.
|
|
|
|
Unlike GPTTreeIndexLeafQuery, this class assumes the graph already stores
|
|
the answer (because it was constructed with a query_str), so it does not
|
|
attempt to parse information down the graph in order to synthesize an answer.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
index: TreeIndex,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
object_map: Optional[dict] = None,
|
|
verbose: bool = False,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self._index = index
|
|
self._index_struct = index.index_struct
|
|
self._docstore = index.docstore
|
|
super().__init__(
|
|
callback_manager=callback_manager, object_map=object_map, verbose=verbose
|
|
)
|
|
|
|
def _retrieve(
|
|
self,
|
|
query_bundle: QueryBundle,
|
|
) -> List[NodeWithScore]:
|
|
"""Get nodes for response."""
|
|
logger.info(f"> Starting query: {query_bundle.query_str}")
|
|
root_nodes = self._docstore.get_node_dict(self._index_struct.root_nodes)
|
|
sorted_nodes = get_sorted_node_list(root_nodes)
|
|
return [NodeWithScore(node=node) for node in sorted_nodes]
|