201 lines
7.7 KiB
Python
201 lines
7.7 KiB
Python
from typing import Any, List, Optional, Sequence
|
|
|
|
from llama_index.bridge.pydantic import BaseModel
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.callbacks.schema import CBEventType, EventPayload
|
|
from llama_index.core.base_query_engine import BaseQueryEngine
|
|
from llama_index.core.base_retriever import BaseRetriever
|
|
from llama_index.core.response.schema import RESPONSE_TYPE
|
|
from llama_index.postprocessor.types import BaseNodePostprocessor
|
|
from llama_index.prompts import BasePromptTemplate
|
|
from llama_index.prompts.mixin import PromptMixinType
|
|
from llama_index.response_synthesizers import (
|
|
BaseSynthesizer,
|
|
ResponseMode,
|
|
get_response_synthesizer,
|
|
)
|
|
from llama_index.schema import NodeWithScore, QueryBundle
|
|
from llama_index.service_context import ServiceContext
|
|
|
|
|
|
class RetrieverQueryEngine(BaseQueryEngine):
|
|
"""Retriever query engine.
|
|
|
|
Args:
|
|
retriever (BaseRetriever): A retriever object.
|
|
response_synthesizer (Optional[BaseSynthesizer]): A BaseSynthesizer
|
|
object.
|
|
callback_manager (Optional[CallbackManager]): A callback manager.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
retriever: BaseRetriever,
|
|
response_synthesizer: Optional[BaseSynthesizer] = None,
|
|
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
) -> None:
|
|
self._retriever = retriever
|
|
self._response_synthesizer = response_synthesizer or get_response_synthesizer(
|
|
service_context=retriever.get_service_context(),
|
|
callback_manager=callback_manager,
|
|
)
|
|
|
|
self._node_postprocessors = node_postprocessors or []
|
|
callback_manager = callback_manager or CallbackManager([])
|
|
for node_postprocessor in self._node_postprocessors:
|
|
node_postprocessor.callback_manager = callback_manager
|
|
|
|
super().__init__(callback_manager)
|
|
|
|
def _get_prompt_modules(self) -> PromptMixinType:
|
|
"""Get prompt sub-modules."""
|
|
return {"response_synthesizer": self._response_synthesizer}
|
|
|
|
@classmethod
|
|
def from_args(
|
|
cls,
|
|
retriever: BaseRetriever,
|
|
response_synthesizer: Optional[BaseSynthesizer] = None,
|
|
service_context: Optional[ServiceContext] = None,
|
|
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
|
|
# response synthesizer args
|
|
response_mode: ResponseMode = ResponseMode.COMPACT,
|
|
text_qa_template: Optional[BasePromptTemplate] = None,
|
|
refine_template: Optional[BasePromptTemplate] = None,
|
|
summary_template: Optional[BasePromptTemplate] = None,
|
|
simple_template: Optional[BasePromptTemplate] = None,
|
|
output_cls: Optional[BaseModel] = None,
|
|
use_async: bool = False,
|
|
streaming: bool = False,
|
|
# class-specific args
|
|
**kwargs: Any,
|
|
) -> "RetrieverQueryEngine":
|
|
"""Initialize a RetrieverQueryEngine object.".
|
|
|
|
Args:
|
|
retriever (BaseRetriever): A retriever object.
|
|
service_context (Optional[ServiceContext]): A ServiceContext object.
|
|
node_postprocessors (Optional[List[BaseNodePostprocessor]]): A list of
|
|
node postprocessors.
|
|
verbose (bool): Whether to print out debug info.
|
|
response_mode (ResponseMode): A ResponseMode object.
|
|
text_qa_template (Optional[BasePromptTemplate]): A BasePromptTemplate
|
|
object.
|
|
refine_template (Optional[BasePromptTemplate]): A BasePromptTemplate object.
|
|
simple_template (Optional[BasePromptTemplate]): A BasePromptTemplate object.
|
|
|
|
use_async (bool): Whether to use async.
|
|
streaming (bool): Whether to use streaming.
|
|
optimizer (Optional[BaseTokenUsageOptimizer]): A BaseTokenUsageOptimizer
|
|
object.
|
|
|
|
"""
|
|
response_synthesizer = response_synthesizer or get_response_synthesizer(
|
|
service_context=service_context,
|
|
text_qa_template=text_qa_template,
|
|
refine_template=refine_template,
|
|
summary_template=summary_template,
|
|
simple_template=simple_template,
|
|
response_mode=response_mode,
|
|
output_cls=output_cls,
|
|
use_async=use_async,
|
|
streaming=streaming,
|
|
)
|
|
|
|
callback_manager = (
|
|
service_context.callback_manager if service_context else CallbackManager([])
|
|
)
|
|
|
|
return cls(
|
|
retriever=retriever,
|
|
response_synthesizer=response_synthesizer,
|
|
callback_manager=callback_manager,
|
|
node_postprocessors=node_postprocessors,
|
|
)
|
|
|
|
def _apply_node_postprocessors(
|
|
self, nodes: List[NodeWithScore], query_bundle: QueryBundle
|
|
) -> List[NodeWithScore]:
|
|
for node_postprocessor in self._node_postprocessors:
|
|
nodes = node_postprocessor.postprocess_nodes(
|
|
nodes, query_bundle=query_bundle
|
|
)
|
|
return nodes
|
|
|
|
def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
nodes = self._retriever.retrieve(query_bundle)
|
|
return self._apply_node_postprocessors(nodes, query_bundle=query_bundle)
|
|
|
|
async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
nodes = await self._retriever.aretrieve(query_bundle)
|
|
return self._apply_node_postprocessors(nodes, query_bundle=query_bundle)
|
|
|
|
def with_retriever(self, retriever: BaseRetriever) -> "RetrieverQueryEngine":
|
|
return RetrieverQueryEngine(
|
|
retriever=retriever,
|
|
response_synthesizer=self._response_synthesizer,
|
|
callback_manager=self.callback_manager,
|
|
node_postprocessors=self._node_postprocessors,
|
|
)
|
|
|
|
def synthesize(
|
|
self,
|
|
query_bundle: QueryBundle,
|
|
nodes: List[NodeWithScore],
|
|
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
|
|
) -> RESPONSE_TYPE:
|
|
return self._response_synthesizer.synthesize(
|
|
query=query_bundle,
|
|
nodes=nodes,
|
|
additional_source_nodes=additional_source_nodes,
|
|
)
|
|
|
|
async def asynthesize(
|
|
self,
|
|
query_bundle: QueryBundle,
|
|
nodes: List[NodeWithScore],
|
|
additional_source_nodes: Optional[Sequence[NodeWithScore]] = None,
|
|
) -> RESPONSE_TYPE:
|
|
return await self._response_synthesizer.asynthesize(
|
|
query=query_bundle,
|
|
nodes=nodes,
|
|
additional_source_nodes=additional_source_nodes,
|
|
)
|
|
|
|
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
|
"""Answer a query."""
|
|
with self.callback_manager.event(
|
|
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
|
|
) as query_event:
|
|
nodes = self.retrieve(query_bundle)
|
|
response = self._response_synthesizer.synthesize(
|
|
query=query_bundle,
|
|
nodes=nodes,
|
|
)
|
|
|
|
query_event.on_end(payload={EventPayload.RESPONSE: response})
|
|
|
|
return response
|
|
|
|
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
|
"""Answer a query."""
|
|
with self.callback_manager.event(
|
|
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
|
|
) as query_event:
|
|
nodes = await self.aretrieve(query_bundle)
|
|
|
|
response = await self._response_synthesizer.asynthesize(
|
|
query=query_bundle,
|
|
nodes=nodes,
|
|
)
|
|
|
|
query_event.on_end(payload={EventPayload.RESPONSE: response})
|
|
|
|
return response
|
|
|
|
@property
|
|
def retriever(self) -> BaseRetriever:
|
|
"""Get the retriever object."""
|
|
return self._retriever
|