faiss_rag_enterprise/llama_index/indices/base.py

416 lines
16 KiB
Python

"""Base index classes."""
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, cast
from llama_index.chat_engine.types import BaseChatEngine, ChatMode
from llama_index.core.base_query_engine import BaseQueryEngine
from llama_index.core.base_retriever import BaseRetriever
from llama_index.data_structs.data_structs import IndexStruct
from llama_index.ingestion import run_transformations
from llama_index.schema import BaseNode, Document, IndexNode
from llama_index.service_context import ServiceContext
from llama_index.storage.docstore.types import BaseDocumentStore, RefDocInfo
from llama_index.storage.storage_context import StorageContext
IS = TypeVar("IS", bound=IndexStruct)
IndexType = TypeVar("IndexType", bound="BaseIndex")
logger = logging.getLogger(__name__)
class BaseIndex(Generic[IS], ABC):
"""Base LlamaIndex.
Args:
nodes (List[Node]): List of nodes to index
show_progress (bool): Whether to show tqdm progress bars. Defaults to False.
service_context (ServiceContext): Service context container (contains
components like LLM, Embeddings, etc.).
"""
index_struct_cls: Type[IS]
def __init__(
self,
nodes: Optional[Sequence[BaseNode]] = None,
objects: Optional[Sequence[IndexNode]] = None,
index_struct: Optional[IS] = None,
storage_context: Optional[StorageContext] = None,
service_context: Optional[ServiceContext] = None,
show_progress: bool = False,
**kwargs: Any,
) -> None:
"""Initialize with parameters."""
if index_struct is None and nodes is None and objects is None:
raise ValueError("One of nodes, objects, or index_struct must be provided.")
if index_struct is not None and nodes is not None:
raise ValueError("Only one of nodes or index_struct can be provided.")
# This is to explicitly make sure that the old UX is not used
if nodes is not None and len(nodes) >= 1 and not isinstance(nodes[0], BaseNode):
if isinstance(nodes[0], Document):
raise ValueError(
"The constructor now takes in a list of Node objects. "
"Since you are passing in a list of Document objects, "
"please use `from_documents` instead."
)
else:
raise ValueError("nodes must be a list of Node objects.")
self._service_context = service_context or ServiceContext.from_defaults()
self._storage_context = storage_context or StorageContext.from_defaults()
self._docstore = self._storage_context.docstore
self._show_progress = show_progress
self._vector_store = self._storage_context.vector_store
self._graph_store = self._storage_context.graph_store
objects = objects or []
self._object_map = {obj.index_id: obj.obj for obj in objects}
with self._service_context.callback_manager.as_trace("index_construction"):
if index_struct is None:
nodes = nodes or []
index_struct = self.build_index_from_nodes(
nodes + objects # type: ignore
)
self._index_struct = index_struct
self._storage_context.index_store.add_index_struct(self._index_struct)
@classmethod
def from_documents(
cls: Type[IndexType],
documents: Sequence[Document],
storage_context: Optional[StorageContext] = None,
service_context: Optional[ServiceContext] = None,
show_progress: bool = False,
**kwargs: Any,
) -> IndexType:
"""Create index from documents.
Args:
documents (Optional[Sequence[BaseDocument]]): List of documents to
build the index from.
"""
storage_context = storage_context or StorageContext.from_defaults()
service_context = service_context or ServiceContext.from_defaults()
docstore = storage_context.docstore
with service_context.callback_manager.as_trace("index_construction"):
for doc in documents:
docstore.set_document_hash(doc.get_doc_id(), doc.hash)
nodes = run_transformations(
documents, # type: ignore
service_context.transformations,
show_progress=show_progress,
**kwargs,
)
return cls(
nodes=nodes,
storage_context=storage_context,
service_context=service_context,
show_progress=show_progress,
**kwargs,
)
@property
def index_struct(self) -> IS:
"""Get the index struct."""
return self._index_struct
@property
def index_id(self) -> str:
"""Get the index struct."""
return self._index_struct.index_id
def set_index_id(self, index_id: str) -> None:
"""Set the index id.
NOTE: if you decide to set the index_id on the index_struct manually,
you will need to explicitly call `add_index_struct` on the `index_store`
to update the index store.
.. code-block:: python
index.index_struct.index_id = index_id
index.storage_context.index_store.add_index_struct(index.index_struct)
Args:
index_id (str): Index id to set.
"""
# delete the old index struct
old_id = self._index_struct.index_id
self._storage_context.index_store.delete_index_struct(old_id)
# add the new index struct
self._index_struct.index_id = index_id
self._storage_context.index_store.add_index_struct(self._index_struct)
@property
def docstore(self) -> BaseDocumentStore:
"""Get the docstore corresponding to the index."""
return self._docstore
@property
def service_context(self) -> ServiceContext:
return self._service_context
@property
def storage_context(self) -> StorageContext:
return self._storage_context
@property
def summary(self) -> str:
return str(self._index_struct.summary)
@summary.setter
def summary(self, new_summary: str) -> None:
self._index_struct.summary = new_summary
self._storage_context.index_store.add_index_struct(self._index_struct)
@abstractmethod
def _build_index_from_nodes(self, nodes: Sequence[BaseNode]) -> IS:
"""Build the index from nodes."""
def build_index_from_nodes(self, nodes: Sequence[BaseNode]) -> IS:
"""Build the index from nodes."""
self._docstore.add_documents(nodes, allow_update=True)
return self._build_index_from_nodes(nodes)
@abstractmethod
def _insert(self, nodes: Sequence[BaseNode], **insert_kwargs: Any) -> None:
"""Index-specific logic for inserting nodes to the index struct."""
def insert_nodes(self, nodes: Sequence[BaseNode], **insert_kwargs: Any) -> None:
"""Insert nodes."""
with self._service_context.callback_manager.as_trace("insert_nodes"):
self.docstore.add_documents(nodes, allow_update=True)
self._insert(nodes, **insert_kwargs)
self._storage_context.index_store.add_index_struct(self._index_struct)
def insert(self, document: Document, **insert_kwargs: Any) -> None:
"""Insert a document."""
with self._service_context.callback_manager.as_trace("insert"):
nodes = run_transformations(
[document],
self._service_context.transformations,
show_progress=self._show_progress,
)
self.insert_nodes(nodes, **insert_kwargs)
self.docstore.set_document_hash(document.get_doc_id(), document.hash)
@abstractmethod
def _delete_node(self, node_id: str, **delete_kwargs: Any) -> None:
"""Delete a node."""
def delete_nodes(
self,
node_ids: List[str],
delete_from_docstore: bool = False,
**delete_kwargs: Any,
) -> None:
"""Delete a list of nodes from the index.
Args:
doc_ids (List[str]): A list of doc_ids from the nodes to delete
"""
for node_id in node_ids:
self._delete_node(node_id, **delete_kwargs)
if delete_from_docstore:
self.docstore.delete_document(node_id, raise_error=False)
self._storage_context.index_store.add_index_struct(self._index_struct)
def delete(self, doc_id: str, **delete_kwargs: Any) -> None:
"""Delete a document from the index.
All nodes in the index related to the index will be deleted.
Args:
doc_id (str): A doc_id of the ingested document
"""
logger.warning(
"delete() is now deprecated, please refer to delete_ref_doc() to delete "
"ingested documents+nodes or delete_nodes to delete a list of nodes."
)
self.delete_ref_doc(doc_id)
def delete_ref_doc(
self, ref_doc_id: str, delete_from_docstore: bool = False, **delete_kwargs: Any
) -> None:
"""Delete a document and it's nodes by using ref_doc_id."""
ref_doc_info = self.docstore.get_ref_doc_info(ref_doc_id)
if ref_doc_info is None:
logger.warning(f"ref_doc_id {ref_doc_id} not found, nothing deleted.")
return
self.delete_nodes(
ref_doc_info.node_ids,
delete_from_docstore=False,
**delete_kwargs,
)
if delete_from_docstore:
self.docstore.delete_ref_doc(ref_doc_id, raise_error=False)
def update(self, document: Document, **update_kwargs: Any) -> None:
"""Update a document and it's corresponding nodes.
This is equivalent to deleting the document and then inserting it again.
Args:
document (Union[BaseDocument, BaseIndex]): document to update
insert_kwargs (Dict): kwargs to pass to insert
delete_kwargs (Dict): kwargs to pass to delete
"""
logger.warning(
"update() is now deprecated, please refer to update_ref_doc() to update "
"ingested documents+nodes."
)
self.update_ref_doc(document, **update_kwargs)
def update_ref_doc(self, document: Document, **update_kwargs: Any) -> None:
"""Update a document and it's corresponding nodes.
This is equivalent to deleting the document and then inserting it again.
Args:
document (Union[BaseDocument, BaseIndex]): document to update
insert_kwargs (Dict): kwargs to pass to insert
delete_kwargs (Dict): kwargs to pass to delete
"""
with self._service_context.callback_manager.as_trace("update"):
self.delete_ref_doc(
document.get_doc_id(),
delete_from_docstore=True,
**update_kwargs.pop("delete_kwargs", {}),
)
self.insert(document, **update_kwargs.pop("insert_kwargs", {}))
def refresh(
self, documents: Sequence[Document], **update_kwargs: Any
) -> List[bool]:
"""Refresh an index with documents that have changed.
This allows users to save LLM and Embedding model calls, while only
updating documents that have any changes in text or metadata. It
will also insert any documents that previously were not stored.
"""
logger.warning(
"refresh() is now deprecated, please refer to refresh_ref_docs() to "
"refresh ingested documents+nodes with an updated list of documents."
)
return self.refresh_ref_docs(documents, **update_kwargs)
def refresh_ref_docs(
self, documents: Sequence[Document], **update_kwargs: Any
) -> List[bool]:
"""Refresh an index with documents that have changed.
This allows users to save LLM and Embedding model calls, while only
updating documents that have any changes in text or metadata. It
will also insert any documents that previously were not stored.
"""
with self._service_context.callback_manager.as_trace("refresh"):
refreshed_documents = [False] * len(documents)
for i, document in enumerate(documents):
existing_doc_hash = self._docstore.get_document_hash(
document.get_doc_id()
)
if existing_doc_hash is None:
self.insert(document, **update_kwargs.pop("insert_kwargs", {}))
refreshed_documents[i] = True
elif existing_doc_hash != document.hash:
self.update_ref_doc(
document, **update_kwargs.pop("update_kwargs", {})
)
refreshed_documents[i] = True
return refreshed_documents
@property
@abstractmethod
def ref_doc_info(self) -> Dict[str, RefDocInfo]:
"""Retrieve a dict mapping of ingested documents and their nodes+metadata."""
...
@abstractmethod
def as_retriever(self, **kwargs: Any) -> BaseRetriever:
...
def as_query_engine(self, **kwargs: Any) -> BaseQueryEngine:
# NOTE: lazy import
from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine
retriever = self.as_retriever(**kwargs)
kwargs["retriever"] = retriever
if "service_context" not in kwargs:
kwargs["service_context"] = self._service_context
return RetrieverQueryEngine.from_args(**kwargs)
def as_chat_engine(
self, chat_mode: ChatMode = ChatMode.BEST, **kwargs: Any
) -> BaseChatEngine:
query_engine = self.as_query_engine(**kwargs)
if "service_context" not in kwargs:
kwargs["service_context"] = self._service_context
# resolve chat mode
if chat_mode in [ChatMode.REACT, ChatMode.OPENAI, ChatMode.BEST]:
# use an agent with query engine tool in these chat modes
# NOTE: lazy import
from llama_index.agent import AgentRunner
from llama_index.tools.query_engine import QueryEngineTool
# get LLM
service_context = cast(ServiceContext, kwargs["service_context"])
llm = service_context.llm
# convert query engine to tool
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
return AgentRunner.from_llm(tools=[query_engine_tool], llm=llm, **kwargs)
if chat_mode == ChatMode.CONDENSE_QUESTION:
# NOTE: lazy import
from llama_index.chat_engine import CondenseQuestionChatEngine
return CondenseQuestionChatEngine.from_defaults(
query_engine=query_engine,
**kwargs,
)
elif chat_mode == ChatMode.CONTEXT:
from llama_index.chat_engine import ContextChatEngine
return ContextChatEngine.from_defaults(
retriever=self.as_retriever(**kwargs),
**kwargs,
)
elif chat_mode == ChatMode.CONDENSE_PLUS_CONTEXT:
from llama_index.chat_engine import CondensePlusContextChatEngine
return CondensePlusContextChatEngine.from_defaults(
retriever=self.as_retriever(**kwargs),
**kwargs,
)
elif chat_mode == ChatMode.SIMPLE:
from llama_index.chat_engine import SimpleChatEngine
return SimpleChatEngine.from_defaults(
**kwargs,
)
else:
raise ValueError(f"Unknown chat mode: {chat_mode}")
# legacy
BaseGPTIndex = BaseIndex