112 lines
4.0 KiB
Python
112 lines
4.0 KiB
Python
"""LLM reranker."""
|
|
from typing import Callable, List, Optional
|
|
|
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
|
from llama_index.indices.utils import (
|
|
default_format_node_batch_fn,
|
|
default_parse_choice_select_answer_fn,
|
|
)
|
|
from llama_index.postprocessor.types import BaseNodePostprocessor
|
|
from llama_index.prompts import BasePromptTemplate
|
|
from llama_index.prompts.default_prompts import DEFAULT_CHOICE_SELECT_PROMPT
|
|
from llama_index.prompts.mixin import PromptDictType
|
|
from llama_index.schema import NodeWithScore, QueryBundle
|
|
from llama_index.service_context import ServiceContext
|
|
|
|
|
|
class LLMRerank(BaseNodePostprocessor):
|
|
"""LLM-based reranker."""
|
|
|
|
top_n: int = Field(description="Top N nodes to return.")
|
|
choice_select_prompt: BasePromptTemplate = Field(
|
|
description="Choice select prompt."
|
|
)
|
|
choice_batch_size: int = Field(description="Batch size for choice select.")
|
|
service_context: ServiceContext = Field(
|
|
description="Service context.", exclude=True
|
|
)
|
|
|
|
_format_node_batch_fn: Callable = PrivateAttr()
|
|
_parse_choice_select_answer_fn: Callable = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
choice_select_prompt: Optional[BasePromptTemplate] = None,
|
|
choice_batch_size: int = 10,
|
|
format_node_batch_fn: Optional[Callable] = None,
|
|
parse_choice_select_answer_fn: Optional[Callable] = None,
|
|
service_context: Optional[ServiceContext] = None,
|
|
top_n: int = 10,
|
|
) -> None:
|
|
choice_select_prompt = choice_select_prompt or DEFAULT_CHOICE_SELECT_PROMPT
|
|
service_context = service_context or ServiceContext.from_defaults()
|
|
|
|
self._format_node_batch_fn = (
|
|
format_node_batch_fn or default_format_node_batch_fn
|
|
)
|
|
self._parse_choice_select_answer_fn = (
|
|
parse_choice_select_answer_fn or default_parse_choice_select_answer_fn
|
|
)
|
|
|
|
super().__init__(
|
|
choice_select_prompt=choice_select_prompt,
|
|
choice_batch_size=choice_batch_size,
|
|
service_context=service_context,
|
|
top_n=top_n,
|
|
)
|
|
|
|
def _get_prompts(self) -> PromptDictType:
|
|
"""Get prompts."""
|
|
return {"choice_select_prompt": self.choice_select_prompt}
|
|
|
|
def _update_prompts(self, prompts: PromptDictType) -> None:
|
|
"""Update prompts."""
|
|
if "choice_select_prompt" in prompts:
|
|
self.choice_select_prompt = prompts["choice_select_prompt"]
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "LLMRerank"
|
|
|
|
def _postprocess_nodes(
|
|
self,
|
|
nodes: List[NodeWithScore],
|
|
query_bundle: Optional[QueryBundle] = None,
|
|
) -> List[NodeWithScore]:
|
|
if query_bundle is None:
|
|
raise ValueError("Query bundle must be provided.")
|
|
if len(nodes) == 0:
|
|
return []
|
|
|
|
initial_results: List[NodeWithScore] = []
|
|
for idx in range(0, len(nodes), self.choice_batch_size):
|
|
nodes_batch = [
|
|
node.node for node in nodes[idx : idx + self.choice_batch_size]
|
|
]
|
|
|
|
query_str = query_bundle.query_str
|
|
fmt_batch_str = self._format_node_batch_fn(nodes_batch)
|
|
# call each batch independently
|
|
raw_response = self.service_context.llm.predict(
|
|
self.choice_select_prompt,
|
|
context_str=fmt_batch_str,
|
|
query_str=query_str,
|
|
)
|
|
|
|
raw_choices, relevances = self._parse_choice_select_answer_fn(
|
|
raw_response, len(nodes_batch)
|
|
)
|
|
choice_idxs = [int(choice) - 1 for choice in raw_choices]
|
|
choice_nodes = [nodes_batch[idx] for idx in choice_idxs]
|
|
relevances = relevances or [1.0 for _ in choice_nodes]
|
|
initial_results.extend(
|
|
[
|
|
NodeWithScore(node=node, score=relevance)
|
|
for node, relevance in zip(choice_nodes, relevances)
|
|
]
|
|
)
|
|
|
|
return sorted(initial_results, key=lambda x: x.score or 0.0, reverse=True)[
|
|
: self.top_n
|
|
]
|