from typing import Any, List, Optional, Sequence from llama_index.bridge.pydantic import BaseModel from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.core.base_query_engine import BaseQueryEngine from llama_index.core.base_retriever import BaseRetriever from llama_index.core.response.schema import RESPONSE_TYPE from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.prompts import BasePromptTemplate from llama_index.prompts.mixin import PromptMixinType from llama_index.response_synthesizers import ( BaseSynthesizer, ResponseMode, get_response_synthesizer, ) from llama_index.schema import NodeWithScore, QueryBundle from llama_index.service_context import ServiceContext class RetrieverQueryEngine(BaseQueryEngine): """Retriever query engine. Args: retriever (BaseRetriever): A retriever object. response_synthesizer (Optional[BaseSynthesizer]): A BaseSynthesizer object. callback_manager (Optional[CallbackManager]): A callback manager. """ def __init__( self, retriever: BaseRetriever, response_synthesizer: Optional[BaseSynthesizer] = None, node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, callback_manager: Optional[CallbackManager] = None, ) -> None: self._retriever = retriever self._response_synthesizer = response_synthesizer or get_response_synthesizer( service_context=retriever.get_service_context(), callback_manager=callback_manager, ) self._node_postprocessors = node_postprocessors or [] callback_manager = callback_manager or CallbackManager([]) for node_postprocessor in self._node_postprocessors: node_postprocessor.callback_manager = callback_manager super().__init__(callback_manager) def _get_prompt_modules(self) -> PromptMixinType: """Get prompt sub-modules.""" return {"response_synthesizer": self._response_synthesizer} @classmethod def from_args( cls, retriever: BaseRetriever, response_synthesizer: Optional[BaseSynthesizer] = None, service_context: Optional[ServiceContext] = None, node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, # response synthesizer args response_mode: ResponseMode = ResponseMode.COMPACT, text_qa_template: Optional[BasePromptTemplate] = None, refine_template: Optional[BasePromptTemplate] = None, summary_template: Optional[BasePromptTemplate] = None, simple_template: Optional[BasePromptTemplate] = None, output_cls: Optional[BaseModel] = None, use_async: bool = False, streaming: bool = False, # class-specific args **kwargs: Any, ) -> "RetrieverQueryEngine": """Initialize a RetrieverQueryEngine object.". Args: retriever (BaseRetriever): A retriever object. service_context (Optional[ServiceContext]): A ServiceContext object. node_postprocessors (Optional[List[BaseNodePostprocessor]]): A list of node postprocessors. verbose (bool): Whether to print out debug info. response_mode (ResponseMode): A ResponseMode object. text_qa_template (Optional[BasePromptTemplate]): A BasePromptTemplate object. refine_template (Optional[BasePromptTemplate]): A BasePromptTemplate object. simple_template (Optional[BasePromptTemplate]): A BasePromptTemplate object. use_async (bool): Whether to use async. streaming (bool): Whether to use streaming. optimizer (Optional[BaseTokenUsageOptimizer]): A BaseTokenUsageOptimizer object. """ response_synthesizer = response_synthesizer or get_response_synthesizer( service_context=service_context, text_qa_template=text_qa_template, refine_template=refine_template, summary_template=summary_template, simple_template=simple_template, response_mode=response_mode, output_cls=output_cls, use_async=use_async, streaming=streaming, ) callback_manager = ( service_context.callback_manager if service_context else CallbackManager([]) ) return cls( retriever=retriever, response_synthesizer=response_synthesizer, callback_manager=callback_manager, node_postprocessors=node_postprocessors, ) def _apply_node_postprocessors( self, nodes: List[NodeWithScore], query_bundle: QueryBundle ) -> List[NodeWithScore]: for node_postprocessor in self._node_postprocessors: nodes = node_postprocessor.postprocess_nodes( nodes, query_bundle=query_bundle ) return nodes def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: nodes = self._retriever.retrieve(query_bundle) return self._apply_node_postprocessors(nodes, query_bundle=query_bundle) async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: nodes = await self._retriever.aretrieve(query_bundle) return self._apply_node_postprocessors(nodes, query_bundle=query_bundle) def with_retriever(self, retriever: BaseRetriever) -> "RetrieverQueryEngine": return RetrieverQueryEngine( retriever=retriever, response_synthesizer=self._response_synthesizer, callback_manager=self.callback_manager, node_postprocessors=self._node_postprocessors, ) def synthesize( self, query_bundle: QueryBundle, nodes: List[NodeWithScore], additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, ) -> RESPONSE_TYPE: return self._response_synthesizer.synthesize( query=query_bundle, nodes=nodes, additional_source_nodes=additional_source_nodes, ) async def asynthesize( self, query_bundle: QueryBundle, nodes: List[NodeWithScore], additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, ) -> RESPONSE_TYPE: return await self._response_synthesizer.asynthesize( query=query_bundle, nodes=nodes, additional_source_nodes=additional_source_nodes, ) def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: """Answer a query.""" with self.callback_manager.event( CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str} ) as query_event: nodes = self.retrieve(query_bundle) response = self._response_synthesizer.synthesize( query=query_bundle, nodes=nodes, ) query_event.on_end(payload={EventPayload.RESPONSE: response}) return response async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: """Answer a query.""" with self.callback_manager.event( CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str} ) as query_event: nodes = await self.aretrieve(query_bundle) response = await self._response_synthesizer.asynthesize( query=query_bundle, nodes=nodes, ) query_event.on_end(payload={EventPayload.RESPONSE: response}) return response @property def retriever(self) -> BaseRetriever: """Get the retriever object.""" return self._retriever