faiss_rag_enterprise/llama_index/postprocessor/llm_rerank.py

112 lines
4.0 KiB
Python

"""LLM reranker."""
from typing import Callable, List, Optional
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.indices.utils import (
default_format_node_batch_fn,
default_parse_choice_select_answer_fn,
)
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.prompts import BasePromptTemplate
from llama_index.prompts.default_prompts import DEFAULT_CHOICE_SELECT_PROMPT
from llama_index.prompts.mixin import PromptDictType
from llama_index.schema import NodeWithScore, QueryBundle
from llama_index.service_context import ServiceContext
class LLMRerank(BaseNodePostprocessor):
"""LLM-based reranker."""
top_n: int = Field(description="Top N nodes to return.")
choice_select_prompt: BasePromptTemplate = Field(
description="Choice select prompt."
)
choice_batch_size: int = Field(description="Batch size for choice select.")
service_context: ServiceContext = Field(
description="Service context.", exclude=True
)
_format_node_batch_fn: Callable = PrivateAttr()
_parse_choice_select_answer_fn: Callable = PrivateAttr()
def __init__(
self,
choice_select_prompt: Optional[BasePromptTemplate] = None,
choice_batch_size: int = 10,
format_node_batch_fn: Optional[Callable] = None,
parse_choice_select_answer_fn: Optional[Callable] = None,
service_context: Optional[ServiceContext] = None,
top_n: int = 10,
) -> None:
choice_select_prompt = choice_select_prompt or DEFAULT_CHOICE_SELECT_PROMPT
service_context = service_context or ServiceContext.from_defaults()
self._format_node_batch_fn = (
format_node_batch_fn or default_format_node_batch_fn
)
self._parse_choice_select_answer_fn = (
parse_choice_select_answer_fn or default_parse_choice_select_answer_fn
)
super().__init__(
choice_select_prompt=choice_select_prompt,
choice_batch_size=choice_batch_size,
service_context=service_context,
top_n=top_n,
)
def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
return {"choice_select_prompt": self.choice_select_prompt}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
if "choice_select_prompt" in prompts:
self.choice_select_prompt = prompts["choice_select_prompt"]
@classmethod
def class_name(cls) -> str:
return "LLMRerank"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
if query_bundle is None:
raise ValueError("Query bundle must be provided.")
if len(nodes) == 0:
return []
initial_results: List[NodeWithScore] = []
for idx in range(0, len(nodes), self.choice_batch_size):
nodes_batch = [
node.node for node in nodes[idx : idx + self.choice_batch_size]
]
query_str = query_bundle.query_str
fmt_batch_str = self._format_node_batch_fn(nodes_batch)
# call each batch independently
raw_response = self.service_context.llm.predict(
self.choice_select_prompt,
context_str=fmt_batch_str,
query_str=query_str,
)
raw_choices, relevances = self._parse_choice_select_answer_fn(
raw_response, len(nodes_batch)
)
choice_idxs = [int(choice) - 1 for choice in raw_choices]
choice_nodes = [nodes_batch[idx] for idx in choice_idxs]
relevances = relevances or [1.0 for _ in choice_nodes]
initial_results.extend(
[
NodeWithScore(node=node, score=relevance)
for node, relevance in zip(choice_nodes, relevances)
]
)
return sorted(initial_results, key=lambda x: x.score or 0.0, reverse=True)[
: self.top_n
]