faiss_rag_enterprise/llama_index/evaluation/retrieval/evaluator.py

135 lines
4.5 KiB
Python

"""Retrieval evaluators."""
from typing import Any, List, Optional, Sequence, Tuple
from llama_index.bridge.pydantic import Field
from llama_index.core.base_retriever import BaseRetriever
from llama_index.evaluation.retrieval.base import (
BaseRetrievalEvaluator,
RetrievalEvalMode,
)
from llama_index.evaluation.retrieval.metrics_base import (
BaseRetrievalMetric,
)
from llama_index.indices.base_retriever import BaseRetriever
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.schema import ImageNode, TextNode
class RetrieverEvaluator(BaseRetrievalEvaluator):
"""Retriever evaluator.
This module will evaluate a retriever using a set of metrics.
Args:
metrics (List[BaseRetrievalMetric]): Sequence of metrics to evaluate
retriever: Retriever to evaluate.
node_postprocessors (Optional[List[BaseNodePostprocessor]]): Post-processor to apply after retrieval.
"""
retriever: BaseRetriever = Field(..., description="Retriever to evaluate")
node_postprocessors: Optional[List[BaseNodePostprocessor]] = Field(
default=None, description="Optional post-processor"
)
def __init__(
self,
metrics: Sequence[BaseRetrievalMetric],
retriever: BaseRetriever,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
**kwargs: Any,
) -> None:
"""Init params."""
super().__init__(
metrics=metrics,
retriever=retriever,
node_postprocessors=node_postprocessors,
**kwargs,
)
async def _aget_retrieved_ids_and_texts(
self, query: str, mode: RetrievalEvalMode = RetrievalEvalMode.TEXT
) -> Tuple[List[str], List[str]]:
"""Get retrieved ids and texts, potentially applying a post-processor."""
retrieved_nodes = await self.retriever.aretrieve(query)
if self.node_postprocessors:
for node_postprocessor in self.node_postprocessors:
retrieved_nodes = node_postprocessor.postprocess_nodes(
retrieved_nodes, query_str=query
)
return (
[node.node.node_id for node in retrieved_nodes],
[node.node.text for node in retrieved_nodes],
)
class MultiModalRetrieverEvaluator(BaseRetrievalEvaluator):
"""Retriever evaluator.
This module will evaluate a retriever using a set of metrics.
Args:
metrics (List[BaseRetrievalMetric]): Sequence of metrics to evaluate
retriever: Retriever to evaluate.
node_postprocessors (Optional[List[BaseNodePostprocessor]]): Post-processor to apply after retrieval.
"""
retriever: BaseRetriever = Field(..., description="Retriever to evaluate")
node_postprocessors: Optional[List[BaseNodePostprocessor]] = Field(
default=None, description="Optional post-processor"
)
def __init__(
self,
metrics: Sequence[BaseRetrievalMetric],
retriever: BaseRetriever,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
**kwargs: Any,
) -> None:
"""Init params."""
super().__init__(
metrics=metrics,
retriever=retriever,
node_postprocessors=node_postprocessors,
**kwargs,
)
async def _aget_retrieved_ids_texts(
self, query: str, mode: RetrievalEvalMode = RetrievalEvalMode.TEXT
) -> Tuple[List[str], List[str]]:
"""Get retrieved ids."""
retrieved_nodes = await self.retriever.aretrieve(query)
image_nodes: List[ImageNode] = []
text_nodes: List[TextNode] = []
if self.node_postprocessors:
for node_postprocessor in self.node_postprocessors:
retrieved_nodes = node_postprocessor.postprocess_nodes(
retrieved_nodes, query_str=query
)
for scored_node in retrieved_nodes:
node = scored_node.node
if isinstance(node, ImageNode):
image_nodes.append(node)
if node.text:
text_nodes.append(node)
if mode == "text":
return (
[node.node_id for node in text_nodes],
[node.text for node in text_nodes],
)
elif mode == "image":
return (
[node.node_id for node in image_nodes],
[node.text for node in image_nodes],
)
else:
raise ValueError("Unsupported mode.")