faiss_rag_enterprise/llama_index/evaluation/retrieval/base.py

198 lines
6.2 KiB
Python

"""Base retrieval abstractions."""
import asyncio
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from llama_index.bridge.pydantic import BaseModel, Field
from llama_index.evaluation.retrieval.metrics import resolve_metrics
from llama_index.evaluation.retrieval.metrics_base import (
BaseRetrievalMetric,
RetrievalMetricResult,
)
from llama_index.finetuning.embeddings.common import EmbeddingQAFinetuneDataset
class RetrievalEvalMode(str, Enum):
"""Evaluation of retrieval modality."""
TEXT = "text"
IMAGE = "image"
@classmethod
def from_str(cls, label: str) -> "RetrievalEvalMode":
if label == "text":
return RetrievalEvalMode.TEXT
elif label == "image":
return RetrievalEvalMode.IMAGE
else:
raise NotImplementedError
class RetrievalEvalResult(BaseModel):
"""Retrieval eval result.
NOTE: this abstraction might change in the future.
Attributes:
query (str): Query string
expected_ids (List[str]): Expected ids
retrieved_ids (List[str]): Retrieved ids
metric_dict (Dict[str, BaseRetrievalMetric]): \
Metric dictionary for the evaluation
"""
class Config:
arbitrary_types_allowed = True
query: str = Field(..., description="Query string")
expected_ids: List[str] = Field(..., description="Expected ids")
expected_texts: Optional[List[str]] = Field(
default=None,
description="Expected texts associated with nodes provided in `expected_ids`",
)
retrieved_ids: List[str] = Field(..., description="Retrieved ids")
retrieved_texts: List[str] = Field(..., description="Retrieved texts")
mode: "RetrievalEvalMode" = Field(
default=RetrievalEvalMode.TEXT, description="text or image"
)
metric_dict: Dict[str, RetrievalMetricResult] = Field(
..., description="Metric dictionary for the evaluation"
)
@property
def metric_vals_dict(self) -> Dict[str, float]:
"""Dictionary of metric values."""
return {k: v.score for k, v in self.metric_dict.items()}
def __str__(self) -> str:
"""String representation."""
return f"Query: {self.query}\n" f"Metrics: {self.metric_vals_dict!s}\n"
class BaseRetrievalEvaluator(BaseModel):
"""Base Retrieval Evaluator class."""
metrics: List[BaseRetrievalMetric] = Field(
..., description="List of metrics to evaluate"
)
class Config:
arbitrary_types_allowed = True
@classmethod
def from_metric_names(
cls, metric_names: List[str], **kwargs: Any
) -> "BaseRetrievalEvaluator":
"""Create evaluator from metric names.
Args:
metric_names (List[str]): List of metric names
**kwargs: Additional arguments for the evaluator
"""
metric_types = resolve_metrics(metric_names)
return cls(metrics=[metric() for metric in metric_types], **kwargs)
@abstractmethod
async def _aget_retrieved_ids_and_texts(
self, query: str, mode: RetrievalEvalMode = RetrievalEvalMode.TEXT
) -> Tuple[List[str], List[str]]:
"""Get retrieved ids and texts."""
raise NotImplementedError
def evaluate(
self,
query: str,
expected_ids: List[str],
expected_texts: Optional[List[str]] = None,
mode: RetrievalEvalMode = RetrievalEvalMode.TEXT,
**kwargs: Any,
) -> RetrievalEvalResult:
"""Run evaluation results with query string and expected ids.
Args:
query (str): Query string
expected_ids (List[str]): Expected ids
Returns:
RetrievalEvalResult: Evaluation result
"""
return asyncio.run(
self.aevaluate(
query=query,
expected_ids=expected_ids,
expected_texts=expected_texts,
mode=mode,
**kwargs,
)
)
# @abstractmethod
async def aevaluate(
self,
query: str,
expected_ids: List[str],
expected_texts: Optional[List[str]] = None,
mode: RetrievalEvalMode = RetrievalEvalMode.TEXT,
**kwargs: Any,
) -> RetrievalEvalResult:
"""Run evaluation with query string, retrieved contexts,
and generated response string.
Subclasses can override this method to provide custom evaluation logic and
take in additional arguments.
"""
retrieved_ids, retrieved_texts = await self._aget_retrieved_ids_and_texts(
query, mode
)
metric_dict = {}
for metric in self.metrics:
eval_result = metric.compute(
query, expected_ids, retrieved_ids, expected_texts, retrieved_texts
)
metric_dict[metric.metric_name] = eval_result
return RetrievalEvalResult(
query=query,
expected_ids=expected_ids,
expected_texts=expected_texts,
retrieved_ids=retrieved_ids,
retrieved_texts=retrieved_texts,
mode=mode,
metric_dict=metric_dict,
)
async def aevaluate_dataset(
self,
dataset: EmbeddingQAFinetuneDataset,
workers: int = 2,
show_progress: bool = False,
**kwargs: Any,
) -> List[RetrievalEvalResult]:
"""Run evaluation with dataset."""
semaphore = asyncio.Semaphore(workers)
async def eval_worker(
query: str, expected_ids: List[str], mode: RetrievalEvalMode
) -> RetrievalEvalResult:
async with semaphore:
return await self.aevaluate(query, expected_ids=expected_ids, mode=mode)
response_jobs = []
mode = RetrievalEvalMode.from_str(dataset.mode)
for query_id, query in dataset.queries.items():
expected_ids = dataset.relevant_docs[query_id]
response_jobs.append(eval_worker(query, expected_ids, mode))
if show_progress:
from tqdm.asyncio import tqdm_asyncio
eval_results = await tqdm_asyncio.gather(*response_jobs)
else:
eval_results = await asyncio.gather(*response_jobs)
return eval_results