198 lines
6.2 KiB
Python
198 lines
6.2 KiB
Python
"""Base retrieval abstractions."""
|
|
|
|
import asyncio
|
|
from abc import abstractmethod
|
|
from enum import Enum
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from llama_index.bridge.pydantic import BaseModel, Field
|
|
from llama_index.evaluation.retrieval.metrics import resolve_metrics
|
|
from llama_index.evaluation.retrieval.metrics_base import (
|
|
BaseRetrievalMetric,
|
|
RetrievalMetricResult,
|
|
)
|
|
from llama_index.finetuning.embeddings.common import EmbeddingQAFinetuneDataset
|
|
|
|
|
|
class RetrievalEvalMode(str, Enum):
|
|
"""Evaluation of retrieval modality."""
|
|
|
|
TEXT = "text"
|
|
IMAGE = "image"
|
|
|
|
@classmethod
|
|
def from_str(cls, label: str) -> "RetrievalEvalMode":
|
|
if label == "text":
|
|
return RetrievalEvalMode.TEXT
|
|
elif label == "image":
|
|
return RetrievalEvalMode.IMAGE
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
class RetrievalEvalResult(BaseModel):
|
|
"""Retrieval eval result.
|
|
|
|
NOTE: this abstraction might change in the future.
|
|
|
|
Attributes:
|
|
query (str): Query string
|
|
expected_ids (List[str]): Expected ids
|
|
retrieved_ids (List[str]): Retrieved ids
|
|
metric_dict (Dict[str, BaseRetrievalMetric]): \
|
|
Metric dictionary for the evaluation
|
|
|
|
"""
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
query: str = Field(..., description="Query string")
|
|
expected_ids: List[str] = Field(..., description="Expected ids")
|
|
expected_texts: Optional[List[str]] = Field(
|
|
default=None,
|
|
description="Expected texts associated with nodes provided in `expected_ids`",
|
|
)
|
|
retrieved_ids: List[str] = Field(..., description="Retrieved ids")
|
|
retrieved_texts: List[str] = Field(..., description="Retrieved texts")
|
|
mode: "RetrievalEvalMode" = Field(
|
|
default=RetrievalEvalMode.TEXT, description="text or image"
|
|
)
|
|
metric_dict: Dict[str, RetrievalMetricResult] = Field(
|
|
..., description="Metric dictionary for the evaluation"
|
|
)
|
|
|
|
@property
|
|
def metric_vals_dict(self) -> Dict[str, float]:
|
|
"""Dictionary of metric values."""
|
|
return {k: v.score for k, v in self.metric_dict.items()}
|
|
|
|
def __str__(self) -> str:
|
|
"""String representation."""
|
|
return f"Query: {self.query}\n" f"Metrics: {self.metric_vals_dict!s}\n"
|
|
|
|
|
|
class BaseRetrievalEvaluator(BaseModel):
|
|
"""Base Retrieval Evaluator class."""
|
|
|
|
metrics: List[BaseRetrievalMetric] = Field(
|
|
..., description="List of metrics to evaluate"
|
|
)
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
@classmethod
|
|
def from_metric_names(
|
|
cls, metric_names: List[str], **kwargs: Any
|
|
) -> "BaseRetrievalEvaluator":
|
|
"""Create evaluator from metric names.
|
|
|
|
Args:
|
|
metric_names (List[str]): List of metric names
|
|
**kwargs: Additional arguments for the evaluator
|
|
|
|
"""
|
|
metric_types = resolve_metrics(metric_names)
|
|
return cls(metrics=[metric() for metric in metric_types], **kwargs)
|
|
|
|
@abstractmethod
|
|
async def _aget_retrieved_ids_and_texts(
|
|
self, query: str, mode: RetrievalEvalMode = RetrievalEvalMode.TEXT
|
|
) -> Tuple[List[str], List[str]]:
|
|
"""Get retrieved ids and texts."""
|
|
raise NotImplementedError
|
|
|
|
def evaluate(
|
|
self,
|
|
query: str,
|
|
expected_ids: List[str],
|
|
expected_texts: Optional[List[str]] = None,
|
|
mode: RetrievalEvalMode = RetrievalEvalMode.TEXT,
|
|
**kwargs: Any,
|
|
) -> RetrievalEvalResult:
|
|
"""Run evaluation results with query string and expected ids.
|
|
|
|
Args:
|
|
query (str): Query string
|
|
expected_ids (List[str]): Expected ids
|
|
|
|
Returns:
|
|
RetrievalEvalResult: Evaluation result
|
|
|
|
"""
|
|
return asyncio.run(
|
|
self.aevaluate(
|
|
query=query,
|
|
expected_ids=expected_ids,
|
|
expected_texts=expected_texts,
|
|
mode=mode,
|
|
**kwargs,
|
|
)
|
|
)
|
|
|
|
# @abstractmethod
|
|
async def aevaluate(
|
|
self,
|
|
query: str,
|
|
expected_ids: List[str],
|
|
expected_texts: Optional[List[str]] = None,
|
|
mode: RetrievalEvalMode = RetrievalEvalMode.TEXT,
|
|
**kwargs: Any,
|
|
) -> RetrievalEvalResult:
|
|
"""Run evaluation with query string, retrieved contexts,
|
|
and generated response string.
|
|
|
|
Subclasses can override this method to provide custom evaluation logic and
|
|
take in additional arguments.
|
|
"""
|
|
retrieved_ids, retrieved_texts = await self._aget_retrieved_ids_and_texts(
|
|
query, mode
|
|
)
|
|
metric_dict = {}
|
|
for metric in self.metrics:
|
|
eval_result = metric.compute(
|
|
query, expected_ids, retrieved_ids, expected_texts, retrieved_texts
|
|
)
|
|
metric_dict[metric.metric_name] = eval_result
|
|
|
|
return RetrievalEvalResult(
|
|
query=query,
|
|
expected_ids=expected_ids,
|
|
expected_texts=expected_texts,
|
|
retrieved_ids=retrieved_ids,
|
|
retrieved_texts=retrieved_texts,
|
|
mode=mode,
|
|
metric_dict=metric_dict,
|
|
)
|
|
|
|
async def aevaluate_dataset(
|
|
self,
|
|
dataset: EmbeddingQAFinetuneDataset,
|
|
workers: int = 2,
|
|
show_progress: bool = False,
|
|
**kwargs: Any,
|
|
) -> List[RetrievalEvalResult]:
|
|
"""Run evaluation with dataset."""
|
|
semaphore = asyncio.Semaphore(workers)
|
|
|
|
async def eval_worker(
|
|
query: str, expected_ids: List[str], mode: RetrievalEvalMode
|
|
) -> RetrievalEvalResult:
|
|
async with semaphore:
|
|
return await self.aevaluate(query, expected_ids=expected_ids, mode=mode)
|
|
|
|
response_jobs = []
|
|
mode = RetrievalEvalMode.from_str(dataset.mode)
|
|
for query_id, query in dataset.queries.items():
|
|
expected_ids = dataset.relevant_docs[query_id]
|
|
response_jobs.append(eval_worker(query, expected_ids, mode))
|
|
if show_progress:
|
|
from tqdm.asyncio import tqdm_asyncio
|
|
|
|
eval_results = await tqdm_asyncio.gather(*response_jobs)
|
|
else:
|
|
eval_results = await asyncio.gather(*response_jobs)
|
|
|
|
return eval_results
|