159 lines
5.9 KiB
Python
159 lines
5.9 KiB
Python
import logging
|
|
from typing import Any, Dict, List, Optional, Sequence
|
|
|
|
from llama_index.bridge.pydantic import Field
|
|
from llama_index.llms import LLM, ChatMessage, ChatResponse, OpenAI
|
|
from llama_index.postprocessor.types import BaseNodePostprocessor
|
|
from llama_index.prompts import BasePromptTemplate
|
|
from llama_index.prompts.default_prompts import RANKGPT_RERANK_PROMPT
|
|
from llama_index.prompts.mixin import PromptDictType
|
|
from llama_index.schema import NodeWithScore, QueryBundle
|
|
from llama_index.utils import print_text
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.WARNING)
|
|
|
|
|
|
class RankGPTRerank(BaseNodePostprocessor):
|
|
"""RankGPT-based reranker."""
|
|
|
|
top_n: int = Field(default=5, description="Top N nodes to return from reranking.")
|
|
llm: LLM = Field(
|
|
default_factory=lambda: OpenAI(model="gpt-3.5-turbo-16k"),
|
|
description="LLM to use for rankGPT",
|
|
)
|
|
verbose: bool = Field(
|
|
default=False, description="Whether to print intermediate steps."
|
|
)
|
|
rankgpt_rerank_prompt: BasePromptTemplate = Field(
|
|
description="rankGPT rerank prompt."
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
top_n: int = 5,
|
|
llm: Optional[LLM] = None,
|
|
verbose: bool = False,
|
|
rankgpt_rerank_prompt: Optional[BasePromptTemplate] = None,
|
|
):
|
|
rankgpt_rerank_prompt = rankgpt_rerank_prompt or RANKGPT_RERANK_PROMPT
|
|
super().__init__(
|
|
verbose=verbose,
|
|
llm=llm,
|
|
top_n=top_n,
|
|
rankgpt_rerank_prompt=rankgpt_rerank_prompt,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "RankGPTRerank"
|
|
|
|
def _postprocess_nodes(
|
|
self,
|
|
nodes: List[NodeWithScore],
|
|
query_bundle: Optional[QueryBundle] = None,
|
|
) -> List[NodeWithScore]:
|
|
if query_bundle is None:
|
|
raise ValueError("Query bundle must be provided.")
|
|
|
|
items = {
|
|
"query": query_bundle.query_str,
|
|
"hits": [{"content": node.get_content()} for node in nodes],
|
|
}
|
|
|
|
messages = self.create_permutation_instruction(item=items)
|
|
permutation = self.run_llm(messages=messages)
|
|
if permutation.message is not None and permutation.message.content is not None:
|
|
rerank_ranks = self._receive_permutation(
|
|
items, str(permutation.message.content)
|
|
)
|
|
if self.verbose:
|
|
print_text(f"After Reranking, new rank list for nodes: {rerank_ranks}")
|
|
|
|
initial_results: List[NodeWithScore] = []
|
|
|
|
for idx in rerank_ranks:
|
|
initial_results.append(
|
|
NodeWithScore(node=nodes[idx].node, score=nodes[idx].score)
|
|
)
|
|
return initial_results[: self.top_n]
|
|
else:
|
|
return nodes[: self.top_n]
|
|
|
|
def _get_prompts(self) -> PromptDictType:
|
|
"""Get prompts."""
|
|
return {"rankgpt_rerank_prompt": self.rankgpt_rerank_prompt}
|
|
|
|
def _update_prompts(self, prompts: PromptDictType) -> None:
|
|
"""Update prompts."""
|
|
if "rankgpt_rerank_prompt" in prompts:
|
|
self.rankgpt_rerank_prompt = prompts["rankgpt_rerank_prompt"]
|
|
|
|
def _get_prefix_prompt(self, query: str, num: int) -> List[ChatMessage]:
|
|
return [
|
|
ChatMessage(
|
|
role="system",
|
|
content="You are RankGPT, an intelligent assistant that can rank passages based on their relevancy to the query.",
|
|
),
|
|
ChatMessage(
|
|
role="user",
|
|
content=f"I will provide you with {num} passages, each indicated by number identifier []. \nRank the passages based on their relevance to query: {query}.",
|
|
),
|
|
ChatMessage(role="assistant", content="Okay, please provide the passages."),
|
|
]
|
|
|
|
def _get_post_prompt(self, query: str, num: int) -> str:
|
|
return self.rankgpt_rerank_prompt.format(query=query, num=num)
|
|
|
|
def create_permutation_instruction(self, item: Dict[str, Any]) -> List[ChatMessage]:
|
|
query = item["query"]
|
|
num = len(item["hits"])
|
|
|
|
messages = self._get_prefix_prompt(query, num)
|
|
rank = 0
|
|
for hit in item["hits"]:
|
|
rank += 1
|
|
content = hit["content"]
|
|
content = content.replace("Title: Content: ", "")
|
|
content = content.strip()
|
|
# For Japanese should cut by character: content = content[:int(max_length)]
|
|
content = " ".join(content.split()[:300])
|
|
messages.append(ChatMessage(role="user", content=f"[{rank}] {content}"))
|
|
messages.append(
|
|
ChatMessage(role="assistant", content=f"Received passage [{rank}].")
|
|
)
|
|
messages.append(
|
|
ChatMessage(role="user", content=self._get_post_prompt(query, num))
|
|
)
|
|
return messages
|
|
|
|
def run_llm(self, messages: Sequence[ChatMessage]) -> ChatResponse:
|
|
return self.llm.chat(messages)
|
|
|
|
def _clean_response(self, response: str) -> str:
|
|
new_response = ""
|
|
for c in response:
|
|
if not c.isdigit():
|
|
new_response += " "
|
|
else:
|
|
new_response += c
|
|
return new_response.strip()
|
|
|
|
def _remove_duplicate(self, response: List[int]) -> List[int]:
|
|
new_response = []
|
|
for c in response:
|
|
if c not in new_response:
|
|
new_response.append(c)
|
|
return new_response
|
|
|
|
def _receive_permutation(self, item: Dict[str, Any], permutation: str) -> List[int]:
|
|
rank_end = len(item["hits"])
|
|
|
|
response = self._clean_response(permutation)
|
|
response_list = [int(x) - 1 for x in response.split()]
|
|
response_list = self._remove_duplicate(response_list)
|
|
response_list = [ss for ss in response_list if ss in range(rank_end)]
|
|
return response_list + [
|
|
tt for tt in range(rank_end) if tt not in response_list
|
|
] # add the rest of the rank
|