faiss_rag_enterprise/llama_index/retrievers/fusion_retriever.py

210 lines
7.6 KiB
Python

import asyncio
from enum import Enum
from typing import Dict, List, Optional, Tuple, cast
from llama_index.async_utils import run_async_tasks
from llama_index.callbacks.base import CallbackManager
from llama_index.constants import DEFAULT_SIMILARITY_TOP_K
from llama_index.llms.utils import LLMType, resolve_llm
from llama_index.prompts import PromptTemplate
from llama_index.prompts.mixin import PromptDictType
from llama_index.retrievers import BaseRetriever
from llama_index.schema import IndexNode, NodeWithScore, QueryBundle
QUERY_GEN_PROMPT = (
"You are a helpful assistant that generates multiple search queries based on a "
"single input query. Generate {num_queries} search queries, one on each line, "
"related to the following input query:\n"
"Query: {query}\n"
"Queries:\n"
)
class FUSION_MODES(str, Enum):
"""Enum for different fusion modes."""
RECIPROCAL_RANK = "reciprocal_rerank" # apply reciprocal rank fusion
SIMPLE = "simple" # simple re-ordering of results based on original scores
class QueryFusionRetriever(BaseRetriever):
def __init__(
self,
retrievers: List[BaseRetriever],
llm: Optional[LLMType] = "default",
query_gen_prompt: Optional[str] = None,
mode: FUSION_MODES = FUSION_MODES.SIMPLE,
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
num_queries: int = 4,
use_async: bool = True,
verbose: bool = False,
callback_manager: Optional[CallbackManager] = None,
objects: Optional[List[IndexNode]] = None,
object_map: Optional[dict] = None,
) -> None:
self.num_queries = num_queries
self.query_gen_prompt = query_gen_prompt or QUERY_GEN_PROMPT
self.similarity_top_k = similarity_top_k
self.mode = mode
self.use_async = use_async
self._retrievers = retrievers
self._llm = resolve_llm(llm)
super().__init__(
callback_manager=callback_manager,
object_map=object_map,
objects=objects,
verbose=verbose,
)
def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
return {"query_gen_prompt": PromptTemplate(self.query_gen_prompt)}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
if "query_gen_prompt" in prompts:
self.query_gen_prompt = cast(
PromptTemplate, prompts["query_gen_prompt"]
).template
def _get_queries(self, original_query: str) -> List[str]:
prompt_str = self.query_gen_prompt.format(
num_queries=self.num_queries - 1,
query=original_query,
)
response = self._llm.complete(prompt_str)
# assume LLM proper put each query on a newline
queries = response.text.split("\n")
if self._verbose:
queries_str = "\n".join(queries)
print(f"Generated queries:\n{queries_str}")
return response.text.split("\n")
def _reciprocal_rerank_fusion(
self, results: Dict[Tuple[str, int], List[NodeWithScore]]
) -> List[NodeWithScore]:
"""Apply reciprocal rank fusion.
The original paper uses k=60 for best results:
https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
"""
k = 60.0 # `k` is a parameter used to control the impact of outlier rankings.
fused_scores = {}
text_to_node = {}
# compute reciprocal rank scores
for nodes_with_scores in results.values():
for rank, node_with_score in enumerate(
sorted(nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True)
):
text = node_with_score.node.get_content()
text_to_node[text] = node_with_score
if text not in fused_scores:
fused_scores[text] = 0.0
fused_scores[text] += 1.0 / (rank + k)
# sort results
reranked_results = dict(
sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
)
# adjust node scores
reranked_nodes: List[NodeWithScore] = []
for text, score in reranked_results.items():
reranked_nodes.append(text_to_node[text])
reranked_nodes[-1].score = score
return reranked_nodes
def _simple_fusion(
self, results: Dict[Tuple[str, int], List[NodeWithScore]]
) -> List[NodeWithScore]:
"""Apply simple fusion."""
# Use a dict to de-duplicate nodes
all_nodes: Dict[str, NodeWithScore] = {}
for nodes_with_scores in results.values():
for node_with_score in nodes_with_scores:
text = node_with_score.node.get_content()
all_nodes[text] = node_with_score
return sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True)
def _run_nested_async_queries(
self, queries: List[str]
) -> Dict[Tuple[str, int], List[NodeWithScore]]:
tasks, task_queries = [], []
for query in queries:
for i, retriever in enumerate(self._retrievers):
tasks.append(retriever.aretrieve(query))
task_queries.append(query)
task_results = run_async_tasks(tasks)
results = {}
for i, (query, query_result) in enumerate(zip(task_queries, task_results)):
results[(query, i)] = query_result
return results
async def _run_async_queries(
self, queries: List[str]
) -> Dict[Tuple[str, int], List[NodeWithScore]]:
tasks, task_queries = [], []
for query in queries:
for i, retriever in enumerate(self._retrievers):
tasks.append(retriever.aretrieve(query))
task_queries.append(query)
task_results = await asyncio.gather(*tasks)
results = {}
for i, (query, query_result) in enumerate(zip(task_queries, task_results)):
results[(query, i)] = query_result
return results
def _run_sync_queries(
self, queries: List[str]
) -> Dict[Tuple[str, int], List[NodeWithScore]]:
results = {}
for query in queries:
for i, retriever in enumerate(self._retrievers):
results[(query, i)] = retriever.retrieve(query)
return results
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
if self.num_queries > 1:
queries = self._get_queries(query_bundle.query_str)
else:
queries = [query_bundle.query_str]
if self.use_async:
results = self._run_nested_async_queries(queries)
else:
results = self._run_sync_queries(queries)
if self.mode == FUSION_MODES.RECIPROCAL_RANK:
return self._reciprocal_rerank_fusion(results)[: self.similarity_top_k]
elif self.mode == FUSION_MODES.SIMPLE:
return self._simple_fusion(results)[: self.similarity_top_k]
else:
raise ValueError(f"Invalid fusion mode: {self.mode}")
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
if self.num_queries > 1:
queries = self._get_queries(query_bundle.query_str)
else:
queries = [query_bundle.query_str]
results = await self._run_async_queries(queries)
if self.mode == FUSION_MODES.RECIPROCAL_RANK:
return self._reciprocal_rerank_fusion(results)[: self.similarity_top_k]
elif self.mode == FUSION_MODES.SIMPLE:
return self._simple_fusion(results)[: self.similarity_top_k]
else:
raise ValueError(f"Invalid fusion mode: {self.mode}")