220 lines
7.7 KiB
Python
220 lines
7.7 KiB
Python
"""Retrievers for SummaryIndex."""
|
|
import logging
|
|
from typing import Any, Callable, List, Optional, Tuple
|
|
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.core.base_retriever import BaseRetriever
|
|
from llama_index.indices.list.base import SummaryIndex
|
|
from llama_index.indices.query.embedding_utils import get_top_k_embeddings
|
|
from llama_index.indices.utils import (
|
|
default_format_node_batch_fn,
|
|
default_parse_choice_select_answer_fn,
|
|
)
|
|
from llama_index.prompts import PromptTemplate
|
|
from llama_index.prompts.default_prompts import (
|
|
DEFAULT_CHOICE_SELECT_PROMPT,
|
|
)
|
|
from llama_index.schema import BaseNode, MetadataMode, NodeWithScore, QueryBundle
|
|
from llama_index.service_context import ServiceContext
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SummaryIndexRetriever(BaseRetriever):
|
|
"""Simple retriever for SummaryIndex that returns all nodes.
|
|
|
|
Args:
|
|
index (SummaryIndex): The index to retrieve from.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
index: SummaryIndex,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
object_map: Optional[dict] = None,
|
|
verbose: bool = False,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self._index = index
|
|
super().__init__(
|
|
callback_manager=callback_manager, object_map=object_map, verbose=verbose
|
|
)
|
|
|
|
def _retrieve(
|
|
self,
|
|
query_bundle: QueryBundle,
|
|
) -> List[NodeWithScore]:
|
|
"""Retrieve nodes."""
|
|
del query_bundle
|
|
|
|
node_ids = self._index.index_struct.nodes
|
|
nodes = self._index.docstore.get_nodes(node_ids)
|
|
return [NodeWithScore(node=node) for node in nodes]
|
|
|
|
|
|
class SummaryIndexEmbeddingRetriever(BaseRetriever):
|
|
"""Embedding based retriever for SummaryIndex.
|
|
|
|
Generates embeddings in a lazy fashion for all
|
|
nodes that are traversed.
|
|
|
|
Args:
|
|
index (SummaryIndex): The index to retrieve from.
|
|
similarity_top_k (Optional[int]): The number of top nodes to return.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
index: SummaryIndex,
|
|
similarity_top_k: Optional[int] = 1,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
object_map: Optional[dict] = None,
|
|
verbose: bool = False,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self._index = index
|
|
self._similarity_top_k = similarity_top_k
|
|
super().__init__(
|
|
callback_manager=callback_manager, object_map=object_map, verbose=verbose
|
|
)
|
|
|
|
def _retrieve(
|
|
self,
|
|
query_bundle: QueryBundle,
|
|
) -> List[NodeWithScore]:
|
|
"""Retrieve nodes."""
|
|
node_ids = self._index.index_struct.nodes
|
|
# top k nodes
|
|
nodes = self._index.docstore.get_nodes(node_ids)
|
|
query_embedding, node_embeddings = self._get_embeddings(query_bundle, nodes)
|
|
|
|
top_similarities, top_idxs = get_top_k_embeddings(
|
|
query_embedding,
|
|
node_embeddings,
|
|
similarity_top_k=self._similarity_top_k,
|
|
embedding_ids=list(range(len(nodes))),
|
|
)
|
|
|
|
top_k_nodes = [nodes[i] for i in top_idxs]
|
|
|
|
node_with_scores = []
|
|
for node, similarity in zip(top_k_nodes, top_similarities):
|
|
node_with_scores.append(NodeWithScore(node=node, score=similarity))
|
|
|
|
logger.debug(f"> Top {len(top_idxs)} nodes:\n")
|
|
nl = "\n"
|
|
logger.debug(f"{ nl.join([n.get_content() for n in top_k_nodes]) }")
|
|
return node_with_scores
|
|
|
|
def _get_embeddings(
|
|
self, query_bundle: QueryBundle, nodes: List[BaseNode]
|
|
) -> Tuple[List[float], List[List[float]]]:
|
|
"""Get top nodes by similarity to the query."""
|
|
if query_bundle.embedding is None:
|
|
query_bundle.embedding = (
|
|
self._index._service_context.embed_model.get_agg_embedding_from_queries(
|
|
query_bundle.embedding_strs
|
|
)
|
|
)
|
|
|
|
node_embeddings: List[List[float]] = []
|
|
nodes_embedded = 0
|
|
for node in nodes:
|
|
if node.embedding is None:
|
|
nodes_embedded += 1
|
|
node.embedding = (
|
|
self._index.service_context.embed_model.get_text_embedding(
|
|
node.get_content(metadata_mode=MetadataMode.EMBED)
|
|
)
|
|
)
|
|
|
|
node_embeddings.append(node.embedding)
|
|
return query_bundle.embedding, node_embeddings
|
|
|
|
|
|
class SummaryIndexLLMRetriever(BaseRetriever):
|
|
"""LLM retriever for SummaryIndex.
|
|
|
|
Args:
|
|
index (SummaryIndex): The index to retrieve from.
|
|
choice_select_prompt (Optional[PromptTemplate]): A Choice-Select Prompt
|
|
(see :ref:`Prompt-Templates`).)
|
|
choice_batch_size (int): The number of nodes to query at a time.
|
|
format_node_batch_fn (Optional[Callable]): A function that formats a
|
|
batch of nodes.
|
|
parse_choice_select_answer_fn (Optional[Callable]): A function that parses the
|
|
choice select answer.
|
|
service_context (Optional[ServiceContext]): A service context.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
index: SummaryIndex,
|
|
choice_select_prompt: Optional[PromptTemplate] = None,
|
|
choice_batch_size: int = 10,
|
|
format_node_batch_fn: Optional[Callable] = None,
|
|
parse_choice_select_answer_fn: Optional[Callable] = None,
|
|
service_context: Optional[ServiceContext] = None,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
object_map: Optional[dict] = None,
|
|
verbose: bool = False,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self._index = index
|
|
self._choice_select_prompt = (
|
|
choice_select_prompt or DEFAULT_CHOICE_SELECT_PROMPT
|
|
)
|
|
self._choice_batch_size = choice_batch_size
|
|
self._format_node_batch_fn = (
|
|
format_node_batch_fn or default_format_node_batch_fn
|
|
)
|
|
self._parse_choice_select_answer_fn = (
|
|
parse_choice_select_answer_fn or default_parse_choice_select_answer_fn
|
|
)
|
|
self._service_context = service_context or index.service_context
|
|
super().__init__(
|
|
callback_manager=callback_manager, object_map=object_map, verbose=verbose
|
|
)
|
|
|
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
"""Retrieve nodes."""
|
|
node_ids = self._index.index_struct.nodes
|
|
results = []
|
|
for idx in range(0, len(node_ids), self._choice_batch_size):
|
|
node_ids_batch = node_ids[idx : idx + self._choice_batch_size]
|
|
nodes_batch = self._index.docstore.get_nodes(node_ids_batch)
|
|
|
|
query_str = query_bundle.query_str
|
|
fmt_batch_str = self._format_node_batch_fn(nodes_batch)
|
|
# call each batch independently
|
|
raw_response = self._service_context.llm.predict(
|
|
self._choice_select_prompt,
|
|
context_str=fmt_batch_str,
|
|
query_str=query_str,
|
|
)
|
|
|
|
raw_choices, relevances = self._parse_choice_select_answer_fn(
|
|
raw_response, len(nodes_batch)
|
|
)
|
|
choice_idxs = [int(choice) - 1 for choice in raw_choices]
|
|
choice_node_ids = [node_ids_batch[idx] for idx in choice_idxs]
|
|
|
|
choice_nodes = self._index.docstore.get_nodes(choice_node_ids)
|
|
relevances = relevances or [1.0 for _ in choice_nodes]
|
|
results.extend(
|
|
[
|
|
NodeWithScore(node=node, score=relevance)
|
|
for node, relevance in zip(choice_nodes, relevances)
|
|
]
|
|
)
|
|
return results
|
|
|
|
|
|
# for backwards compatibility
|
|
ListIndexEmbeddingRetriever = SummaryIndexEmbeddingRetriever
|
|
ListIndexLLMRetriever = SummaryIndexLLMRetriever
|
|
ListIndexRetriever = SummaryIndexRetriever
|