97 lines
3.2 KiB
Python
97 lines
3.2 KiB
Python
from typing import Any, List, Optional
|
|
|
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
|
from llama_index.callbacks import CBEventType, EventPayload
|
|
from llama_index.postprocessor.types import BaseNodePostprocessor
|
|
from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle
|
|
from llama_index.utils import infer_torch_device
|
|
|
|
DEFAULT_SENTENCE_TRANSFORMER_MAX_LENGTH = 512
|
|
|
|
|
|
class SentenceTransformerRerank(BaseNodePostprocessor):
|
|
model: str = Field(description="Sentence transformer model name.")
|
|
top_n: int = Field(description="Number of nodes to return sorted by score.")
|
|
device: str = Field(
|
|
default="cpu",
|
|
description="Device to use for sentence transformer.",
|
|
)
|
|
keep_retrieval_score: bool = Field(
|
|
default=False,
|
|
description="Whether to keep the retrieval score in metadata.",
|
|
)
|
|
_model: Any = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
top_n: int = 2,
|
|
model: str = "cross-encoder/stsb-distilroberta-base",
|
|
device: Optional[str] = None,
|
|
keep_retrieval_score: Optional[bool] = False,
|
|
):
|
|
try:
|
|
from sentence_transformers import CrossEncoder
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Cannot import sentence-transformers or torch package,",
|
|
"please `pip install torch sentence-transformers`",
|
|
)
|
|
device = infer_torch_device() if device is None else device
|
|
self._model = CrossEncoder(
|
|
model, max_length=DEFAULT_SENTENCE_TRANSFORMER_MAX_LENGTH, device=device
|
|
)
|
|
super().__init__(
|
|
top_n=top_n,
|
|
model=model,
|
|
device=device,
|
|
keep_retrieval_score=keep_retrieval_score,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "SentenceTransformerRerank"
|
|
|
|
def _postprocess_nodes(
|
|
self,
|
|
nodes: List[NodeWithScore],
|
|
query_bundle: Optional[QueryBundle] = None,
|
|
) -> List[NodeWithScore]:
|
|
if query_bundle is None:
|
|
raise ValueError("Missing query bundle in extra info.")
|
|
if len(nodes) == 0:
|
|
return []
|
|
|
|
query_and_nodes = [
|
|
(
|
|
query_bundle.query_str,
|
|
node.node.get_content(metadata_mode=MetadataMode.EMBED),
|
|
)
|
|
for node in nodes
|
|
]
|
|
|
|
with self.callback_manager.event(
|
|
CBEventType.RERANKING,
|
|
payload={
|
|
EventPayload.NODES: nodes,
|
|
EventPayload.MODEL_NAME: self.model,
|
|
EventPayload.QUERY_STR: query_bundle.query_str,
|
|
EventPayload.TOP_K: self.top_n,
|
|
},
|
|
) as event:
|
|
scores = self._model.predict(query_and_nodes)
|
|
|
|
assert len(scores) == len(nodes)
|
|
|
|
for node, score in zip(nodes, scores):
|
|
if self.keep_retrieval_score:
|
|
# keep the retrieval score in metadata
|
|
node.node.metadata["retrieval_score"] = node.score
|
|
node.score = score
|
|
|
|
new_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[
|
|
: self.top_n
|
|
]
|
|
event.on_end(payload={EventPayload.NODES: new_nodes})
|
|
|
|
return new_nodes
|