135 lines
4.5 KiB
Python
135 lines
4.5 KiB
Python
"""Retrieval evaluators."""
|
|
|
|
from typing import Any, List, Optional, Sequence, Tuple
|
|
|
|
from llama_index.bridge.pydantic import Field
|
|
from llama_index.core.base_retriever import BaseRetriever
|
|
from llama_index.evaluation.retrieval.base import (
|
|
BaseRetrievalEvaluator,
|
|
RetrievalEvalMode,
|
|
)
|
|
from llama_index.evaluation.retrieval.metrics_base import (
|
|
BaseRetrievalMetric,
|
|
)
|
|
from llama_index.indices.base_retriever import BaseRetriever
|
|
from llama_index.postprocessor.types import BaseNodePostprocessor
|
|
from llama_index.schema import ImageNode, TextNode
|
|
|
|
|
|
class RetrieverEvaluator(BaseRetrievalEvaluator):
|
|
"""Retriever evaluator.
|
|
|
|
This module will evaluate a retriever using a set of metrics.
|
|
|
|
Args:
|
|
metrics (List[BaseRetrievalMetric]): Sequence of metrics to evaluate
|
|
retriever: Retriever to evaluate.
|
|
node_postprocessors (Optional[List[BaseNodePostprocessor]]): Post-processor to apply after retrieval.
|
|
|
|
|
|
"""
|
|
|
|
retriever: BaseRetriever = Field(..., description="Retriever to evaluate")
|
|
node_postprocessors: Optional[List[BaseNodePostprocessor]] = Field(
|
|
default=None, description="Optional post-processor"
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
metrics: Sequence[BaseRetrievalMetric],
|
|
retriever: BaseRetriever,
|
|
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Init params."""
|
|
super().__init__(
|
|
metrics=metrics,
|
|
retriever=retriever,
|
|
node_postprocessors=node_postprocessors,
|
|
**kwargs,
|
|
)
|
|
|
|
async def _aget_retrieved_ids_and_texts(
|
|
self, query: str, mode: RetrievalEvalMode = RetrievalEvalMode.TEXT
|
|
) -> Tuple[List[str], List[str]]:
|
|
"""Get retrieved ids and texts, potentially applying a post-processor."""
|
|
retrieved_nodes = await self.retriever.aretrieve(query)
|
|
|
|
if self.node_postprocessors:
|
|
for node_postprocessor in self.node_postprocessors:
|
|
retrieved_nodes = node_postprocessor.postprocess_nodes(
|
|
retrieved_nodes, query_str=query
|
|
)
|
|
|
|
return (
|
|
[node.node.node_id for node in retrieved_nodes],
|
|
[node.node.text for node in retrieved_nodes],
|
|
)
|
|
|
|
|
|
class MultiModalRetrieverEvaluator(BaseRetrievalEvaluator):
|
|
"""Retriever evaluator.
|
|
|
|
This module will evaluate a retriever using a set of metrics.
|
|
|
|
Args:
|
|
metrics (List[BaseRetrievalMetric]): Sequence of metrics to evaluate
|
|
retriever: Retriever to evaluate.
|
|
node_postprocessors (Optional[List[BaseNodePostprocessor]]): Post-processor to apply after retrieval.
|
|
|
|
"""
|
|
|
|
retriever: BaseRetriever = Field(..., description="Retriever to evaluate")
|
|
node_postprocessors: Optional[List[BaseNodePostprocessor]] = Field(
|
|
default=None, description="Optional post-processor"
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
metrics: Sequence[BaseRetrievalMetric],
|
|
retriever: BaseRetriever,
|
|
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Init params."""
|
|
super().__init__(
|
|
metrics=metrics,
|
|
retriever=retriever,
|
|
node_postprocessors=node_postprocessors,
|
|
**kwargs,
|
|
)
|
|
|
|
async def _aget_retrieved_ids_texts(
|
|
self, query: str, mode: RetrievalEvalMode = RetrievalEvalMode.TEXT
|
|
) -> Tuple[List[str], List[str]]:
|
|
"""Get retrieved ids."""
|
|
retrieved_nodes = await self.retriever.aretrieve(query)
|
|
image_nodes: List[ImageNode] = []
|
|
text_nodes: List[TextNode] = []
|
|
|
|
if self.node_postprocessors:
|
|
for node_postprocessor in self.node_postprocessors:
|
|
retrieved_nodes = node_postprocessor.postprocess_nodes(
|
|
retrieved_nodes, query_str=query
|
|
)
|
|
|
|
for scored_node in retrieved_nodes:
|
|
node = scored_node.node
|
|
if isinstance(node, ImageNode):
|
|
image_nodes.append(node)
|
|
if node.text:
|
|
text_nodes.append(node)
|
|
|
|
if mode == "text":
|
|
return (
|
|
[node.node_id for node in text_nodes],
|
|
[node.text for node in text_nodes],
|
|
)
|
|
elif mode == "image":
|
|
return (
|
|
[node.node_id for node in image_nodes],
|
|
[node.text for node in image_nodes],
|
|
)
|
|
else:
|
|
raise ValueError("Unsupported mode.")
|