132 lines
5.3 KiB
Python
132 lines
5.3 KiB
Python
from typing import Any, List, Optional
|
|
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.callbacks.schema import CBEventType, EventPayload
|
|
from llama_index.core.base_query_engine import BaseQueryEngine
|
|
from llama_index.core.base_retriever import BaseRetriever
|
|
from llama_index.core.response.schema import RESPONSE_TYPE, Response
|
|
from llama_index.indices.managed.vectara.retriever import VectaraRetriever
|
|
from llama_index.postprocessor.types import BaseNodePostprocessor
|
|
from llama_index.prompts.mixin import PromptDictType, PromptMixinType
|
|
from llama_index.schema import NodeWithScore, QueryBundle
|
|
|
|
|
|
class VectaraQueryEngine(BaseQueryEngine):
|
|
"""Retriever query engine for Vectara.
|
|
|
|
Args:
|
|
retriever (VectaraRetriever): A retriever object.
|
|
summary_response_lang: response language for summary (ISO 639-2 code)
|
|
summary_num_results: number of results to use for summary generation.
|
|
summary_prompt_name: name of the prompt to use for summary generation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
retriever: VectaraRetriever,
|
|
summary_enabled: bool = False,
|
|
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
summary_response_lang: str = "eng",
|
|
summary_num_results: int = 5,
|
|
summary_prompt_name: str = "vectara-experimental-summary-ext-2023-10-23-small",
|
|
) -> None:
|
|
self._retriever = retriever
|
|
self._summary_enabled = summary_enabled
|
|
self._summary_response_lang = summary_response_lang
|
|
self._summary_num_results = summary_num_results
|
|
self._summary_prompt_name = summary_prompt_name
|
|
self._node_postprocessors = node_postprocessors or []
|
|
super().__init__(callback_manager=callback_manager)
|
|
|
|
@classmethod
|
|
def from_args(
|
|
cls,
|
|
retriever: VectaraRetriever,
|
|
summary_enabled: bool = False,
|
|
summary_response_lang: str = "eng",
|
|
summary_num_results: int = 5,
|
|
summary_prompt_name: str = "vectara-experimental-summary-ext-2023-10-23-small",
|
|
**kwargs: Any,
|
|
) -> "VectaraQueryEngine":
|
|
"""Initialize a VectaraQueryEngine object.".
|
|
|
|
Args:
|
|
retriever (VectaraRetriever): A Vectara retriever object.
|
|
summary_response_lang: response language for summary (ISO 639-2 code)
|
|
summary_num_results: number of results to use for summary generation.
|
|
summary_prompt_name: name of the prompt to use for summary generation.
|
|
|
|
"""
|
|
return cls(
|
|
retriever=retriever,
|
|
summary_enabled=summary_enabled,
|
|
summary_response_lang=summary_response_lang,
|
|
summary_num_results=summary_num_results,
|
|
summary_prompt_name=summary_prompt_name,
|
|
)
|
|
|
|
def _apply_node_postprocessors(
|
|
self, nodes: List[NodeWithScore], query_bundle: QueryBundle
|
|
) -> List[NodeWithScore]:
|
|
for node_postprocessor in self._node_postprocessors:
|
|
nodes = node_postprocessor.postprocess_nodes(
|
|
nodes, query_bundle=query_bundle
|
|
)
|
|
return nodes
|
|
|
|
def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
nodes = self._retriever.retrieve(query_bundle)
|
|
return self._apply_node_postprocessors(nodes, query_bundle=query_bundle)
|
|
|
|
async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
nodes = await self._retriever.aretrieve(query_bundle)
|
|
return self._apply_node_postprocessors(nodes, query_bundle=query_bundle)
|
|
|
|
def with_retriever(self, retriever: VectaraRetriever) -> "VectaraQueryEngine":
|
|
return VectaraQueryEngine(
|
|
retriever=retriever,
|
|
summary_enabled=self._summary_enabled,
|
|
summary_response_lang=self._summary_response_lang,
|
|
summary_num_results=self._summary_num_results,
|
|
summary_prompt_name=self._summary_prompt_name,
|
|
)
|
|
|
|
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
|
"""Answer a query."""
|
|
with self.callback_manager.event(
|
|
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
|
|
) as query_event:
|
|
kwargs = (
|
|
{
|
|
"summary_response_lang": self._summary_response_lang,
|
|
"summary_num_results": self._summary_num_results,
|
|
"summary_prompt_name": self._summary_prompt_name,
|
|
}
|
|
if self._summary_enabled
|
|
else {}
|
|
)
|
|
nodes, response = self._retriever._vectara_query(query_bundle, **kwargs)
|
|
query_event.on_end(payload={EventPayload.RESPONSE: response})
|
|
return Response(response=response, source_nodes=nodes)
|
|
|
|
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
|
return self._query(query_bundle)
|
|
|
|
@property
|
|
def retriever(self) -> BaseRetriever:
|
|
"""Get the retriever object."""
|
|
return self._retriever
|
|
|
|
# required for PromptMixin
|
|
def _get_prompts(self) -> PromptDictType:
|
|
"""Get prompts."""
|
|
return {}
|
|
|
|
def _get_prompt_modules(self) -> PromptMixinType:
|
|
"""Get prompt modules."""
|
|
return {}
|
|
|
|
def _update_prompts(self, prompts: PromptDictType) -> None:
|
|
"""Update prompts."""
|