import asyncio from enum import Enum from typing import Dict, List, Optional, Tuple, cast from llama_index.async_utils import run_async_tasks from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K from llama_index.llms.utils import LLMType, resolve_llm from llama_index.prompts import PromptTemplate from llama_index.prompts.mixin import PromptDictType from llama_index.retrievers import BaseRetriever from llama_index.schema import IndexNode, NodeWithScore, QueryBundle QUERY_GEN_PROMPT = ( "You are a helpful assistant that generates multiple search queries based on a " "single input query. Generate {num_queries} search queries, one on each line, " "related to the following input query:\n" "Query: {query}\n" "Queries:\n" ) class FUSION_MODES(str, Enum): """Enum for different fusion modes.""" RECIPROCAL_RANK = "reciprocal_rerank" # apply reciprocal rank fusion SIMPLE = "simple" # simple re-ordering of results based on original scores class QueryFusionRetriever(BaseRetriever): def __init__( self, retrievers: List[BaseRetriever], llm: Optional[LLMType] = "default", query_gen_prompt: Optional[str] = None, mode: FUSION_MODES = FUSION_MODES.SIMPLE, similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, num_queries: int = 4, use_async: bool = True, verbose: bool = False, callback_manager: Optional[CallbackManager] = None, objects: Optional[List[IndexNode]] = None, object_map: Optional[dict] = None, ) -> None: self.num_queries = num_queries self.query_gen_prompt = query_gen_prompt or QUERY_GEN_PROMPT self.similarity_top_k = similarity_top_k self.mode = mode self.use_async = use_async self._retrievers = retrievers self._llm = resolve_llm(llm) super().__init__( callback_manager=callback_manager, object_map=object_map, objects=objects, verbose=verbose, ) def _get_prompts(self) -> PromptDictType: """Get prompts.""" return {"query_gen_prompt": PromptTemplate(self.query_gen_prompt)} def _update_prompts(self, prompts: PromptDictType) -> None: """Update prompts.""" if "query_gen_prompt" in prompts: self.query_gen_prompt = cast( PromptTemplate, prompts["query_gen_prompt"] ).template def _get_queries(self, original_query: str) -> List[str]: prompt_str = self.query_gen_prompt.format( num_queries=self.num_queries - 1, query=original_query, ) response = self._llm.complete(prompt_str) # assume LLM proper put each query on a newline queries = response.text.split("\n") if self._verbose: queries_str = "\n".join(queries) print(f"Generated queries:\n{queries_str}") return response.text.split("\n") def _reciprocal_rerank_fusion( self, results: Dict[Tuple[str, int], List[NodeWithScore]] ) -> List[NodeWithScore]: """Apply reciprocal rank fusion. The original paper uses k=60 for best results: https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf """ k = 60.0 # `k` is a parameter used to control the impact of outlier rankings. fused_scores = {} text_to_node = {} # compute reciprocal rank scores for nodes_with_scores in results.values(): for rank, node_with_score in enumerate( sorted(nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True) ): text = node_with_score.node.get_content() text_to_node[text] = node_with_score if text not in fused_scores: fused_scores[text] = 0.0 fused_scores[text] += 1.0 / (rank + k) # sort results reranked_results = dict( sorted(fused_scores.items(), key=lambda x: x[1], reverse=True) ) # adjust node scores reranked_nodes: List[NodeWithScore] = [] for text, score in reranked_results.items(): reranked_nodes.append(text_to_node[text]) reranked_nodes[-1].score = score return reranked_nodes def _simple_fusion( self, results: Dict[Tuple[str, int], List[NodeWithScore]] ) -> List[NodeWithScore]: """Apply simple fusion.""" # Use a dict to de-duplicate nodes all_nodes: Dict[str, NodeWithScore] = {} for nodes_with_scores in results.values(): for node_with_score in nodes_with_scores: text = node_with_score.node.get_content() all_nodes[text] = node_with_score return sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True) def _run_nested_async_queries( self, queries: List[str] ) -> Dict[Tuple[str, int], List[NodeWithScore]]: tasks, task_queries = [], [] for query in queries: for i, retriever in enumerate(self._retrievers): tasks.append(retriever.aretrieve(query)) task_queries.append(query) task_results = run_async_tasks(tasks) results = {} for i, (query, query_result) in enumerate(zip(task_queries, task_results)): results[(query, i)] = query_result return results async def _run_async_queries( self, queries: List[str] ) -> Dict[Tuple[str, int], List[NodeWithScore]]: tasks, task_queries = [], [] for query in queries: for i, retriever in enumerate(self._retrievers): tasks.append(retriever.aretrieve(query)) task_queries.append(query) task_results = await asyncio.gather(*tasks) results = {} for i, (query, query_result) in enumerate(zip(task_queries, task_results)): results[(query, i)] = query_result return results def _run_sync_queries( self, queries: List[str] ) -> Dict[Tuple[str, int], List[NodeWithScore]]: results = {} for query in queries: for i, retriever in enumerate(self._retrievers): results[(query, i)] = retriever.retrieve(query) return results def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: if self.num_queries > 1: queries = self._get_queries(query_bundle.query_str) else: queries = [query_bundle.query_str] if self.use_async: results = self._run_nested_async_queries(queries) else: results = self._run_sync_queries(queries) if self.mode == FUSION_MODES.RECIPROCAL_RANK: return self._reciprocal_rerank_fusion(results)[: self.similarity_top_k] elif self.mode == FUSION_MODES.SIMPLE: return self._simple_fusion(results)[: self.similarity_top_k] else: raise ValueError(f"Invalid fusion mode: {self.mode}") async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: if self.num_queries > 1: queries = self._get_queries(query_bundle.query_str) else: queries = [query_bundle.query_str] results = await self._run_async_queries(queries) if self.mode == FUSION_MODES.RECIPROCAL_RANK: return self._reciprocal_rerank_fusion(results)[: self.similarity_top_k] elif self.mode == FUSION_MODES.SIMPLE: return self._simple_fusion(results)[: self.similarity_top_k] else: raise ValueError(f"Invalid fusion mode: {self.mode}")