from typing import Any, List, Optional from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CBEventType, EventPayload from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle from llama_index.utils import infer_torch_device DEFAULT_SENTENCE_TRANSFORMER_MAX_LENGTH = 512 class SentenceTransformerRerank(BaseNodePostprocessor): model: str = Field(description="Sentence transformer model name.") top_n: int = Field(description="Number of nodes to return sorted by score.") device: str = Field( default="cpu", description="Device to use for sentence transformer.", ) keep_retrieval_score: bool = Field( default=False, description="Whether to keep the retrieval score in metadata.", ) _model: Any = PrivateAttr() def __init__( self, top_n: int = 2, model: str = "cross-encoder/stsb-distilroberta-base", device: Optional[str] = None, keep_retrieval_score: Optional[bool] = False, ): try: from sentence_transformers import CrossEncoder except ImportError: raise ImportError( "Cannot import sentence-transformers or torch package,", "please `pip install torch sentence-transformers`", ) device = infer_torch_device() if device is None else device self._model = CrossEncoder( model, max_length=DEFAULT_SENTENCE_TRANSFORMER_MAX_LENGTH, device=device ) super().__init__( top_n=top_n, model=model, device=device, keep_retrieval_score=keep_retrieval_score, ) @classmethod def class_name(cls) -> str: return "SentenceTransformerRerank" def _postprocess_nodes( self, nodes: List[NodeWithScore], query_bundle: Optional[QueryBundle] = None, ) -> List[NodeWithScore]: if query_bundle is None: raise ValueError("Missing query bundle in extra info.") if len(nodes) == 0: return [] query_and_nodes = [ ( query_bundle.query_str, node.node.get_content(metadata_mode=MetadataMode.EMBED), ) for node in nodes ] with self.callback_manager.event( CBEventType.RERANKING, payload={ EventPayload.NODES: nodes, EventPayload.MODEL_NAME: self.model, EventPayload.QUERY_STR: query_bundle.query_str, EventPayload.TOP_K: self.top_n, }, ) as event: scores = self._model.predict(query_and_nodes) assert len(scores) == len(nodes) for node, score in zip(nodes, scores): if self.keep_retrieval_score: # keep the retrieval score in metadata node.node.metadata["retrieval_score"] = node.score node.score = score new_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[ : self.top_n ] event.on_end(payload={EventPayload.NODES: new_nodes}) return new_nodes