181 lines
6.9 KiB
Python
181 lines
6.9 KiB
Python
# Auto Merging Retriever
|
|
|
|
import logging
|
|
from collections import defaultdict
|
|
from typing import Dict, List, Optional, Tuple, cast
|
|
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.core.base_retriever import BaseRetriever
|
|
from llama_index.indices.query.schema import QueryBundle
|
|
from llama_index.indices.utils import truncate_text
|
|
from llama_index.indices.vector_store.retrievers.retriever import VectorIndexRetriever
|
|
from llama_index.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle
|
|
from llama_index.storage.storage_context import StorageContext
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AutoMergingRetriever(BaseRetriever):
|
|
"""This retriever will try to merge context into parent context.
|
|
|
|
The retriever first retrieves chunks from a vector store.
|
|
Then, it will try to merge the chunks into a single context.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vector_retriever: VectorIndexRetriever,
|
|
storage_context: StorageContext,
|
|
simple_ratio_thresh: float = 0.5,
|
|
verbose: bool = False,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
object_map: Optional[dict] = None,
|
|
objects: Optional[List[IndexNode]] = None,
|
|
) -> None:
|
|
"""Init params."""
|
|
self._vector_retriever = vector_retriever
|
|
self._storage_context = storage_context
|
|
self._simple_ratio_thresh = simple_ratio_thresh
|
|
super().__init__(
|
|
callback_manager=callback_manager,
|
|
object_map=object_map,
|
|
objects=objects,
|
|
verbose=verbose,
|
|
)
|
|
|
|
def _get_parents_and_merge(
|
|
self, nodes: List[NodeWithScore]
|
|
) -> Tuple[List[NodeWithScore], bool]:
|
|
"""Get parents and merge nodes."""
|
|
# retrieve all parent nodes
|
|
parent_nodes: Dict[str, BaseNode] = {}
|
|
parent_cur_children_dict: Dict[str, List[NodeWithScore]] = defaultdict(list)
|
|
for node in nodes:
|
|
if node.node.parent_node is None:
|
|
continue
|
|
parent_node_info = node.node.parent_node
|
|
|
|
# Fetch actual parent node if doesn't exist in `parent_nodes` cache yet
|
|
parent_node_id = parent_node_info.node_id
|
|
if parent_node_id not in parent_nodes:
|
|
parent_node = self._storage_context.docstore.get_document(
|
|
parent_node_id
|
|
)
|
|
parent_nodes[parent_node_id] = cast(BaseNode, parent_node)
|
|
|
|
# add reference to child from parent
|
|
parent_cur_children_dict[parent_node_id].append(node)
|
|
|
|
# compute ratios and "merge" nodes
|
|
# merging: delete some children nodes, add some parent nodes
|
|
node_ids_to_delete = set()
|
|
nodes_to_add: Dict[str, BaseNode] = {}
|
|
for parent_node_id, parent_node in parent_nodes.items():
|
|
parent_child_nodes = parent_node.child_nodes
|
|
parent_num_children = len(parent_child_nodes) if parent_child_nodes else 1
|
|
parent_cur_children = parent_cur_children_dict[parent_node_id]
|
|
ratio = len(parent_cur_children) / parent_num_children
|
|
|
|
# if ratio is high enough, merge
|
|
if ratio > self._simple_ratio_thresh:
|
|
node_ids_to_delete.update(
|
|
set({n.node.node_id for n in parent_cur_children})
|
|
)
|
|
|
|
parent_node_text = truncate_text(parent_node.text, 100)
|
|
info_str = (
|
|
f"> Merging {len(parent_cur_children)} nodes into parent node.\n"
|
|
f"> Parent node id: {parent_node_id}.\n"
|
|
f"> Parent node text: {parent_node_text}\n"
|
|
)
|
|
logger.info(info_str)
|
|
if self._verbose:
|
|
print(info_str)
|
|
|
|
# add parent node
|
|
# can try averaging score across embeddings for now
|
|
|
|
avg_score = sum(
|
|
[n.get_score() or 0.0 for n in parent_cur_children]
|
|
) / len(parent_cur_children)
|
|
parent_node_with_score = NodeWithScore(
|
|
node=parent_node, score=avg_score
|
|
)
|
|
nodes_to_add[parent_node_id] = parent_node_with_score
|
|
|
|
# delete old child nodes, add new parent nodes
|
|
new_nodes = [n for n in nodes if n.node.node_id not in node_ids_to_delete]
|
|
# add parent nodes
|
|
new_nodes.extend(list(nodes_to_add.values()))
|
|
|
|
is_changed = len(node_ids_to_delete) > 0
|
|
|
|
return new_nodes, is_changed
|
|
|
|
def _fill_in_nodes(
|
|
self, nodes: List[NodeWithScore]
|
|
) -> Tuple[List[NodeWithScore], bool]:
|
|
"""Fill in nodes."""
|
|
new_nodes = []
|
|
is_changed = False
|
|
for idx, node in enumerate(nodes):
|
|
new_nodes.append(node)
|
|
if idx >= len(nodes) - 1:
|
|
continue
|
|
|
|
cur_node = cast(BaseNode, node.node)
|
|
# if there's a node in the middle, add that to the queue
|
|
if (
|
|
cur_node.next_node is not None
|
|
and cur_node.next_node == nodes[idx + 1].node.prev_node
|
|
):
|
|
is_changed = True
|
|
next_node = self._storage_context.docstore.get_document(
|
|
cur_node.next_node.node_id
|
|
)
|
|
next_node = cast(BaseNode, next_node)
|
|
|
|
next_node_text = truncate_text(next_node.get_text(), 100)
|
|
info_str = (
|
|
f"> Filling in node. Node id: {cur_node.next_node.node_id}"
|
|
f"> Node text: {next_node_text}\n"
|
|
)
|
|
logger.info(info_str)
|
|
if self._verbose:
|
|
print(info_str)
|
|
|
|
# set score to be average of current node and next node
|
|
avg_score = (node.get_score() + nodes[idx + 1].get_score()) / 2
|
|
new_nodes.append(NodeWithScore(node=next_node, score=avg_score))
|
|
return new_nodes, is_changed
|
|
|
|
def _try_merging(
|
|
self, nodes: List[NodeWithScore]
|
|
) -> Tuple[List[NodeWithScore], bool]:
|
|
"""Try different ways to merge nodes."""
|
|
# first try filling in nodes
|
|
nodes, is_changed_0 = self._fill_in_nodes(nodes)
|
|
# then try merging nodes
|
|
nodes, is_changed_1 = self._get_parents_and_merge(nodes)
|
|
return nodes, is_changed_0 or is_changed_1
|
|
|
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
"""Retrieve nodes given query.
|
|
|
|
Implemented by the user.
|
|
|
|
"""
|
|
initial_nodes = self._vector_retriever.retrieve(query_bundle)
|
|
|
|
cur_nodes, is_changed = self._try_merging(initial_nodes)
|
|
# cur_nodes, is_changed = self._get_parents_and_merge(initial_nodes)
|
|
while is_changed:
|
|
cur_nodes, is_changed = self._try_merging(cur_nodes)
|
|
# cur_nodes, is_changed = self._get_parents_and_merge(cur_nodes)
|
|
|
|
# sort by similarity
|
|
cur_nodes.sort(key=lambda x: x.get_score(), reverse=True)
|
|
|
|
return cur_nodes
|