faiss_rag_enterprise/llama_index/retrievers/router_retriever.py

143 lines
5.7 KiB
Python

"""Router retriever."""
import asyncio
import logging
from typing import List, Optional, Sequence
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.core.base_retriever import BaseRetriever
from llama_index.core.base_selector import BaseSelector
from llama_index.prompts.mixin import PromptMixinType
from llama_index.schema import IndexNode, NodeWithScore, QueryBundle
from llama_index.selectors.utils import get_selector_from_context
from llama_index.service_context import ServiceContext
from llama_index.tools.retriever_tool import RetrieverTool
logger = logging.getLogger(__name__)
class RouterRetriever(BaseRetriever):
"""Router retriever.
Selects one (or multiple) out of several candidate retrievers to execute a query.
Args:
selector (BaseSelector): A selector that chooses one out of many options based
on each candidate's metadata and query.
retriever_tools (Sequence[RetrieverTool]): A sequence of candidate
retrievers. They must be wrapped as tools to expose metadata to
the selector.
service_context (Optional[ServiceContext]): A service context.
"""
def __init__(
self,
selector: BaseSelector,
retriever_tools: Sequence[RetrieverTool],
service_context: Optional[ServiceContext] = None,
objects: Optional[List[IndexNode]] = None,
object_map: Optional[dict] = None,
verbose: bool = False,
) -> None:
self.service_context = service_context or ServiceContext.from_defaults()
self._selector = selector
self._retrievers: List[BaseRetriever] = [x.retriever for x in retriever_tools]
self._metadatas = [x.metadata for x in retriever_tools]
super().__init__(
callback_manager=self.service_context.callback_manager,
object_map=object_map,
objects=objects,
verbose=verbose,
)
def _get_prompt_modules(self) -> PromptMixinType:
"""Get prompt sub-modules."""
# NOTE: don't include tools for now
return {"selector": self._selector}
@classmethod
def from_defaults(
cls,
retriever_tools: Sequence[RetrieverTool],
service_context: Optional[ServiceContext] = None,
selector: Optional[BaseSelector] = None,
select_multi: bool = False,
) -> "RouterRetriever":
selector = selector or get_selector_from_context(
service_context or ServiceContext.from_defaults(), is_multi=select_multi
)
return cls(
selector,
retriever_tools,
service_context=service_context,
)
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
with self.callback_manager.event(
CBEventType.RETRIEVE,
payload={EventPayload.QUERY_STR: query_bundle.query_str},
) as query_event:
result = self._selector.select(self._metadatas, query_bundle)
if len(result.inds) > 1:
retrieved_results = {}
for i, engine_ind in enumerate(result.inds):
logger.info(
f"Selecting retriever {engine_ind}: " f"{result.reasons[i]}."
)
selected_retriever = self._retrievers[engine_ind]
cur_results = selected_retriever.retrieve(query_bundle)
retrieved_results.update({n.node.node_id: n for n in cur_results})
else:
try:
selected_retriever = self._retrievers[result.ind]
logger.info(f"Selecting retriever {result.ind}: {result.reason}.")
except ValueError as e:
raise ValueError("Failed to select retriever") from e
cur_results = selected_retriever.retrieve(query_bundle)
retrieved_results = {n.node.node_id: n for n in cur_results}
query_event.on_end(payload={EventPayload.NODES: retrieved_results.values()})
return list(retrieved_results.values())
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
with self.callback_manager.event(
CBEventType.RETRIEVE,
payload={EventPayload.QUERY_STR: query_bundle.query_str},
) as query_event:
result = await self._selector.aselect(self._metadatas, query_bundle)
if len(result.inds) > 1:
retrieved_results = {}
tasks = []
for i, engine_ind in enumerate(result.inds):
logger.info(
f"Selecting retriever {engine_ind}: " f"{result.reasons[i]}."
)
selected_retriever = self._retrievers[engine_ind]
tasks.append(selected_retriever.aretrieve(query_bundle))
results_of_results = await asyncio.gather(*tasks)
cur_results = [
item for sublist in results_of_results for item in sublist
]
retrieved_results.update({n.node.node_id: n for n in cur_results})
else:
try:
selected_retriever = self._retrievers[result.ind]
logger.info(f"Selecting retriever {result.ind}: {result.reason}.")
except ValueError as e:
raise ValueError("Failed to select retriever") from e
cur_results = await selected_retriever.aretrieve(query_bundle)
retrieved_results = {n.node.node_id: n for n in cur_results}
query_event.on_end(payload={EventPayload.NODES: retrieved_results.values()})
return list(retrieved_results.values())