from typing import Any, List, Optional, Sequence from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.core.base_query_engine import BaseQueryEngine from llama_index.core.base_retriever import BaseRetriever from llama_index.core.response.schema import RESPONSE_TYPE from llama_index.indices.base import BaseGPTIndex from llama_index.node_parser import SentenceSplitter, TextSplitter from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.prompts import PromptTemplate from llama_index.prompts.base import BasePromptTemplate from llama_index.prompts.mixin import PromptMixinType from llama_index.response_synthesizers import ( BaseSynthesizer, ResponseMode, get_response_synthesizer, ) from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle, TextNode CITATION_QA_TEMPLATE = PromptTemplate( "Please provide an answer based solely on the provided sources. " "When referencing information from a source, " "cite the appropriate source(s) using their corresponding numbers. " "Every answer should include at least one source citation. " "Only cite a source when you are explicitly referencing it. " "If none of the sources are helpful, you should indicate that. " "For example:\n" "Source 1:\n" "The sky is red in the evening and blue in the morning.\n" "Source 2:\n" "Water is wet when the sky is red.\n" "Query: When is water wet?\n" "Answer: Water will be wet when the sky is red [2], " "which occurs in the evening [1].\n" "Now it's your turn. Below are several numbered sources of information:" "\n------\n" "{context_str}" "\n------\n" "Query: {query_str}\n" "Answer: " ) CITATION_REFINE_TEMPLATE = PromptTemplate( "Please provide an answer based solely on the provided sources. " "When referencing information from a source, " "cite the appropriate source(s) using their corresponding numbers. " "Every answer should include at least one source citation. " "Only cite a source when you are explicitly referencing it. " "If none of the sources are helpful, you should indicate that. " "For example:\n" "Source 1:\n" "The sky is red in the evening and blue in the morning.\n" "Source 2:\n" "Water is wet when the sky is red.\n" "Query: When is water wet?\n" "Answer: Water will be wet when the sky is red [2], " "which occurs in the evening [1].\n" "Now it's your turn. " "We have provided an existing answer: {existing_answer}" "Below are several numbered sources of information. " "Use them to refine the existing answer. " "If the provided sources are not helpful, you will repeat the existing answer." "\nBegin refining!" "\n------\n" "{context_msg}" "\n------\n" "Query: {query_str}\n" "Answer: " ) DEFAULT_CITATION_CHUNK_SIZE = 512 DEFAULT_CITATION_CHUNK_OVERLAP = 20 class CitationQueryEngine(BaseQueryEngine): """Citation query engine. Args: retriever (BaseRetriever): A retriever object. response_synthesizer (Optional[BaseSynthesizer]): A BaseSynthesizer object. citation_chunk_size (int): Size of citation chunks, default=512. Useful for controlling granularity of sources. citation_chunk_overlap (int): Overlap of citation nodes, default=20. text_splitter (Optional[TextSplitter]): A text splitter for creating citation source nodes. Default is a SentenceSplitter. callback_manager (Optional[CallbackManager]): A callback manager. metadata_mode (MetadataMode): A MetadataMode object that controls how metadata is included in the citation prompt. """ def __init__( self, retriever: BaseRetriever, response_synthesizer: Optional[BaseSynthesizer] = None, citation_chunk_size: int = DEFAULT_CITATION_CHUNK_SIZE, citation_chunk_overlap: int = DEFAULT_CITATION_CHUNK_OVERLAP, text_splitter: Optional[TextSplitter] = None, node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, callback_manager: Optional[CallbackManager] = None, metadata_mode: MetadataMode = MetadataMode.NONE, ) -> None: self.text_splitter = text_splitter or SentenceSplitter( chunk_size=citation_chunk_size, chunk_overlap=citation_chunk_overlap ) self._retriever = retriever self._response_synthesizer = response_synthesizer or get_response_synthesizer( service_context=retriever.get_service_context(), callback_manager=callback_manager, ) self._node_postprocessors = node_postprocessors or [] self._metadata_mode = metadata_mode callback_manager = callback_manager or CallbackManager() for node_postprocessor in self._node_postprocessors: node_postprocessor.callback_manager = callback_manager super().__init__(callback_manager) @classmethod def from_args( cls, index: BaseGPTIndex, response_synthesizer: Optional[BaseSynthesizer] = None, citation_chunk_size: int = DEFAULT_CITATION_CHUNK_SIZE, citation_chunk_overlap: int = DEFAULT_CITATION_CHUNK_OVERLAP, text_splitter: Optional[TextSplitter] = None, citation_qa_template: BasePromptTemplate = CITATION_QA_TEMPLATE, citation_refine_template: BasePromptTemplate = CITATION_REFINE_TEMPLATE, retriever: Optional[BaseRetriever] = None, node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, # response synthesizer args response_mode: ResponseMode = ResponseMode.COMPACT, use_async: bool = False, streaming: bool = False, # class-specific args metadata_mode: MetadataMode = MetadataMode.NONE, **kwargs: Any, ) -> "CitationQueryEngine": """Initialize a CitationQueryEngine object.". Args: index: (BastGPTIndex): index to use for querying citation_chunk_size (int): Size of citation chunks, default=512. Useful for controlling granularity of sources. citation_chunk_overlap (int): Overlap of citation nodes, default=20. text_splitter (Optional[TextSplitter]): A text splitter for creating citation source nodes. Default is a SentenceSplitter. citation_qa_template (BasePromptTemplate): Template for initial citation QA citation_refine_template (BasePromptTemplate): Template for citation refinement. retriever (BaseRetriever): A retriever object. service_context (Optional[ServiceContext]): A ServiceContext object. node_postprocessors (Optional[List[BaseNodePostprocessor]]): A list of node postprocessors. verbose (bool): Whether to print out debug info. response_mode (ResponseMode): A ResponseMode object. use_async (bool): Whether to use async. streaming (bool): Whether to use streaming. optimizer (Optional[BaseTokenUsageOptimizer]): A BaseTokenUsageOptimizer object. """ retriever = retriever or index.as_retriever(**kwargs) response_synthesizer = response_synthesizer or get_response_synthesizer( service_context=index.service_context, text_qa_template=citation_qa_template, refine_template=citation_refine_template, response_mode=response_mode, use_async=use_async, streaming=streaming, ) return cls( retriever=retriever, response_synthesizer=response_synthesizer, callback_manager=index.service_context.callback_manager, citation_chunk_size=citation_chunk_size, citation_chunk_overlap=citation_chunk_overlap, text_splitter=text_splitter, node_postprocessors=node_postprocessors, metadata_mode=metadata_mode, ) def _get_prompt_modules(self) -> PromptMixinType: """Get prompt sub-modules.""" return {"response_synthesizer": self._response_synthesizer} def _create_citation_nodes(self, nodes: List[NodeWithScore]) -> List[NodeWithScore]: """Modify retrieved nodes to be granular sources.""" new_nodes: List[NodeWithScore] = [] for node in nodes: text_chunks = self.text_splitter.split_text( node.node.get_content(metadata_mode=self._metadata_mode) ) for text_chunk in text_chunks: text = f"Source {len(new_nodes)+1}:\n{text_chunk}\n" new_node = NodeWithScore( node=TextNode.parse_obj(node.node), score=node.score ) new_node.node.text = text new_nodes.append(new_node) return new_nodes def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: nodes = self._retriever.retrieve(query_bundle) for postprocessor in self._node_postprocessors: nodes = postprocessor.postprocess_nodes(nodes, query_bundle=query_bundle) return nodes async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: nodes = await self._retriever.aretrieve(query_bundle) for postprocessor in self._node_postprocessors: nodes = postprocessor.postprocess_nodes(nodes, query_bundle=query_bundle) return nodes @property def retriever(self) -> BaseRetriever: """Get the retriever object.""" return self._retriever def synthesize( self, query_bundle: QueryBundle, nodes: List[NodeWithScore], additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, ) -> RESPONSE_TYPE: nodes = self._create_citation_nodes(nodes) return self._response_synthesizer.synthesize( query=query_bundle, nodes=nodes, additional_source_nodes=additional_source_nodes, ) async def asynthesize( self, query_bundle: QueryBundle, nodes: List[NodeWithScore], additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, ) -> RESPONSE_TYPE: nodes = self._create_citation_nodes(nodes) return await self._response_synthesizer.asynthesize( query=query_bundle, nodes=nodes, additional_source_nodes=additional_source_nodes, ) def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: """Answer a query.""" with self.callback_manager.event( CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str} ) as query_event: with self.callback_manager.event( CBEventType.RETRIEVE, payload={EventPayload.QUERY_STR: query_bundle.query_str}, ) as retrieve_event: nodes = self.retrieve(query_bundle) nodes = self._create_citation_nodes(nodes) retrieve_event.on_end(payload={EventPayload.NODES: nodes}) response = self._response_synthesizer.synthesize( query=query_bundle, nodes=nodes, ) query_event.on_end(payload={EventPayload.RESPONSE: response}) return response async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: """Answer a query.""" with self.callback_manager.event( CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str} ) as query_event: with self.callback_manager.event( CBEventType.RETRIEVE, payload={EventPayload.QUERY_STR: query_bundle.query_str}, ) as retrieve_event: nodes = await self.aretrieve(query_bundle) nodes = self._create_citation_nodes(nodes) retrieve_event.on_end(payload={EventPayload.NODES: nodes}) response = await self._response_synthesizer.asynthesize( query=query_bundle, nodes=nodes, ) query_event.on_end(payload={EventPayload.RESPONSE: response}) return response