faiss_rag_enterprise/llama_index/query_engine/retriever_query_engine.py

201 lines
7.7 KiB
Python

from typing import Any, List, Optional, Sequence
from llama_index.bridge.pydantic import BaseModel
from llama_index.callbacks.base import CallbackManager
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.core.base_query_engine import BaseQueryEngine
from llama_index.core.base_retriever import BaseRetriever
from llama_index.core.response.schema import RESPONSE_TYPE
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.prompts import BasePromptTemplate
from llama_index.prompts.mixin import PromptMixinType
from llama_index.response_synthesizers import (
BaseSynthesizer,
ResponseMode,
get_response_synthesizer,
)
from llama_index.schema import NodeWithScore, QueryBundle
from llama_index.service_context import ServiceContext
class RetrieverQueryEngine(BaseQueryEngine):
"""Retriever query engine.
Args:
retriever (BaseRetriever): A retriever object.
response_synthesizer (Optional[BaseSynthesizer]): A BaseSynthesizer
object.
callback_manager (Optional[CallbackManager]): A callback manager.
"""
def __init__(
self,
retriever: BaseRetriever,
response_synthesizer: Optional[BaseSynthesizer] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
callback_manager: Optional[CallbackManager] = None,
) -> None:
self._retriever = retriever
self._response_synthesizer = response_synthesizer or get_response_synthesizer(
service_context=retriever.get_service_context(),
callback_manager=callback_manager,
)
self._node_postprocessors = node_postprocessors or []
callback_manager = callback_manager or CallbackManager([])
for node_postprocessor in self._node_postprocessors:
node_postprocessor.callback_manager = callback_manager
super().__init__(callback_manager)
def _get_prompt_modules(self) -> PromptMixinType:
"""Get prompt sub-modules."""
return {"response_synthesizer": self._response_synthesizer}
@classmethod
def from_args(
cls,
retriever: BaseRetriever,
response_synthesizer: Optional[BaseSynthesizer] = None,
service_context: Optional[ServiceContext] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
# response synthesizer args
response_mode: ResponseMode = ResponseMode.COMPACT,
text_qa_template: Optional[BasePromptTemplate] = None,
refine_template: Optional[BasePromptTemplate] = None,
summary_template: Optional[BasePromptTemplate] = None,
simple_template: Optional[BasePromptTemplate] = None,
output_cls: Optional[BaseModel] = None,
use_async: bool = False,
streaming: bool = False,
# class-specific args
**kwargs: Any,
) -> "RetrieverQueryEngine":
"""Initialize a RetrieverQueryEngine object.".
Args:
retriever (BaseRetriever): A retriever object.
service_context (Optional[ServiceContext]): A ServiceContext object.
node_postprocessors (Optional[List[BaseNodePostprocessor]]): A list of
node postprocessors.
verbose (bool): Whether to print out debug info.
response_mode (ResponseMode): A ResponseMode object.
text_qa_template (Optional[BasePromptTemplate]): A BasePromptTemplate
object.
refine_template (Optional[BasePromptTemplate]): A BasePromptTemplate object.
simple_template (Optional[BasePromptTemplate]): A BasePromptTemplate object.
use_async (bool): Whether to use async.
streaming (bool): Whether to use streaming.
optimizer (Optional[BaseTokenUsageOptimizer]): A BaseTokenUsageOptimizer
object.
"""
response_synthesizer = response_synthesizer or get_response_synthesizer(
service_context=service_context,
text_qa_template=text_qa_template,
refine_template=refine_template,
summary_template=summary_template,
simple_template=simple_template,
response_mode=response_mode,
output_cls=output_cls,
use_async=use_async,
streaming=streaming,
)
callback_manager = (
service_context.callback_manager if service_context else CallbackManager([])
)
return cls(
retriever=retriever,
response_synthesizer=response_synthesizer,
callback_manager=callback_manager,
node_postprocessors=node_postprocessors,
)
def _apply_node_postprocessors(
self, nodes: List[NodeWithScore], query_bundle: QueryBundle
) -> List[NodeWithScore]:
for node_postprocessor in self._node_postprocessors:
nodes = node_postprocessor.postprocess_nodes(
nodes, query_bundle=query_bundle
)
return nodes
def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
nodes = self._retriever.retrieve(query_bundle)
return self._apply_node_postprocessors(nodes, query_bundle=query_bundle)
async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
nodes = await self._retriever.aretrieve(query_bundle)
return self._apply_node_postprocessors(nodes, query_bundle=query_bundle)
def with_retriever(self, retriever: BaseRetriever) -> "RetrieverQueryEngine":
return RetrieverQueryEngine(
retriever=retriever,
response_synthesizer=self._response_synthesizer,
callback_manager=self.callback_manager,
node_postprocessors=self._node_postprocessors,
)
def synthesize(
self,
query_bundle: QueryBundle,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
) -> RESPONSE_TYPE:
return self._response_synthesizer.synthesize(
query=query_bundle,
nodes=nodes,
additional_source_nodes=additional_source_nodes,
)
async def asynthesize(
self,
query_bundle: QueryBundle,
nodes: List[NodeWithScore],
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
) -> RESPONSE_TYPE:
return await self._response_synthesizer.asynthesize(
query=query_bundle,
nodes=nodes,
additional_source_nodes=additional_source_nodes,
)
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Answer a query."""
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
) as query_event:
nodes = self.retrieve(query_bundle)
response = self._response_synthesizer.synthesize(
query=query_bundle,
nodes=nodes,
)
query_event.on_end(payload={EventPayload.RESPONSE: response})
return response
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
"""Answer a query."""
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
) as query_event:
nodes = await self.aretrieve(query_bundle)
response = await self._response_synthesizer.asynthesize(
query=query_bundle,
nodes=nodes,
)
query_event.on_end(payload={EventPayload.RESPONSE: response})
return response
@property
def retriever(self) -> BaseRetriever:
"""Get the retriever object."""
return self._retriever