faiss_rag_enterprise/llama_index/query_engine/citation_query_engine.py

305 lines
12 KiB
Python

from typing import Any, List, Optional, Sequence
from llama_index.callbacks.base import CallbackManager
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.core.base_query_engine import BaseQueryEngine
from llama_index.core.base_retriever import BaseRetriever
from llama_index.core.response.schema import RESPONSE_TYPE
from llama_index.indices.base import BaseGPTIndex
from llama_index.node_parser import SentenceSplitter, TextSplitter
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.prompts import PromptTemplate
from llama_index.prompts.base import BasePromptTemplate
from llama_index.prompts.mixin import PromptMixinType
from llama_index.response_synthesizers import (
BaseSynthesizer,
ResponseMode,
get_response_synthesizer,
)
from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle, TextNode
CITATION_QA_TEMPLATE = PromptTemplate(
"Please provide an answer based solely on the provided sources. "
"When referencing information from a source, "
"cite the appropriate source(s) using their corresponding numbers. "
"Every answer should include at least one source citation. "
"Only cite a source when you are explicitly referencing it. "
"If none of the sources are helpful, you should indicate that. "
"For example:\n"
"Source 1:\n"
"The sky is red in the evening and blue in the morning.\n"
"Source 2:\n"
"Water is wet when the sky is red.\n"
"Query: When is water wet?\n"
"Answer: Water will be wet when the sky is red [2], "
"which occurs in the evening [1].\n"
"Now it's your turn. Below are several numbered sources of information:"
"\n------\n"
"{context_str}"
"\n------\n"
"Query: {query_str}\n"
"Answer: "
)
CITATION_REFINE_TEMPLATE = PromptTemplate(
"Please provide an answer based solely on the provided sources. "
"When referencing information from a source, "
"cite the appropriate source(s) using their corresponding numbers. "
"Every answer should include at least one source citation. "
"Only cite a source when you are explicitly referencing it. "
"If none of the sources are helpful, you should indicate that. "
"For example:\n"
"Source 1:\n"
"The sky is red in the evening and blue in the morning.\n"
"Source 2:\n"
"Water is wet when the sky is red.\n"
"Query: When is water wet?\n"
"Answer: Water will be wet when the sky is red [2], "
"which occurs in the evening [1].\n"
"Now it's your turn. "
"We have provided an existing answer: {existing_answer}"
"Below are several numbered sources of information. "
"Use them to refine the existing answer. "
"If the provided sources are not helpful, you will repeat the existing answer."
"\nBegin refining!"
"\n------\n"
"{context_msg}"
"\n------\n"
"Query: {query_str}\n"
"Answer: "
)
DEFAULT_CITATION_CHUNK_SIZE = 512
DEFAULT_CITATION_CHUNK_OVERLAP = 20
class CitationQueryEngine(BaseQueryEngine):
"""Citation query engine.
Args:
retriever (BaseRetriever): A retriever object.
response_synthesizer (Optional[BaseSynthesizer]):
A BaseSynthesizer object.
citation_chunk_size (int):
Size of citation chunks, default=512. Useful for controlling
granularity of sources.
citation_chunk_overlap (int): Overlap of citation nodes, default=20.
text_splitter (Optional[TextSplitter]):
A text splitter for creating citation source nodes. Default is
a SentenceSplitter.
callback_manager (Optional[CallbackManager]): A callback manager.
metadata_mode (MetadataMode): A MetadataMode object that controls how
metadata is included in the citation prompt.
"""
def __init__(
self,
retriever: BaseRetriever,
response_synthesizer: Optional[BaseSynthesizer] = None,
citation_chunk_size: int = DEFAULT_CITATION_CHUNK_SIZE,
citation_chunk_overlap: int = DEFAULT_CITATION_CHUNK_OVERLAP,
text_splitter: Optional[TextSplitter] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
callback_manager: Optional[CallbackManager] = None,
metadata_mode: MetadataMode = MetadataMode.NONE,
) -> None:
self.text_splitter = text_splitter or SentenceSplitter(
chunk_size=citation_chunk_size, chunk_overlap=citation_chunk_overlap
)
self._retriever = retriever
self._response_synthesizer = response_synthesizer or get_response_synthesizer(
service_context=retriever.get_service_context(),
callback_manager=callback_manager,
)
self._node_postprocessors = node_postprocessors or []
self._metadata_mode = metadata_mode
callback_manager = callback_manager or CallbackManager()
for node_postprocessor in self._node_postprocessors:
node_postprocessor.callback_manager = callback_manager
super().__init__(callback_manager)
@classmethod
def from_args(
cls,
index: BaseGPTIndex,
response_synthesizer: Optional[BaseSynthesizer] = None,
citation_chunk_size: int = DEFAULT_CITATION_CHUNK_SIZE,
citation_chunk_overlap: int = DEFAULT_CITATION_CHUNK_OVERLAP,
text_splitter: Optional[TextSplitter] = None,
citation_qa_template: BasePromptTemplate = CITATION_QA_TEMPLATE,
citation_refine_template: BasePromptTemplate = CITATION_REFINE_TEMPLATE,
retriever: Optional[BaseRetriever] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
# response synthesizer args
response_mode: ResponseMode = ResponseMode.COMPACT,
use_async: bool = False,
streaming: bool = False,
# class-specific args
metadata_mode: MetadataMode = MetadataMode.NONE,
**kwargs: Any,
) -> "CitationQueryEngine":
"""Initialize a CitationQueryEngine object.".
Args:
index: (BastGPTIndex): index to use for querying
citation_chunk_size (int):
Size of citation chunks, default=512. Useful for controlling
granularity of sources.
citation_chunk_overlap (int): Overlap of citation nodes, default=20.
text_splitter (Optional[TextSplitter]):
A text splitter for creating citation source nodes. Default is
a SentenceSplitter.
citation_qa_template (BasePromptTemplate): Template for initial citation QA
citation_refine_template (BasePromptTemplate):
Template for citation refinement.
retriever (BaseRetriever): A retriever object.
service_context (Optional[ServiceContext]): A ServiceContext object.
node_postprocessors (Optional[List[BaseNodePostprocessor]]): A list of
node postprocessors.
verbose (bool): Whether to print out debug info.
response_mode (ResponseMode): A ResponseMode object.
use_async (bool): Whether to use async.
streaming (bool): Whether to use streaming.
optimizer (Optional[BaseTokenUsageOptimizer]): A BaseTokenUsageOptimizer
object.
"""
retriever = retriever or index.as_retriever(**kwargs)
response_synthesizer = response_synthesizer or get_response_synthesizer(
service_context=index.service_context,
text_qa_template=citation_qa_template,
refine_template=citation_refine_template,
response_mode=response_mode,
use_async=use_async,
streaming=streaming,
)
return cls(
retriever=retriever,
response_synthesizer=response_synthesizer,
callback_manager=index.service_context.callback_manager,
citation_chunk_size=citation_chunk_size,
citation_chunk_overlap=citation_chunk_overlap,
text_splitter=text_splitter,
node_postprocessors=node_postprocessors,
metadata_mode=metadata_mode,
)
def _get_prompt_modules(self) -> PromptMixinType:
"""Get prompt sub-modules."""
return {"response_synthesizer": self._response_synthesizer}
def _create_citation_nodes(self, nodes: List[NodeWithScore]) -> List[NodeWithScore]:
"""Modify retrieved nodes to be granular sources."""
new_nodes: List[NodeWithScore] = []
for node in nodes:
text_chunks = self.text_splitter.split_text(
node.node.get_content(metadata_mode=self._metadata_mode)
)
for text_chunk in text_chunks:
text = f"Source {len(new_nodes)+1}:\n{text_chunk}\n"
new_node = NodeWithScore(
node=TextNode.parse_obj(node.node), score=node.score
)
new_node.node.text = text
new_nodes.append(new_node)
return new_nodes
def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
nodes = self._retriever.retrieve(query_bundle)
for postprocessor in self._node_postprocessors:
nodes = postprocessor.postprocess_nodes(nodes, query_bundle=query_bundle)
return nodes
async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
nodes = await self._retriever.aretrieve(query_bundle)
for postprocessor in self._node_postprocessors:
nodes = postprocessor.postprocess_nodes(nodes, query_bundle=query_bundle)
return nodes
@property
def retriever(self) -> BaseRetriever:
"""Get the retriever object."""
return self._retriever
def synthesize(
self,
query_bundle: QueryBundle,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
) -> RESPONSE_TYPE:
nodes = self._create_citation_nodes(nodes)
return self._response_synthesizer.synthesize(
query=query_bundle,
nodes=nodes,
additional_source_nodes=additional_source_nodes,
)
async def asynthesize(
self,
query_bundle: QueryBundle,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
) -> RESPONSE_TYPE:
nodes = self._create_citation_nodes(nodes)
return await self._response_synthesizer.asynthesize(
query=query_bundle,
nodes=nodes,
additional_source_nodes=additional_source_nodes,
)
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Answer a query."""
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
) as query_event:
with self.callback_manager.event(
CBEventType.RETRIEVE,
payload={EventPayload.QUERY_STR: query_bundle.query_str},
) as retrieve_event:
nodes = self.retrieve(query_bundle)
nodes = self._create_citation_nodes(nodes)
retrieve_event.on_end(payload={EventPayload.NODES: nodes})
response = self._response_synthesizer.synthesize(
query=query_bundle,
nodes=nodes,
)
query_event.on_end(payload={EventPayload.RESPONSE: response})
return response
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Answer a query."""
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
) as query_event:
with self.callback_manager.event(
CBEventType.RETRIEVE,
payload={EventPayload.QUERY_STR: query_bundle.query_str},
) as retrieve_event:
nodes = await self.aretrieve(query_bundle)
nodes = self._create_citation_nodes(nodes)
retrieve_event.on_end(payload={EventPayload.NODES: nodes})
response = await self._response_synthesizer.asynthesize(
query=query_bundle,
nodes=nodes,
)
query_event.on_end(payload={EventPayload.RESPONSE: response})
return response