"""Leaf query mechanism.""" import logging from typing import Any, Dict, List, Optional, cast from llama_index.callbacks.base import CallbackManager from llama_index.core.base_retriever import BaseRetriever from llama_index.core.response.schema import Response from llama_index.indices.query.schema import QueryBundle from llama_index.indices.tree.base import TreeIndex from llama_index.indices.tree.utils import get_numbered_text_from_nodes from llama_index.indices.utils import ( extract_numbers_given_response, get_sorted_node_list, ) from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompt_selectors import DEFAULT_REFINE_PROMPT_SEL from llama_index.prompts.default_prompts import ( DEFAULT_QUERY_PROMPT, DEFAULT_QUERY_PROMPT_MULTIPLE, DEFAULT_TEXT_QA_PROMPT, ) from llama_index.response_synthesizers import get_response_synthesizer from llama_index.schema import BaseNode, MetadataMode, NodeWithScore, QueryBundle from llama_index.utils import print_text, truncate_text logger = logging.getLogger(__name__) def get_text_from_node( node: BaseNode, level: Optional[int] = None, verbose: bool = False, ) -> str: """Get text from node.""" level_str = "" if level is None else f"[Level {level}]" fmt_text_chunk = truncate_text(node.get_content(metadata_mode=MetadataMode.LLM), 50) logger.debug(f">{level_str} Searching in chunk: {fmt_text_chunk}") response_txt = node.get_content(metadata_mode=MetadataMode.LLM) fmt_response = truncate_text(response_txt, 200) if verbose: print_text(f">{level_str} Got node text: {fmt_response}\n", color="blue") return response_txt class TreeSelectLeafRetriever(BaseRetriever): """Tree select leaf retriever. This class traverses the index graph and searches for a leaf node that can best answer the query. Args: query_template (Optional[BasePromptTemplate]): Tree Select Query Prompt (see :ref:`Prompt-Templates`). query_template_multiple (Optional[BasePromptTemplate]): Tree Select Query Prompt (Multiple) (see :ref:`Prompt-Templates`). child_branch_factor (int): Number of child nodes to consider at each level. If child_branch_factor is 1, then the query will only choose one child node to traverse for any given parent node. If child_branch_factor is 2, then the query will choose two child nodes. """ def __init__( self, index: TreeIndex, query_template: Optional[BasePromptTemplate] = None, text_qa_template: Optional[BasePromptTemplate] = None, refine_template: Optional[BasePromptTemplate] = None, query_template_multiple: Optional[BasePromptTemplate] = None, child_branch_factor: int = 1, verbose: bool = False, callback_manager: Optional[CallbackManager] = None, object_map: Optional[dict] = None, **kwargs: Any, ): self._index = index self._index_struct = index.index_struct self._docstore = index.docstore self._service_context = index.service_context self._text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT self._refine_template = refine_template or DEFAULT_REFINE_PROMPT_SEL self.query_template = query_template or DEFAULT_QUERY_PROMPT self.query_template_multiple = ( query_template_multiple or DEFAULT_QUERY_PROMPT_MULTIPLE ) self.child_branch_factor = child_branch_factor super().__init__( callback_manager=callback_manager, object_map=object_map, verbose=verbose ) def _query_with_selected_node( self, selected_node: BaseNode, query_bundle: QueryBundle, prev_response: Optional[str] = None, level: int = 0, ) -> str: """Get response for selected node. If not leaf node, it will recursively call _query on the child nodes. If prev_response is provided, we will update prev_response with the answer. """ query_str = query_bundle.query_str if len(self._index_struct.get_children(selected_node)) == 0: response_builder = get_response_synthesizer( service_context=self._service_context, text_qa_template=self._text_qa_template, refine_template=self._refine_template, ) # use response builder to get answer from node node_text = get_text_from_node(selected_node, level=level) cur_response = response_builder.get_response( query_str, [node_text], prev_response=prev_response ) cur_response = cast(str, cur_response) logger.debug(f">[Level {level}] Current answer response: {cur_response} ") else: cur_response = self._query_level( self._index_struct.get_children(selected_node), query_bundle, level=level + 1, ) if prev_response is None: return cur_response else: context_msg = selected_node.get_content(metadata_mode=MetadataMode.LLM) cur_response = self._service_context.llm.predict( self._refine_template, query_str=query_str, existing_answer=prev_response, context_msg=context_msg, ) logger.debug(f">[Level {level}] Current refined response: {cur_response} ") return cur_response def _query_level( self, cur_node_ids: Dict[int, str], query_bundle: QueryBundle, level: int = 0, ) -> str: """Answer a query recursively.""" query_str = query_bundle.query_str cur_nodes = { index: self._docstore.get_node(node_id) for index, node_id in cur_node_ids.items() } cur_node_list = get_sorted_node_list(cur_nodes) if len(cur_node_list) == 1: logger.debug(f">[Level {level}] Only one node left. Querying node.") return self._query_with_selected_node( cur_node_list[0], query_bundle, level=level ) elif self.child_branch_factor == 1: query_template = self.query_template.partial_format( num_chunks=len(cur_node_list), query_str=query_str ) text_splitter = ( self._service_context.prompt_helper.get_text_splitter_given_prompt( prompt=query_template, num_chunks=len(cur_node_list), ) ) numbered_node_text = get_numbered_text_from_nodes( cur_node_list, text_splitter=text_splitter ) response = self._service_context.llm.predict( query_template, context_list=numbered_node_text, ) else: query_template_multiple = self.query_template_multiple.partial_format( num_chunks=len(cur_node_list), query_str=query_str, branching_factor=self.child_branch_factor, ) text_splitter = ( self._service_context.prompt_helper.get_text_splitter_given_prompt( prompt=query_template_multiple, num_chunks=len(cur_node_list), ) ) numbered_node_text = get_numbered_text_from_nodes( cur_node_list, text_splitter=text_splitter ) response = self._service_context.llm.predict( query_template_multiple, context_list=numbered_node_text, ) debug_str = f">[Level {level}] Current response: {response}" logger.debug(debug_str) if self._verbose: print_text(debug_str, end="\n") numbers = extract_numbers_given_response(response, n=self.child_branch_factor) if numbers is None: debug_str = ( f">[Level {level}] Could not retrieve response - no numbers present" ) logger.debug(debug_str) if self._verbose: print_text(debug_str, end="\n") # just join text from current nodes as response return response result_response = None for number_str in numbers: number = int(number_str) if number > len(cur_node_list): logger.debug( f">[Level {level}] Invalid response: {response} - " f"number {number} out of range" ) return response # number is 1-indexed, so subtract 1 selected_node = cur_node_list[number - 1] info_str = ( f">[Level {level}] Selected node: " f"[{number}]/[{','.join([str(int(n)) for n in numbers])}]" ) logger.info(info_str) if self._verbose: print_text(info_str, end="\n") debug_str = " ".join( selected_node.get_content(metadata_mode=MetadataMode.LLM).splitlines() ) full_debug_str = ( f">[Level {level}] Node " f"[{number}] Summary text: " f"{ selected_node.get_content(metadata_mode=MetadataMode.LLM) }" ) logger.debug(full_debug_str) if self._verbose: print_text(full_debug_str, end="\n") result_response = self._query_with_selected_node( selected_node, query_bundle, prev_response=result_response, level=level, ) # result_response should not be None return cast(str, result_response) def _query(self, query_bundle: QueryBundle) -> Response: """Answer a query.""" # NOTE: this overrides the _query method in the base class info_str = f"> Starting query: {query_bundle.query_str}" logger.info(info_str) if self._verbose: print_text(info_str, end="\n") response_str = self._query_level( self._index_struct.root_nodes, query_bundle, level=0, ).strip() # TODO: fix source nodes return Response(response_str, source_nodes=[]) def _select_nodes( self, cur_node_list: List[BaseNode], query_bundle: QueryBundle, level: int = 0, ) -> List[BaseNode]: query_str = query_bundle.query_str if self.child_branch_factor == 1: query_template = self.query_template.partial_format( num_chunks=len(cur_node_list), query_str=query_str ) text_splitter = ( self._service_context.prompt_helper.get_text_splitter_given_prompt( prompt=query_template, num_chunks=len(cur_node_list), ) ) numbered_node_text = get_numbered_text_from_nodes( cur_node_list, text_splitter=text_splitter ) response = self._service_context.llm.predict( query_template, context_list=numbered_node_text, ) else: query_template_multiple = self.query_template_multiple.partial_format( num_chunks=len(cur_node_list), query_str=query_str, branching_factor=self.child_branch_factor, ) text_splitter = ( self._service_context.prompt_helper.get_text_splitter_given_prompt( prompt=query_template_multiple, num_chunks=len(cur_node_list), ) ) numbered_node_text = get_numbered_text_from_nodes( cur_node_list, text_splitter=text_splitter ) response = self._service_context.llm.predict( query_template_multiple, context_list=numbered_node_text, ) debug_str = f">[Level {level}] Current response: {response}" logger.debug(debug_str) if self._verbose: print_text(debug_str, end="\n") numbers = extract_numbers_given_response(response, n=self.child_branch_factor) if numbers is None: debug_str = ( f">[Level {level}] Could not retrieve response - no numbers present" ) logger.debug(debug_str) if self._verbose: print_text(debug_str, end="\n") # just join text from current nodes as response return [] selected_nodes = [] for number_str in numbers: number = int(number_str) if number > len(cur_node_list): logger.debug( f">[Level {level}] Invalid response: {response} - " f"number {number} out of range" ) continue # number is 1-indexed, so subtract 1 selected_node = cur_node_list[number - 1] info_str = ( f">[Level {level}] Selected node: " f"[{number}]/[{','.join([str(int(n)) for n in numbers])}]" ) logger.info(info_str) if self._verbose: print_text(info_str, end="\n") debug_str = " ".join( selected_node.get_content(metadata_mode=MetadataMode.LLM).splitlines() ) full_debug_str = ( f">[Level {level}] Node " f"[{number}] Summary text: " f"{ selected_node.get_content(metadata_mode=MetadataMode.LLM) }" ) logger.debug(full_debug_str) if self._verbose: print_text(full_debug_str, end="\n") selected_nodes.append(selected_node) return selected_nodes def _retrieve_level( self, cur_node_ids: Dict[int, str], query_bundle: QueryBundle, level: int = 0, ) -> List[BaseNode]: """Answer a query recursively.""" cur_nodes = { index: self._docstore.get_node(node_id) for index, node_id in cur_node_ids.items() } cur_node_list = get_sorted_node_list(cur_nodes) if len(cur_node_list) > self.child_branch_factor: selected_nodes = self._select_nodes( cur_node_list, query_bundle, level=level, ) else: selected_nodes = cur_node_list children_nodes = {} for node in selected_nodes: node_dict = self._index_struct.get_children(node) children_nodes.update(node_dict) if len(children_nodes) == 0: # NOTE: leaf level return selected_nodes else: return self._retrieve_level(children_nodes, query_bundle, level + 1) def _retrieve( self, query_bundle: QueryBundle, ) -> List[NodeWithScore]: """Get nodes for response.""" nodes = self._retrieve_level( self._index_struct.root_nodes, query_bundle, level=0, ) return [NodeWithScore(node=node) for node in nodes]