faiss_rag_enterprise/llama_index/postprocessor/rankGPT_rerank.py

159 lines
5.9 KiB
Python

import logging
from typing import Any, Dict, List, Optional, Sequence
from llama_index.bridge.pydantic import Field
from llama_index.llms import LLM, ChatMessage, ChatResponse, OpenAI
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.prompts import BasePromptTemplate
from llama_index.prompts.default_prompts import RANKGPT_RERANK_PROMPT
from llama_index.prompts.mixin import PromptDictType
from llama_index.schema import NodeWithScore, QueryBundle
from llama_index.utils import print_text
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
class RankGPTRerank(BaseNodePostprocessor):
"""RankGPT-based reranker."""
top_n: int = Field(default=5, description="Top N nodes to return from reranking.")
llm: LLM = Field(
default_factory=lambda: OpenAI(model="gpt-3.5-turbo-16k"),
description="LLM to use for rankGPT",
)
verbose: bool = Field(
default=False, description="Whether to print intermediate steps."
)
rankgpt_rerank_prompt: BasePromptTemplate = Field(
description="rankGPT rerank prompt."
)
def __init__(
self,
top_n: int = 5,
llm: Optional[LLM] = None,
verbose: bool = False,
rankgpt_rerank_prompt: Optional[BasePromptTemplate] = None,
):
rankgpt_rerank_prompt = rankgpt_rerank_prompt or RANKGPT_RERANK_PROMPT
super().__init__(
verbose=verbose,
llm=llm,
top_n=top_n,
rankgpt_rerank_prompt=rankgpt_rerank_prompt,
)
@classmethod
def class_name(cls) -> str:
return "RankGPTRerank"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
if query_bundle is None:
raise ValueError("Query bundle must be provided.")
items = {
"query": query_bundle.query_str,
"hits": [{"content": node.get_content()} for node in nodes],
}
messages = self.create_permutation_instruction(item=items)
permutation = self.run_llm(messages=messages)
if permutation.message is not None and permutation.message.content is not None:
rerank_ranks = self._receive_permutation(
items, str(permutation.message.content)
)
if self.verbose:
print_text(f"After Reranking, new rank list for nodes: {rerank_ranks}")
initial_results: List[NodeWithScore] = []
for idx in rerank_ranks:
initial_results.append(
NodeWithScore(node=nodes[idx].node, score=nodes[idx].score)
)
return initial_results[: self.top_n]
else:
return nodes[: self.top_n]
def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
return {"rankgpt_rerank_prompt": self.rankgpt_rerank_prompt}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
if "rankgpt_rerank_prompt" in prompts:
self.rankgpt_rerank_prompt = prompts["rankgpt_rerank_prompt"]
def _get_prefix_prompt(self, query: str, num: int) -> List[ChatMessage]:
return [
ChatMessage(
role="system",
content="You are RankGPT, an intelligent assistant that can rank passages based on their relevancy to the query.",
),
ChatMessage(
role="user",
content=f"I will provide you with {num} passages, each indicated by number identifier []. \nRank the passages based on their relevance to query: {query}.",
),
ChatMessage(role="assistant", content="Okay, please provide the passages."),
]
def _get_post_prompt(self, query: str, num: int) -> str:
return self.rankgpt_rerank_prompt.format(query=query, num=num)
def create_permutation_instruction(self, item: Dict[str, Any]) -> List[ChatMessage]:
query = item["query"]
num = len(item["hits"])
messages = self._get_prefix_prompt(query, num)
rank = 0
for hit in item["hits"]:
rank += 1
content = hit["content"]
content = content.replace("Title: Content: ", "")
content = content.strip()
# For Japanese should cut by character: content = content[:int(max_length)]
content = " ".join(content.split()[:300])
messages.append(ChatMessage(role="user", content=f"[{rank}] {content}"))
messages.append(
ChatMessage(role="assistant", content=f"Received passage [{rank}].")
)
messages.append(
ChatMessage(role="user", content=self._get_post_prompt(query, num))
)
return messages
def run_llm(self, messages: Sequence[ChatMessage]) -> ChatResponse:
return self.llm.chat(messages)
def _clean_response(self, response: str) -> str:
new_response = ""
for c in response:
if not c.isdigit():
new_response += " "
else:
new_response += c
return new_response.strip()
def _remove_duplicate(self, response: List[int]) -> List[int]:
new_response = []
for c in response:
if c not in new_response:
new_response.append(c)
return new_response
def _receive_permutation(self, item: Dict[str, Any], permutation: str) -> List[int]:
rank_end = len(item["hits"])
response = self._clean_response(permutation)
response_list = [int(x) - 1 for x in response.split()]
response_list = self._remove_duplicate(response_list)
response_list = [ss for ss in response_list if ss in range(rank_end)]
return response_list + [
tt for tt in range(rank_end) if tt not in response_list
] # add the rest of the rank