faiss_rag_enterprise/llama_index/composability/joint_qa_summary.py

99 lines
3.3 KiB
Python

"""Joint QA Summary graph."""
from typing import Optional, Sequence
from llama_index.indices.list.base import SummaryIndex
from llama_index.indices.vector_store import VectorStoreIndex
from llama_index.ingestion import run_transformations
from llama_index.query_engine.router_query_engine import RouterQueryEngine
from llama_index.schema import Document
from llama_index.service_context import ServiceContext
from llama_index.storage.storage_context import StorageContext
from llama_index.tools.query_engine import QueryEngineTool
DEFAULT_SUMMARY_TEXT = "Use this index for summarization queries"
DEFAULT_QA_TEXT = (
"Use this index for queries that require retrieval of specific "
"context from documents."
)
class QASummaryQueryEngineBuilder:
"""Joint QA Summary graph builder.
Can build a graph that provides a unified query interface
for both QA and summarization tasks.
NOTE: this is a beta feature. The API may change in the future.
Args:
docstore (BaseDocumentStore): A BaseDocumentStore to use for storing nodes.
service_context (ServiceContext): A ServiceContext to use for
building indices.
summary_text (str): Text to use for the summary index.
qa_text (str): Text to use for the QA index.
node_parser (NodeParser): A NodeParser to use for parsing.
"""
def __init__(
self,
storage_context: Optional[StorageContext] = None,
service_context: Optional[ServiceContext] = None,
summary_text: str = DEFAULT_SUMMARY_TEXT,
qa_text: str = DEFAULT_QA_TEXT,
) -> None:
"""Init params."""
self._storage_context = storage_context or StorageContext.from_defaults()
self._service_context = service_context or ServiceContext.from_defaults()
self._summary_text = summary_text
self._qa_text = qa_text
def build_from_documents(
self,
documents: Sequence[Document],
) -> RouterQueryEngine:
"""Build query engine."""
# parse nodes
nodes = run_transformations(
documents, self._service_context.transformations # type: ignore
)
# ingest nodes
self._storage_context.docstore.add_documents(nodes, allow_update=True)
# build indices
vector_index = VectorStoreIndex(
nodes,
service_context=self._service_context,
storage_context=self._storage_context,
)
summary_index = SummaryIndex(
nodes,
service_context=self._service_context,
storage_context=self._storage_context,
)
vector_query_engine = vector_index.as_query_engine(
service_context=self._service_context
)
list_query_engine = summary_index.as_query_engine(
service_context=self._service_context,
response_mode="tree_summarize",
)
# build query engine
return RouterQueryEngine.from_defaults(
query_engine_tools=[
QueryEngineTool.from_defaults(
vector_query_engine, description=self._qa_text
),
QueryEngineTool.from_defaults(
list_query_engine, description=self._summary_text
),
],
service_context=self._service_context,
select_multi=False,
)