124 lines
4.5 KiB
Python
124 lines
4.5 KiB
Python
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from llama_index.callbacks.schema import CBEventType, EventPayload
|
|
from llama_index.core.base_query_engine import BaseQueryEngine
|
|
from llama_index.core.response.schema import RESPONSE_TYPE
|
|
from llama_index.indices.composability.graph import ComposableGraph
|
|
from llama_index.schema import IndexNode, NodeWithScore, QueryBundle, TextNode
|
|
|
|
|
|
class ComposableGraphQueryEngine(BaseQueryEngine):
|
|
"""Composable graph query engine.
|
|
|
|
This query engine can operate over a ComposableGraph.
|
|
It can take in custom query engines for its sub-indices.
|
|
|
|
Args:
|
|
graph (ComposableGraph): A ComposableGraph object.
|
|
custom_query_engines (Optional[Dict[str, BaseQueryEngine]]): A dictionary of
|
|
custom query engines.
|
|
recursive (bool): Whether to recursively query the graph.
|
|
**kwargs: additional arguments to be passed to the underlying index query
|
|
engine.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
graph: ComposableGraph,
|
|
custom_query_engines: Optional[Dict[str, BaseQueryEngine]] = None,
|
|
recursive: bool = True,
|
|
**kwargs: Any
|
|
) -> None:
|
|
"""Init params."""
|
|
self._graph = graph
|
|
self._custom_query_engines = custom_query_engines or {}
|
|
self._kwargs = kwargs
|
|
|
|
# additional configs
|
|
self._recursive = recursive
|
|
callback_manager = self._graph.service_context.callback_manager
|
|
super().__init__(callback_manager)
|
|
|
|
def _get_prompt_modules(self) -> Dict[str, Any]:
|
|
"""Get prompt modules."""
|
|
return {}
|
|
|
|
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
|
return self._query_index(query_bundle, index_id=None, level=0)
|
|
|
|
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
|
|
return self._query_index(query_bundle, index_id=None, level=0)
|
|
|
|
def _query_index(
|
|
self,
|
|
query_bundle: QueryBundle,
|
|
index_id: Optional[str] = None,
|
|
level: int = 0,
|
|
) -> RESPONSE_TYPE:
|
|
"""Query a single index."""
|
|
index_id = index_id or self._graph.root_id
|
|
|
|
with self.callback_manager.event(
|
|
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
|
|
) as query_event:
|
|
# get query engine
|
|
if index_id in self._custom_query_engines:
|
|
query_engine = self._custom_query_engines[index_id]
|
|
else:
|
|
query_engine = self._graph.get_index(index_id).as_query_engine(
|
|
**self._kwargs
|
|
)
|
|
|
|
with self.callback_manager.event(
|
|
CBEventType.RETRIEVE,
|
|
payload={EventPayload.QUERY_STR: query_bundle.query_str},
|
|
) as retrieve_event:
|
|
nodes = query_engine.retrieve(query_bundle)
|
|
retrieve_event.on_end(payload={EventPayload.NODES: nodes})
|
|
|
|
if self._recursive:
|
|
# do recursion here
|
|
nodes_for_synthesis = []
|
|
additional_source_nodes = []
|
|
for node_with_score in nodes:
|
|
node_with_score, source_nodes = self._fetch_recursive_nodes(
|
|
node_with_score, query_bundle, level
|
|
)
|
|
nodes_for_synthesis.append(node_with_score)
|
|
additional_source_nodes.extend(source_nodes)
|
|
response = query_engine.synthesize(
|
|
query_bundle, nodes_for_synthesis, additional_source_nodes
|
|
)
|
|
else:
|
|
response = query_engine.synthesize(query_bundle, nodes)
|
|
|
|
query_event.on_end(payload={EventPayload.RESPONSE: response})
|
|
|
|
return response
|
|
|
|
def _fetch_recursive_nodes(
|
|
self,
|
|
node_with_score: NodeWithScore,
|
|
query_bundle: QueryBundle,
|
|
level: int,
|
|
) -> Tuple[NodeWithScore, List[NodeWithScore]]:
|
|
"""Fetch nodes.
|
|
|
|
Uses existing node if it's not an index node.
|
|
Otherwise fetch response from corresponding index.
|
|
|
|
"""
|
|
if isinstance(node_with_score.node, IndexNode):
|
|
index_node = node_with_score.node
|
|
# recursive call
|
|
response = self._query_index(query_bundle, index_node.index_id, level + 1)
|
|
|
|
new_node = TextNode(text=str(response))
|
|
new_node_with_score = NodeWithScore(
|
|
node=new_node, score=node_with_score.score
|
|
)
|
|
return new_node_with_score, response.source_nodes
|
|
else:
|
|
return node_with_score, []
|