143 lines
5.7 KiB
Python
143 lines
5.7 KiB
Python
"""Router retriever."""
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import List, Optional, Sequence
|
|
|
|
from llama_index.callbacks.schema import CBEventType, EventPayload
|
|
from llama_index.core.base_retriever import BaseRetriever
|
|
from llama_index.core.base_selector import BaseSelector
|
|
from llama_index.prompts.mixin import PromptMixinType
|
|
from llama_index.schema import IndexNode, NodeWithScore, QueryBundle
|
|
from llama_index.selectors.utils import get_selector_from_context
|
|
from llama_index.service_context import ServiceContext
|
|
from llama_index.tools.retriever_tool import RetrieverTool
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RouterRetriever(BaseRetriever):
|
|
"""Router retriever.
|
|
|
|
Selects one (or multiple) out of several candidate retrievers to execute a query.
|
|
|
|
Args:
|
|
selector (BaseSelector): A selector that chooses one out of many options based
|
|
on each candidate's metadata and query.
|
|
retriever_tools (Sequence[RetrieverTool]): A sequence of candidate
|
|
retrievers. They must be wrapped as tools to expose metadata to
|
|
the selector.
|
|
service_context (Optional[ServiceContext]): A service context.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
selector: BaseSelector,
|
|
retriever_tools: Sequence[RetrieverTool],
|
|
service_context: Optional[ServiceContext] = None,
|
|
objects: Optional[List[IndexNode]] = None,
|
|
object_map: Optional[dict] = None,
|
|
verbose: bool = False,
|
|
) -> None:
|
|
self.service_context = service_context or ServiceContext.from_defaults()
|
|
self._selector = selector
|
|
self._retrievers: List[BaseRetriever] = [x.retriever for x in retriever_tools]
|
|
self._metadatas = [x.metadata for x in retriever_tools]
|
|
|
|
super().__init__(
|
|
callback_manager=self.service_context.callback_manager,
|
|
object_map=object_map,
|
|
objects=objects,
|
|
verbose=verbose,
|
|
)
|
|
|
|
def _get_prompt_modules(self) -> PromptMixinType:
|
|
"""Get prompt sub-modules."""
|
|
# NOTE: don't include tools for now
|
|
return {"selector": self._selector}
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
retriever_tools: Sequence[RetrieverTool],
|
|
service_context: Optional[ServiceContext] = None,
|
|
selector: Optional[BaseSelector] = None,
|
|
select_multi: bool = False,
|
|
) -> "RouterRetriever":
|
|
selector = selector or get_selector_from_context(
|
|
service_context or ServiceContext.from_defaults(), is_multi=select_multi
|
|
)
|
|
|
|
return cls(
|
|
selector,
|
|
retriever_tools,
|
|
service_context=service_context,
|
|
)
|
|
|
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
with self.callback_manager.event(
|
|
CBEventType.RETRIEVE,
|
|
payload={EventPayload.QUERY_STR: query_bundle.query_str},
|
|
) as query_event:
|
|
result = self._selector.select(self._metadatas, query_bundle)
|
|
|
|
if len(result.inds) > 1:
|
|
retrieved_results = {}
|
|
for i, engine_ind in enumerate(result.inds):
|
|
logger.info(
|
|
f"Selecting retriever {engine_ind}: " f"{result.reasons[i]}."
|
|
)
|
|
selected_retriever = self._retrievers[engine_ind]
|
|
cur_results = selected_retriever.retrieve(query_bundle)
|
|
retrieved_results.update({n.node.node_id: n for n in cur_results})
|
|
else:
|
|
try:
|
|
selected_retriever = self._retrievers[result.ind]
|
|
logger.info(f"Selecting retriever {result.ind}: {result.reason}.")
|
|
except ValueError as e:
|
|
raise ValueError("Failed to select retriever") from e
|
|
|
|
cur_results = selected_retriever.retrieve(query_bundle)
|
|
retrieved_results = {n.node.node_id: n for n in cur_results}
|
|
|
|
query_event.on_end(payload={EventPayload.NODES: retrieved_results.values()})
|
|
|
|
return list(retrieved_results.values())
|
|
|
|
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
with self.callback_manager.event(
|
|
CBEventType.RETRIEVE,
|
|
payload={EventPayload.QUERY_STR: query_bundle.query_str},
|
|
) as query_event:
|
|
result = await self._selector.aselect(self._metadatas, query_bundle)
|
|
|
|
if len(result.inds) > 1:
|
|
retrieved_results = {}
|
|
tasks = []
|
|
for i, engine_ind in enumerate(result.inds):
|
|
logger.info(
|
|
f"Selecting retriever {engine_ind}: " f"{result.reasons[i]}."
|
|
)
|
|
selected_retriever = self._retrievers[engine_ind]
|
|
tasks.append(selected_retriever.aretrieve(query_bundle))
|
|
|
|
results_of_results = await asyncio.gather(*tasks)
|
|
cur_results = [
|
|
item for sublist in results_of_results for item in sublist
|
|
]
|
|
retrieved_results.update({n.node.node_id: n for n in cur_results})
|
|
else:
|
|
try:
|
|
selected_retriever = self._retrievers[result.ind]
|
|
logger.info(f"Selecting retriever {result.ind}: {result.reason}.")
|
|
except ValueError as e:
|
|
raise ValueError("Failed to select retriever") from e
|
|
|
|
cur_results = await selected_retriever.aretrieve(query_bundle)
|
|
retrieved_results = {n.node.node_id: n for n in cur_results}
|
|
|
|
query_event.on_end(payload={EventPayload.NODES: retrieved_results.values()})
|
|
|
|
return list(retrieved_results.values())
|