faiss_rag_enterprise/llama_index/indices/document_summary/retrievers.py

184 lines
6.9 KiB
Python

"""Document summary retrievers.
This module contains retrievers for document summary indices.
"""
import logging
from typing import Any, Callable, List, Optional
from llama_index.callbacks.base import CallbackManager
from llama_index.core.base_retriever import BaseRetriever
from llama_index.indices.document_summary.base import DocumentSummaryIndex
from llama_index.indices.utils import (
default_format_node_batch_fn,
default_parse_choice_select_answer_fn,
)
from llama_index.prompts import BasePromptTemplate
from llama_index.prompts.default_prompts import DEFAULT_CHOICE_SELECT_PROMPT
from llama_index.schema import NodeWithScore, QueryBundle
from llama_index.service_context import ServiceContext
from llama_index.vector_stores.types import VectorStoreQuery
logger = logging.getLogger(__name__)
class DocumentSummaryIndexLLMRetriever(BaseRetriever):
"""Document Summary Index LLM Retriever.
By default, select relevant summaries from index using LLM calls.
Args:
index (DocumentSummaryIndex): The index to retrieve from.
choice_select_prompt (Optional[BasePromptTemplate]): The prompt to use for selecting relevant summaries.
choice_batch_size (int): The number of summary nodes to send to LLM at a time.
choice_top_k (int): The number of summary nodes to retrieve.
format_node_batch_fn (Callable): Function to format a batch of nodes for LLM.
parse_choice_select_answer_fn (Callable): Function to parse LLM response.
service_context (ServiceContext): The service context to use.
"""
def __init__(
self,
index: DocumentSummaryIndex,
choice_select_prompt: Optional[BasePromptTemplate] = None,
choice_batch_size: int = 10,
choice_top_k: int = 1,
format_node_batch_fn: Optional[Callable] = None,
parse_choice_select_answer_fn: Optional[Callable] = None,
service_context: Optional[ServiceContext] = None,
callback_manager: Optional[CallbackManager] = None,
object_map: Optional[dict] = None,
verbose: bool = False,
**kwargs: Any,
) -> None:
self._index = index
self._choice_select_prompt = (
choice_select_prompt or DEFAULT_CHOICE_SELECT_PROMPT
)
self._choice_batch_size = choice_batch_size
self._choice_top_k = choice_top_k
self._format_node_batch_fn = (
format_node_batch_fn or default_format_node_batch_fn
)
self._parse_choice_select_answer_fn = (
parse_choice_select_answer_fn or default_parse_choice_select_answer_fn
)
self._service_context = service_context or index.service_context
super().__init__(
callback_manager=callback_manager, object_map=object_map, verbose=verbose
)
def _retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
"""Retrieve nodes."""
summary_ids = self._index.index_struct.summary_ids
all_summary_ids: List[str] = []
all_relevances: List[float] = []
for idx in range(0, len(summary_ids), self._choice_batch_size):
summary_ids_batch = summary_ids[idx : idx + self._choice_batch_size]
summary_nodes = self._index.docstore.get_nodes(summary_ids_batch)
query_str = query_bundle.query_str
fmt_batch_str = self._format_node_batch_fn(summary_nodes)
# call each batch independently
raw_response = self._service_context.llm.predict(
self._choice_select_prompt,
context_str=fmt_batch_str,
query_str=query_str,
)
raw_choices, relevances = self._parse_choice_select_answer_fn(
raw_response, len(summary_nodes)
)
choice_idxs = [choice - 1 for choice in raw_choices]
choice_summary_ids = [summary_ids_batch[ci] for ci in choice_idxs]
all_summary_ids.extend(choice_summary_ids)
all_relevances.extend(relevances)
zipped_list = list(zip(all_summary_ids, all_relevances))
sorted_list = sorted(zipped_list, key=lambda x: x[1], reverse=True)
top_k_list = sorted_list[: self._choice_top_k]
results = []
for summary_id, relevance in top_k_list:
node_ids = self._index.index_struct.summary_id_to_node_ids[summary_id]
nodes = self._index.docstore.get_nodes(node_ids)
results.extend([NodeWithScore(node=n, score=relevance) for n in nodes])
return results
class DocumentSummaryIndexEmbeddingRetriever(BaseRetriever):
"""Document Summary Index Embedding Retriever.
Args:
index (DocumentSummaryIndex): The index to retrieve from.
similarity_top_k (int): The number of summary nodes to retrieve.
"""
def __init__(
self,
index: DocumentSummaryIndex,
similarity_top_k: int = 1,
callback_manager: Optional[CallbackManager] = None,
object_map: Optional[dict] = None,
verbose: bool = False,
**kwargs: Any,
) -> None:
"""Init params."""
self._index = index
self._vector_store = self._index.vector_store
self._service_context = self._index.service_context
self._docstore = self._index.docstore
self._index_struct = self._index.index_struct
self._similarity_top_k = similarity_top_k
super().__init__(
callback_manager=callback_manager, object_map=object_map, verbose=verbose
)
def _retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
"""Retrieve nodes."""
if self._vector_store.is_embedding_query:
if query_bundle.embedding is None:
query_bundle.embedding = (
self._service_context.embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs
)
)
query = VectorStoreQuery(
query_embedding=query_bundle.embedding,
similarity_top_k=self._similarity_top_k,
)
query_result = self._vector_store.query(query)
top_k_summary_ids: List[str]
if query_result.ids is not None:
top_k_summary_ids = query_result.ids
elif query_result.nodes is not None:
top_k_summary_ids = [n.node_id for n in query_result.nodes]
else:
raise ValueError(
"Vector store query result should return "
"at least one of nodes or ids."
)
results = []
for summary_id in top_k_summary_ids:
node_ids = self._index_struct.summary_id_to_node_ids[summary_id]
nodes = self._docstore.get_nodes(node_ids)
results.extend([NodeWithScore(node=n) for n in nodes])
return results
# legacy, backward compatibility
DocumentSummaryIndexRetriever = DocumentSummaryIndexLLMRetriever