145 lines
4.4 KiB
Python
145 lines
4.4 KiB
Python
import os
|
|
from typing import Any, Callable, Dict, List, Literal, Optional, Type
|
|
|
|
import numpy as np
|
|
|
|
from llama_index.bridge.pydantic import Field, PrivateAttr
|
|
from llama_index.evaluation.retrieval.metrics_base import (
|
|
BaseRetrievalMetric,
|
|
RetrievalMetricResult,
|
|
)
|
|
|
|
_AGG_FUNC: Dict[str, Callable] = {"mean": np.mean, "median": np.median, "max": np.max}
|
|
|
|
|
|
class HitRate(BaseRetrievalMetric):
|
|
"""Hit rate metric."""
|
|
|
|
metric_name: str = "hit_rate"
|
|
|
|
def compute(
|
|
self,
|
|
query: Optional[str] = None,
|
|
expected_ids: Optional[List[str]] = None,
|
|
retrieved_ids: Optional[List[str]] = None,
|
|
expected_texts: Optional[List[str]] = None,
|
|
retrieved_texts: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> RetrievalMetricResult:
|
|
"""Compute metric."""
|
|
if retrieved_ids is None or expected_ids is None:
|
|
raise ValueError("Retrieved ids and expected ids must be provided")
|
|
is_hit = any(id in expected_ids for id in retrieved_ids)
|
|
return RetrievalMetricResult(
|
|
score=1.0 if is_hit else 0.0,
|
|
)
|
|
|
|
|
|
class MRR(BaseRetrievalMetric):
|
|
"""MRR metric."""
|
|
|
|
metric_name: str = "mrr"
|
|
|
|
def compute(
|
|
self,
|
|
query: Optional[str] = None,
|
|
expected_ids: Optional[List[str]] = None,
|
|
retrieved_ids: Optional[List[str]] = None,
|
|
expected_texts: Optional[List[str]] = None,
|
|
retrieved_texts: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> RetrievalMetricResult:
|
|
"""Compute metric."""
|
|
if retrieved_ids is None or expected_ids is None:
|
|
raise ValueError("Retrieved ids and expected ids must be provided")
|
|
for i, id in enumerate(retrieved_ids):
|
|
if id in expected_ids:
|
|
return RetrievalMetricResult(
|
|
score=1.0 / (i + 1),
|
|
)
|
|
return RetrievalMetricResult(
|
|
score=0.0,
|
|
)
|
|
|
|
|
|
class CohereRerankRelevancyMetric(BaseRetrievalMetric):
|
|
"""Cohere rerank relevancy metric."""
|
|
|
|
model: str = Field(description="Cohere model name.")
|
|
metric_name: str = "cohere_rerank_relevancy"
|
|
|
|
_client: Any = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = "rerank-english-v2.0",
|
|
api_key: Optional[str] = None,
|
|
):
|
|
try:
|
|
api_key = api_key or os.environ["COHERE_API_KEY"]
|
|
except IndexError:
|
|
raise ValueError(
|
|
"Must pass in cohere api key or "
|
|
"specify via COHERE_API_KEY environment variable "
|
|
)
|
|
try:
|
|
from cohere import Client
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Cannot import cohere package, please `pip install cohere`."
|
|
)
|
|
|
|
self._client = Client(api_key=api_key)
|
|
super().__init__(model=model)
|
|
|
|
def _get_agg_func(self, agg: Literal["max", "median", "mean"]) -> Callable:
|
|
"""Get agg func."""
|
|
return _AGG_FUNC[agg]
|
|
|
|
def compute(
|
|
self,
|
|
query: Optional[str] = None,
|
|
expected_ids: Optional[List[str]] = None,
|
|
retrieved_ids: Optional[List[str]] = None,
|
|
expected_texts: Optional[List[str]] = None,
|
|
retrieved_texts: Optional[List[str]] = None,
|
|
agg: Literal["max", "median", "mean"] = "max",
|
|
**kwargs: Any,
|
|
) -> RetrievalMetricResult:
|
|
"""Compute metric."""
|
|
del expected_texts # unused
|
|
|
|
if retrieved_texts is None:
|
|
raise ValueError("Retrieved texts must be provided")
|
|
|
|
results = self._client.rerank(
|
|
model=self.model,
|
|
top_n=len(
|
|
retrieved_texts
|
|
), # i.e. get a rank score for each retrieved chunk
|
|
query=query,
|
|
documents=retrieved_texts,
|
|
)
|
|
relevance_scores = [r.relevance_score for r in results]
|
|
agg_func = self._get_agg_func(agg)
|
|
|
|
return RetrievalMetricResult(
|
|
score=agg_func(relevance_scores), metadata={"agg": agg}
|
|
)
|
|
|
|
|
|
METRIC_REGISTRY: Dict[str, Type[BaseRetrievalMetric]] = {
|
|
"hit_rate": HitRate,
|
|
"mrr": MRR,
|
|
"cohere_rerank_relevancy": CohereRerankRelevancyMetric,
|
|
}
|
|
|
|
|
|
def resolve_metrics(metrics: List[str]) -> List[Type[BaseRetrievalMetric]]:
|
|
"""Resolve metrics from list of metric names."""
|
|
for metric in metrics:
|
|
if metric not in METRIC_REGISTRY:
|
|
raise ValueError(f"Invalid metric name: {metric}")
|
|
|
|
return [METRIC_REGISTRY[metric] for metric in metrics]
|