298 lines
10 KiB
Python
298 lines
10 KiB
Python
"""Document summary index.
|
|
|
|
A data structure where LlamaIndex stores the summary per document, and maps
|
|
the summary to the underlying Nodes.
|
|
This summary can be used for retrieval.
|
|
|
|
"""
|
|
import logging
|
|
from collections import defaultdict
|
|
from enum import Enum
|
|
from typing import Any, Dict, List, Optional, Sequence, Union, cast
|
|
|
|
from llama_index.core.base_retriever import BaseRetriever
|
|
from llama_index.core.response.schema import Response
|
|
from llama_index.data_structs.document_summary import IndexDocumentSummary
|
|
from llama_index.indices.base import BaseIndex
|
|
from llama_index.indices.utils import embed_nodes
|
|
from llama_index.response_synthesizers import (
|
|
BaseSynthesizer,
|
|
ResponseMode,
|
|
get_response_synthesizer,
|
|
)
|
|
from llama_index.schema import (
|
|
BaseNode,
|
|
IndexNode,
|
|
NodeRelationship,
|
|
NodeWithScore,
|
|
RelatedNodeInfo,
|
|
TextNode,
|
|
)
|
|
from llama_index.service_context import ServiceContext
|
|
from llama_index.storage.docstore.types import RefDocInfo
|
|
from llama_index.storage.storage_context import StorageContext
|
|
from llama_index.utils import get_tqdm_iterable
|
|
from llama_index.vector_stores.types import VectorStore
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
DEFAULT_SUMMARY_QUERY = (
|
|
"Describe what the provided text is about. "
|
|
"Also describe some of the questions that this text can answer. "
|
|
)
|
|
|
|
|
|
class DocumentSummaryRetrieverMode(str, Enum):
|
|
EMBEDDING = "embedding"
|
|
LLM = "llm"
|
|
|
|
|
|
_RetrieverMode = DocumentSummaryRetrieverMode
|
|
|
|
|
|
class DocumentSummaryIndex(BaseIndex[IndexDocumentSummary]):
|
|
"""Document Summary Index.
|
|
|
|
Args:
|
|
response_synthesizer (BaseSynthesizer): A response synthesizer for generating
|
|
summaries.
|
|
summary_query (str): The query to use to generate the summary for each document.
|
|
show_progress (bool): Whether to show tqdm progress bars.
|
|
Defaults to False.
|
|
embed_summaries (bool): Whether to embed the summaries.
|
|
This is required for running the default embedding-based retriever.
|
|
Defaults to True.
|
|
|
|
"""
|
|
|
|
index_struct_cls = IndexDocumentSummary
|
|
|
|
def __init__(
|
|
self,
|
|
nodes: Optional[Sequence[BaseNode]] = None,
|
|
objects: Optional[Sequence[IndexNode]] = None,
|
|
index_struct: Optional[IndexDocumentSummary] = None,
|
|
service_context: Optional[ServiceContext] = None,
|
|
storage_context: Optional[StorageContext] = None,
|
|
response_synthesizer: Optional[BaseSynthesizer] = None,
|
|
summary_query: str = DEFAULT_SUMMARY_QUERY,
|
|
show_progress: bool = False,
|
|
embed_summaries: bool = True,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize params."""
|
|
self._response_synthesizer = response_synthesizer or get_response_synthesizer(
|
|
service_context=service_context, response_mode=ResponseMode.TREE_SUMMARIZE
|
|
)
|
|
self._summary_query = summary_query
|
|
self._embed_summaries = embed_summaries
|
|
|
|
super().__init__(
|
|
nodes=nodes,
|
|
index_struct=index_struct,
|
|
service_context=service_context,
|
|
storage_context=storage_context,
|
|
show_progress=show_progress,
|
|
objects=objects,
|
|
**kwargs,
|
|
)
|
|
|
|
@property
|
|
def vector_store(self) -> VectorStore:
|
|
return self._vector_store
|
|
|
|
def as_retriever(
|
|
self,
|
|
retriever_mode: Union[str, _RetrieverMode] = _RetrieverMode.EMBEDDING,
|
|
**kwargs: Any,
|
|
) -> BaseRetriever:
|
|
"""Get retriever.
|
|
|
|
Args:
|
|
retriever_mode (Union[str, DocumentSummaryRetrieverMode]): A retriever mode.
|
|
Defaults to DocumentSummaryRetrieverMode.EMBEDDING.
|
|
|
|
"""
|
|
from llama_index.indices.document_summary.retrievers import (
|
|
DocumentSummaryIndexEmbeddingRetriever,
|
|
DocumentSummaryIndexLLMRetriever,
|
|
)
|
|
|
|
LLMRetriever = DocumentSummaryIndexLLMRetriever
|
|
EmbeddingRetriever = DocumentSummaryIndexEmbeddingRetriever
|
|
|
|
if retriever_mode == _RetrieverMode.EMBEDDING:
|
|
if not self._embed_summaries:
|
|
raise ValueError(
|
|
"Cannot use embedding retriever if embed_summaries is False"
|
|
)
|
|
|
|
if "service_context" not in kwargs:
|
|
kwargs["service_context"] = self._service_context
|
|
return EmbeddingRetriever(self, object_map=self._object_map, **kwargs)
|
|
if retriever_mode == _RetrieverMode.LLM:
|
|
return LLMRetriever(self, object_map=self._object_map, **kwargs)
|
|
else:
|
|
raise ValueError(f"Unknown retriever mode: {retriever_mode}")
|
|
|
|
def get_document_summary(self, doc_id: str) -> str:
|
|
"""Get document summary by doc id.
|
|
|
|
Args:
|
|
doc_id (str): A document id.
|
|
|
|
"""
|
|
if doc_id not in self._index_struct.doc_id_to_summary_id:
|
|
raise ValueError(f"doc_id {doc_id} not in index")
|
|
summary_id = self._index_struct.doc_id_to_summary_id[doc_id]
|
|
return self.docstore.get_node(summary_id).get_content()
|
|
|
|
def _add_nodes_to_index(
|
|
self,
|
|
index_struct: IndexDocumentSummary,
|
|
nodes: Sequence[BaseNode],
|
|
show_progress: bool = False,
|
|
) -> None:
|
|
"""Add nodes to index."""
|
|
doc_id_to_nodes = defaultdict(list)
|
|
for node in nodes:
|
|
if node.ref_doc_id is None:
|
|
raise ValueError(
|
|
"ref_doc_id of node cannot be None when building a document "
|
|
"summary index"
|
|
)
|
|
doc_id_to_nodes[node.ref_doc_id].append(node)
|
|
|
|
summary_node_dict = {}
|
|
items = doc_id_to_nodes.items()
|
|
iterable_with_progress = get_tqdm_iterable(
|
|
items, show_progress, "Summarizing documents"
|
|
)
|
|
|
|
for doc_id, nodes in iterable_with_progress:
|
|
print(f"current doc id: {doc_id}")
|
|
nodes_with_scores = [NodeWithScore(node=n) for n in nodes]
|
|
# get the summary for each doc_id
|
|
summary_response = self._response_synthesizer.synthesize(
|
|
query=self._summary_query,
|
|
nodes=nodes_with_scores,
|
|
)
|
|
summary_response = cast(Response, summary_response)
|
|
summary_node_dict[doc_id] = TextNode(
|
|
text=summary_response.response,
|
|
relationships={
|
|
NodeRelationship.SOURCE: RelatedNodeInfo(node_id=doc_id)
|
|
},
|
|
)
|
|
self.docstore.add_documents([summary_node_dict[doc_id]])
|
|
logger.info(
|
|
f"> Generated summary for doc {doc_id}: " f"{summary_response.response}"
|
|
)
|
|
|
|
for doc_id, nodes in doc_id_to_nodes.items():
|
|
index_struct.add_summary_and_nodes(summary_node_dict[doc_id], nodes)
|
|
|
|
if self._embed_summaries:
|
|
embed_model = self._service_context.embed_model
|
|
summary_nodes = list(summary_node_dict.values())
|
|
id_to_embed_map = embed_nodes(
|
|
summary_nodes, embed_model, show_progress=show_progress
|
|
)
|
|
|
|
summary_nodes_with_embedding = []
|
|
for node in summary_nodes:
|
|
node_with_embedding = node.copy()
|
|
node_with_embedding.embedding = id_to_embed_map[node.node_id]
|
|
summary_nodes_with_embedding.append(node_with_embedding)
|
|
|
|
self._vector_store.add(summary_nodes_with_embedding)
|
|
|
|
def _build_index_from_nodes(
|
|
self, nodes: Sequence[BaseNode]
|
|
) -> IndexDocumentSummary:
|
|
"""Build index from nodes."""
|
|
# first get doc_id to nodes_dict, generate a summary for each doc_id,
|
|
# then build the index struct
|
|
index_struct = IndexDocumentSummary()
|
|
self._add_nodes_to_index(index_struct, nodes, self._show_progress)
|
|
return index_struct
|
|
|
|
def _insert(self, nodes: Sequence[BaseNode], **insert_kwargs: Any) -> None:
|
|
"""Insert a document."""
|
|
self._add_nodes_to_index(self._index_struct, nodes)
|
|
|
|
def _delete_node(self, node_id: str, **delete_kwargs: Any) -> None:
|
|
pass
|
|
|
|
def delete_nodes(
|
|
self,
|
|
node_ids: List[str],
|
|
delete_from_docstore: bool = False,
|
|
**delete_kwargs: Any,
|
|
) -> None:
|
|
"""Delete a list of nodes from the index.
|
|
|
|
Args:
|
|
node_ids (List[str]): A list of node_ids from the nodes to delete
|
|
|
|
"""
|
|
index_nodes = self._index_struct.node_id_to_summary_id.keys()
|
|
for node in node_ids:
|
|
if node not in index_nodes:
|
|
logger.warning(f"node_id {node} not found, will not be deleted.")
|
|
node_ids.remove(node)
|
|
|
|
self._index_struct.delete_nodes(node_ids)
|
|
|
|
remove_summary_ids = [
|
|
summary_id
|
|
for summary_id in self._index_struct.summary_id_to_node_ids
|
|
if len(self._index_struct.summary_id_to_node_ids[summary_id]) == 0
|
|
]
|
|
|
|
remove_docs = [
|
|
doc_id
|
|
for doc_id in self._index_struct.doc_id_to_summary_id
|
|
if self._index_struct.doc_id_to_summary_id[doc_id] in remove_summary_ids
|
|
]
|
|
|
|
for doc_id in remove_docs:
|
|
self.delete_ref_doc(doc_id)
|
|
|
|
def delete_ref_doc(
|
|
self, ref_doc_id: str, delete_from_docstore: bool = False, **delete_kwargs: Any
|
|
) -> None:
|
|
"""Delete a document from the index.
|
|
All nodes in the index related to the document will be deleted.
|
|
"""
|
|
ref_doc_info = self.docstore.get_ref_doc_info(ref_doc_id)
|
|
if ref_doc_info is None:
|
|
logger.warning(f"ref_doc_id {ref_doc_id} not found, nothing deleted.")
|
|
return
|
|
self._index_struct.delete(ref_doc_id)
|
|
self._vector_store.delete(ref_doc_id)
|
|
|
|
if delete_from_docstore:
|
|
self.docstore.delete_ref_doc(ref_doc_id, raise_error=False)
|
|
|
|
self._storage_context.index_store.add_index_struct(self._index_struct)
|
|
|
|
@property
|
|
def ref_doc_info(self) -> Dict[str, RefDocInfo]:
|
|
"""Retrieve a dict mapping of ingested documents and their nodes+metadata."""
|
|
ref_doc_ids = list(self._index_struct.doc_id_to_summary_id.keys())
|
|
|
|
all_ref_doc_info = {}
|
|
for ref_doc_id in ref_doc_ids:
|
|
ref_doc_info = self.docstore.get_ref_doc_info(ref_doc_id)
|
|
if not ref_doc_info:
|
|
continue
|
|
|
|
all_ref_doc_info[ref_doc_id] = ref_doc_info
|
|
return all_ref_doc_info
|
|
|
|
|
|
# legacy
|
|
GPTDocumentSummaryIndex = DocumentSummaryIndex
|