210 lines
7.6 KiB
Python
210 lines
7.6 KiB
Python
import asyncio
|
|
from enum import Enum
|
|
from typing import Dict, List, Optional, Tuple, cast
|
|
|
|
from llama_index.async_utils import run_async_tasks
|
|
from llama_index.callbacks.base import CallbackManager
|
|
from llama_index.constants import DEFAULT_SIMILARITY_TOP_K
|
|
from llama_index.llms.utils import LLMType, resolve_llm
|
|
from llama_index.prompts import PromptTemplate
|
|
from llama_index.prompts.mixin import PromptDictType
|
|
from llama_index.retrievers import BaseRetriever
|
|
from llama_index.schema import IndexNode, NodeWithScore, QueryBundle
|
|
|
|
QUERY_GEN_PROMPT = (
|
|
"You are a helpful assistant that generates multiple search queries based on a "
|
|
"single input query. Generate {num_queries} search queries, one on each line, "
|
|
"related to the following input query:\n"
|
|
"Query: {query}\n"
|
|
"Queries:\n"
|
|
)
|
|
|
|
|
|
class FUSION_MODES(str, Enum):
|
|
"""Enum for different fusion modes."""
|
|
|
|
RECIPROCAL_RANK = "reciprocal_rerank" # apply reciprocal rank fusion
|
|
SIMPLE = "simple" # simple re-ordering of results based on original scores
|
|
|
|
|
|
class QueryFusionRetriever(BaseRetriever):
|
|
def __init__(
|
|
self,
|
|
retrievers: List[BaseRetriever],
|
|
llm: Optional[LLMType] = "default",
|
|
query_gen_prompt: Optional[str] = None,
|
|
mode: FUSION_MODES = FUSION_MODES.SIMPLE,
|
|
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
|
num_queries: int = 4,
|
|
use_async: bool = True,
|
|
verbose: bool = False,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
objects: Optional[List[IndexNode]] = None,
|
|
object_map: Optional[dict] = None,
|
|
) -> None:
|
|
self.num_queries = num_queries
|
|
self.query_gen_prompt = query_gen_prompt or QUERY_GEN_PROMPT
|
|
self.similarity_top_k = similarity_top_k
|
|
self.mode = mode
|
|
self.use_async = use_async
|
|
|
|
self._retrievers = retrievers
|
|
self._llm = resolve_llm(llm)
|
|
super().__init__(
|
|
callback_manager=callback_manager,
|
|
object_map=object_map,
|
|
objects=objects,
|
|
verbose=verbose,
|
|
)
|
|
|
|
def _get_prompts(self) -> PromptDictType:
|
|
"""Get prompts."""
|
|
return {"query_gen_prompt": PromptTemplate(self.query_gen_prompt)}
|
|
|
|
def _update_prompts(self, prompts: PromptDictType) -> None:
|
|
"""Update prompts."""
|
|
if "query_gen_prompt" in prompts:
|
|
self.query_gen_prompt = cast(
|
|
PromptTemplate, prompts["query_gen_prompt"]
|
|
).template
|
|
|
|
def _get_queries(self, original_query: str) -> List[str]:
|
|
prompt_str = self.query_gen_prompt.format(
|
|
num_queries=self.num_queries - 1,
|
|
query=original_query,
|
|
)
|
|
response = self._llm.complete(prompt_str)
|
|
|
|
# assume LLM proper put each query on a newline
|
|
queries = response.text.split("\n")
|
|
if self._verbose:
|
|
queries_str = "\n".join(queries)
|
|
print(f"Generated queries:\n{queries_str}")
|
|
return response.text.split("\n")
|
|
|
|
def _reciprocal_rerank_fusion(
|
|
self, results: Dict[Tuple[str, int], List[NodeWithScore]]
|
|
) -> List[NodeWithScore]:
|
|
"""Apply reciprocal rank fusion.
|
|
|
|
The original paper uses k=60 for best results:
|
|
https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
|
|
"""
|
|
k = 60.0 # `k` is a parameter used to control the impact of outlier rankings.
|
|
fused_scores = {}
|
|
text_to_node = {}
|
|
|
|
# compute reciprocal rank scores
|
|
for nodes_with_scores in results.values():
|
|
for rank, node_with_score in enumerate(
|
|
sorted(nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True)
|
|
):
|
|
text = node_with_score.node.get_content()
|
|
text_to_node[text] = node_with_score
|
|
if text not in fused_scores:
|
|
fused_scores[text] = 0.0
|
|
fused_scores[text] += 1.0 / (rank + k)
|
|
|
|
# sort results
|
|
reranked_results = dict(
|
|
sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
|
|
)
|
|
|
|
# adjust node scores
|
|
reranked_nodes: List[NodeWithScore] = []
|
|
for text, score in reranked_results.items():
|
|
reranked_nodes.append(text_to_node[text])
|
|
reranked_nodes[-1].score = score
|
|
|
|
return reranked_nodes
|
|
|
|
def _simple_fusion(
|
|
self, results: Dict[Tuple[str, int], List[NodeWithScore]]
|
|
) -> List[NodeWithScore]:
|
|
"""Apply simple fusion."""
|
|
# Use a dict to de-duplicate nodes
|
|
all_nodes: Dict[str, NodeWithScore] = {}
|
|
for nodes_with_scores in results.values():
|
|
for node_with_score in nodes_with_scores:
|
|
text = node_with_score.node.get_content()
|
|
all_nodes[text] = node_with_score
|
|
|
|
return sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True)
|
|
|
|
def _run_nested_async_queries(
|
|
self, queries: List[str]
|
|
) -> Dict[Tuple[str, int], List[NodeWithScore]]:
|
|
tasks, task_queries = [], []
|
|
for query in queries:
|
|
for i, retriever in enumerate(self._retrievers):
|
|
tasks.append(retriever.aretrieve(query))
|
|
task_queries.append(query)
|
|
|
|
task_results = run_async_tasks(tasks)
|
|
|
|
results = {}
|
|
for i, (query, query_result) in enumerate(zip(task_queries, task_results)):
|
|
results[(query, i)] = query_result
|
|
|
|
return results
|
|
|
|
async def _run_async_queries(
|
|
self, queries: List[str]
|
|
) -> Dict[Tuple[str, int], List[NodeWithScore]]:
|
|
tasks, task_queries = [], []
|
|
for query in queries:
|
|
for i, retriever in enumerate(self._retrievers):
|
|
tasks.append(retriever.aretrieve(query))
|
|
task_queries.append(query)
|
|
|
|
task_results = await asyncio.gather(*tasks)
|
|
|
|
results = {}
|
|
for i, (query, query_result) in enumerate(zip(task_queries, task_results)):
|
|
results[(query, i)] = query_result
|
|
|
|
return results
|
|
|
|
def _run_sync_queries(
|
|
self, queries: List[str]
|
|
) -> Dict[Tuple[str, int], List[NodeWithScore]]:
|
|
results = {}
|
|
for query in queries:
|
|
for i, retriever in enumerate(self._retrievers):
|
|
results[(query, i)] = retriever.retrieve(query)
|
|
|
|
return results
|
|
|
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
if self.num_queries > 1:
|
|
queries = self._get_queries(query_bundle.query_str)
|
|
else:
|
|
queries = [query_bundle.query_str]
|
|
|
|
if self.use_async:
|
|
results = self._run_nested_async_queries(queries)
|
|
else:
|
|
results = self._run_sync_queries(queries)
|
|
|
|
if self.mode == FUSION_MODES.RECIPROCAL_RANK:
|
|
return self._reciprocal_rerank_fusion(results)[: self.similarity_top_k]
|
|
elif self.mode == FUSION_MODES.SIMPLE:
|
|
return self._simple_fusion(results)[: self.similarity_top_k]
|
|
else:
|
|
raise ValueError(f"Invalid fusion mode: {self.mode}")
|
|
|
|
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
if self.num_queries > 1:
|
|
queries = self._get_queries(query_bundle.query_str)
|
|
else:
|
|
queries = [query_bundle.query_str]
|
|
|
|
results = await self._run_async_queries(queries)
|
|
|
|
if self.mode == FUSION_MODES.RECIPROCAL_RANK:
|
|
return self._reciprocal_rerank_fusion(results)[: self.similarity_top_k]
|
|
elif self.mode == FUSION_MODES.SIMPLE:
|
|
return self._simple_fusion(results)[: self.similarity_top_k]
|
|
else:
|
|
raise ValueError(f"Invalid fusion mode: {self.mode}")
|