faiss_rag_enterprise/llama_index/evaluation/retrieval/metrics.py

145 lines
4.4 KiB
Python

import os
from typing import Any, Callable, Dict, List, Literal, Optional, Type
import numpy as np
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.evaluation.retrieval.metrics_base import (
BaseRetrievalMetric,
RetrievalMetricResult,
)
_AGG_FUNC: Dict[str, Callable] = {"mean": np.mean, "median": np.median, "max": np.max}
class HitRate(BaseRetrievalMetric):
"""Hit rate metric."""
metric_name: str = "hit_rate"
def compute(
self,
query: Optional[str] = None,
expected_ids: Optional[List[str]] = None,
retrieved_ids: Optional[List[str]] = None,
expected_texts: Optional[List[str]] = None,
retrieved_texts: Optional[List[str]] = None,
**kwargs: Any,
) -> RetrievalMetricResult:
"""Compute metric."""
if retrieved_ids is None or expected_ids is None:
raise ValueError("Retrieved ids and expected ids must be provided")
is_hit = any(id in expected_ids for id in retrieved_ids)
return RetrievalMetricResult(
score=1.0 if is_hit else 0.0,
)
class MRR(BaseRetrievalMetric):
"""MRR metric."""
metric_name: str = "mrr"
def compute(
self,
query: Optional[str] = None,
expected_ids: Optional[List[str]] = None,
retrieved_ids: Optional[List[str]] = None,
expected_texts: Optional[List[str]] = None,
retrieved_texts: Optional[List[str]] = None,
**kwargs: Any,
) -> RetrievalMetricResult:
"""Compute metric."""
if retrieved_ids is None or expected_ids is None:
raise ValueError("Retrieved ids and expected ids must be provided")
for i, id in enumerate(retrieved_ids):
if id in expected_ids:
return RetrievalMetricResult(
score=1.0 / (i + 1),
)
return RetrievalMetricResult(
score=0.0,
)
class CohereRerankRelevancyMetric(BaseRetrievalMetric):
"""Cohere rerank relevancy metric."""
model: str = Field(description="Cohere model name.")
metric_name: str = "cohere_rerank_relevancy"
_client: Any = PrivateAttr()
def __init__(
self,
model: str = "rerank-english-v2.0",
api_key: Optional[str] = None,
):
try:
api_key = api_key or os.environ["COHERE_API_KEY"]
except IndexError:
raise ValueError(
"Must pass in cohere api key or "
"specify via COHERE_API_KEY environment variable "
)
try:
from cohere import Client
except ImportError:
raise ImportError(
"Cannot import cohere package, please `pip install cohere`."
)
self._client = Client(api_key=api_key)
super().__init__(model=model)
def _get_agg_func(self, agg: Literal["max", "median", "mean"]) -> Callable:
"""Get agg func."""
return _AGG_FUNC[agg]
def compute(
self,
query: Optional[str] = None,
expected_ids: Optional[List[str]] = None,
retrieved_ids: Optional[List[str]] = None,
expected_texts: Optional[List[str]] = None,
retrieved_texts: Optional[List[str]] = None,
agg: Literal["max", "median", "mean"] = "max",
**kwargs: Any,
) -> RetrievalMetricResult:
"""Compute metric."""
del expected_texts # unused
if retrieved_texts is None:
raise ValueError("Retrieved texts must be provided")
results = self._client.rerank(
model=self.model,
top_n=len(
retrieved_texts
), # i.e. get a rank score for each retrieved chunk
query=query,
documents=retrieved_texts,
)
relevance_scores = [r.relevance_score for r in results]
agg_func = self._get_agg_func(agg)
return RetrievalMetricResult(
score=agg_func(relevance_scores), metadata={"agg": agg}
)
METRIC_REGISTRY: Dict[str, Type[BaseRetrievalMetric]] = {
"hit_rate": HitRate,
"mrr": MRR,
"cohere_rerank_relevancy": CohereRerankRelevancyMetric,
}
def resolve_metrics(metrics: List[str]) -> List[Type[BaseRetrievalMetric]]:
"""Resolve metrics from list of metric names."""
for metric in metrics:
if metric not in METRIC_REGISTRY:
raise ValueError(f"Invalid metric name: {metric}")
return [METRIC_REGISTRY[metric] for metric in metrics]