"""Document summary retrievers. This module contains retrievers for document summary indices. """ import logging from typing import Any, Callable, List, Optional from llama_index.callbacks.base import CallbackManager from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.document_summary.base import DocumentSummaryIndex from llama_index.indices.utils import ( default_format_node_batch_fn, default_parse_choice_select_answer_fn, ) from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_CHOICE_SELECT_PROMPT from llama_index.schema import NodeWithScore, QueryBundle from llama_index.service_context import ServiceContext from llama_index.vector_stores.types import VectorStoreQuery logger = logging.getLogger(__name__) class DocumentSummaryIndexLLMRetriever(BaseRetriever): """Document Summary Index LLM Retriever. By default, select relevant summaries from index using LLM calls. Args: index (DocumentSummaryIndex): The index to retrieve from. choice_select_prompt (Optional[BasePromptTemplate]): The prompt to use for selecting relevant summaries. choice_batch_size (int): The number of summary nodes to send to LLM at a time. choice_top_k (int): The number of summary nodes to retrieve. format_node_batch_fn (Callable): Function to format a batch of nodes for LLM. parse_choice_select_answer_fn (Callable): Function to parse LLM response. service_context (ServiceContext): The service context to use. """ def __init__( self, index: DocumentSummaryIndex, choice_select_prompt: Optional[BasePromptTemplate] = None, choice_batch_size: int = 10, choice_top_k: int = 1, format_node_batch_fn: Optional[Callable] = None, parse_choice_select_answer_fn: Optional[Callable] = None, service_context: Optional[ServiceContext] = None, callback_manager: Optional[CallbackManager] = None, object_map: Optional[dict] = None, verbose: bool = False, **kwargs: Any, ) -> None: self._index = index self._choice_select_prompt = ( choice_select_prompt or DEFAULT_CHOICE_SELECT_PROMPT ) self._choice_batch_size = choice_batch_size self._choice_top_k = choice_top_k self._format_node_batch_fn = ( format_node_batch_fn or default_format_node_batch_fn ) self._parse_choice_select_answer_fn = ( parse_choice_select_answer_fn or default_parse_choice_select_answer_fn ) self._service_context = service_context or index.service_context super().__init__( callback_manager=callback_manager, object_map=object_map, verbose=verbose ) def _retrieve( self, query_bundle: QueryBundle, ) -> List[NodeWithScore]: """Retrieve nodes.""" summary_ids = self._index.index_struct.summary_ids all_summary_ids: List[str] = [] all_relevances: List[float] = [] for idx in range(0, len(summary_ids), self._choice_batch_size): summary_ids_batch = summary_ids[idx : idx + self._choice_batch_size] summary_nodes = self._index.docstore.get_nodes(summary_ids_batch) query_str = query_bundle.query_str fmt_batch_str = self._format_node_batch_fn(summary_nodes) # call each batch independently raw_response = self._service_context.llm.predict( self._choice_select_prompt, context_str=fmt_batch_str, query_str=query_str, ) raw_choices, relevances = self._parse_choice_select_answer_fn( raw_response, len(summary_nodes) ) choice_idxs = [choice - 1 for choice in raw_choices] choice_summary_ids = [summary_ids_batch[ci] for ci in choice_idxs] all_summary_ids.extend(choice_summary_ids) all_relevances.extend(relevances) zipped_list = list(zip(all_summary_ids, all_relevances)) sorted_list = sorted(zipped_list, key=lambda x: x[1], reverse=True) top_k_list = sorted_list[: self._choice_top_k] results = [] for summary_id, relevance in top_k_list: node_ids = self._index.index_struct.summary_id_to_node_ids[summary_id] nodes = self._index.docstore.get_nodes(node_ids) results.extend([NodeWithScore(node=n, score=relevance) for n in nodes]) return results class DocumentSummaryIndexEmbeddingRetriever(BaseRetriever): """Document Summary Index Embedding Retriever. Args: index (DocumentSummaryIndex): The index to retrieve from. similarity_top_k (int): The number of summary nodes to retrieve. """ def __init__( self, index: DocumentSummaryIndex, similarity_top_k: int = 1, callback_manager: Optional[CallbackManager] = None, object_map: Optional[dict] = None, verbose: bool = False, **kwargs: Any, ) -> None: """Init params.""" self._index = index self._vector_store = self._index.vector_store self._service_context = self._index.service_context self._docstore = self._index.docstore self._index_struct = self._index.index_struct self._similarity_top_k = similarity_top_k super().__init__( callback_manager=callback_manager, object_map=object_map, verbose=verbose ) def _retrieve( self, query_bundle: QueryBundle, ) -> List[NodeWithScore]: """Retrieve nodes.""" if self._vector_store.is_embedding_query: if query_bundle.embedding is None: query_bundle.embedding = ( self._service_context.embed_model.get_agg_embedding_from_queries( query_bundle.embedding_strs ) ) query = VectorStoreQuery( query_embedding=query_bundle.embedding, similarity_top_k=self._similarity_top_k, ) query_result = self._vector_store.query(query) top_k_summary_ids: List[str] if query_result.ids is not None: top_k_summary_ids = query_result.ids elif query_result.nodes is not None: top_k_summary_ids = [n.node_id for n in query_result.nodes] else: raise ValueError( "Vector store query result should return " "at least one of nodes or ids." ) results = [] for summary_id in top_k_summary_ids: node_ids = self._index_struct.summary_id_to_node_ids[summary_id] nodes = self._docstore.get_nodes(node_ids) results.extend([NodeWithScore(node=n) for n in nodes]) return results # legacy, backward compatibility DocumentSummaryIndexRetriever = DocumentSummaryIndexLLMRetriever