213 lines
7.4 KiB
Python
213 lines
7.4 KiB
Python
import json
|
|
import os
|
|
import re
|
|
import string
|
|
from collections import Counter
|
|
from shutil import rmtree
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import requests
|
|
import tqdm
|
|
|
|
from llama_index.core.base_query_engine import BaseQueryEngine
|
|
from llama_index.core.base_retriever import BaseRetriever
|
|
from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine
|
|
from llama_index.schema import NodeWithScore, QueryBundle, TextNode
|
|
from llama_index.utils import get_cache_dir
|
|
|
|
DEV_DISTRACTOR_URL = """http://curtis.ml.cmu.edu/datasets/\
|
|
hotpot/hotpot_dev_distractor_v1.json"""
|
|
|
|
|
|
class HotpotQAEvaluator:
|
|
"""
|
|
Refer to https://hotpotqa.github.io/ for more details on the dataset.
|
|
"""
|
|
|
|
def _download_datasets(self) -> Dict[str, str]:
|
|
cache_dir = get_cache_dir()
|
|
|
|
dataset_paths = {}
|
|
dataset = "hotpot_dev_distractor"
|
|
dataset_full_path = os.path.join(cache_dir, "datasets", "HotpotQA")
|
|
if not os.path.exists(dataset_full_path):
|
|
url = DEV_DISTRACTOR_URL
|
|
try:
|
|
os.makedirs(dataset_full_path, exist_ok=True)
|
|
save_file = open(
|
|
os.path.join(dataset_full_path, "dev_distractor.json"), "wb"
|
|
)
|
|
response = requests.get(url, stream=True)
|
|
|
|
# Define the size of each chunk
|
|
chunk_size = 1024
|
|
|
|
# Loop over the chunks and parse the JSON data
|
|
for chunk in tqdm.tqdm(response.iter_content(chunk_size=chunk_size)):
|
|
if chunk:
|
|
save_file.write(chunk)
|
|
except Exception as e:
|
|
if os.path.exists(dataset_full_path):
|
|
print(
|
|
"Dataset:", dataset, "not found at:", url, "Removing cached dir"
|
|
)
|
|
rmtree(dataset_full_path)
|
|
raise ValueError(f"could not download {dataset} dataset") from e
|
|
dataset_paths[dataset] = os.path.join(dataset_full_path, "dev_distractor.json")
|
|
print("Dataset:", dataset, "downloaded at:", dataset_full_path)
|
|
return dataset_paths
|
|
|
|
def run(
|
|
self,
|
|
query_engine: BaseQueryEngine,
|
|
queries: int = 10,
|
|
queries_fraction: Optional[float] = None,
|
|
show_result: bool = False,
|
|
) -> None:
|
|
dataset_paths = self._download_datasets()
|
|
dataset = "hotpot_dev_distractor"
|
|
dataset_path = dataset_paths[dataset]
|
|
print("Evaluating on dataset:", dataset)
|
|
print("-------------------------------------")
|
|
|
|
f = open(dataset_path)
|
|
query_objects = json.loads(f.read())
|
|
if queries_fraction:
|
|
queries_to_load = int(len(query_objects) * queries_fraction)
|
|
else:
|
|
queries_to_load = queries
|
|
queries_fraction = round(queries / len(query_objects), 5)
|
|
|
|
print(
|
|
f"Loading {queries_to_load} queries out of \
|
|
{len(query_objects)} (fraction: {queries_fraction})"
|
|
)
|
|
query_objects = query_objects[:queries_to_load]
|
|
|
|
assert isinstance(
|
|
query_engine, RetrieverQueryEngine
|
|
), "query_engine must be a RetrieverQueryEngine for this evaluation"
|
|
retriever = HotpotQARetriever(query_objects)
|
|
# Mock the query engine with a retriever
|
|
query_engine = query_engine.with_retriever(retriever=retriever)
|
|
|
|
scores = {"exact_match": 0.0, "f1": 0.0}
|
|
|
|
for query in query_objects:
|
|
query_bundle = QueryBundle(
|
|
query_str=query["question"]
|
|
+ " Give a short factoid answer (as few words as possible).",
|
|
custom_embedding_strs=[query["question"]],
|
|
)
|
|
response = query_engine.query(query_bundle)
|
|
em = int(
|
|
exact_match_score(
|
|
prediction=str(response), ground_truth=query["answer"]
|
|
)
|
|
)
|
|
f1, _, _ = f1_score(prediction=str(response), ground_truth=query["answer"])
|
|
scores["exact_match"] += em
|
|
scores["f1"] += f1
|
|
if show_result:
|
|
print("Question: ", query["question"])
|
|
print("Response:", response)
|
|
print("Correct answer: ", query["answer"])
|
|
print("EM:", em, "F1:", f1)
|
|
print("-------------------------------------")
|
|
|
|
for score in scores:
|
|
scores[score] /= len(query_objects)
|
|
|
|
print("Scores: ", scores)
|
|
|
|
|
|
class HotpotQARetriever(BaseRetriever):
|
|
"""
|
|
This is a mocked retriever for HotpotQA dataset. It is only meant to be used
|
|
with the hotpotqa dev dataset in the distractor setting. This is the setting that
|
|
does not require retrieval but requires identifying the supporting facts from
|
|
a list of 10 sources.
|
|
"""
|
|
|
|
def __init__(self, query_objects: Any) -> None:
|
|
assert isinstance(
|
|
query_objects,
|
|
list,
|
|
), f"query_objects must be a list, got: {type(query_objects)}"
|
|
self._queries = {}
|
|
for object in query_objects:
|
|
self._queries[object["question"]] = object
|
|
|
|
def _retrieve(self, query: QueryBundle) -> List[NodeWithScore]:
|
|
if query.custom_embedding_strs:
|
|
query_str = query.custom_embedding_strs[0]
|
|
else:
|
|
query_str = query.query_str
|
|
contexts = self._queries[query_str]["context"]
|
|
node_with_scores = []
|
|
for ctx in contexts:
|
|
text_list = ctx[1]
|
|
text = "\n".join(text_list)
|
|
node = TextNode(text=text, metadata={"title": ctx[0]})
|
|
node_with_scores.append(NodeWithScore(node=node, score=1.0))
|
|
|
|
return node_with_scores
|
|
|
|
def __str__(self) -> str:
|
|
return "HotpotQARetriever"
|
|
|
|
|
|
"""
|
|
Utils from https://github.com/hotpotqa/hotpot/blob/master/hotpot_evaluate_v1.py
|
|
"""
|
|
|
|
|
|
def normalize_answer(s: str) -> str:
|
|
def remove_articles(text: str) -> str:
|
|
return re.sub(r"\b(a|an|the)\b", " ", text)
|
|
|
|
def white_space_fix(text: str) -> str:
|
|
return " ".join(text.split())
|
|
|
|
def remove_punc(text: str) -> str:
|
|
exclude = set(string.punctuation)
|
|
return "".join(ch for ch in text if ch not in exclude)
|
|
|
|
def lower(text: str) -> str:
|
|
return text.lower()
|
|
|
|
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
|
|
|
|
|
def f1_score(prediction: str, ground_truth: str) -> Tuple[float, float, float]:
|
|
normalized_prediction = normalize_answer(prediction)
|
|
normalized_ground_truth = normalize_answer(ground_truth)
|
|
|
|
ZERO_METRIC = (0, 0, 0)
|
|
|
|
if (
|
|
normalized_prediction in ["yes", "no", "noanswer"]
|
|
and normalized_prediction != normalized_ground_truth
|
|
):
|
|
return ZERO_METRIC
|
|
if (
|
|
normalized_ground_truth in ["yes", "no", "noanswer"]
|
|
and normalized_prediction != normalized_ground_truth
|
|
):
|
|
return ZERO_METRIC
|
|
|
|
prediction_tokens = normalized_prediction.split()
|
|
ground_truth_tokens = normalized_ground_truth.split()
|
|
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
|
num_same = sum(common.values())
|
|
if num_same == 0:
|
|
return ZERO_METRIC
|
|
precision = 1.0 * num_same / len(prediction_tokens)
|
|
recall = 1.0 * num_same / len(ground_truth_tokens)
|
|
f1 = (2 * precision * recall) / (precision + recall)
|
|
return f1, precision, recall
|
|
|
|
|
|
def exact_match_score(prediction: str, ground_truth: str) -> bool:
|
|
return normalize_answer(prediction) == normalize_answer(ground_truth)
|