first commit
This commit is contained in:
commit
cb54502fae
|
|
@ -0,0 +1,26 @@
|
|||
FROM python:3.10-slim
|
||||
|
||||
# 安装系统依赖
|
||||
RUN apt-get update && apt-get install -y gcc libglib2.0-0 && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
|
||||
# 安装 Python 依赖
|
||||
COPY requirements.txt .
|
||||
RUN pip install --upgrade pip && pip install -r requirements.txt
|
||||
|
||||
# 安装本地 FlagEmbedding 源码
|
||||
COPY FlagEmbedding /opt/FlagEmbedding
|
||||
RUN pip install --no-deps --upgrade -e /opt/FlagEmbedding
|
||||
|
||||
# 拷贝应用代码和模型权重
|
||||
COPY app /app/app
|
||||
COPY model/bge-m3 /app/model/bge-m3
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8001
|
||||
|
||||
# 启动 FastAPI 服务
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8001"]
|
||||
|
||||
|
|
@ -0,0 +1,146 @@
|
|||
*.memmap
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
.idea/
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
../docs/_build/
|
||||
../docs/build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
Untitled.ipynb
|
||||
try.py
|
||||
update_model_card.py
|
||||
model_card.md
|
||||
pic.py
|
||||
pic2.py
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# MacOS associated
|
||||
.DS_Store
|
||||
|
||||
# results
|
||||
en_results
|
||||
zh_results
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
from .abc.inference import *
|
||||
from .inference import *
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
from .arguments import AbsEvalArgs, AbsEvalModelArgs
|
||||
from .evaluator import AbsEvaluator
|
||||
from .data_loader import AbsEvalDataLoader
|
||||
from .searcher import EvalRetriever, EvalDenseRetriever, EvalReranker
|
||||
from .runner import AbsEvalRunner
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AbsEvalArgs",
|
||||
"AbsEvalModelArgs",
|
||||
"AbsEvaluator",
|
||||
"AbsEvalDataLoader",
|
||||
"EvalRetriever",
|
||||
"EvalDenseRetriever",
|
||||
"EvalReranker",
|
||||
"AbsEvalRunner",
|
||||
]
|
||||
|
|
@ -0,0 +1,190 @@
|
|||
"""
|
||||
Adapted from https://github.com/AIR-Bench/AIR-Bench/blob/0.1.0/air_benchmark/evaluation_utils/evaluation_arguments.py
|
||||
"""
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbsEvalArgs:
|
||||
"""
|
||||
Base class for evaluation arguments.
|
||||
"""
|
||||
eval_name: str = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of the evaluation task, such as msmarco, beir, miracl, etc."}
|
||||
)
|
||||
dataset_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "1) If you want to perform evaluation on your own dataset, you can provide the path to the dataset directory (must exists in local). "
|
||||
"The dataset directory should contain the following files: corpus.jsonl, <split>_queries.jsonl, <split>_qrels.jsonl, or contain multiple directories, each of which contains the following files: corpus.jsonl, <split>_queries.jsonl, <split>_qrels.jsonl."
|
||||
"2) If you want to perform evaluation on the datasets we provide evaluation APIs for, you can provide the path to saving the downloaded dataset. If you provide None, the dataset will be only downloaded to the cache directory."
|
||||
}
|
||||
)
|
||||
force_redownload: bool = field(
|
||||
default=False, metadata={"help": "Whether to force redownload the dataset. This is useful when you load dataset from remote and want to update the dataset."}
|
||||
)
|
||||
dataset_names: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The names of the datasets to evaluate. Default: None. If None, all available datasets will be evaluated. The name can be a specific dataset name (BEIR), a specific language (MIRACL), etc.",
|
||||
"nargs": "+"
|
||||
}
|
||||
)
|
||||
splits: str = field(
|
||||
default="test",
|
||||
metadata={"help": "Splits to evaluate. Default: test", "nargs": "+"}
|
||||
)
|
||||
corpus_embd_save_dir: str = field(
|
||||
default=None, metadata={"help": "Path to save corpus embeddings. If None, embeddings are not saved."}
|
||||
)
|
||||
output_dir: str = field(
|
||||
default="./search_results", metadata={"help": "Path to save results."}
|
||||
)
|
||||
search_top_k: int = field(
|
||||
default=1000, metadata={"help": "Top k for retrieving."}
|
||||
)
|
||||
rerank_top_k: int = field(default=100, metadata={"help": "Top k for reranking."})
|
||||
cache_path: str = field(
|
||||
default=None, metadata={"help": "Cache directory for loading datasets."}
|
||||
)
|
||||
token: str = field(
|
||||
default_factory=lambda: os.getenv('HF_TOKEN', None),
|
||||
metadata={"help": "The token to use when accessing the model."}
|
||||
)
|
||||
overwrite: bool = field(
|
||||
default=False, metadata={"help": "whether to overwrite evaluation results"}
|
||||
)
|
||||
ignore_identical_ids: bool = field(
|
||||
default=False, metadata={"help": "whether to ignore identical ids in search results"}
|
||||
)
|
||||
# ================ for evaluation ===============
|
||||
k_values: int = field(
|
||||
default_factory=lambda: [1, 3, 5, 10, 100, 1000],
|
||||
metadata={"help": "k values for evaluation. Default: [1, 3, 5, 10, 100, 1000]", "nargs": "+"}
|
||||
)
|
||||
eval_output_method: str = field(
|
||||
default="markdown",
|
||||
metadata={"help": "The output method for evaluation results. Available methods: ['json', 'markdown']. Default: markdown.", "choices": ["json", "markdown"]}
|
||||
)
|
||||
eval_output_path: str = field(
|
||||
default="./eval_results.md", metadata={"help": "The path to save evaluation results."}
|
||||
)
|
||||
eval_metrics: str = field(
|
||||
default_factory=lambda: ["ndcg_at_10", "recall_at_10"],
|
||||
metadata={"help": "The metrics to evaluate. Default: ['ndcg_at_10', 'recall_at_10']", "nargs": "+"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbsEvalModelArgs:
|
||||
"""
|
||||
Base class for model arguments during evaluation.
|
||||
"""
|
||||
embedder_name_or_path: str = field(
|
||||
metadata={"help": "The embedder name or path.", "required": True}
|
||||
)
|
||||
embedder_model_class: Optional[str] = field(
|
||||
default=None, metadata={"help": "The embedder model class. Available classes: ['encoder-only-base', 'encoder-only-m3', 'decoder-only-base', 'decoder-only-icl']. Default: None. For the custom model, you need to specifiy the model class.", "choices": ["encoder-only-base", "encoder-only-m3", "decoder-only-base", "decoder-only-icl"]}
|
||||
)
|
||||
normalize_embeddings: bool = field(
|
||||
default=True, metadata={"help": "whether to normalize the embeddings"}
|
||||
)
|
||||
pooling_method: str = field(
|
||||
default="cls", metadata={"help": "The pooling method fot the embedder."}
|
||||
)
|
||||
use_fp16: bool = field(
|
||||
default=True, metadata={"help": "whether to use fp16 for inference"}
|
||||
)
|
||||
devices: Optional[str] = field(
|
||||
default=None, metadata={"help": "Devices to use for inference.", "nargs": "+"}
|
||||
)
|
||||
query_instruction_for_retrieval: Optional[str] = field(
|
||||
default=None, metadata={"help": "Instruction for query"}
|
||||
)
|
||||
query_instruction_format_for_retrieval: str = field(
|
||||
default="{}{}", metadata={"help": "Format for query instruction"}
|
||||
)
|
||||
examples_for_task: Optional[str] = field(
|
||||
default=None, metadata={"help": "Examples for task"}
|
||||
)
|
||||
examples_instruction_format: str = field(
|
||||
default="{}{}", metadata={"help": "Format for examples instruction"}
|
||||
)
|
||||
trust_remote_code: bool = field(
|
||||
default=False, metadata={"help": "Trust remote code"}
|
||||
)
|
||||
reranker_name_or_path: Optional[str] = field(
|
||||
default=None, metadata={"help": "The reranker name or path."}
|
||||
)
|
||||
reranker_model_class: Optional[str] = field(
|
||||
default=None, metadata={"help": "The reranker model class. Available classes: ['encoder-only-base', 'decoder-only-base', 'decoder-only-layerwise', 'decoder-only-lightweight']. Default: None. For the custom model, you need to specify the model class.", "choices": ["encoder-only-base", "decoder-only-base", "decoder-only-layerwise", "decoder-only-lightweight"]}
|
||||
)
|
||||
reranker_peft_path: Optional[str] = field(
|
||||
default=None, metadata={"help": "The reranker peft path."}
|
||||
)
|
||||
use_bf16: bool = field(
|
||||
default=False, metadata={"help": "whether to use bf16 for inference"}
|
||||
)
|
||||
query_instruction_for_rerank: Optional[str] = field(
|
||||
default=None, metadata={"help": "Instruction for query"}
|
||||
)
|
||||
query_instruction_format_for_rerank: str = field(
|
||||
default="{}{}", metadata={"help": "Format for query instruction"}
|
||||
)
|
||||
passage_instruction_for_rerank: Optional[str] = field(
|
||||
default=None, metadata={"help": "Instruction for passage"}
|
||||
)
|
||||
passage_instruction_format_for_rerank: str = field(
|
||||
default="{}{}", metadata={"help": "Format for passage instruction"}
|
||||
)
|
||||
cache_dir: str = field(
|
||||
default=None, metadata={"help": "Cache directory for models."}
|
||||
)
|
||||
# ================ for inference ===============
|
||||
embedder_batch_size: int = field(
|
||||
default=3000, metadata={"help": "Batch size for inference."}
|
||||
)
|
||||
reranker_batch_size: int = field(
|
||||
default=3000, metadata={"help": "Batch size for inference."}
|
||||
)
|
||||
embedder_query_max_length: int = field(
|
||||
default=512, metadata={"help": "Max length for query."}
|
||||
)
|
||||
embedder_passage_max_length: int = field(
|
||||
default=512, metadata={"help": "Max length for passage."}
|
||||
)
|
||||
reranker_query_max_length: Optional[int] = field(
|
||||
default=None, metadata={"help": "Max length for reranking."}
|
||||
)
|
||||
reranker_max_length: int = field(
|
||||
default=512, metadata={"help": "Max length for reranking."}
|
||||
)
|
||||
normalize: bool = field(
|
||||
default=False, metadata={"help": "whether to normalize the reranking scores"}
|
||||
)
|
||||
prompt: Optional[str] = field(
|
||||
default=None, metadata={"help": "The prompt for the reranker."}
|
||||
)
|
||||
cutoff_layers: List[int] = field(
|
||||
default=None, metadata={"help": "The output layers of layerwise/lightweight reranker."}
|
||||
)
|
||||
compress_ratio: int = field(
|
||||
default=1, metadata={"help": "The compress ratio of lightweight reranker."}
|
||||
)
|
||||
compress_layers: Optional[int] = field(
|
||||
default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# replace "\\n" with "\n"
|
||||
if "\\n" in self.query_instruction_format_for_retrieval:
|
||||
self.query_instruction_format_for_retrieval = self.query_instruction_format_for_retrieval.replace("\\n", "\n")
|
||||
if "\\n" in self.examples_instruction_format:
|
||||
self.examples_instruction_format = self.examples_instruction_format.replace("\\n", "\n")
|
||||
if "\\n" in self.query_instruction_format_for_rerank:
|
||||
self.query_instruction_format_for_rerank = self.query_instruction_format_for_rerank.replace("\\n", "\n")
|
||||
if "\\n" in self.passage_instruction_format_for_rerank:
|
||||
self.passage_instruction_format_for_rerank = self.passage_instruction_format_for_rerank.replace("\\n", "\n")
|
||||
|
|
@ -0,0 +1,423 @@
|
|||
"""
|
||||
Adapted from https://github.com/AIR-Bench/AIR-Bench/blob/0.1.0/air_benchmark/evaluation_utils/data_loader.py
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
import datasets
|
||||
import subprocess
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbsEvalDataLoader(ABC):
|
||||
"""
|
||||
Base class of data loader for evaluation.
|
||||
|
||||
Args:
|
||||
eval_name (str): The experiment name of current evaluation.
|
||||
dataset_dir (str, optional): path to the datasets. Defaults to ``None``.
|
||||
cache_dir (str, optional): Path to HuggingFace cache directory. Defaults to ``None``.
|
||||
token (str, optional): HF_TOKEN to access the private datasets/models in HF. Defaults to ``None``.
|
||||
force_redownload: If True, will force redownload the dataset to cover the local dataset. Defaults to ``False``.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
eval_name: str,
|
||||
dataset_dir: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
force_redownload: bool = False
|
||||
):
|
||||
self.eval_name = eval_name
|
||||
self.dataset_dir = dataset_dir
|
||||
if cache_dir is None:
|
||||
cache_dir = os.getenv('HF_HUB_CACHE', '~/.cache/huggingface/hub')
|
||||
self.cache_dir = os.path.join(cache_dir, eval_name)
|
||||
self.token = token
|
||||
self.force_redownload = force_redownload
|
||||
self.hf_download_mode = None if not force_redownload else "force_redownload"
|
||||
|
||||
def available_dataset_names(self) -> List[str]:
|
||||
"""
|
||||
Returns: List[str]: Available dataset names.
|
||||
"""
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def available_splits(self, dataset_name: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Returns: List[str]: Available splits in the dataset.
|
||||
"""
|
||||
pass
|
||||
|
||||
def check_dataset_names(self, dataset_names: Union[str, List[str]]) -> List[str]:
|
||||
"""Check the validity of dataset names
|
||||
|
||||
Args:
|
||||
dataset_names (Union[str, List[str]]): a dataset name (str) or a list of dataset names (List[str])
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
|
||||
Returns:
|
||||
List[str]: List of valid dataset names.
|
||||
"""
|
||||
available_dataset_names = self.available_dataset_names()
|
||||
if isinstance(dataset_names, str):
|
||||
dataset_names = [dataset_names]
|
||||
|
||||
for dataset_name in dataset_names:
|
||||
if dataset_name not in available_dataset_names:
|
||||
raise ValueError(f"Dataset name '{dataset_name}' not found in the dataset. Available dataset names: {available_dataset_names}")
|
||||
return dataset_names
|
||||
|
||||
def check_splits(self, splits: Union[str, List[str]], dataset_name: Optional[str] = None) -> List[str]:
|
||||
"""Check whether the splits are available in the dataset.
|
||||
|
||||
Args:
|
||||
splits (Union[str, List[str]]): Splits to check.
|
||||
dataset_name (Optional[str], optional): Name of dataset to check. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
List[str]: The available splits.
|
||||
"""
|
||||
available_splits = self.available_splits(dataset_name=dataset_name)
|
||||
if isinstance(splits, str):
|
||||
splits = [splits]
|
||||
checked_splits = []
|
||||
for split in splits:
|
||||
if split not in available_splits:
|
||||
logger.warning(f"Split '{split}' not found in the dataset. Removing it from the list.")
|
||||
else:
|
||||
checked_splits.append(split)
|
||||
return checked_splits
|
||||
|
||||
def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDict:
|
||||
"""Load the corpus from the dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
|
||||
"""
|
||||
if self.dataset_dir is not None:
|
||||
if dataset_name is None:
|
||||
save_dir = self.dataset_dir
|
||||
else:
|
||||
save_dir = os.path.join(self.dataset_dir, dataset_name)
|
||||
return self._load_local_corpus(save_dir, dataset_name=dataset_name)
|
||||
else:
|
||||
return self._load_remote_corpus(dataset_name=dataset_name)
|
||||
|
||||
def load_qrels(self, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
|
||||
"""Load the qrels from the dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
split (str, optional): The split to load relevance from. Defaults to ``'test'``.
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: A dict of relevance of query and document.
|
||||
"""
|
||||
if self.dataset_dir is not None:
|
||||
if dataset_name is None:
|
||||
save_dir = self.dataset_dir
|
||||
else:
|
||||
checked_dataset_names = self.check_dataset_names(dataset_name)
|
||||
if len(checked_dataset_names) == 0:
|
||||
raise ValueError(f"Dataset name {dataset_name} not found in the dataset.")
|
||||
dataset_name = checked_dataset_names[0]
|
||||
|
||||
save_dir = os.path.join(self.dataset_dir, dataset_name)
|
||||
|
||||
return self._load_local_qrels(save_dir, dataset_name=dataset_name, split=split)
|
||||
else:
|
||||
return self._load_remote_qrels(dataset_name=dataset_name, split=split)
|
||||
|
||||
def load_queries(self, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
|
||||
"""Load the queries from the dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
split (str, optional): The split to load queries from. Defaults to ``'test'``.
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: A dict of queries with id as key, query text as value.
|
||||
"""
|
||||
if self.dataset_dir is not None:
|
||||
if dataset_name is None:
|
||||
save_dir = self.dataset_dir
|
||||
else:
|
||||
checked_dataset_names = self.check_dataset_names(dataset_name)
|
||||
if len(checked_dataset_names) == 0:
|
||||
raise ValueError(f"Dataset name {dataset_name} not found in the dataset.")
|
||||
dataset_name = checked_dataset_names[0]
|
||||
|
||||
save_dir = os.path.join(self.dataset_dir, dataset_name)
|
||||
|
||||
return self._load_local_queries(save_dir, dataset_name=dataset_name, split=split)
|
||||
else:
|
||||
return self._load_remote_queries(dataset_name=dataset_name, split=split)
|
||||
|
||||
def _load_remote_corpus(
|
||||
self,
|
||||
dataset_name: Optional[str] = None,
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Abstract method to load corpus from remote dataset, to be overrode in child class.
|
||||
|
||||
Args:
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
save_dir (Optional[str], optional): Path to save the new downloaded corpus. Defaults to ``None``.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Loading remote corpus is not implemented.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
|
||||
"""
|
||||
raise NotImplementedError("Loading remote corpus is not implemented.")
|
||||
|
||||
def _load_remote_qrels(
|
||||
self,
|
||||
dataset_name: Optional[str] = None,
|
||||
split: str = 'test',
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Abstract method to load relevance from remote dataset, to be overrode in child class.
|
||||
|
||||
Args:
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
split (str, optional): Split to load from the remote dataset. Defaults to ``'test'``.
|
||||
save_dir (Optional[str], optional): Path to save the new downloaded relevance. Defaults to ``None``.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Loading remote qrels is not implemented.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: A dict of relevance of query and document.
|
||||
"""
|
||||
raise NotImplementedError("Loading remote qrels is not implemented.")
|
||||
|
||||
def _load_remote_queries(
|
||||
self,
|
||||
dataset_name: Optional[str] = None,
|
||||
split: str = 'test',
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Abstract method to load queries from remote dataset, to be overrode in child class.
|
||||
|
||||
Args:
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
split (str, optional): Split to load from the remote dataset. Defaults to ``'test'``.
|
||||
save_dir (Optional[str], optional): Path to save the new downloaded queries. Defaults to ``None``.
|
||||
|
||||
Raises:
|
||||
NotImplementedError
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: A dict of queries with id as key, query text as value.
|
||||
"""
|
||||
raise NotImplementedError("Loading remote queries is not implemented.")
|
||||
|
||||
def _load_local_corpus(self, save_dir: str, dataset_name: Optional[str] = None) -> datasets.DatasetDict:
|
||||
"""Load corpus from local dataset.
|
||||
|
||||
Args:
|
||||
save_dir (str): Path to save the loaded corpus.
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
|
||||
"""
|
||||
corpus_path = os.path.join(save_dir, 'corpus.jsonl')
|
||||
if self.force_redownload or not os.path.exists(corpus_path):
|
||||
logger.warning(f"Corpus not found in {corpus_path}. Trying to download the corpus from the remote and save it to {save_dir}.")
|
||||
return self._load_remote_corpus(dataset_name=dataset_name, save_dir=save_dir)
|
||||
else:
|
||||
corpus_data = datasets.load_dataset('json', data_files=corpus_path, cache_dir=self.cache_dir)['train']
|
||||
|
||||
corpus = {}
|
||||
for e in corpus_data:
|
||||
corpus[e['id']] = {'title': e.get('title', ""), 'text': e['text']}
|
||||
|
||||
return datasets.DatasetDict(corpus)
|
||||
|
||||
def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
|
||||
"""Load relevance from local dataset.
|
||||
|
||||
Args:
|
||||
save_dir (str): Path to save the loaded relevance.
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
split (str, optional): Split to load from the local dataset. Defaults to ``'test'``.
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: A dict of relevance of query and document.
|
||||
"""
|
||||
checked_split = self.check_splits(split, dataset_name=dataset_name)
|
||||
if len(checked_split) == 0:
|
||||
raise ValueError(f"Split {split} not found in the dataset.")
|
||||
split = checked_split[0]
|
||||
|
||||
qrels_path = os.path.join(save_dir, f"{split}_qrels.jsonl")
|
||||
if self.force_redownload or not os.path.exists(qrels_path):
|
||||
logger.warning(f"Qrels not found in {qrels_path}. Trying to download the qrels from the remote and save it to {save_dir}.")
|
||||
return self._load_remote_qrels(dataset_name=dataset_name, split=split, save_dir=save_dir)
|
||||
else:
|
||||
qrels_data = datasets.load_dataset('json', data_files=qrels_path, cache_dir=self.cache_dir)['train']
|
||||
|
||||
qrels = {}
|
||||
for data in qrels_data:
|
||||
qid = data['qid']
|
||||
if qid not in qrels:
|
||||
qrels[qid] = {}
|
||||
qrels[qid][data['docid']] = data['relevance']
|
||||
|
||||
return datasets.DatasetDict(qrels)
|
||||
|
||||
def _load_local_queries(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
|
||||
"""Load queries from local dataset.
|
||||
|
||||
Args:
|
||||
save_dir (str): Path to save the loaded queries.
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
split (str, optional): Split to load from the local dataset. Defaults to ``'test'``.
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: A dict of queries with id as key, query text as value.
|
||||
"""
|
||||
checked_split = self.check_splits(split, dataset_name=dataset_name)
|
||||
if len(checked_split) == 0:
|
||||
raise ValueError(f"Split {split} not found in the dataset.")
|
||||
split = checked_split[0]
|
||||
|
||||
queries_path = os.path.join(save_dir, f"{split}_queries.jsonl")
|
||||
if self.force_redownload or not os.path.exists(queries_path):
|
||||
logger.warning(f"Queries not found in {queries_path}. Trying to download the queries from the remote and save it to {save_dir}.")
|
||||
return self._load_remote_queries(dataset_name=dataset_name, split=split, save_dir=save_dir)
|
||||
else:
|
||||
queries_data = datasets.load_dataset('json', data_files=queries_path, cache_dir=self.cache_dir)['train']
|
||||
|
||||
queries = {e['id']: e['text'] for e in queries_data}
|
||||
return datasets.DatasetDict(queries)
|
||||
|
||||
def _download_file(self, download_url: str, save_dir: str):
|
||||
"""Download file from provided URL.
|
||||
|
||||
Args:
|
||||
download_url (str): Source URL of the file.
|
||||
save_dir (str): Path to the directory to save the zip file.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError
|
||||
|
||||
Returns:
|
||||
str: The path of the downloaded file.
|
||||
"""
|
||||
save_path = os.path.join(save_dir, download_url.split('/')[-1])
|
||||
|
||||
if self.force_redownload or (not os.path.exists(save_path) or os.path.getsize(save_path) == 0):
|
||||
cmd = ["wget", "-O", save_path, download_url]
|
||||
else:
|
||||
cmd = ["wget", "-nc", "-O", save_path, download_url]
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning(e.output)
|
||||
|
||||
if not os.path.exists(save_path) or os.path.getsize(save_path) == 0:
|
||||
raise FileNotFoundError(f"Failed to download file from {download_url} to {save_path}")
|
||||
else:
|
||||
logger.info(f"Downloaded file from {download_url} to {save_path}")
|
||||
return save_path
|
||||
|
||||
def _get_fpath_size(self, fpath: str) -> int:
|
||||
"""Get the total size of the files in provided path.
|
||||
|
||||
Args:
|
||||
fpath (str): path of files to compute the size.
|
||||
|
||||
Returns:
|
||||
int: The total size in bytes.
|
||||
"""
|
||||
if not os.path.isdir(fpath):
|
||||
return os.path.getsize(fpath)
|
||||
else:
|
||||
total_size = 0
|
||||
for dirpath, _, filenames in os.walk(fpath):
|
||||
for f in filenames:
|
||||
fp = os.path.join(dirpath, f)
|
||||
total_size += os.path.getsize(fp)
|
||||
return total_size
|
||||
|
||||
def _download_gz_file(self, download_url: str, save_dir: str):
|
||||
"""Download and unzip the gzip file from provided URL.
|
||||
|
||||
Args:
|
||||
download_url (str): Source URL of the gzip file.
|
||||
save_dir (str): Path to the directory to save the gzip file.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError
|
||||
|
||||
Returns:
|
||||
str: The path to the file after unzip.
|
||||
"""
|
||||
gz_file_path = self._download_file(download_url, save_dir)
|
||||
cmd = ["gzip", "-d", gz_file_path]
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning(e.output)
|
||||
|
||||
file_path = gz_file_path.replace(".gz", "")
|
||||
if not os.path.exists(file_path) or self._get_fpath_size(file_path) == 0:
|
||||
raise FileNotFoundError(f"Failed to unzip file {gz_file_path}")
|
||||
|
||||
return file_path
|
||||
|
||||
def _download_zip_file(self, download_url: str, save_dir: str):
|
||||
"""Download and unzip the zip file from provided URL.
|
||||
|
||||
Args:
|
||||
download_url (str): Source URL of the zip file.
|
||||
save_dir (str): Path to the directory to save the zip file.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError
|
||||
|
||||
Returns:
|
||||
str: The path to the file after unzip.
|
||||
"""
|
||||
zip_file_path = self._download_file(download_url, save_dir)
|
||||
file_path = zip_file_path.replace(".zip", "")
|
||||
if self.force_redownload or not os.path.exists(file_path):
|
||||
cmd = ["unzip", "-o", zip_file_path, "-d", file_path]
|
||||
else:
|
||||
cmd = ["unzip", "-n", zip_file_path, "-d", file_path]
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning(e.output)
|
||||
|
||||
if not os.path.exists(file_path) or self._get_fpath_size(file_path) == 0:
|
||||
raise FileNotFoundError(f"Failed to unzip file {zip_file_path}")
|
||||
|
||||
return file_path
|
||||
|
|
@ -0,0 +1,494 @@
|
|||
"""
|
||||
Adapted from https://github.com/AIR-Bench/AIR-Bench/blob/0.1.0/air_benchmark/evaluation_utils/evaluator.py
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import pandas as pd
|
||||
from typing import Dict, Optional, List, Union
|
||||
|
||||
from .data_loader import AbsEvalDataLoader
|
||||
from .searcher import EvalRetriever, EvalReranker
|
||||
from .utils import evaluate_metrics, evaluate_mrr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbsEvaluator:
|
||||
"""
|
||||
Base class of Evaluator.
|
||||
|
||||
Args:
|
||||
eval_name (str): The experiment name of current evaluation.
|
||||
data_loader (AbsEvalDataLoader): The data_loader to deal with data.
|
||||
overwrite (bool): If true, will overwrite the existing results.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
eval_name: str,
|
||||
data_loader: AbsEvalDataLoader,
|
||||
overwrite: bool = False,
|
||||
):
|
||||
self.eval_name = eval_name
|
||||
self.data_loader = data_loader
|
||||
self.overwrite = overwrite
|
||||
|
||||
def check_data_info(
|
||||
self,
|
||||
data_info: Dict[str, str],
|
||||
model_name: str,
|
||||
reranker_name: str,
|
||||
split: str,
|
||||
dataset_name: Optional[str] = None,
|
||||
):
|
||||
"""Check the validity of data info.
|
||||
|
||||
Args:
|
||||
data_info (Dict[str, str]): The loaded data info to be check.
|
||||
model_name (str): Name of model used.
|
||||
reranker_name (str): Name of reranker used.
|
||||
split (str): Split used in searching.
|
||||
dataset_name (Optional[str], optional): Name of dataset used. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: eval_name mismatch
|
||||
ValueError: model_name or reranker_name mismatch
|
||||
ValueError: split mismatch
|
||||
ValueError: dataset_name mismatch
|
||||
"""
|
||||
if data_info["eval_name"] != self.eval_name:
|
||||
raise ValueError(
|
||||
f'eval_name mismatch: {data_info["eval_name"]} vs {self.eval_name}'
|
||||
)
|
||||
if (
|
||||
data_info["model_name"] != model_name
|
||||
or data_info["reranker_name"] != reranker_name
|
||||
):
|
||||
raise ValueError(
|
||||
f'model_name or reranker_name mismatch: {data_info["model_name"]} vs {model_name} or {data_info["reranker_name"]} vs {reranker_name}'
|
||||
)
|
||||
if (data_info["split"] != split):
|
||||
raise ValueError(
|
||||
f'split mismatch: {data_info["split"]} vs {split}'
|
||||
)
|
||||
if dataset_name is not None and data_info["dataset_name"] != dataset_name:
|
||||
raise ValueError(
|
||||
f'dataset_name mismatch: {data_info["dataset_name"]} vs {dataset_name}'
|
||||
)
|
||||
|
||||
def get_corpus_embd_save_dir(
|
||||
self,
|
||||
retriever_name: str,
|
||||
corpus_embd_save_dir: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
If corpus_embd_save_dir is not None, then it will be used as the base directory to save the corpus embeddings. For dataset such as MKQA,
|
||||
the corpus for all languages is the same, so the subclass can override this method to save the corpus embeddings in the same directory.
|
||||
|
||||
Args:
|
||||
retriever_name (str): Name of the retriever.
|
||||
corpus_embd_save_dir (str, optional): Directory that saving the corpus embedding.
|
||||
dataset_name (str, optional):
|
||||
"""
|
||||
if corpus_embd_save_dir is not None:
|
||||
if dataset_name is not None:
|
||||
corpus_embd_save_dir = os.path.join(corpus_embd_save_dir, retriever_name, dataset_name)
|
||||
else:
|
||||
corpus_embd_save_dir = os.path.join(corpus_embd_save_dir, retriever_name)
|
||||
return corpus_embd_save_dir
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
splits: Union[str, List[str]],
|
||||
search_results_save_dir: str,
|
||||
retriever: EvalRetriever,
|
||||
reranker: Optional[EvalReranker] = None,
|
||||
corpus_embd_save_dir: Optional[str] = None,
|
||||
ignore_identical_ids: bool = False,
|
||||
k_values: List[int] = [1, 3, 5, 10, 100, 1000],
|
||||
dataset_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""This is called during the evaluation process.
|
||||
|
||||
Args:
|
||||
splits (Union[str, List[str]]): Splits of datasets.
|
||||
search_results_save_dir (str): Directory to save the search results.
|
||||
retriever (EvalRetriever): object of :class:EvalRetriever.
|
||||
reranker (Optional[EvalReranker], optional): Object of :class:EvalReranker. Defaults to :data:`None`.
|
||||
corpus_embd_save_dir (Optional[str], optional): Directory to save the embedded corpus. Defaults to :data:`None`.
|
||||
ignore_identical_ids (bool, optional): If True, will ignore identical ids in search results. Defaults to :data:`False`.
|
||||
k_values (List[int], optional): Cutoffs. Defaults to :data:`[1, 3, 5, 10, 100, 1000]`.
|
||||
dataset_name (Optional[str], optional): Name of the datasets. Defaults to :data:`None`.
|
||||
"""
|
||||
# Check Splits
|
||||
checked_splits = self.data_loader.check_splits(splits, dataset_name=dataset_name)
|
||||
if len(checked_splits) == 0:
|
||||
logger.warning(f"{splits} not found in the dataset. Skipping evaluation.")
|
||||
return
|
||||
splits = checked_splits
|
||||
|
||||
if dataset_name is not None:
|
||||
save_name = f"{dataset_name}-" + "{split}.json"
|
||||
else:
|
||||
save_name = "{split}.json"
|
||||
|
||||
corpus_embd_save_dir = self.get_corpus_embd_save_dir(
|
||||
retriever_name=str(retriever),
|
||||
corpus_embd_save_dir=corpus_embd_save_dir,
|
||||
dataset_name=dataset_name
|
||||
)
|
||||
|
||||
# Retrieval Stage
|
||||
no_reranker_search_results_save_dir = os.path.join(
|
||||
search_results_save_dir, str(retriever), "NoReranker"
|
||||
)
|
||||
os.makedirs(no_reranker_search_results_save_dir, exist_ok=True)
|
||||
|
||||
flag = False
|
||||
for split in splits:
|
||||
split_no_reranker_search_results_save_path = os.path.join(
|
||||
no_reranker_search_results_save_dir, save_name.format(split=split)
|
||||
)
|
||||
if not os.path.exists(split_no_reranker_search_results_save_path) or self.overwrite:
|
||||
flag = True
|
||||
break
|
||||
|
||||
no_reranker_search_results_dict = {}
|
||||
if flag:
|
||||
corpus = self.data_loader.load_corpus(dataset_name=dataset_name)
|
||||
|
||||
queries_dict = {
|
||||
split: self.data_loader.load_queries(dataset_name=dataset_name, split=split)
|
||||
for split in splits
|
||||
}
|
||||
|
||||
all_queries = {}
|
||||
for _, split_queries in queries_dict.items():
|
||||
all_queries.update(split_queries)
|
||||
|
||||
all_no_reranker_search_results = retriever(
|
||||
corpus=corpus,
|
||||
queries=all_queries,
|
||||
corpus_embd_save_dir=corpus_embd_save_dir,
|
||||
ignore_identical_ids=ignore_identical_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
for split in splits:
|
||||
split_queries = queries_dict[split]
|
||||
no_reranker_search_results_dict[split] = {
|
||||
qid: all_no_reranker_search_results[qid] for qid in split_queries
|
||||
}
|
||||
split_no_reranker_search_results_save_path = os.path.join(
|
||||
no_reranker_search_results_save_dir, save_name.format(split=split)
|
||||
)
|
||||
|
||||
self.save_search_results(
|
||||
eval_name=self.eval_name,
|
||||
model_name=str(retriever),
|
||||
reranker_name="NoReranker",
|
||||
search_results=no_reranker_search_results_dict[split],
|
||||
output_path=split_no_reranker_search_results_save_path,
|
||||
split=split,
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
else:
|
||||
for split in splits:
|
||||
split_no_reranker_search_results_save_path = os.path.join(
|
||||
no_reranker_search_results_save_dir, save_name.format(split=split)
|
||||
)
|
||||
data_info, search_results = self.load_search_results(split_no_reranker_search_results_save_path)
|
||||
|
||||
self.check_data_info(
|
||||
data_info=data_info,
|
||||
model_name=str(retriever),
|
||||
reranker_name="NoReranker",
|
||||
split=split,
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
no_reranker_search_results_dict[split] = search_results
|
||||
retriever.stop_multi_process_pool()
|
||||
eval_results_save_path = os.path.join(no_reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
|
||||
if not os.path.exists(eval_results_save_path) or self.overwrite or flag:
|
||||
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir, k_values=k_values)
|
||||
self.output_eval_results_to_json(retriever_eval_results, eval_results_save_path)
|
||||
|
||||
# Reranking Stage
|
||||
if reranker is not None:
|
||||
reranker_search_results_save_dir = os.path.join(
|
||||
search_results_save_dir, str(retriever), str(reranker)
|
||||
)
|
||||
os.makedirs(reranker_search_results_save_dir, exist_ok=True)
|
||||
|
||||
corpus = self.data_loader.load_corpus(dataset_name=dataset_name)
|
||||
|
||||
queries_dict = {
|
||||
split: self.data_loader.load_queries(dataset_name=dataset_name, split=split)
|
||||
for split in splits
|
||||
}
|
||||
|
||||
flag = False
|
||||
for split in splits:
|
||||
rerank_search_results_save_path = os.path.join(
|
||||
reranker_search_results_save_dir, save_name.format(split=split)
|
||||
)
|
||||
|
||||
if os.path.exists(rerank_search_results_save_path) and not self.overwrite:
|
||||
continue
|
||||
|
||||
flag = True
|
||||
rerank_search_results = reranker(
|
||||
corpus=corpus,
|
||||
queries=queries_dict[split],
|
||||
search_results=no_reranker_search_results_dict[split],
|
||||
ignore_identical_ids=ignore_identical_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.save_search_results(
|
||||
eval_name=self.eval_name,
|
||||
model_name=str(retriever),
|
||||
reranker_name=str(reranker),
|
||||
search_results=rerank_search_results,
|
||||
output_path=rerank_search_results_save_path,
|
||||
split=split,
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
reranker.stop_multi_process_pool()
|
||||
eval_results_save_path = os.path.join(reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
|
||||
if not os.path.exists(eval_results_save_path) or self.overwrite or flag:
|
||||
reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir, k_values=k_values)
|
||||
self.output_eval_results_to_json(reranker_eval_results, eval_results_save_path)
|
||||
|
||||
@staticmethod
|
||||
def save_search_results(
|
||||
eval_name: str,
|
||||
model_name: str,
|
||||
reranker_name: str,
|
||||
search_results: Dict[str, Dict[str, float]],
|
||||
output_path: str,
|
||||
split: str,
|
||||
dataset_name: Optional[str] = None,
|
||||
):
|
||||
"""Save the metadata and search results into a file.
|
||||
|
||||
Args:
|
||||
eval_name (str): The experiment name of current evaluation.
|
||||
model_name (str): Name of model used.
|
||||
reranker_name (str): Name of reranker used.
|
||||
search_results (Dict[str, Dict[str, float]]): Dictionary of search results.
|
||||
output_path (str): Output path to write the results.
|
||||
split (str): Split used in searching.
|
||||
dataset_name (Optional[str], optional): Name of dataset used. Defaults to :data:`None`.
|
||||
"""
|
||||
data = {
|
||||
"eval_name": eval_name,
|
||||
"model_name": model_name,
|
||||
"reranker_name": reranker_name,
|
||||
"split": split,
|
||||
"dataset_name": dataset_name,
|
||||
"search_results": search_results,
|
||||
}
|
||||
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
@staticmethod
|
||||
def load_search_results(input_path: str):
|
||||
"""Load search results from path.
|
||||
|
||||
Args:
|
||||
input_path (str): Path to load from.
|
||||
|
||||
Returns:
|
||||
dict, dict: data info that contains metadata and search results.
|
||||
"""
|
||||
with open(input_path, "r", encoding="utf-8") as f:
|
||||
data_info = json.load(f)
|
||||
|
||||
search_results = data_info.pop("search_results")
|
||||
return data_info, search_results
|
||||
|
||||
@staticmethod
|
||||
def compute_metrics(
|
||||
qrels: Dict[str, Dict[str, int]],
|
||||
search_results: Dict[str, Dict[str, float]],
|
||||
k_values: List[int],
|
||||
):
|
||||
"""Evaluate the model with metrics.
|
||||
|
||||
Args:
|
||||
qrels (Dict[str, Dict[str, int]]): Ground truth relevance of queries and documents.
|
||||
search_results (Dict[str, Dict[str, float]]): Dictionary of search results
|
||||
k_values (List[int]): Cutoffs.
|
||||
|
||||
Returns:
|
||||
dict: The results of the metrics.
|
||||
"""
|
||||
ndcg, _map, recall, precision = evaluate_metrics(
|
||||
qrels=qrels,
|
||||
results=search_results,
|
||||
k_values=k_values,
|
||||
)
|
||||
mrr = evaluate_mrr(
|
||||
qrels=qrels,
|
||||
results=search_results,
|
||||
k_values=k_values,
|
||||
)
|
||||
scores = {
|
||||
**{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
|
||||
**{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
|
||||
**{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
|
||||
**{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
|
||||
**{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
|
||||
}
|
||||
return scores
|
||||
|
||||
def evaluate_results(
|
||||
self,
|
||||
search_results_save_dir: str,
|
||||
k_values: List[int] = [1, 3, 5, 10, 100, 1000]
|
||||
):
|
||||
"""Compute metrics according to the results in the directory.
|
||||
|
||||
Args:
|
||||
search_results_save_dir (str): Path to the search results.
|
||||
k_values (List[int], optional): Cutoffs. Defaults to :data:`[1, 3, 5, 10, 100, 1000]`.
|
||||
|
||||
Returns:
|
||||
dict: Evaluation results.
|
||||
"""
|
||||
eval_results_dict = {}
|
||||
|
||||
for file in os.listdir(search_results_save_dir):
|
||||
if not file.endswith('.json'):
|
||||
continue
|
||||
|
||||
file_path = os.path.join(search_results_save_dir, file)
|
||||
data_info, search_results = self.load_search_results(file_path)
|
||||
|
||||
_eval_name = data_info['eval_name']
|
||||
assert _eval_name == self.eval_name, f'Mismatch eval_name: {_eval_name} vs {self.eval_name} in {file_path}'
|
||||
|
||||
split = data_info['split']
|
||||
dataset_name = data_info.get('dataset_name', None)
|
||||
qrels = self.data_loader.load_qrels(dataset_name=dataset_name, split=split)
|
||||
|
||||
eval_results = self.compute_metrics(
|
||||
qrels=qrels,
|
||||
search_results=search_results,
|
||||
k_values=k_values
|
||||
)
|
||||
|
||||
if dataset_name is not None:
|
||||
key = f"{dataset_name}-{split}"
|
||||
else:
|
||||
key = split
|
||||
eval_results_dict[key] = eval_results
|
||||
|
||||
return eval_results_dict
|
||||
|
||||
@staticmethod
|
||||
def output_eval_results_to_json(eval_results_dict: dict, output_path: str):
|
||||
"""Write the evaluation results into a json file.
|
||||
|
||||
Args:
|
||||
eval_results_dict (dict): Dictionary of the evaluation results.
|
||||
output_path (str): Output path to write the json file.
|
||||
"""
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(eval_results_dict, f, indent=4)
|
||||
logger.info(f"Results saved to {output_path}")
|
||||
|
||||
@staticmethod
|
||||
def get_results_df(metric: str, eval_results_dict: dict):
|
||||
"""Get the results from dictionary to a DataFrame.
|
||||
|
||||
Args:
|
||||
metric (str): Selected metric.
|
||||
eval_results_dict (dict): Dictionary of the evaluation results.
|
||||
|
||||
Returns:
|
||||
DataFrame: DataFrame of the results.
|
||||
"""
|
||||
results_dict = {}
|
||||
|
||||
for model_name, model_results in eval_results_dict.items():
|
||||
results_dict[model_name] = {}
|
||||
for reranker_name, reranker_results in model_results.items():
|
||||
results_dict[model_name][reranker_name] = {}
|
||||
for split, split_results in reranker_results.items():
|
||||
if metric in split_results:
|
||||
results_dict[model_name][reranker_name][split] = split_results[metric]
|
||||
else:
|
||||
results_dict[model_name][reranker_name][split] = None
|
||||
|
||||
model_reranker_pairs = set()
|
||||
all_splits = set()
|
||||
for model_name, model_results in results_dict.items():
|
||||
for reranker_name, reranker_results in model_results.items():
|
||||
model_reranker_pairs.add((model_name, reranker_name))
|
||||
all_splits.update(reranker_results.keys())
|
||||
|
||||
index = [(model, reranker) for model, reranker in model_reranker_pairs]
|
||||
multi_index = pd.MultiIndex.from_tuples(index, names=['Model', 'Reranker'])
|
||||
|
||||
all_splits = sorted(list(all_splits))
|
||||
overall_columns = ['average'] + all_splits
|
||||
overall_df = pd.DataFrame(index=multi_index, columns=overall_columns)
|
||||
|
||||
for model, reranker in model_reranker_pairs:
|
||||
for split in all_splits:
|
||||
if model in results_dict and reranker in results_dict[model] and split in results_dict[model][reranker]:
|
||||
overall_df.loc[(model, reranker), split] = results_dict[model][reranker][split]
|
||||
else:
|
||||
overall_df.loc[(model, reranker), split] = None
|
||||
if overall_df.loc[(model, reranker), all_splits].isnull().any():
|
||||
overall_df.loc[(model, reranker), 'average'] = None
|
||||
else:
|
||||
overall_df.loc[(model, reranker), 'average'] = overall_df.loc[(model, reranker), all_splits].mean()
|
||||
|
||||
return overall_df
|
||||
|
||||
@staticmethod
|
||||
def output_eval_results_to_markdown(eval_results_dict: dict, output_path: str, metrics: Union[List[str], str]):
|
||||
"""Write the evaluation results to a markdown file.
|
||||
|
||||
Args:
|
||||
eval_results_dict (dict): Dictionary that contains evaluation results.
|
||||
output_path (str): Path to write the output to.
|
||||
metrics (Union[List[str], str]): The metrics that will be written in the markdown file.
|
||||
"""
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
if isinstance(metrics, str):
|
||||
metrics = [metrics]
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for metric in metrics:
|
||||
f.write(f"## {metric}\n\n")
|
||||
results_df = AbsEvaluator.get_results_df(metric, eval_results_dict)
|
||||
max_index = dict(results_df.idxmax(axis=0))
|
||||
splits = results_df.columns
|
||||
f.write(f"| Model | Reranker | {' | '.join(splits)} |\n")
|
||||
f.write(f"| :---- | :---- | {' | '.join([':---:' for _ in splits])} |\n")
|
||||
for i, row in results_df.iterrows():
|
||||
line = f"| {i[0]} | {i[1]} | "
|
||||
for s, v in row.items():
|
||||
if v is None:
|
||||
line += "- | "
|
||||
else:
|
||||
if i != max_index[s]:
|
||||
line += f'{v*100:.3f} | '
|
||||
else:
|
||||
line += f'**{v*100:.3f}** | '
|
||||
f.write(line + "\n")
|
||||
f.write("\n")
|
||||
logger.info(f"Results saved to {output_path}")
|
||||
|
|
@ -0,0 +1,225 @@
|
|||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Union, Tuple
|
||||
|
||||
from FlagEmbedding import FlagAutoModel, FlagAutoReranker, AbsEmbedder, AbsReranker
|
||||
|
||||
from .arguments import AbsEvalArgs, AbsEvalModelArgs
|
||||
from .evaluator import AbsEvaluator
|
||||
from .searcher import EvalDenseRetriever, EvalReranker
|
||||
from .data_loader import AbsEvalDataLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbsEvalRunner:
|
||||
"""
|
||||
Abstract class of evaluation runner.
|
||||
|
||||
Args:
|
||||
eval_args (AbsEvalArgs): :class:AbsEvalArgs object with the evaluation arguments.
|
||||
model_args (AbsEvalModelArgs): :class:AbsEvalModelArgs object with the model arguments.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
eval_args: AbsEvalArgs,
|
||||
model_args: AbsEvalModelArgs,
|
||||
):
|
||||
self.eval_args = eval_args
|
||||
self.model_args = model_args
|
||||
|
||||
self.retriever, self.reranker = self.load_retriever_and_reranker()
|
||||
self.data_loader = self.load_data_loader()
|
||||
self.evaluator = self.load_evaluator()
|
||||
|
||||
@staticmethod
|
||||
def get_models(model_args: AbsEvalModelArgs) -> Tuple[AbsEmbedder, Union[AbsReranker, None]]:
|
||||
"""Get the embedding and reranker model
|
||||
|
||||
Args:
|
||||
model_args (AbsEvalModelArgs): :class:AbsEvalModelArgs object with the model arguments.
|
||||
|
||||
Returns:
|
||||
Tuple[AbsEmbedder, Union[AbsReranker, None]]: A :class:AbsEmbedder object of embedding model, and
|
||||
:class:AbsReranker object of reranker model if path provided.
|
||||
"""
|
||||
embedder = FlagAutoModel.from_finetuned(
|
||||
model_name_or_path=model_args.embedder_name_or_path,
|
||||
model_class=model_args.embedder_model_class,
|
||||
normalize_embeddings=model_args.normalize_embeddings,
|
||||
pooling_method=model_args.pooling_method,
|
||||
use_fp16=model_args.use_fp16,
|
||||
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
|
||||
query_instruction_format=model_args.query_instruction_format_for_retrieval,
|
||||
devices=model_args.devices,
|
||||
examples_for_task=model_args.examples_for_task,
|
||||
examples_instruction_format=model_args.examples_instruction_format,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
cache_dir=model_args.cache_dir,
|
||||
batch_size=model_args.embedder_batch_size,
|
||||
query_max_length=model_args.embedder_query_max_length,
|
||||
passage_max_length=model_args.embedder_passage_max_length,
|
||||
)
|
||||
embedder.model.config._name_or_path = model_args.embedder_name_or_path
|
||||
reranker = None
|
||||
if model_args.reranker_name_or_path is not None:
|
||||
reranker = FlagAutoReranker.from_finetuned(
|
||||
model_name_or_path=model_args.reranker_name_or_path,
|
||||
model_class=model_args.reranker_model_class,
|
||||
peft_path=model_args.reranker_peft_path,
|
||||
use_fp16=model_args.use_fp16,
|
||||
use_bf16=model_args.use_bf16,
|
||||
query_instruction_for_rerank=model_args.query_instruction_for_rerank,
|
||||
query_instruction_format=model_args.query_instruction_format_for_rerank,
|
||||
passage_instruction_for_rerank=model_args.passage_instruction_for_rerank,
|
||||
passage_instruction_format=model_args.passage_instruction_format_for_rerank,
|
||||
cache_dir=model_args.cache_dir,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
devices=model_args.devices,
|
||||
normalize=model_args.normalize,
|
||||
prompt=model_args.prompt,
|
||||
cutoff_layers=model_args.cutoff_layers,
|
||||
compress_layers=model_args.compress_layers,
|
||||
compress_ratio=model_args.compress_ratio,
|
||||
batch_size=model_args.reranker_batch_size,
|
||||
query_max_length=model_args.reranker_query_max_length,
|
||||
max_length=model_args.reranker_max_length,
|
||||
)
|
||||
reranker.model.config._name_or_path = model_args.reranker_name_or_path
|
||||
return embedder, reranker
|
||||
|
||||
def load_retriever_and_reranker(self) -> Tuple[EvalDenseRetriever, Union[EvalReranker, None]]:
|
||||
"""Load retriever and reranker for evaluation
|
||||
|
||||
Returns:
|
||||
Tuple[EvalDenseRetriever, Union[EvalReranker, None]]: A :class:EvalDenseRetriever object for retrieval, and a
|
||||
:class:EvalReranker object if reranker provided.
|
||||
"""
|
||||
embedder, reranker = self.get_models(self.model_args)
|
||||
retriever = EvalDenseRetriever(
|
||||
embedder,
|
||||
search_top_k=self.eval_args.search_top_k,
|
||||
overwrite=self.eval_args.overwrite
|
||||
)
|
||||
if reranker is not None:
|
||||
reranker = EvalReranker(reranker, rerank_top_k=self.eval_args.rerank_top_k)
|
||||
return retriever, reranker
|
||||
|
||||
def load_data_loader(self) -> AbsEvalDataLoader:
|
||||
"""Load the data loader
|
||||
|
||||
Returns:
|
||||
AbsEvalDataLoader: Data loader object for that specific task.
|
||||
"""
|
||||
data_loader = AbsEvalDataLoader(
|
||||
eval_name=self.eval_args.eval_name,
|
||||
dataset_dir=self.eval_args.dataset_dir,
|
||||
cache_dir=self.eval_args.cache_path,
|
||||
token=self.eval_args.token,
|
||||
force_redownload=self.eval_args.force_redownload,
|
||||
)
|
||||
return data_loader
|
||||
|
||||
def load_evaluator(self) -> AbsEvaluator:
|
||||
"""Load the evaluator for evaluation
|
||||
|
||||
Returns:
|
||||
AbsEvaluator: the evaluator to run the evaluation.
|
||||
"""
|
||||
evaluator = AbsEvaluator(
|
||||
eval_name=self.eval_args.eval_name,
|
||||
data_loader=self.data_loader,
|
||||
overwrite=self.eval_args.overwrite,
|
||||
)
|
||||
return evaluator
|
||||
|
||||
@staticmethod
|
||||
def evaluate_metrics(
|
||||
search_results_save_dir: str,
|
||||
output_method: str = "markdown",
|
||||
output_path: str = "./eval_dev_results.md",
|
||||
metrics: Union[str, List[str]] = ["ndcg_at_10", "recall_at_10"]
|
||||
):
|
||||
"""Evaluate the provided metrics and write the results.
|
||||
|
||||
Args:
|
||||
search_results_save_dir (str): Path to save the search results.
|
||||
output_method (str, optional): Output results to `json` or `markdown`. Defaults to :data:`"markdown"`.
|
||||
output_path (str, optional): Path to write the output. Defaults to :data:`"./eval_dev_results.md"`.
|
||||
metrics (Union[str, List[str]], optional): metrics to use. Defaults to :data:`["ndcg_at_10", "recall_at_10"]`.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: Eval results not found
|
||||
ValueError: Invalid output method
|
||||
"""
|
||||
eval_results_dict = {}
|
||||
for model_name in sorted(os.listdir(search_results_save_dir)):
|
||||
model_search_results_save_dir = os.path.join(search_results_save_dir, model_name)
|
||||
if not os.path.isdir(model_search_results_save_dir):
|
||||
continue
|
||||
for reranker_name in sorted(os.listdir(model_search_results_save_dir)):
|
||||
reranker_search_results_save_dir = os.path.join(model_search_results_save_dir, reranker_name)
|
||||
if not os.path.isdir(reranker_search_results_save_dir):
|
||||
continue
|
||||
eval_results_path = os.path.join(reranker_search_results_save_dir, 'EVAL', "eval_results.json")
|
||||
if os.path.exists(eval_results_path):
|
||||
eval_results = json.load(open(eval_results_path, encoding='utf-8'))
|
||||
else:
|
||||
raise FileNotFoundError(f"Eval results not found: {eval_results_path}")
|
||||
|
||||
if model_name not in eval_results_dict:
|
||||
eval_results_dict[model_name] = {}
|
||||
eval_results_dict[model_name][reranker_name] = eval_results
|
||||
|
||||
if output_method == "json":
|
||||
AbsEvaluator.output_eval_results_to_json(eval_results_dict, output_path)
|
||||
elif output_method == "markdown":
|
||||
AbsEvaluator.output_eval_results_to_markdown(eval_results_dict, output_path, metrics)
|
||||
else:
|
||||
raise ValueError(f"Invalid output method: {output_method}. Available methods: ['json', 'markdown']")
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run the whole evaluation.
|
||||
"""
|
||||
if self.eval_args.dataset_names is None:
|
||||
dataset_names = self.data_loader.available_dataset_names()
|
||||
else:
|
||||
dataset_names = self.data_loader.check_dataset_names(self.eval_args.dataset_names)
|
||||
|
||||
if len(dataset_names) == 0:
|
||||
logger.info(f"Running {self.eval_args.eval_name} evaluation on the default dataset.")
|
||||
self.evaluator(
|
||||
splits=self.eval_args.splits,
|
||||
search_results_save_dir=self.eval_args.output_dir,
|
||||
retriever=self.retriever,
|
||||
reranker=self.reranker,
|
||||
corpus_embd_save_dir=self.eval_args.corpus_embd_save_dir,
|
||||
ignore_identical_ids=self.eval_args.ignore_identical_ids,
|
||||
k_values=self.eval_args.k_values
|
||||
)
|
||||
logger.info(f"{self.eval_args.eval_name} evaluation completed.")
|
||||
else:
|
||||
logger.info(f"Running {self.eval_args.eval_name} evaluation on the following dataset names: {dataset_names}")
|
||||
for dataset_name in dataset_names:
|
||||
logger.info(f"Running {self.eval_args.eval_name} evaluation on: {dataset_name}")
|
||||
self.evaluator(
|
||||
splits=self.eval_args.splits,
|
||||
search_results_save_dir=self.eval_args.output_dir,
|
||||
retriever=self.retriever,
|
||||
reranker=self.reranker,
|
||||
corpus_embd_save_dir=self.eval_args.corpus_embd_save_dir,
|
||||
ignore_identical_ids=self.eval_args.ignore_identical_ids,
|
||||
k_values=self.eval_args.k_values,
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
logger.info(f"{self.eval_args.eval_name} evaluation on {dataset_names} completed.")
|
||||
|
||||
logger.info("Start computing metrics.")
|
||||
self.evaluate_metrics(
|
||||
search_results_save_dir=self.eval_args.output_dir,
|
||||
output_method=self.eval_args.eval_output_method,
|
||||
output_path=self.eval_args.eval_output_path,
|
||||
metrics=self.eval_args.eval_metrics
|
||||
)
|
||||
|
|
@ -0,0 +1,248 @@
|
|||
"""
|
||||
Adapted from https://github.com/AIR-Bench/AIR-Bench/blob/0.1.0/air_benchmark/evaluation_utils/searcher.py
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
import gc
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Any, Dict, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from FlagEmbedding.abc.inference import AbsEmbedder, AbsReranker
|
||||
from FlagEmbedding.abc.evaluation.utils import index, search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvalRetriever(ABC):
|
||||
"""
|
||||
This is the base class for retriever.
|
||||
"""
|
||||
def __init__(self, embedder: AbsEmbedder, search_top_k: int = 1000, overwrite: bool = False):
|
||||
self.embedder = embedder
|
||||
self.search_top_k = search_top_k
|
||||
self.overwrite = overwrite
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Returns: str: Name of the retriever.
|
||||
"""
|
||||
return os.path.basename(self.embedder.model.config._name_or_path)
|
||||
|
||||
def stop_multi_process_pool(self):
|
||||
self.embedder.stop_self_pool()
|
||||
# if self.embedder.pool is not None:
|
||||
# self.embedder.stop_multi_process_pool(self.embedder.pool)
|
||||
# self.embedder.pool = None
|
||||
# self.embedder.model.to('cpu')
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self,
|
||||
corpus: Dict[str, Dict[str, Any]],
|
||||
queries: Dict[str, str],
|
||||
corpus_embd_save_dir: Optional[str] = None,
|
||||
ignore_identical_ids: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
Abstract method to be overrode. This is called during the retrieval process.
|
||||
|
||||
Parameters:
|
||||
corpus: Dict[str, Dict[str, Any]]: Corpus of documents.
|
||||
Structure: {<docid>: {"text": <text>}}.
|
||||
Example: {"doc-0": {"text": "This is a document."}}
|
||||
queries: Dict[str, str]: Queries to search for.
|
||||
Structure: {<qid>: <query>}.
|
||||
Example: {"q-0": "This is a query."}
|
||||
corpus_embd_save_dir (Optional[str]): Defaults to :data:`None`.
|
||||
ignore_identical_ids (bool): Defaults to :data:`False`.
|
||||
**kwargs: Any: Additional arguments.
|
||||
|
||||
Returns: Dict[str, Dict[str, float]]: Top-k search results for each query. k is specified by search_top_k.
|
||||
Structure: {qid: {docid: score}}. The higher is the score, the more relevant is the document.
|
||||
Example: {"q-0": {"doc-0": 0.9}}
|
||||
"""
|
||||
|
||||
|
||||
class EvalDenseRetriever(EvalRetriever):
|
||||
"""
|
||||
Child class of :class:EvalRetriever for dense retrieval.
|
||||
"""
|
||||
def __call__(
|
||||
self,
|
||||
corpus: Dict[str, Dict[str, Any]],
|
||||
queries: Dict[str, str],
|
||||
corpus_embd_save_dir: Optional[str] = None,
|
||||
ignore_identical_ids: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
This is called during the retrieval process.
|
||||
|
||||
Parameters:
|
||||
corpus: Dict[str, Dict[str, Any]]: Corpus of documents.
|
||||
Structure: {<docid>: {"text": <text>}}.
|
||||
Example: {"doc-0": {"text": "This is a document."}}
|
||||
queries: Dict[str, str]: Queries to search for.
|
||||
Structure: {<qid>: <query>}.
|
||||
Example: {"q-0": "This is a query."}
|
||||
corpus_embd_save_dir (Optional[str]): Defaults to :data:`None`.
|
||||
ignore_identical_ids (bool): Defaults to :data:`False`.
|
||||
**kwargs: Any: Additional arguments.
|
||||
|
||||
Returns: Dict[str, Dict[str, float]]: Top-k search results for each query. k is specified by search_top_k.
|
||||
Structure: {qid: {docid: score}}. The higher is the score, the more relevant is the document.
|
||||
Example: {"q-0": {"doc-0": 0.9}}
|
||||
"""
|
||||
if ignore_identical_ids:
|
||||
logger.warning("ignore_identical_ids is set to True. This means that the search results will not contain identical ids. Note: Dataset such as MIRACL should NOT set this to True.")
|
||||
|
||||
# dense embedding models do not require language as input: AIRBench evaluation
|
||||
kwargs.pop("language", None)
|
||||
|
||||
corpus_ids = []
|
||||
corpus_texts = []
|
||||
for docid, doc in corpus.items():
|
||||
corpus_ids.append(docid)
|
||||
corpus_texts.append(
|
||||
doc["text"] if "title" not in doc
|
||||
else f"{doc['title']} {doc['text']}".strip()
|
||||
)
|
||||
queries_ids = []
|
||||
queries_texts = []
|
||||
for qid, query in queries.items():
|
||||
queries_ids.append(qid)
|
||||
queries_texts.append(query)
|
||||
|
||||
if corpus_embd_save_dir is not None:
|
||||
if os.path.exists(os.path.join(corpus_embd_save_dir, "doc.npy")) and not self.overwrite:
|
||||
corpus_emb = np.load(os.path.join(corpus_embd_save_dir, "doc.npy"))
|
||||
else:
|
||||
corpus_emb = self.embedder.encode_corpus(corpus_texts, **kwargs)
|
||||
else:
|
||||
corpus_emb = self.embedder.encode_corpus(corpus_texts, **kwargs)
|
||||
|
||||
queries_emb = self.embedder.encode_queries(queries_texts, **kwargs)
|
||||
|
||||
# check if the embeddings are in dictionary format: M3Embedder
|
||||
if isinstance(corpus_emb, dict):
|
||||
corpus_emb = corpus_emb["dense_vecs"]
|
||||
if isinstance(queries_emb, dict):
|
||||
queries_emb = queries_emb["dense_vecs"]
|
||||
|
||||
if corpus_embd_save_dir is not None and \
|
||||
(not os.path.exists(os.path.join(corpus_embd_save_dir, "doc.npy")) or self.overwrite):
|
||||
os.makedirs(corpus_embd_save_dir, exist_ok=True)
|
||||
np.save(os.path.join(corpus_embd_save_dir, "doc.npy"), corpus_emb)
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
faiss_index = index(corpus_embeddings=corpus_emb)
|
||||
all_scores, all_indices = search(query_embeddings=queries_emb, faiss_index=faiss_index, k=self.search_top_k)
|
||||
|
||||
results = {}
|
||||
for idx, (scores, indices) in enumerate(zip(all_scores, all_indices)):
|
||||
results[queries_ids[idx]] = {}
|
||||
for score, indice in zip(scores, indices):
|
||||
if indice != -1:
|
||||
if ignore_identical_ids and corpus_ids[indice] == queries_ids[idx]:
|
||||
continue
|
||||
results[queries_ids[idx]][corpus_ids[indice]] = float(score)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class EvalReranker:
|
||||
"""
|
||||
Class for reranker during evaluation.
|
||||
"""
|
||||
def __init__(self, reranker: AbsReranker, rerank_top_k: int = 100):
|
||||
self.reranker = reranker
|
||||
self.rerank_top_k = rerank_top_k
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Returns: str: Name of the reranker.
|
||||
"""
|
||||
return os.path.basename(self.reranker.model.config._name_or_path)
|
||||
|
||||
def stop_multi_process_pool(self):
|
||||
self.reranker.stop_self_pool()
|
||||
# if self.reranker.pool is not None:
|
||||
# self.reranker.stop_multi_process_pool(self.reranker.pool)
|
||||
# self.reranker.pool = None
|
||||
# self.reranker.model.to('cpu')
|
||||
# gc.collect()
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
corpus: Dict[str, Dict[str, Any]],
|
||||
queries: Dict[str, str],
|
||||
search_results: Dict[str, Dict[str, float]],
|
||||
ignore_identical_ids: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
This is called during the reranking process.
|
||||
|
||||
Parameters:
|
||||
corpus: Dict[str, Dict[str, Any]]: Corpus of documents.
|
||||
Structure: {<docid>: {"text": <text>}}.
|
||||
Example: {"doc-0": {"text": "This is a document."}}
|
||||
queries: Dict[str, str]: Queries to search for.
|
||||
Structure: {<qid>: <query>}.
|
||||
Example: {"q-0": "This is a query."}
|
||||
search_results: Dict[str, Dict[str, float]]: Search results for each query.
|
||||
Structure: {qid: {docid: score}}. The higher is the score, the more relevant is the document.
|
||||
Example: {"q-0": {"doc-0": 0.9}}
|
||||
**kwargs: Any: Additional arguments.
|
||||
|
||||
Returns: Dict[str, Dict[str, float]]: Reranked search results for each query. k is specified by rerank_top_k.
|
||||
Structure: {qid: {docid: score}}. The higher is the score, the more relevant is the document.
|
||||
Example: {"q-0": {"doc-0": 0.9}}
|
||||
"""
|
||||
# truncate search results to top_k
|
||||
for qid in search_results:
|
||||
search_results[qid] = dict(
|
||||
sorted(search_results[qid].items(), key=lambda x: x[1], reverse=True)[
|
||||
:self.rerank_top_k
|
||||
]
|
||||
)
|
||||
# generate sentence pairs
|
||||
sentence_pairs = []
|
||||
pairs = []
|
||||
for qid in search_results:
|
||||
for docid in search_results[qid]:
|
||||
if ignore_identical_ids and qid == docid:
|
||||
continue
|
||||
sentence_pairs.append(
|
||||
{
|
||||
"qid": qid,
|
||||
"docid": docid,
|
||||
"query": queries[qid],
|
||||
"doc": corpus[docid]["text"] if "title" not in corpus[docid]
|
||||
else f"{corpus[docid]['title']} {corpus[docid]['text']}".strip(),
|
||||
}
|
||||
)
|
||||
pairs.append(
|
||||
(
|
||||
queries[qid],
|
||||
corpus[docid]["text"] if "title" not in corpus[docid]
|
||||
else f"{corpus[docid]['title']} {corpus[docid]['text']}".strip()
|
||||
)
|
||||
)
|
||||
# compute scores
|
||||
scores = self.reranker.compute_score(pairs)
|
||||
for i, score in enumerate(scores):
|
||||
sentence_pairs[i]["score"] = float(score)
|
||||
# rerank
|
||||
reranked_results = {qid: {} for qid in search_results}
|
||||
for pair in sentence_pairs:
|
||||
reranked_results[pair["qid"]][pair["docid"]] = pair["score"]
|
||||
return reranked_results
|
||||
|
|
@ -0,0 +1,189 @@
|
|||
import faiss
|
||||
import torch
|
||||
import logging
|
||||
import numpy as np
|
||||
import pytrec_eval
|
||||
from tqdm import tqdm
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Modified from https://github.com/beir-cellar/beir/blob/f062f038c4bfd19a8ca942a9910b1e0d218759d4/beir/retrieval/custom_metrics.py#L4
|
||||
def evaluate_mrr(
|
||||
qrels: Dict[str, Dict[str, int]],
|
||||
results: Dict[str, Dict[str, float]],
|
||||
k_values: List[int],
|
||||
) -> Tuple[Dict[str, float]]:
|
||||
"""Compute mean reciprocal rank (MRR).
|
||||
|
||||
Args:
|
||||
qrels (Dict[str, Dict[str, int]]): Ground truth relevance.
|
||||
results (Dict[str, Dict[str, float]]): Search results to evaluate.
|
||||
k_values (List[int]): Cutoffs.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, float]]: MRR results at provided k values.
|
||||
"""
|
||||
mrr = defaultdict(list)
|
||||
|
||||
k_max, top_hits = max(k_values), {}
|
||||
|
||||
for query_id, doc_scores in results.items():
|
||||
top_hits[query_id] = sorted(
|
||||
doc_scores.items(), key=lambda item: item[1], reverse=True
|
||||
)[0:k_max]
|
||||
|
||||
for query_id in top_hits:
|
||||
query_relevant_docs = {
|
||||
doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0
|
||||
}
|
||||
for k in k_values:
|
||||
rr = 0
|
||||
for rank, hit in enumerate(top_hits[query_id][0:k], 1):
|
||||
if hit[0] in query_relevant_docs:
|
||||
rr = 1.0 / rank
|
||||
break
|
||||
mrr[f"MRR@{k}"].append(rr)
|
||||
|
||||
for k in k_values:
|
||||
mrr[f"MRR@{k}"] = round(sum(mrr[f"MRR@{k}"]) / len(qrels), 5)
|
||||
return mrr
|
||||
|
||||
|
||||
# Modified from https://github.com/embeddings-benchmark/mteb/blob/18f730696451a5aaa026494cecf288fd5cde9fd0/mteb/evaluation/evaluators/RetrievalEvaluator.py#L501
|
||||
def evaluate_metrics(
|
||||
qrels: Dict[str, Dict[str, int]],
|
||||
results: Dict[str, Dict[str, float]],
|
||||
k_values: List[int],
|
||||
) -> Tuple[
|
||||
Dict[str, float],
|
||||
Dict[str, float],
|
||||
Dict[str, float],
|
||||
Dict[str, float],
|
||||
]:
|
||||
"""Evaluate the main metrics.
|
||||
|
||||
Args:
|
||||
qrels (Dict[str, Dict[str, int]]): Ground truth relevance.
|
||||
results (Dict[str, Dict[str, float]]): Search results to evaluate.
|
||||
k_values (List[int]): Cutoffs.
|
||||
|
||||
Returns:
|
||||
Tuple[ Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float], ]: Results of different metrics at
|
||||
different provided k values.
|
||||
"""
|
||||
all_ndcgs, all_aps, all_recalls, all_precisions = defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list)
|
||||
|
||||
map_string = "map_cut." + ",".join([str(k) for k in k_values])
|
||||
ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
|
||||
recall_string = "recall." + ",".join([str(k) for k in k_values])
|
||||
precision_string = "P." + ",".join([str(k) for k in k_values])
|
||||
evaluator = pytrec_eval.RelevanceEvaluator(
|
||||
qrels, {map_string, ndcg_string, recall_string, precision_string}
|
||||
)
|
||||
scores = evaluator.evaluate(results)
|
||||
|
||||
for query_id in scores.keys():
|
||||
for k in k_values:
|
||||
all_ndcgs[f"NDCG@{k}"].append(scores[query_id]["ndcg_cut_" + str(k)])
|
||||
all_aps[f"MAP@{k}"].append(scores[query_id]["map_cut_" + str(k)])
|
||||
all_recalls[f"Recall@{k}"].append(scores[query_id]["recall_" + str(k)])
|
||||
all_precisions[f"P@{k}"].append(scores[query_id]["P_" + str(k)])
|
||||
|
||||
ndcg, _map, recall, precision = (
|
||||
all_ndcgs.copy(),
|
||||
all_aps.copy(),
|
||||
all_recalls.copy(),
|
||||
all_precisions.copy(),
|
||||
)
|
||||
|
||||
for k in k_values:
|
||||
ndcg[f"NDCG@{k}"] = round(sum(ndcg[f"NDCG@{k}"]) / len(scores), 5)
|
||||
_map[f"MAP@{k}"] = round(sum(_map[f"MAP@{k}"]) / len(scores), 5)
|
||||
recall[f"Recall@{k}"] = round(sum(recall[f"Recall@{k}"]) / len(scores), 5)
|
||||
precision[f"P@{k}"] = round(sum(precision[f"P@{k}"]) / len(scores), 5)
|
||||
|
||||
return ndcg, _map, recall, precision
|
||||
|
||||
|
||||
def index(
|
||||
index_factory: str = "Flat",
|
||||
corpus_embeddings: Optional[np.ndarray] = None,
|
||||
load_path: Optional[str] = None,
|
||||
device: Optional[str] = None
|
||||
):
|
||||
"""Create and add embeddings into a Faiss index.
|
||||
|
||||
Args:
|
||||
index_factory (str, optional): Type of Faiss index to create. Defaults to "Flat".
|
||||
corpus_embeddings (Optional[np.ndarray], optional): The embedding vectors of the corpus. Defaults to None.
|
||||
load_path (Optional[str], optional): Path to load embeddings from. Defaults to None.
|
||||
device (Optional[str], optional): Device to hold Faiss index. Defaults to None.
|
||||
|
||||
Returns:
|
||||
faiss.Index: The Faiss index that contains all the corpus embeddings.
|
||||
"""
|
||||
if corpus_embeddings is None:
|
||||
corpus_embeddings = np.load(load_path)
|
||||
|
||||
logger.info(f"Shape of embeddings: {corpus_embeddings.shape}")
|
||||
# create faiss index
|
||||
logger.info(f'Indexing {corpus_embeddings.shape[0]} documents...')
|
||||
faiss_index = faiss.index_factory(corpus_embeddings.shape[-1], index_factory, faiss.METRIC_INNER_PRODUCT)
|
||||
|
||||
if device is None and torch.cuda.is_available():
|
||||
try:
|
||||
co = faiss.GpuMultipleClonerOptions()
|
||||
co.shard = True
|
||||
co.useFloat16 = True
|
||||
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
|
||||
except:
|
||||
print('faiss do not support GPU, please uninstall faiss-cpu, faiss-gpu and install faiss-gpu again.')
|
||||
|
||||
logger.info('Adding embeddings ...')
|
||||
corpus_embeddings = corpus_embeddings.astype(np.float32)
|
||||
faiss_index.train(corpus_embeddings)
|
||||
faiss_index.add(corpus_embeddings)
|
||||
logger.info('Embeddings add over...')
|
||||
return faiss_index
|
||||
|
||||
|
||||
def search(
|
||||
faiss_index: faiss.Index,
|
||||
k: int = 100,
|
||||
query_embeddings: Optional[np.ndarray] = None,
|
||||
load_path: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
1. Encode queries into dense embeddings;
|
||||
2. Search through faiss index
|
||||
|
||||
Args:
|
||||
faiss_index (faiss.Index): The Faiss index that contains all the corpus embeddings.
|
||||
k (int, optional): Top k numbers of closest neighbours. Defaults to :data:`100`.
|
||||
query_embeddings (Optional[np.ndarray], optional): The embedding vectors of queries. Defaults to :data:`None`.
|
||||
load_path (Optional[str], optional): Path to load embeddings from. Defaults to :data:`None`.
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray]: The scores of search results and their corresponding indices.
|
||||
"""
|
||||
if query_embeddings is None:
|
||||
query_embeddings = np.load(load_path)
|
||||
|
||||
query_size = len(query_embeddings)
|
||||
|
||||
all_scores = []
|
||||
all_indices = []
|
||||
|
||||
for i in tqdm(range(0, query_size, 32), desc="Searching"):
|
||||
j = min(i + 32, query_size)
|
||||
query_embedding = query_embeddings[i: j]
|
||||
score, indice = faiss_index.search(query_embedding.astype(np.float32), k=k)
|
||||
all_scores.append(score)
|
||||
all_indices.append(indice)
|
||||
|
||||
all_scores = np.concatenate(all_scores, axis=0)
|
||||
all_indices = np.concatenate(all_indices, axis=0)
|
||||
return all_scores, all_indices
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbsEmbedderModelArguments:
|
||||
"""
|
||||
Abstract class for model arguments.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "The model checkpoint for initialization."}
|
||||
)
|
||||
config_name: str = field(
|
||||
default=None,
|
||||
metadata={"help": "Pretrained config name or path if not the same as model_name."}
|
||||
)
|
||||
tokenizer_name: str = field(
|
||||
default=None,
|
||||
metadata={"help": "Pretrained tokenizer name or path if not the same as model_name."}
|
||||
)
|
||||
cache_dir: str = field(
|
||||
default=None,
|
||||
metadata={"help": "Where do you want to store the pre-trained models downloaded from s3."}
|
||||
)
|
||||
trust_remote_code: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Trust remote code"}
|
||||
)
|
||||
token: str = field(
|
||||
default_factory=lambda: os.getenv('HF_TOKEN', None),
|
||||
metadata={"help": "The token to use when accessing the model."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbsEmbedderDataArguments:
|
||||
"""
|
||||
Abstract class for data arguments.
|
||||
"""
|
||||
train_data: str = field(
|
||||
default=None, metadata={
|
||||
"help": "One or more paths to training data. `query: str`, `pos: List[str]`, `neg: List[str]` are required in the training data.",
|
||||
"nargs": "+"
|
||||
}
|
||||
)
|
||||
cache_path: Optional[str] = field(
|
||||
default=None, metadata={"help": "Where do you want to store the cached data"}
|
||||
)
|
||||
train_group_size: int = field(default=8)
|
||||
|
||||
query_max_len: int = field(
|
||||
default=32,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer than this will be truncated."
|
||||
},
|
||||
)
|
||||
|
||||
passage_max_len: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer than this will be truncated."
|
||||
},
|
||||
)
|
||||
|
||||
pad_to_multiple_of: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set will pad the sequence to be a multiple of the provided value."
|
||||
},
|
||||
)
|
||||
|
||||
max_example_num_per_dataset: int = field(
|
||||
default=100000000, metadata={"help": "the max number of examples for each dataset"}
|
||||
)
|
||||
|
||||
query_instruction_for_retrieval: str= field(
|
||||
default=None, metadata={"help": "instruction for query"}
|
||||
)
|
||||
query_instruction_format: str = field(
|
||||
default="{}{}", metadata={"help": "format for query instruction"}
|
||||
)
|
||||
|
||||
knowledge_distillation: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use knowledge distillation when `pos_scores: List[float]` and `neg_scores: List[float]` are in features of training data"}
|
||||
)
|
||||
|
||||
passage_instruction_for_retrieval: Optional[str] = field(
|
||||
default=None, metadata={"help": "instruction for passage"}
|
||||
)
|
||||
passage_instruction_format: Optional[str] = field(
|
||||
default="{}{}", metadata={"help": "format for passage instruction"}
|
||||
)
|
||||
|
||||
shuffle_ratio: float = field(
|
||||
default=0.0, metadata={"help": "The ratio of shuffling the text"}
|
||||
)
|
||||
|
||||
# Parameters for SameDatasetDataArguments
|
||||
same_dataset_within_batch: bool = field(
|
||||
default=False, metadata={"help": "All samples in the same batch comes from the same dataset."}
|
||||
)
|
||||
small_threshold: int = field(
|
||||
default=0,
|
||||
metadata={"help": "The threshold of small dataset. All small dataset in the same directory will be merged into one dataset."}
|
||||
)
|
||||
drop_threshold: int = field(
|
||||
default=0,
|
||||
metadata={"help": "The threshold for dropping merged small dataset. If the number of examples in the merged small dataset is less than this threshold, it will be dropped."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# replace "\\n" with "\n"
|
||||
if "\\n" in self.query_instruction_format:
|
||||
self.query_instruction_format = self.query_instruction_format.replace("\\n", "\n")
|
||||
if "\\n" in self.passage_instruction_format:
|
||||
self.passage_instruction_format = self.passage_instruction_format.replace("\\n", "\n")
|
||||
|
||||
# check the existence of train data
|
||||
for train_dir in self.train_data:
|
||||
if not os.path.exists(train_dir):
|
||||
raise FileNotFoundError(f"cannot find file: {train_dir}, please set a true path")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbsEmbedderTrainingArguments(TrainingArguments):
|
||||
negatives_cross_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
|
||||
temperature: Optional[float] = field(default=0.02, metadata={"help": "temperature used for similarity score"})
|
||||
fix_position_embedding: bool = field(default=False, metadata={"help": "Freeze the parameters of position embeddings"})
|
||||
sentence_pooling_method: str = field(default='cls', metadata={"help": "the pooling method. Available options: cls, mean, last_token. Default: cls", "choices": ['cls', 'mean', 'last_token']})
|
||||
normalize_embeddings: bool = field(default=True, metadata={"help": "whether to normalize the embeddings"})
|
||||
sub_batch_size: Optional[int] = field(default=None, metadata={"help": "sub batch size for training"})
|
||||
kd_loss_type: str = field(default='kl_div', metadata={"help": "the loss type for knowledge distillation. Available options: kl_div, m3_kd_loss. Default: kl_div.", "choices": ['kl_div', 'm3_kd_loss']})
|
||||
|
|
@ -0,0 +1,616 @@
|
|||
import os
|
||||
import math
|
||||
import random
|
||||
import logging
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import (
|
||||
PreTrainedTokenizer,
|
||||
DataCollatorWithPadding,
|
||||
TrainerCallback,
|
||||
TrainerState,
|
||||
TrainerControl
|
||||
)
|
||||
|
||||
from .AbsArguments import AbsEmbedderDataArguments, AbsEmbedderTrainingArguments
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbsEmbedderTrainDataset(Dataset):
|
||||
"""Abstract class for training dataset.
|
||||
|
||||
Args:
|
||||
args (AbsEmbedderDataArguments): Data arguments.
|
||||
tokenizer (PreTrainedTokenizer): Tokenizer to use.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
args: AbsEmbedderDataArguments,
|
||||
tokenizer: PreTrainedTokenizer
|
||||
):
|
||||
self.args = args
|
||||
self.tokenizer = tokenizer
|
||||
self.shuffle_ratio = args.shuffle_ratio
|
||||
|
||||
train_datasets = []
|
||||
for data_dir in args.train_data:
|
||||
if not os.path.isdir(data_dir):
|
||||
if not (data_dir.endswith('.json') or data_dir.endswith('.jsonl')): continue
|
||||
temp_dataset = self._load_dataset(data_dir)
|
||||
if len(temp_dataset) == 0: continue
|
||||
train_datasets.append(temp_dataset)
|
||||
else:
|
||||
for file in os.listdir(data_dir):
|
||||
if not (file.endswith('.json') or file.endswith('.jsonl')): continue
|
||||
temp_dataset = self._load_dataset(os.path.join(data_dir, file))
|
||||
if len(temp_dataset) == 0: continue
|
||||
train_datasets.append(temp_dataset)
|
||||
self.dataset = datasets.concatenate_datasets(train_datasets)
|
||||
|
||||
def _load_dataset(self, file_path: str):
|
||||
"""Load dataset from path.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to load the datasets from.
|
||||
|
||||
Raises:
|
||||
ValueError: `pos_scores` and `neg_scores` not found in the features of training data
|
||||
|
||||
Returns:
|
||||
datasets.Dataset: Loaded HF dataset.
|
||||
"""
|
||||
if dist.get_rank() == 0:
|
||||
logger.info(f'loading data from {file_path} ...')
|
||||
|
||||
temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path)
|
||||
if len(temp_dataset) > self.args.max_example_num_per_dataset:
|
||||
temp_dataset = temp_dataset.select(random.sample(list(range(len(temp_dataset))), self.args.max_example_num_per_dataset))
|
||||
if not self.args.knowledge_distillation:
|
||||
if 'pos_scores' in temp_dataset.column_names:
|
||||
temp_dataset = temp_dataset.remove_columns(['pos_scores'])
|
||||
if 'neg_scores' in temp_dataset.column_names:
|
||||
temp_dataset = temp_dataset.remove_columns(['neg_scores'])
|
||||
else:
|
||||
if 'pos_scores' not in temp_dataset.column_names or 'neg_scores' not in temp_dataset.column_names:
|
||||
raise ValueError(f"`pos_scores` and `neg_scores` not found in the features of training data in {file_path}, which is necessary when using knowledge distillation.")
|
||||
return temp_dataset
|
||||
|
||||
def _shuffle_text(self, text):
|
||||
"""shuffle the input text.
|
||||
|
||||
Args:
|
||||
text (str): Input text.
|
||||
|
||||
Returns:
|
||||
str: Shuffled text.
|
||||
"""
|
||||
if self.shuffle_ratio > 0 and len(text) > 100 and random.random() < self.shuffle_ratio:
|
||||
split_text = []
|
||||
chunk_size = len(text)//3 + 1
|
||||
for i in range(0, len(text), chunk_size):
|
||||
split_text.append(text[i:i+chunk_size])
|
||||
random.shuffle(split_text)
|
||||
return " ".join(split_text)
|
||||
else:
|
||||
return text
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, item):
|
||||
data = self.dataset[item]
|
||||
train_group_size = self.args.train_group_size
|
||||
|
||||
query = data['query']
|
||||
if self.args.query_instruction_for_retrieval is not None:
|
||||
query = self.args.query_instruction_format.format(
|
||||
data['prompt'] if 'prompt' in data else self.args.query_instruction_for_retrieval,
|
||||
query
|
||||
)
|
||||
|
||||
passages = []
|
||||
teacher_scores = []
|
||||
|
||||
assert isinstance(data['pos'], list) and isinstance(data['neg'], list)
|
||||
|
||||
pos_idx = random.choice(list(range(len(data['pos']))))
|
||||
passages.append(self._shuffle_text(data['pos'][pos_idx]))
|
||||
|
||||
neg_all_idx = list(range(len(data['neg'])))
|
||||
if len(data['neg']) < train_group_size - 1:
|
||||
num = math.ceil((train_group_size - 1) / len(data['neg']))
|
||||
neg_idxs = random.sample(neg_all_idx * num, train_group_size - 1)
|
||||
else:
|
||||
neg_idxs = random.sample(neg_all_idx, self.args.train_group_size - 1)
|
||||
for neg_idx in neg_idxs:
|
||||
passages.append(data['neg'][neg_idx])
|
||||
|
||||
if self.args.knowledge_distillation:
|
||||
assert isinstance(data['pos_scores'], list) and isinstance(data['neg_scores'], list)
|
||||
teacher_scores.append(data['pos_scores'][pos_idx])
|
||||
for neg_idx in neg_idxs:
|
||||
teacher_scores.append(data['neg_scores'][neg_idx])
|
||||
if not all(isinstance(score, (int, float)) for score in teacher_scores):
|
||||
raise ValueError(f"pos_score or neg_score must be digit")
|
||||
else:
|
||||
teacher_scores = None
|
||||
|
||||
if self.args.passage_instruction_for_retrieval is not None:
|
||||
passages = [
|
||||
self.args.passage_instruction_format.format(
|
||||
self.args.passage_instruction_for_retrieval, p
|
||||
)
|
||||
for p in passages
|
||||
]
|
||||
|
||||
return query, passages, teacher_scores
|
||||
|
||||
@dataclass
|
||||
class AbsEmbedderCollator(DataCollatorWithPadding):
|
||||
"""
|
||||
The abstract embedder collator.
|
||||
"""
|
||||
query_max_len: int = 32
|
||||
passage_max_len: int = 128
|
||||
sub_batch_size: int = -1
|
||||
|
||||
def __call__(self, features):
|
||||
queries = [f[0] for f in features]
|
||||
passages = [f[1] for f in features]
|
||||
teacher_scores = [f[2] for f in features]
|
||||
if teacher_scores[0] is None:
|
||||
teacher_scores = None
|
||||
elif isinstance(teacher_scores[0], list):
|
||||
teacher_scores = sum(teacher_scores, [])
|
||||
|
||||
if isinstance(queries[0], list):
|
||||
queries = sum(queries, [])
|
||||
if isinstance(passages[0], list):
|
||||
passages = sum(passages, [])
|
||||
|
||||
queries_inputs = self.tokenizer(
|
||||
queries,
|
||||
truncation=True,
|
||||
max_length=self.query_max_len,
|
||||
return_tensors=None
|
||||
)
|
||||
passages_inputs = self.tokenizer(
|
||||
passages,
|
||||
truncation=True,
|
||||
max_length=self.passage_max_len,
|
||||
return_tensors=None
|
||||
)
|
||||
|
||||
if self.sub_batch_size is None or self.sub_batch_size <= 0:
|
||||
q_collated = self.tokenizer.pad(
|
||||
queries_inputs,
|
||||
padding=self.padding,
|
||||
max_length=self.query_max_len,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors
|
||||
)
|
||||
d_collated = self.tokenizer.pad(
|
||||
passages_inputs,
|
||||
padding=self.padding,
|
||||
max_length=self.passage_max_len,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors
|
||||
)
|
||||
else:
|
||||
batch_size = self.sub_batch_size
|
||||
|
||||
q_collated = []
|
||||
for i in range(0, len(queries_inputs['attention_mask']), batch_size):
|
||||
start = i
|
||||
end = min(len(queries_inputs['attention_mask']), i + batch_size)
|
||||
sub_features = {}
|
||||
for k, v in queries_inputs.items():
|
||||
sub_features[k] = v[start:end]
|
||||
q_collated.append(self.tokenizer.pad(
|
||||
sub_features,
|
||||
padding=self.padding,
|
||||
max_length=self.query_max_len,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors
|
||||
))
|
||||
|
||||
d_collated = []
|
||||
for i in range(0, len(passages_inputs['attention_mask']), batch_size):
|
||||
start = i
|
||||
end = min(len(passages_inputs['attention_mask']), i + batch_size)
|
||||
sub_features = {}
|
||||
|
||||
for k, v in passages_inputs.items():
|
||||
sub_features[k] = v[start:end]
|
||||
d_collated.append(self.tokenizer.pad(
|
||||
sub_features,
|
||||
padding=self.padding,
|
||||
max_length=self.passage_max_len,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors
|
||||
))
|
||||
return {
|
||||
"queries": q_collated,
|
||||
"passages": d_collated,
|
||||
"teacher_scores": teacher_scores,
|
||||
"no_in_batch_neg_flag": False
|
||||
}
|
||||
|
||||
|
||||
class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset):
|
||||
"""Abstract class for training dataset that samples batches from same dataset.
|
||||
|
||||
Args:
|
||||
args (AbsEmbedderDataArguments): Data arguments.
|
||||
default_batch_size (int): The default batch size for training.
|
||||
seed (int): Random seed.
|
||||
tokenizer (PreTrainedTokenizer): Tokenizer to use.
|
||||
process_index (int, optional): Current process index. Defaults to 0.
|
||||
num_processes (int, optional): Total number of processes. Defaults to 1.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
args: AbsEmbedderDataArguments,
|
||||
default_batch_size: int,
|
||||
seed: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
process_index: int=0,
|
||||
num_processes: int=1
|
||||
):
|
||||
self.args = args
|
||||
self.shuffle_ratio = args.shuffle_ratio
|
||||
self.defaut_batch_size = default_batch_size
|
||||
self.deterministic_generator = np.random.default_rng(seed)
|
||||
self.tokenizer = tokenizer
|
||||
self.process_index = process_index
|
||||
self.num_processes = num_processes
|
||||
|
||||
self.step = 0
|
||||
|
||||
train_datasets = []
|
||||
each_data_idxs = []
|
||||
batch_size_idxs = []
|
||||
no_in_batch_neg_flags = []
|
||||
cur_all_num = 0
|
||||
|
||||
small_threshold = args.small_threshold
|
||||
drop_threshold = args.drop_threshold
|
||||
|
||||
for data_dir in args.train_data:
|
||||
if not os.path.isdir(data_dir):
|
||||
# Add `no_in_batch_neg` **suffix** to `data_dir` to indicate that this dataset does not use in-batch negatives
|
||||
no_in_batch_neg_flag = data_dir.split('.')[-2].endswith('no_in_batch_neg')
|
||||
if not (data_dir.endswith('.json') or data_dir.endswith('.jsonl')): continue
|
||||
temp_dataset = self._load_dataset(data_dir)
|
||||
|
||||
if len(temp_dataset) == 0 or len(temp_dataset) < small_threshold: continue
|
||||
else:
|
||||
train_datasets.append(temp_dataset)
|
||||
each_data_idxs.append(np.arange(len(temp_dataset)) + cur_all_num)
|
||||
cur_all_num += len(temp_dataset)
|
||||
batch_size_idxs.append(self._get_file_batch_size(temp_dataset, default_batch_size))
|
||||
no_in_batch_neg_flags.append(no_in_batch_neg_flag)
|
||||
|
||||
else:
|
||||
small_datasets = []
|
||||
small_batch_size = math.inf
|
||||
|
||||
# Add `no_in_batch_neg` **suffix** to `data_dir` to indicate that this dataset does not use in-batch negatives
|
||||
no_in_batch_neg_flag = data_dir.endswith('no_in_batch_neg')
|
||||
for file in os.listdir(data_dir):
|
||||
if not (file.endswith('.json') or file.endswith('.jsonl')): continue
|
||||
temp_dataset = self._load_dataset(os.path.join(data_dir, file))
|
||||
|
||||
if len(temp_dataset) == 0: continue
|
||||
elif len(temp_dataset) < small_threshold:
|
||||
small_datasets.append(temp_dataset)
|
||||
small_batch_size = min(small_batch_size, self._get_file_batch_size(temp_dataset, default_batch_size))
|
||||
else:
|
||||
train_datasets.append(temp_dataset)
|
||||
each_data_idxs.append(np.arange(len(temp_dataset)) + cur_all_num)
|
||||
cur_all_num += len(temp_dataset)
|
||||
batch_size_idxs.append(self._get_file_batch_size(temp_dataset, default_batch_size))
|
||||
no_in_batch_neg_flags.append(no_in_batch_neg_flag)
|
||||
|
||||
if len(small_datasets) > 0:
|
||||
small_dataset = datasets.concatenate_datasets(small_datasets)
|
||||
if len(small_dataset) >= drop_threshold:
|
||||
train_datasets.append(small_dataset)
|
||||
each_data_idxs.append(np.arange(len(small_dataset)) + cur_all_num)
|
||||
cur_all_num += len(small_dataset)
|
||||
batch_size_idxs.append(small_batch_size)
|
||||
no_in_batch_neg_flags.append(no_in_batch_neg_flag)
|
||||
|
||||
self.dataset = datasets.concatenate_datasets(train_datasets)
|
||||
self.each_data_idxs = each_data_idxs
|
||||
self.datasets_inxs = np.arange(len(each_data_idxs))
|
||||
self.batch_size_idxs = batch_size_idxs
|
||||
self.no_in_batch_neg_flags = no_in_batch_neg_flags
|
||||
|
||||
self.refresh_epoch()
|
||||
|
||||
def _load_dataset(self, file_path: str):
|
||||
"""Load datset from given path.
|
||||
|
||||
Args:
|
||||
file_path (str): The path to load or download from HF hub.
|
||||
|
||||
Returns:
|
||||
datasets.Dataset: The loaded dataset.
|
||||
"""
|
||||
if dist.get_rank() == 0:
|
||||
logger.info(f'loading data from {file_path} ...')
|
||||
|
||||
temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path)
|
||||
if len(temp_dataset) > self.args.max_example_num_per_dataset:
|
||||
temp_dataset = temp_dataset.select(random.sample(list(range(len(temp_dataset))), self.args.max_example_num_per_dataset))
|
||||
if not self.args.knowledge_distillation:
|
||||
if 'pos_scores' in temp_dataset.column_names:
|
||||
temp_dataset = temp_dataset.remove_columns(['pos_scores'])
|
||||
if 'neg_scores' in temp_dataset.column_names:
|
||||
temp_dataset = temp_dataset.remove_columns(['neg_scores'])
|
||||
return temp_dataset
|
||||
|
||||
@staticmethod
|
||||
def _get_file_batch_size(temp_dataset: datasets.Dataset, default_batch_size: int):
|
||||
"""Get the appropriate batch size for the dataset.
|
||||
|
||||
Args:
|
||||
temp_dataset (datasets.Dataset): Loaded :data:`datasets.Dataset` object.
|
||||
default_batch_size (int): The default batch size to use if not specified in the dataset.
|
||||
|
||||
Returns:
|
||||
int: The final batch size to use.
|
||||
"""
|
||||
if 'batch_size' in temp_dataset.column_names:
|
||||
return temp_dataset['batch_size'][0]
|
||||
if 'type' in temp_dataset.column_names:
|
||||
data_type = temp_dataset['type'][0]
|
||||
if 'symmetric' in data_type:
|
||||
return default_batch_size // 2 # make the symmetric data have smaller batch size
|
||||
return default_batch_size
|
||||
|
||||
def refresh_epoch(self):
|
||||
"""
|
||||
Refresh data for epoch.
|
||||
"""
|
||||
logger.info(f'-- Rank {self.process_index}: refresh data --')
|
||||
self.deterministic_generator.shuffle(self.datasets_inxs)
|
||||
|
||||
batch_datas = []
|
||||
for dataset_inx in self.datasets_inxs:
|
||||
self.deterministic_generator.shuffle(self.each_data_idxs[dataset_inx])
|
||||
cur_batch_size = self.batch_size_idxs[dataset_inx]*self.num_processes
|
||||
no_in_batch_neg_flag = self.no_in_batch_neg_flags[dataset_inx]
|
||||
for start_index in range(0, len(self.each_data_idxs[dataset_inx]), cur_batch_size):
|
||||
# judge the last batch's length
|
||||
if len(self.each_data_idxs[dataset_inx]) - start_index < cur_batch_size:
|
||||
break
|
||||
batch_datas.append((
|
||||
self.each_data_idxs[dataset_inx][start_index:start_index+cur_batch_size],
|
||||
no_in_batch_neg_flag
|
||||
))
|
||||
self.deterministic_generator.shuffle(batch_datas)
|
||||
self.batch_datas = batch_datas
|
||||
self.step = 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batch_datas) * self.num_processes
|
||||
|
||||
def __getitem__(self, _):
|
||||
batch_indices, no_in_batch_neg_flag = self.batch_datas[self.step] # extend here
|
||||
cur_batch_size = int(len(batch_indices) / self.num_processes)
|
||||
batch_indices = batch_indices[self.process_index * cur_batch_size: (self.process_index + 1) * cur_batch_size]
|
||||
batch_data = self.dataset[batch_indices]
|
||||
self.step += 1
|
||||
queries, passages, teacher_scores = self._create_batch_data(batch_raw_data=batch_data)
|
||||
return queries, passages, teacher_scores, no_in_batch_neg_flag
|
||||
|
||||
def _get_train_group_size(self, batch_raw_data):
|
||||
"""Get the training group size and data type.
|
||||
|
||||
Args:
|
||||
batch_raw_data (datasets.Dataset): One batch of raw data.
|
||||
|
||||
Returns:
|
||||
int: The training group size.
|
||||
str: The type of data for the task.
|
||||
"""
|
||||
if 'type' in batch_raw_data:
|
||||
data_type = batch_raw_data['type'][0]
|
||||
if data_type in ['only_1neg']:
|
||||
return 2, data_type
|
||||
elif data_type in ['symmetric_class']:
|
||||
return min(len(batch_raw_data['neg'][0]) + 1, self.args.train_group_size), data_type
|
||||
else:
|
||||
return self.args.train_group_size, data_type
|
||||
return self.args.train_group_size, None
|
||||
|
||||
def _create_batch_data(self, batch_raw_data):
|
||||
"""Create a comple batch of data with queries, documents and teacher scores.
|
||||
|
||||
Args:
|
||||
batch_raw_data (datasets.Dataset): One batch of raw data.
|
||||
|
||||
Returns:
|
||||
List[str]: Queries with instruction format.
|
||||
List[str]: Documents with instruction format.
|
||||
List[float]: Teacher scores for model distillation.
|
||||
"""
|
||||
queries, passages, teacher_scores = [], [], []
|
||||
|
||||
train_group_size, data_type = self._get_train_group_size(batch_raw_data)
|
||||
|
||||
for i in range(len(batch_raw_data['query'])):
|
||||
if data_type is not None:
|
||||
assert batch_raw_data['type'][i] == data_type, f"Data type is not consistent in the same batch"
|
||||
|
||||
queries.append(
|
||||
self.args.query_instruction_format.format(
|
||||
batch_raw_data['prompt'][i] if 'prompt' in batch_raw_data else self.args.query_instruction_for_retrieval,
|
||||
batch_raw_data['query'][i]
|
||||
)
|
||||
)
|
||||
tmp_passages = []
|
||||
pos_idx = random.choice(list(range(len(batch_raw_data['pos'][i]))))
|
||||
pos = self._shuffle_text(batch_raw_data['pos'][i][pos_idx])
|
||||
tmp_passages.append(pos)
|
||||
|
||||
neg_all_idx = list(range(len(batch_raw_data['neg'][i])))
|
||||
if len(batch_raw_data['neg'][i]) < train_group_size - 1:
|
||||
num = math.ceil((train_group_size - 1) / len(batch_raw_data['neg'][i]))
|
||||
neg_idxs = random.sample(neg_all_idx * num, train_group_size - 1)
|
||||
else:
|
||||
neg_idxs = random.sample(neg_all_idx, train_group_size - 1)
|
||||
for neg_idx in neg_idxs:
|
||||
tmp_passages.append(batch_raw_data['neg'][i][neg_idx])
|
||||
|
||||
if self.args.knowledge_distillation:
|
||||
if 'pos_scores' in batch_raw_data and batch_raw_data['pos_scores'][i] is not None:
|
||||
teacher_scores.append(batch_raw_data['pos_scores'][i][pos_idx])
|
||||
for neg_idx in neg_idxs:
|
||||
if 'neg_scores' in batch_raw_data and batch_raw_data['neg_scores'][i] is not None:
|
||||
teacher_scores.append(batch_raw_data['neg_scores'][i][neg_idx])
|
||||
else:
|
||||
teacher_scores = None
|
||||
|
||||
if data_type is not None and data_type in ['symmetric_sts', 'symmetric_clustering']:
|
||||
tmp_passages = [
|
||||
self.args.query_instruction_format.format(
|
||||
batch_raw_data['prompt'][i] if 'prompt' in batch_raw_data else self.args.query_instruction_for_retrieval,
|
||||
p
|
||||
) for p in tmp_passages
|
||||
]
|
||||
else:
|
||||
if self.args.passage_instruction_for_retrieval is not None:
|
||||
tmp_passages = [
|
||||
self.args.passage_instruction_format.format(
|
||||
self.args.passage_instruction_for_retrieval, p
|
||||
) for p in tmp_passages
|
||||
]
|
||||
|
||||
passages.extend(tmp_passages)
|
||||
|
||||
if teacher_scores is not None:
|
||||
if len(teacher_scores) > 0 and len(passages) > 0:
|
||||
assert len(teacher_scores) == len(passages)
|
||||
|
||||
return queries, passages, teacher_scores
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbsEmbedderSameDatasetCollator(DataCollatorWithPadding):
|
||||
"""
|
||||
EmbedCollator for SameDataset.
|
||||
Note that after using this collator, the training_args should be set as:
|
||||
|
||||
``training_args.per_device_train_batch_size = 1``
|
||||
|
||||
``training_args.dataloader_num_workers = 0 # avoid multi-processing``
|
||||
"""
|
||||
query_max_len: int = 32
|
||||
passage_max_len: int = 128
|
||||
sub_batch_size: int = -1
|
||||
|
||||
def __call__(self, features):
|
||||
queries = features[0][0]
|
||||
passages = features[0][1]
|
||||
teacher_scores = features[0][2]
|
||||
no_in_batch_neg_flag = features[0][3]
|
||||
|
||||
queries_inputs = self.tokenizer(
|
||||
queries,
|
||||
truncation=True,
|
||||
max_length=self.query_max_len,
|
||||
return_tensors=None
|
||||
)
|
||||
passages_inputs = self.tokenizer(
|
||||
passages,
|
||||
truncation=True,
|
||||
max_length=self.passage_max_len,
|
||||
return_tensors=None
|
||||
)
|
||||
|
||||
if self.sub_batch_size is None or self.sub_batch_size <= 0:
|
||||
q_collated = self.tokenizer.pad(
|
||||
queries_inputs,
|
||||
padding=self.padding,
|
||||
max_length=self.query_max_len,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors,
|
||||
)
|
||||
|
||||
d_collated = self.tokenizer.pad(
|
||||
passages_inputs,
|
||||
padding=self.padding,
|
||||
max_length=self.passage_max_len,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors,
|
||||
)
|
||||
else:
|
||||
batch_size = self.sub_batch_size
|
||||
|
||||
q_collated = []
|
||||
for i in range(0, len(queries_inputs['attention_mask']), batch_size):
|
||||
start = i
|
||||
end = min(len(queries_inputs['attention_mask']), i + batch_size)
|
||||
sub_features = {}
|
||||
for k, v in queries_inputs.items():
|
||||
sub_features[k] = v[start:end]
|
||||
q_collated.append(self.tokenizer.pad(
|
||||
sub_features,
|
||||
padding=self.padding,
|
||||
max_length=self.query_max_len,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors,
|
||||
))
|
||||
|
||||
d_collated = []
|
||||
for i in range(0, len(passages_inputs['attention_mask']), batch_size):
|
||||
start = i
|
||||
end = min(len(passages_inputs['attention_mask']), i + batch_size)
|
||||
sub_features = {}
|
||||
|
||||
for k, v in passages_inputs.items():
|
||||
sub_features[k] = v[start:end]
|
||||
d_collated.append(self.tokenizer.pad(
|
||||
sub_features,
|
||||
padding=self.padding,
|
||||
max_length=self.passage_max_len,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors,
|
||||
))
|
||||
|
||||
if isinstance(teacher_scores, list) and len(teacher_scores) == 0:
|
||||
teacher_scores = None
|
||||
|
||||
return {
|
||||
"queries": q_collated,
|
||||
"passages": d_collated,
|
||||
"teacher_scores": teacher_scores,
|
||||
"no_in_batch_neg_flag": no_in_batch_neg_flag
|
||||
}
|
||||
|
||||
|
||||
class EmbedderTrainerCallbackForDataRefresh(TrainerCallback):
|
||||
"""
|
||||
Callback class to inspect the state of the training loop and take decision.
|
||||
"""
|
||||
def __init__(self, train_dataset: AbsEmbedderSameDatasetTrainDataset):
|
||||
self.train_dataset = train_dataset
|
||||
|
||||
def on_epoch_end(
|
||||
self,
|
||||
args: AbsEmbedderTrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Event called at the end of an epoch.
|
||||
"""
|
||||
self.train_dataset.refresh_epoch()
|
||||
|
|
@ -0,0 +1,340 @@
|
|||
import torch
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.file_utils import ModelOutput
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, List, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbedderOutput(ModelOutput):
|
||||
"""
|
||||
Output information returned by the model.
|
||||
"""
|
||||
q_reps: Optional[Tensor] = None
|
||||
p_reps: Optional[Tensor] = None
|
||||
loss: Optional[Tensor] = None
|
||||
scores: Optional[Tensor] = None
|
||||
|
||||
|
||||
class AbsEmbedderModel(ABC, nn.Module):
|
||||
"""Abstract class of embedding model for training.
|
||||
|
||||
Args:
|
||||
base_model: The base model to train on.
|
||||
tokenizer (PreTrainedTokenizer, optional): The tokenizer to use. Defaults to ``None``.
|
||||
negatives_cross_device (bool, optional): If True, will compute cross devices negative loss. Defaults to ``False``.
|
||||
temperature (float, optional): Temperature to control the scale of scores. Defaults to ``1.0``.
|
||||
sub_batch_size (int, optional): Sub-batch size during encoding. If negative, will not split to sub-batch.
|
||||
Defaults to ``-1``.
|
||||
kd_loss_type (str, optional): Type of knowledge distillation loss. Defaults to ``"kl_div"``.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
base_model,
|
||||
tokenizer: PreTrainedTokenizer = None,
|
||||
negatives_cross_device: bool = False,
|
||||
temperature: float = 1.0,
|
||||
sub_batch_size: int = -1,
|
||||
kd_loss_type: str = 'kl_div',
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
self.model = base_model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.temperature = temperature
|
||||
self.negatives_cross_device = negatives_cross_device
|
||||
if self.negatives_cross_device:
|
||||
if not dist.is_initialized():
|
||||
raise ValueError('Distributed training has not been initialized for representation all gather.')
|
||||
self.process_rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
|
||||
self.sub_batch_size = sub_batch_size
|
||||
self.kd_loss_type = kd_loss_type
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, features):
|
||||
"""Abstract method encode and get the embedding.
|
||||
|
||||
Args:
|
||||
features (Union[list, dict]): Features feed to the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, scores, target):
|
||||
"""Abstract method compute the loss.
|
||||
|
||||
Args:
|
||||
scores (torch.Tensor): Computed score.
|
||||
target (torch.Tensor): The target value.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_score(self, q_reps, p_reps):
|
||||
"""Abstract method to compute the score.
|
||||
|
||||
Args:
|
||||
q_reps (torch.Tensor): Queries representations.
|
||||
p_reps (torch.Tensor): Passages rerpresentations.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, output_dir: str):
|
||||
"""Abstract method to save the model.
|
||||
|
||||
Args:
|
||||
output_dir (str): Directory for saving the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_local_score(self, q_reps, p_reps, all_scores):
|
||||
"""Get the local score of queries and passages.
|
||||
|
||||
Args:
|
||||
q_reps (torch.Tensor): Queries representations.
|
||||
p_reps (torch.Tensor): Passages rerpresentations.
|
||||
all_scores (torch.Tensor): All the query-passage scores computed.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Local scores to compute loss.
|
||||
"""
|
||||
group_size = p_reps.size(0) // q_reps.size(0)
|
||||
indices = torch.arange(0, q_reps.size(0), device=q_reps.device) * group_size
|
||||
specific_scores = []
|
||||
for i in range(group_size):
|
||||
specific_scores.append(
|
||||
all_scores[torch.arange(q_reps.size(0), device=q_reps.device), indices + i]
|
||||
)
|
||||
return torch.stack(specific_scores, dim=1).view(q_reps.size(0), -1)
|
||||
|
||||
def compute_local_score(self, q_reps, p_reps, compute_score_func=None, **kwargs):
|
||||
"""Compute the local score of queries and passages.
|
||||
|
||||
Args:
|
||||
q_reps (torch.Tensor): Queries representations.
|
||||
p_reps (torch.Tensor): Passages rerpresentations.
|
||||
compute_score_func (function, optional): Function to compute score. Defaults to ``None``, which will use the
|
||||
:meth:`self.compute_score`.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Local scores to compute loss.
|
||||
"""
|
||||
if compute_score_func is None:
|
||||
all_scores = self.compute_score(q_reps, p_reps)
|
||||
else:
|
||||
all_scores = compute_score_func(q_reps, p_reps, **kwargs)
|
||||
loacl_scores = self.get_local_score(q_reps, p_reps, all_scores)
|
||||
return loacl_scores
|
||||
|
||||
def _compute_no_in_batch_neg_loss(self, q_reps, p_reps, teacher_targets=None, compute_score_func=None, **kwargs):
|
||||
"""
|
||||
Compute loss when using no in-batch negatives and no cross-device negatives
|
||||
"""
|
||||
group_size = p_reps.size(0) // q_reps.size(0)
|
||||
|
||||
local_scores = self.compute_local_score(q_reps, p_reps, compute_score_func, **kwargs) # (batch_size, group_size)
|
||||
|
||||
if teacher_targets is not None:
|
||||
# compute kd loss
|
||||
loss = self.distill_loss(self.kd_loss_type, teacher_targets, local_scores, group_size=group_size)
|
||||
|
||||
# add normal loss if needed
|
||||
if self.kd_loss_type == "kl_div":
|
||||
local_targets = torch.zeros(local_scores.size(0), device=local_scores.device, dtype=torch.long) # (batch_size)
|
||||
loss += self.compute_loss(local_scores, local_targets)
|
||||
else:
|
||||
local_targets = torch.zeros(local_scores.size(0), device=local_scores.device, dtype=torch.long) # (batch_size)
|
||||
loss = self.compute_loss(local_scores, local_targets)
|
||||
|
||||
return local_scores, loss
|
||||
|
||||
def _compute_in_batch_neg_loss(self, q_reps, p_reps, teacher_targets=None, compute_score_func=None, **kwargs):
|
||||
"""
|
||||
Compute loss when only using in-batch negatives
|
||||
"""
|
||||
group_size = p_reps.size(0) // q_reps.size(0)
|
||||
|
||||
if compute_score_func is None:
|
||||
scores = self.compute_score(q_reps, p_reps) # (batch_size, batch_size * group_size)
|
||||
else:
|
||||
scores = compute_score_func(q_reps, p_reps, **kwargs) # (batch_size, batch_size * group_size)
|
||||
|
||||
if teacher_targets is not None:
|
||||
# compute kd loss
|
||||
if self.kd_loss_type == "kl_div":
|
||||
student_scores = self.get_local_score(q_reps, p_reps, scores) # (batch_size, group_size)
|
||||
|
||||
loss = self.distill_loss(self.kd_loss_type, teacher_targets, student_scores, group_size)
|
||||
|
||||
idxs = torch.arange(q_reps.size(0), device=q_reps.device, dtype=torch.long)
|
||||
targets = idxs * (p_reps.size(0) // q_reps.size(0)) # (batch_size)
|
||||
loss += self.compute_loss(scores, targets)
|
||||
elif self.kd_loss_type == "m3_kd_loss":
|
||||
loss = self.distill_loss(self.kd_loss_type, teacher_targets, scores, group_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid kd_loss_type: {self.kd_loss_type}")
|
||||
else:
|
||||
idxs = torch.arange(q_reps.size(0), device=q_reps.device, dtype=torch.long)
|
||||
targets = idxs * group_size # (batch_size)
|
||||
loss = self.compute_loss(scores, targets)
|
||||
|
||||
return scores, loss
|
||||
|
||||
def _compute_cross_device_neg_loss(self, q_reps, p_reps, teacher_targets=None, compute_score_func=None, **kwargs):
|
||||
"""
|
||||
Compute loss when using both in-batch negatives and cross-device negatives
|
||||
"""
|
||||
group_size = p_reps.size(0) // q_reps.size(0)
|
||||
|
||||
cross_q_reps = self._dist_gather_tensor(q_reps) # (world_size * batch_size, dim)
|
||||
cross_p_reps = self._dist_gather_tensor(p_reps) # (world_size * batch_size * group_size, dim)
|
||||
|
||||
if compute_score_func is None:
|
||||
cross_scores = self.compute_score(cross_q_reps, cross_p_reps) # (world_size * batch_size, world_size * batch_size * group_size)
|
||||
else:
|
||||
cross_scores = compute_score_func(cross_q_reps, cross_p_reps, **kwargs) # (world_size * batch_size, world_size * batch_size * group_size)
|
||||
|
||||
if teacher_targets is not None:
|
||||
# compute kd loss
|
||||
if self.kd_loss_type == "kl_div":
|
||||
student_scores = self.get_local_score(cross_q_reps, cross_p_reps, cross_scores) # (world_size * batch_size, group_size)
|
||||
student_scores = student_scores[
|
||||
q_reps.size(0)*self.process_rank : q_reps.size(0)*(self.process_rank+1)
|
||||
] # (batch_size, group_size)
|
||||
|
||||
loss = self.distill_loss(self.kd_loss_type, teacher_targets, student_scores, group_size)
|
||||
|
||||
cross_idxs = torch.arange(cross_q_reps.size(0), device=cross_q_reps.device, dtype=torch.long)
|
||||
cross_targets = cross_idxs * group_size # (world_size * batch_size)
|
||||
loss += self.compute_loss(cross_scores, cross_targets)
|
||||
elif self.kd_loss_type == "m3_kd_loss":
|
||||
cross_teacher_targets = self._dist_gather_tensor(teacher_targets) # (world_size * batch_size, group_size)
|
||||
|
||||
loss = self.distill_loss(self.kd_loss_type, cross_teacher_targets, cross_scores, group_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid kd_loss_type: {self.kd_loss_type}")
|
||||
else:
|
||||
cross_idxs = torch.arange(cross_q_reps.size(0), device=cross_q_reps.device, dtype=torch.long)
|
||||
cross_targets = cross_idxs * group_size # (world_size * batch_size)
|
||||
loss = self.compute_loss(cross_scores, cross_targets)
|
||||
|
||||
return cross_scores, loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
queries: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None,
|
||||
passages: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None,
|
||||
teacher_scores: Union[None, List[float]] = None,
|
||||
no_in_batch_neg_flag: bool = False,
|
||||
):
|
||||
"""The computation performed at every call.
|
||||
|
||||
Args:
|
||||
queries (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): Input queries. Defaults to ``None``.
|
||||
passages (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): Input passages. Defaults to ``None``.
|
||||
teacher_scores (Union[None, List[float]], optional): Teacher scores for distillation. Defaults to ``None``.
|
||||
no_in_batch_neg_flag (bool, optional): If True, use no in-batch negatives and no cross-device negatives. Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
EmbedderOutput: Output of the forward call of model.
|
||||
"""
|
||||
q_reps = self.encode(queries) # (batch_size, dim)
|
||||
p_reps = self.encode(passages) # (batch_size * group_size, dim)
|
||||
|
||||
if self.training:
|
||||
if teacher_scores is not None:
|
||||
teacher_scores = torch.tensor(teacher_scores, device=q_reps.device)
|
||||
teacher_scores = teacher_scores.view(q_reps.size(0), -1).detach() # (batch_size, group_size)
|
||||
teacher_targets = F.softmax(teacher_scores, dim=-1) # (batch_size, group_size)
|
||||
else:
|
||||
teacher_targets = None
|
||||
|
||||
if no_in_batch_neg_flag:
|
||||
compute_loss_func = self._compute_no_in_batch_neg_loss
|
||||
else:
|
||||
if self.negatives_cross_device:
|
||||
compute_loss_func = self._compute_cross_device_neg_loss
|
||||
else:
|
||||
compute_loss_func = self._compute_in_batch_neg_loss
|
||||
|
||||
scores, loss = compute_loss_func(q_reps, p_reps, teacher_targets=teacher_targets)
|
||||
else:
|
||||
loss = None
|
||||
|
||||
return EmbedderOutput(
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def distill_loss(kd_loss_type, teacher_targets, student_scores, group_size=None):
|
||||
"""Compute the distillation loss.
|
||||
|
||||
Args:
|
||||
kd_loss_type (str): Type of knowledge distillation loss, supports "kl_div" and "m3_kd_loss".
|
||||
teacher_targets (torch.Tensor): Targets from the teacher model.
|
||||
student_scores (torch.Tensor): Score of student model.
|
||||
group_size (int, optional): Number of groups for . Defaults to ``None``.
|
||||
|
||||
Raises:
|
||||
ValueError: Invalid kd_loss_type
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A scalar of computed distillation loss.
|
||||
"""
|
||||
if kd_loss_type == 'kl_div':
|
||||
# teacher_targets: (batch_size, group_size) / (world_size * batch_size, group_size)
|
||||
# student_scores: (batch_size, group_size) / (world_size * batch_size, group_size)
|
||||
return - torch.mean(
|
||||
torch.sum(torch.log_softmax(student_scores, dim=-1) * teacher_targets, dim=-1)
|
||||
)
|
||||
elif kd_loss_type == 'm3_kd_loss':
|
||||
# teacher_targets: (batch_size, group_size) / (world_size * batch_size, group_size)
|
||||
# student_scores: (batch_size, batch_size * group_size) / (world_size * batch_size, world_size * batch_size * group_size)
|
||||
labels = torch.arange(student_scores.size(0), device=student_scores.device, dtype=torch.long)
|
||||
labels = labels * group_size
|
||||
|
||||
loss = 0
|
||||
mask = torch.zeros_like(student_scores)
|
||||
for i in range(group_size):
|
||||
temp_target = labels + i
|
||||
temp_scores = student_scores + mask
|
||||
temp_loss = F.cross_entropy(temp_scores, temp_target, reduction="none") # B
|
||||
loss += torch.mean(teacher_targets[:, i] * temp_loss)
|
||||
mask = torch.scatter(mask, dim=-1, index=temp_target.unsqueeze(-1),
|
||||
value=torch.finfo(student_scores.dtype).min)
|
||||
return loss
|
||||
else:
|
||||
raise ValueError(f"Invalid kd_loss_type: {kd_loss_type}")
|
||||
|
||||
def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
|
||||
"""Gather a tensor from all processes in a distributed setting.
|
||||
|
||||
Args:
|
||||
t (Optional[torch.Tensor]): The input tensor to be gathered. If `None`, no gathering is performed.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, None]: A concatenated tensor from all processes if ``t`` is not ``None``,
|
||||
otherwise returns ``None``.
|
||||
"""
|
||||
if t is None:
|
||||
return None
|
||||
t = t.contiguous()
|
||||
|
||||
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
|
||||
dist.all_gather(all_tensors, t)
|
||||
|
||||
all_tensors[self.process_rank] = t
|
||||
all_tensors = torch.cat(all_tensors, dim=0)
|
||||
|
||||
return all_tensors
|
||||
|
|
@ -0,0 +1,150 @@
|
|||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
from transformers import set_seed, PreTrainedTokenizer
|
||||
|
||||
|
||||
from .AbsArguments import (
|
||||
AbsEmbedderModelArguments,
|
||||
AbsEmbedderDataArguments,
|
||||
AbsEmbedderTrainingArguments
|
||||
)
|
||||
from .AbsTrainer import AbsEmbedderTrainer
|
||||
from .AbsModeling import AbsEmbedderModel
|
||||
from .AbsDataset import (
|
||||
AbsEmbedderTrainDataset, AbsEmbedderCollator,
|
||||
AbsEmbedderSameDatasetTrainDataset, AbsEmbedderSameDatasetCollator
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbsEmbedderRunner(ABC):
|
||||
"""Abstract class to run embedding model fine-tuning.
|
||||
|
||||
Args:
|
||||
model_args (AbsEmbedderModelArguments): Model arguments
|
||||
data_args (AbsEmbedderDataArguments): Data arguments.
|
||||
training_args (AbsEmbedderTrainingArguments): Training arguments.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model_args: AbsEmbedderModelArguments,
|
||||
data_args: AbsEmbedderDataArguments,
|
||||
training_args: AbsEmbedderTrainingArguments
|
||||
):
|
||||
self.model_args = model_args
|
||||
self.data_args = data_args
|
||||
self.training_args = training_args
|
||||
|
||||
if (
|
||||
os.path.exists(training_args.output_dir)
|
||||
and os.listdir(training_args.output_dir)
|
||||
and training_args.do_train
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
training_args.local_rank,
|
||||
training_args.device,
|
||||
training_args.n_gpu,
|
||||
bool(training_args.local_rank != -1),
|
||||
training_args.fp16,
|
||||
)
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
logger.info("Model parameters %s", model_args)
|
||||
logger.info("Data parameters %s", data_args)
|
||||
|
||||
# Set seed
|
||||
set_seed(training_args.seed)
|
||||
|
||||
self.tokenizer, self.model = self.load_tokenizer_and_model()
|
||||
self.train_dataset = self.load_train_dataset()
|
||||
self.data_collator = self.load_data_collator()
|
||||
self.trainer = self.load_trainer()
|
||||
|
||||
@abstractmethod
|
||||
def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderModel]:
|
||||
"""Abstract method to load the tokenizer and model.
|
||||
|
||||
Returns:
|
||||
Tuple[PreTrainedTokenizer, AbsEmbedderModel]: Loaded tokenizer and model instances.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_trainer(self) -> AbsEmbedderTrainer:
|
||||
"""Abstract method to load the trainer.
|
||||
|
||||
Returns:
|
||||
AbsEmbedderTrainer: The loaded trainer instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
def load_train_dataset(self) -> AbsEmbedderTrainDataset:
|
||||
"""Loads the training dataset based on data arguments.
|
||||
|
||||
Returns:
|
||||
AbsEmbedderTrainDataset: The loaded dataset instance.
|
||||
"""
|
||||
if self.data_args.same_dataset_within_batch:
|
||||
train_dataset = AbsEmbedderSameDatasetTrainDataset(
|
||||
args=self.data_args,
|
||||
default_batch_size=self.training_args.per_device_train_batch_size,
|
||||
seed=self.training_args.seed,
|
||||
tokenizer=self.tokenizer,
|
||||
process_index=self.training_args.process_index,
|
||||
num_processes=self.training_args.world_size
|
||||
)
|
||||
self.training_args.per_device_train_batch_size = 1
|
||||
self.training_args.dataloader_num_workers = 0 # avoid multi-processing
|
||||
else:
|
||||
train_dataset = AbsEmbedderTrainDataset(
|
||||
args=self.data_args,
|
||||
tokenizer=self.tokenizer
|
||||
)
|
||||
return train_dataset
|
||||
|
||||
def load_data_collator(self) -> AbsEmbedderCollator:
|
||||
"""Loads the appropriate data collator.
|
||||
|
||||
Returns:
|
||||
AbsEmbedderCollator: Loaded data collator.
|
||||
"""
|
||||
if self.data_args.same_dataset_within_batch:
|
||||
EmbedCollator = AbsEmbedderSameDatasetCollator
|
||||
else:
|
||||
EmbedCollator = AbsEmbedderCollator
|
||||
|
||||
data_collator = EmbedCollator(
|
||||
tokenizer=self.tokenizer,
|
||||
query_max_len=self.data_args.query_max_len,
|
||||
passage_max_len=self.data_args.passage_max_len,
|
||||
sub_batch_size=self.training_args.sub_batch_size,
|
||||
pad_to_multiple_of=self.data_args.pad_to_multiple_of,
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
return data_collator
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Executes the training process.
|
||||
"""
|
||||
Path(self.training_args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Training
|
||||
self.trainer.train(resume_from_checkpoint=self.training_args.resume_from_checkpoint)
|
||||
self.trainer.save_model()
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbsEmbedderTrainer(ABC, Trainer):
|
||||
"""
|
||||
Abstract class for the trainer of embedder.
|
||||
"""
|
||||
@abstractmethod
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
pass
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
|
||||
Args:
|
||||
model (AbsEmbedderModel): The model being trained.
|
||||
inputs (dict): A dictionary of input tensors to be passed to the model.
|
||||
return_outputs (bool, optional): If ``True``, returns both the loss and the model's outputs. Otherwise,
|
||||
returns only the loss.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, tuple(torch.Tensor, EmbedderOutput)]: The computed loss. If ``return_outputs`` is ``True``,
|
||||
also returns the model's outputs in a tuple ``(loss, outputs)``.
|
||||
"""
|
||||
|
||||
outputs = model(**inputs)
|
||||
loss = outputs.loss
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
from .AbsArguments import (
|
||||
AbsEmbedderDataArguments,
|
||||
AbsEmbedderModelArguments,
|
||||
AbsEmbedderTrainingArguments,
|
||||
)
|
||||
from .AbsDataset import (
|
||||
AbsEmbedderCollator, AbsEmbedderSameDatasetCollator,
|
||||
AbsEmbedderSameDatasetTrainDataset,
|
||||
AbsEmbedderTrainDataset,
|
||||
EmbedderTrainerCallbackForDataRefresh,
|
||||
)
|
||||
from .AbsModeling import AbsEmbedderModel, EmbedderOutput
|
||||
from .AbsTrainer import AbsEmbedderTrainer
|
||||
from .AbsRunner import AbsEmbedderRunner
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AbsEmbedderModelArguments",
|
||||
"AbsEmbedderDataArguments",
|
||||
"AbsEmbedderTrainingArguments",
|
||||
"AbsEmbedderModel",
|
||||
"AbsEmbedderTrainer",
|
||||
"AbsEmbedderRunner",
|
||||
"AbsEmbedderTrainDataset",
|
||||
"AbsEmbedderCollator",
|
||||
"AbsEmbedderSameDatasetTrainDataset",
|
||||
"AbsEmbedderSameDatasetCollator",
|
||||
"EmbedderOutput",
|
||||
"EmbedderTrainerCallbackForDataRefresh",
|
||||
]
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbsRerankerModelArguments:
|
||||
"""
|
||||
Abstract class for reranker model arguments.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "The model checkpoint for initialization."}
|
||||
)
|
||||
config_name: str = field(
|
||||
default=None,
|
||||
metadata={"help": "Pretrained config name or path if not the same as model_name."}
|
||||
)
|
||||
tokenizer_name: str = field(
|
||||
default=None,
|
||||
metadata={"help": "Pretrained tokenizer name or path if not the same as model_name."}
|
||||
)
|
||||
cache_dir: str = field(
|
||||
default=None,
|
||||
metadata={"help": "Where do you want to store the pre-trained models downloaded from s3."}
|
||||
)
|
||||
trust_remote_code: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Trust remote code"}
|
||||
)
|
||||
model_type: str = field(
|
||||
default='encoder',
|
||||
metadata={"help": "Type of finetune, ['encoder', 'decoder']"}
|
||||
)
|
||||
token: str = field(
|
||||
default_factory=lambda: os.getenv('HF_TOKEN', None),
|
||||
metadata={"help": "The token to use when accessing the model."}
|
||||
)
|
||||
# finetune_type: str = field(
|
||||
# default='sratch',
|
||||
# metadata={"help": "Type of finetune, ['sratch', 'finetune']"}
|
||||
# )
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbsRerankerDataArguments:
|
||||
"""
|
||||
Abstract class for reranker data arguments.
|
||||
"""
|
||||
train_data: str = field(
|
||||
default=None, metadata={
|
||||
"help": "One or more paths to training data. `query: str`, `pos: List[str]`, `neg: List[str]` are required in the training data.",
|
||||
"nargs": "+"
|
||||
}
|
||||
)
|
||||
cache_path: Optional[str] = field(
|
||||
default=None, metadata={"help": "Where do you want to store the cached data"}
|
||||
)
|
||||
train_group_size: int = field(default=8)
|
||||
|
||||
query_max_len: int = field(
|
||||
default=32,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer than this will be truncated."
|
||||
},
|
||||
)
|
||||
|
||||
passage_max_len: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer than this will be truncated."
|
||||
},
|
||||
)
|
||||
|
||||
max_len: int = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated."
|
||||
},
|
||||
)
|
||||
|
||||
pad_to_multiple_of: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set will pad the sequence to be a multiple of the provided value."
|
||||
},
|
||||
)
|
||||
|
||||
max_example_num_per_dataset: int = field(
|
||||
default=100000000, metadata={"help": "the max number of examples for each dataset"}
|
||||
)
|
||||
|
||||
query_instruction_for_rerank: str= field(
|
||||
default=None, metadata={"help": "instruction for query"}
|
||||
)
|
||||
query_instruction_format: str = field(
|
||||
default="{}{}", metadata={"help": "format for query instruction"}
|
||||
)
|
||||
|
||||
knowledge_distillation: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use knowledge distillation when `pos_scores: List[float]` and `neg_scores: List[float]` are in features of training data"}
|
||||
)
|
||||
|
||||
passage_instruction_for_rerank: Optional[str] = field(
|
||||
default=None, metadata={"help": "instruction for passage"}
|
||||
)
|
||||
passage_instruction_format: Optional[str] = field(
|
||||
default="{}{}", metadata={"help": "format for passage instruction"}
|
||||
)
|
||||
|
||||
shuffle_ratio: float = field(
|
||||
default=0.0, metadata={"help": "The ratio of shuffling the text"}
|
||||
)
|
||||
|
||||
sep_token: str = field(
|
||||
default='\n', metadata={"help": "The sep token for LLM reranker to discriminate between query and passage"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# replace "\\n" with "\n"
|
||||
if "\\n" in self.query_instruction_format:
|
||||
self.query_instruction_format = self.query_instruction_format.replace("\\n", "\n")
|
||||
if "\\n" in self.passage_instruction_format:
|
||||
self.passage_instruction_format = self.passage_instruction_format.replace("\\n", "\n")
|
||||
|
||||
# check the existence of train data
|
||||
for train_dir in self.train_data:
|
||||
if not os.path.exists(train_dir):
|
||||
raise FileNotFoundError(f"cannot find file: {train_dir}, please set a true path")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbsRerankerTrainingArguments(TrainingArguments):
|
||||
sub_batch_size: Optional[int] = field(default=None, metadata={"help": "sub batch size for training, not implemented yet"})
|
||||
|
|
@ -0,0 +1,400 @@
|
|||
import os
|
||||
import math
|
||||
import random
|
||||
import logging
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
from dataclasses import dataclass
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import (
|
||||
PreTrainedTokenizer,
|
||||
DataCollatorWithPadding,
|
||||
BatchEncoding,
|
||||
DataCollatorForSeq2Seq
|
||||
)
|
||||
from typing import List
|
||||
|
||||
from .AbsArguments import AbsRerankerDataArguments
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbsRerankerTrainDataset(Dataset):
|
||||
"""Abstract class for reranker training dataset.
|
||||
|
||||
Args:
|
||||
args (AbsRerankerDataArguments): Data arguments.
|
||||
tokenizer (PreTrainedTokenizer): Tokenizer to use.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
args: AbsRerankerDataArguments,
|
||||
tokenizer: PreTrainedTokenizer
|
||||
):
|
||||
self.args = args
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
train_datasets = []
|
||||
for data_dir in args.train_data:
|
||||
if not os.path.isdir(data_dir):
|
||||
if not (data_dir.endswith('.json') or data_dir.endswith('.jsonl')): continue
|
||||
temp_dataset = self._load_dataset(data_dir)
|
||||
if len(temp_dataset) == 0: continue
|
||||
train_datasets.append(temp_dataset)
|
||||
else:
|
||||
for file in os.listdir(data_dir):
|
||||
if not (file.endswith('.json') or file.endswith('.jsonl')): continue
|
||||
temp_dataset = self._load_dataset(os.path.join(data_dir, file))
|
||||
if len(temp_dataset) == 0: continue
|
||||
train_datasets.append(temp_dataset)
|
||||
self.dataset = datasets.concatenate_datasets(train_datasets)
|
||||
|
||||
self.max_length = self.args.query_max_len + self.args.passage_max_len
|
||||
|
||||
def _load_dataset(self, file_path: str):
|
||||
"""Load dataset from path.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to load the datasets from.
|
||||
|
||||
Raises:
|
||||
ValueError: `pos_scores` and `neg_scores` not found in the features of training data
|
||||
|
||||
Returns:
|
||||
datasets.Dataset: Loaded HF dataset.
|
||||
"""
|
||||
if dist.get_rank() == 0:
|
||||
logger.info(f'loading data from {file_path} ...')
|
||||
|
||||
temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path)
|
||||
if len(temp_dataset) > self.args.max_example_num_per_dataset:
|
||||
temp_dataset = temp_dataset.select(random.sample(list(range(len(temp_dataset))), self.args.max_example_num_per_dataset))
|
||||
if not self.args.knowledge_distillation:
|
||||
if 'pos_scores' in temp_dataset.column_names:
|
||||
temp_dataset = temp_dataset.remove_columns(['pos_scores'])
|
||||
if 'neg_scores' in temp_dataset.column_names:
|
||||
temp_dataset = temp_dataset.remove_columns(['neg_scores'])
|
||||
else:
|
||||
if 'pos_scores' not in temp_dataset.column_names or 'neg_scores' not in temp_dataset.column_names:
|
||||
raise ValueError(f"`pos_scores` and `neg_scores` not found in the features of training data in {file_path}, which is necessary when using knowledge distillation.")
|
||||
return temp_dataset
|
||||
|
||||
def _shuffle_text(self, text):
|
||||
"""shuffle the input text.
|
||||
|
||||
Args:
|
||||
text (str): Input text.
|
||||
|
||||
Returns:
|
||||
str: Shuffled text.
|
||||
"""
|
||||
if self.args.shuffle_ratio > 0 and len(text) > 100 and random.random() < self.args.shuffle_ratio:
|
||||
split_text = []
|
||||
chunk_size = len(text)//3 + 1
|
||||
for i in range(0, len(text), chunk_size):
|
||||
split_text.append(text[i:i+chunk_size])
|
||||
random.shuffle(split_text)
|
||||
return " ".join(split_text)
|
||||
else:
|
||||
return text
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def create_one_example(self, qry_encoding: str, doc_encoding: str):
|
||||
"""Creates a single input example by encoding and preparing a query and document pair for the model.
|
||||
|
||||
Args:
|
||||
qry_encoding (str): Query to be encoded.
|
||||
doc_encoding (str): Document to be encoded.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing tokenized and prepared inputs, ready for model consumption.
|
||||
"""
|
||||
qry_inputs = self.tokenizer.encode(qry_encoding, truncation=True, max_length=self.args.query_max_len + self.args.passage_max_len // 4, add_special_tokens=False)
|
||||
doc_inputs = self.tokenizer.encode(doc_encoding, truncation=True, max_length=self.args.passage_max_len + self.args.query_max_len // 2, add_special_tokens=False)
|
||||
item = self.tokenizer.prepare_for_model(
|
||||
qry_inputs,
|
||||
doc_inputs,
|
||||
truncation='only_second',
|
||||
max_length=self.args.query_max_len + self.args.passage_max_len,
|
||||
padding=False,
|
||||
)
|
||||
return item
|
||||
|
||||
def __getitem__(self, item):
|
||||
data = self.dataset[item]
|
||||
train_group_size = self.args.train_group_size
|
||||
|
||||
query = data['query']
|
||||
if self.args.query_instruction_for_rerank is not None:
|
||||
query = self.args.query_instruction_format.format(
|
||||
data['query_prompt'] if 'query_prompt' in data else self.args.query_instruction_for_rerank,
|
||||
query
|
||||
)
|
||||
|
||||
passages = []
|
||||
teacher_scores = []
|
||||
|
||||
assert isinstance(data['pos'], list) and isinstance(data['neg'], list)
|
||||
|
||||
pos_idx = random.choice(list(range(len(data['pos']))))
|
||||
passages.append(self._shuffle_text(data['pos'][pos_idx]))
|
||||
|
||||
neg_all_idx = list(range(len(data['neg'])))
|
||||
if len(data['neg']) < train_group_size - 1:
|
||||
num = math.ceil((train_group_size - 1) / len(data['neg']))
|
||||
neg_idxs = random.sample(neg_all_idx * num, train_group_size - 1)
|
||||
else:
|
||||
neg_idxs = random.sample(neg_all_idx, self.args.train_group_size - 1)
|
||||
for neg_idx in neg_idxs:
|
||||
passages.append(data['neg'][neg_idx])
|
||||
|
||||
if self.args.knowledge_distillation:
|
||||
assert isinstance(data['pos_scores'], list) and isinstance(data['neg_scores'], list)
|
||||
teacher_scores.append(data['pos_scores'][pos_idx])
|
||||
for neg_idx in neg_idxs:
|
||||
teacher_scores.append(data['neg_scores'][neg_idx])
|
||||
if not all(isinstance(score, (int, float)) for score in teacher_scores):
|
||||
raise ValueError(f"pos_score or neg_score must be digit")
|
||||
else:
|
||||
teacher_scores = None
|
||||
|
||||
if self.args.passage_instruction_for_rerank is not None:
|
||||
passages = [
|
||||
self.args.passage_instruction_format.format(
|
||||
data['passage_prompt'] if 'passage_prompt' in data else self.args.passage_instruction_for_rerank, p
|
||||
)
|
||||
for p in passages
|
||||
]
|
||||
|
||||
batch_data = []
|
||||
for passage in passages:
|
||||
batch_data.append(self.create_one_example(query, passage))
|
||||
|
||||
return batch_data, teacher_scores
|
||||
|
||||
@dataclass
|
||||
class AbsRerankerCollator(DataCollatorWithPadding):
|
||||
"""
|
||||
The abstract reranker collator.
|
||||
"""
|
||||
query_max_len: int = 32
|
||||
passage_max_len: int = 128
|
||||
|
||||
def __call__(self, features) -> List[BatchEncoding]:
|
||||
teacher_scores = [f[1] for f in features]
|
||||
if teacher_scores[0] is None:
|
||||
teacher_scores = None
|
||||
elif isinstance(teacher_scores[0], list):
|
||||
teacher_scores = sum(teacher_scores, [])
|
||||
|
||||
features = [f[0] for f in features]
|
||||
if isinstance(features[0], list):
|
||||
features = sum(features, [])
|
||||
|
||||
collated = self.tokenizer.pad(
|
||||
features,
|
||||
padding=self.padding,
|
||||
max_length=self.query_max_len + self.passage_max_len,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors,
|
||||
)
|
||||
|
||||
return {
|
||||
"pair": collated,
|
||||
"teacher_scores": teacher_scores,
|
||||
}
|
||||
|
||||
class AbsLLMRerankerTrainDataset(AbsRerankerTrainDataset):
|
||||
"""Abstract class for LLM reranker training dataset.
|
||||
|
||||
Args:
|
||||
args (AbsRerankerDataArguments): Data arguments.
|
||||
tokenizer (PreTrainedTokenizer): Tokenizer to use.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
args: AbsRerankerDataArguments,
|
||||
tokenizer: PreTrainedTokenizer
|
||||
):
|
||||
super().__init__(args, tokenizer)
|
||||
sep = self.args.sep_token
|
||||
self.sep_inputs = self.tokenizer(
|
||||
sep,
|
||||
return_tensors=None,
|
||||
add_special_tokens=False
|
||||
)['input_ids']
|
||||
|
||||
def __getitem__(self, item) -> List[BatchEncoding]:
|
||||
data = self.dataset[item]
|
||||
train_group_size = self.args.train_group_size
|
||||
|
||||
query = data['query']
|
||||
if self.args.query_instruction_for_rerank is not None:
|
||||
query = self.args.query_instruction_format.format(
|
||||
data['query_prompt'] if 'query_prompt' in data else self.args.query_instruction_for_rerank,
|
||||
query
|
||||
)
|
||||
|
||||
passages = []
|
||||
teacher_scores = []
|
||||
|
||||
assert isinstance(data['pos'], list) and isinstance(data['neg'], list)
|
||||
|
||||
pos_idx = random.choice(list(range(len(data['pos']))))
|
||||
passages.append(self._shuffle_text(data['pos'][pos_idx]))
|
||||
|
||||
neg_all_idx = list(range(len(data['neg'])))
|
||||
if len(data['neg']) < train_group_size - 1:
|
||||
num = math.ceil((train_group_size - 1) / len(data['neg']))
|
||||
neg_idxs = random.sample(neg_all_idx * num, train_group_size - 1)
|
||||
else:
|
||||
neg_idxs = random.sample(neg_all_idx, self.args.train_group_size - 1)
|
||||
for neg_idx in neg_idxs:
|
||||
passages.append(data['neg'][neg_idx])
|
||||
|
||||
if self.args.knowledge_distillation:
|
||||
assert isinstance(data['pos_scores'], list) and isinstance(data['neg_scores'], list)
|
||||
teacher_scores.append(data['pos_scores'][pos_idx])
|
||||
for neg_idx in neg_idxs:
|
||||
teacher_scores.append(data['neg_scores'][neg_idx])
|
||||
if not all(isinstance(score, (int, float)) for score in teacher_scores):
|
||||
raise ValueError(f"pos_score or neg_score must be digit")
|
||||
else:
|
||||
teacher_scores = None
|
||||
|
||||
if self.args.passage_instruction_for_rerank is not None:
|
||||
passages = [
|
||||
self.args.passage_instruction_format.format(
|
||||
data['passage_prompt'] if 'passage_prompt' in data else self.args.passage_instruction_for_rerank, p
|
||||
)
|
||||
for p in passages
|
||||
]
|
||||
|
||||
prompt = self.dataset[item].get('prompt', "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'.")
|
||||
|
||||
query_inputs = self.tokenizer(
|
||||
query,
|
||||
return_tensors=None,
|
||||
max_length=self.args.query_max_len + self.args.passage_max_len // 4,
|
||||
truncation=True,
|
||||
add_special_tokens=False
|
||||
)
|
||||
|
||||
prompt_inputs = self.tokenizer(
|
||||
prompt,
|
||||
return_tensors=None,
|
||||
add_special_tokens=False
|
||||
)['input_ids']
|
||||
|
||||
max_length = self.max_length - len(prompt_inputs) - len(self.sep_inputs)
|
||||
|
||||
passages_inputs = []
|
||||
for i, passage in enumerate(passages):
|
||||
passage_inputs = self.tokenizer(
|
||||
passage,
|
||||
return_tensors=None,
|
||||
max_length=self.args.passage_max_len + self.args.query_max_len // 2,
|
||||
truncation=True,
|
||||
add_special_tokens=False
|
||||
)
|
||||
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.tokenizer.pad_token_id:
|
||||
item = self.tokenizer.prepare_for_model(
|
||||
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
|
||||
self.sep_inputs + passage_inputs['input_ids'],
|
||||
truncation='only_second',
|
||||
max_length=max_length,
|
||||
padding=False,
|
||||
return_attention_mask=False,
|
||||
return_token_type_ids=False,
|
||||
add_special_tokens=False
|
||||
)
|
||||
else:
|
||||
item = self.tokenizer.prepare_for_model(
|
||||
query_inputs['input_ids'],
|
||||
self.sep_inputs + passage_inputs['input_ids'],
|
||||
truncation='only_second',
|
||||
max_length=max_length,
|
||||
padding=False,
|
||||
return_attention_mask=False,
|
||||
return_token_type_ids=False,
|
||||
add_special_tokens=False
|
||||
)
|
||||
|
||||
passage_inputs['input_ids'] = item['input_ids'] + self.sep_inputs + prompt_inputs
|
||||
|
||||
passage_inputs['attention_mask'] = [1] * len(passage_inputs['input_ids'])
|
||||
# passage_inputs['labels'] = passage_inputs['input_ids'].copy()
|
||||
# passage_inputs['labels'] = [-100] * (len(passage_inputs['input_ids']) - 1) + passage_inputs['labels'][(len(passage_inputs['input_ids']) - 1):]
|
||||
passage_inputs.pop('token_type_ids') if 'token_type_ids' in passage_inputs.keys() else None
|
||||
if 'position_ids' in passage_inputs.keys():
|
||||
passage_inputs['position_ids'] = list(range(len(passage_inputs['input_ids'])))
|
||||
passages_inputs.append(passage_inputs)
|
||||
|
||||
return passages_inputs, teacher_scores
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbsLLMRerankerCollator(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
|
||||
and pass batch separately to the actual collator.
|
||||
Abstract out data detail for the model.
|
||||
"""
|
||||
query_max_len: int = 32
|
||||
passage_max_len: int = 128
|
||||
|
||||
def __call__(self, features, return_tensors='pt'):
|
||||
if return_tensors is None:
|
||||
return_tensors = self.return_tensors
|
||||
|
||||
teacher_scores = [f[1] for f in features]
|
||||
if teacher_scores[0] is None:
|
||||
teacher_scores = None
|
||||
elif isinstance(teacher_scores[0], list):
|
||||
teacher_scores = sum(teacher_scores, [])
|
||||
|
||||
features = [f[0] for f in features]
|
||||
if isinstance(features[0], list):
|
||||
features = sum(features, [])
|
||||
|
||||
labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
|
||||
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
|
||||
# same length to return tensors.
|
||||
if labels is not None:
|
||||
max_label_length = max(len(l) for l in labels)
|
||||
# print(max_label_length)
|
||||
if self.pad_to_multiple_of is not None:
|
||||
max_label_length = (
|
||||
(max_label_length + self.pad_to_multiple_of - 1)
|
||||
// self.pad_to_multiple_of
|
||||
* self.pad_to_multiple_of
|
||||
)
|
||||
|
||||
padding_side = self.tokenizer.padding_side
|
||||
for feature in features:
|
||||
remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
|
||||
if isinstance(feature["labels"], list):
|
||||
feature["labels"] = (
|
||||
feature["labels"] + remainder
|
||||
if padding_side == "right" else remainder + feature["labels"]
|
||||
)
|
||||
elif padding_side == "right":
|
||||
feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
|
||||
else:
|
||||
feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
|
||||
|
||||
collated = self.tokenizer.pad(
|
||||
features,
|
||||
padding=self.padding,
|
||||
max_length=self.query_max_len + self.passage_max_len,
|
||||
return_tensors=return_tensors,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
)
|
||||
|
||||
return {
|
||||
"pair": collated,
|
||||
"teacher_scores": teacher_scores,
|
||||
}
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
import torch
|
||||
from torch import nn, Tensor
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.file_utils import ModelOutput
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, List, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RerankerOutput(ModelOutput):
|
||||
loss: Optional[Tensor] = None
|
||||
scores: Optional[Tensor] = None
|
||||
|
||||
|
||||
class AbsRerankerModel(ABC, nn.Module):
|
||||
"""Abstract class of embedding model for training.
|
||||
|
||||
Args:
|
||||
base_model: The base model to train on.
|
||||
tokenizer (PreTrainedTokenizer, optional): The tokenizer to use. Defaults to ``None``.
|
||||
train_batch_size (int, optional): Batch size used for training. Defaults to ``4``.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
base_model: None,
|
||||
tokenizer: PreTrainedTokenizer = None,
|
||||
train_batch_size: int = 4,
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
self.model = base_model
|
||||
self.tokenizer = tokenizer
|
||||
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
|
||||
|
||||
if self.model.config.pad_token_id is None:
|
||||
self.model.config.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.config = self.model.config
|
||||
|
||||
self.train_batch_size = train_batch_size
|
||||
|
||||
self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][-1]
|
||||
|
||||
def gradient_checkpointing_enable(self, **kwargs):
|
||||
"""
|
||||
Activates gradient checkpointing for the current model.
|
||||
"""
|
||||
self.model.gradient_checkpointing_enable(**kwargs)
|
||||
|
||||
def enable_input_require_grads(self, **kwargs):
|
||||
"""
|
||||
Enables the gradients for the input embeddings.
|
||||
"""
|
||||
self.model.enable_input_require_grads(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, features):
|
||||
"""Abstract method of encode.
|
||||
|
||||
Args:
|
||||
features (dict): Teatures to pass to the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
def forward(self, pair: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None, teacher_scores: Optional[Tensor] = None):
|
||||
"""The computation performed at every call.
|
||||
|
||||
Args:
|
||||
pair (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): The query-document pair. Defaults to ``None``.
|
||||
teacher_scores (Optional[Tensor], optional): Teacher scores of knowledge distillation. Defaults to None.
|
||||
|
||||
Returns:
|
||||
RerankerOutput: Output of reranker model.
|
||||
"""
|
||||
ranker_logits = self.encode(pair) # (batch_size * num, dim)
|
||||
if teacher_scores is not None:
|
||||
teacher_scores = torch.Tensor(teacher_scores)
|
||||
teacher_targets = teacher_scores.view(self.train_batch_size, -1)
|
||||
teacher_targets = torch.softmax(teacher_targets.detach(), dim=-1)
|
||||
|
||||
if self.training:
|
||||
grouped_logits = ranker_logits.view(self.train_batch_size, -1)
|
||||
target = torch.zeros(self.train_batch_size, device=grouped_logits.device, dtype=torch.long)
|
||||
loss = self.compute_loss(grouped_logits, target)
|
||||
if teacher_scores is not None:
|
||||
teacher_targets = teacher_targets.to(grouped_logits.device)
|
||||
# print(teacher_targets, torch.mean(torch.sum(torch.log_softmax(grouped_logits, dim=-1) * teacher_targets, dim=-1)))
|
||||
loss += - torch.mean(torch.sum(torch.log_softmax(grouped_logits, dim=-1) * teacher_targets, dim=-1))
|
||||
else:
|
||||
loss = None
|
||||
|
||||
# print(loss)
|
||||
return RerankerOutput(
|
||||
loss=loss,
|
||||
scores=ranker_logits,
|
||||
)
|
||||
|
||||
def compute_loss(self, scores, target):
|
||||
"""Compute the loss.
|
||||
|
||||
Args:
|
||||
scores (torch.Tensor): Computed scores.
|
||||
target (torch.Tensor): The target value.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The computed loss.
|
||||
"""
|
||||
return self.cross_entropy(scores, target)
|
||||
|
||||
def save(self, output_dir: str):
|
||||
"""Save the model.
|
||||
|
||||
Args:
|
||||
output_dir (str): Directory for saving the model.
|
||||
"""
|
||||
# self.model.save_pretrained(output_dir)
|
||||
state_dict = self.model.state_dict()
|
||||
state_dict = type(state_dict)(
|
||||
{k: v.clone().cpu()
|
||||
for k,
|
||||
v in state_dict.items()})
|
||||
self.model.save_pretrained(output_dir, state_dict=state_dict)
|
||||
|
||||
def save_pretrained(self, *args, **kwargs):
|
||||
"""
|
||||
Save the tokenizer and model.
|
||||
"""
|
||||
self.tokenizer.save_pretrained(*args, **kwargs)
|
||||
return self.model.save_pretrained(*args, **kwargs)
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
from transformers import set_seed, PreTrainedTokenizer
|
||||
|
||||
|
||||
from .AbsArguments import (
|
||||
AbsRerankerModelArguments,
|
||||
AbsRerankerDataArguments,
|
||||
AbsRerankerTrainingArguments
|
||||
)
|
||||
from .AbsTrainer import AbsRerankerTrainer
|
||||
from .AbsModeling import AbsRerankerModel
|
||||
from .AbsDataset import (
|
||||
AbsRerankerTrainDataset, AbsRerankerCollator,
|
||||
AbsLLMRerankerTrainDataset, AbsLLMRerankerCollator
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbsRerankerRunner(ABC):
|
||||
"""Abstract class to run reranker model fine-tuning.
|
||||
|
||||
Args:
|
||||
model_args (AbsRerankerModelArguments): Model arguments
|
||||
data_args (AbsRerankerDataArguments): Data arguments.
|
||||
training_args (AbsRerankerTrainingArguments): Training arguments.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model_args: AbsRerankerModelArguments,
|
||||
data_args: AbsRerankerDataArguments,
|
||||
training_args: AbsRerankerTrainingArguments
|
||||
):
|
||||
self.model_args = model_args
|
||||
self.data_args = data_args
|
||||
self.training_args = training_args
|
||||
|
||||
if (
|
||||
os.path.exists(training_args.output_dir)
|
||||
and os.listdir(training_args.output_dir)
|
||||
and training_args.do_train
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
training_args.local_rank,
|
||||
training_args.device,
|
||||
training_args.n_gpu,
|
||||
bool(training_args.local_rank != -1),
|
||||
training_args.fp16,
|
||||
)
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
logger.info("Model parameters %s", model_args)
|
||||
logger.info("Data parameters %s", data_args)
|
||||
|
||||
# Set seed
|
||||
set_seed(training_args.seed)
|
||||
|
||||
self.tokenizer, self.model = self.load_tokenizer_and_model()
|
||||
self.train_dataset = self.load_train_dataset()
|
||||
self.data_collator = self.load_data_collator()
|
||||
self.trainer = self.load_trainer()
|
||||
|
||||
@abstractmethod
|
||||
def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsRerankerModel]:
|
||||
"""Abstract method to load the tokenizer and model.
|
||||
|
||||
Returns:
|
||||
Tuple[PreTrainedTokenizer, AbsRerankerModel]: Loaded tokenizer and model instances.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_trainer(self) -> AbsRerankerTrainer:
|
||||
"""Abstract method to load the trainer.
|
||||
|
||||
Returns:
|
||||
AbsRerankerTrainer: The loaded trainer instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
def load_train_dataset(self) -> AbsRerankerTrainDataset:
|
||||
"""Loads the training dataset based on data arguments.
|
||||
|
||||
Returns:
|
||||
AbsRerankerTrainDataset: The loaded dataset instance.
|
||||
"""
|
||||
if self.model_args.model_type == 'encoder':
|
||||
train_dataset = AbsRerankerTrainDataset(
|
||||
args=self.data_args,
|
||||
tokenizer=self.tokenizer
|
||||
)
|
||||
else:
|
||||
train_dataset = AbsLLMRerankerTrainDataset(
|
||||
args=self.data_args,
|
||||
tokenizer=self.tokenizer
|
||||
)
|
||||
return train_dataset
|
||||
|
||||
def load_data_collator(self) -> AbsRerankerCollator:
|
||||
"""Loads the appropriate data collator.
|
||||
|
||||
Returns:
|
||||
AbsRerankerCollator: Loaded data collator.
|
||||
"""
|
||||
if self.model_args.model_type == 'encoder':
|
||||
RerankerCollator = AbsRerankerCollator
|
||||
else:
|
||||
RerankerCollator = AbsLLMRerankerCollator
|
||||
|
||||
data_collator = RerankerCollator(
|
||||
tokenizer=self.tokenizer,
|
||||
query_max_len=self.data_args.query_max_len,
|
||||
passage_max_len=self.data_args.passage_max_len,
|
||||
pad_to_multiple_of=self.data_args.pad_to_multiple_of,
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
return data_collator
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Executes the training process.
|
||||
"""
|
||||
Path(self.training_args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Training
|
||||
self.trainer.train(resume_from_checkpoint=self.training_args.resume_from_checkpoint)
|
||||
self.trainer.save_model()
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbsRerankerTrainer(ABC, Trainer):
|
||||
"""
|
||||
Abstract class for the trainer of reranker.
|
||||
"""
|
||||
@abstractmethod
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
pass
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
|
||||
Subclass and override for custom behavior.
|
||||
|
||||
Args:
|
||||
model (AbsRerankerModel): The model being trained.
|
||||
inputs (dict): A dictionary of input tensors to be passed to the model.
|
||||
return_outputs (bool, optional): If ``True``, returns both the loss and the model's outputs. Otherwise,
|
||||
returns only the loss. Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, tuple(torch.Tensor, RerankerOutput)]: The computed loss. If ``return_outputs`` is ``True``,
|
||||
also returns the model's outputs in a tuple ``(loss, outputs)``.
|
||||
"""
|
||||
|
||||
outputs = model(**inputs)
|
||||
loss = outputs.loss
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
from .AbsArguments import AbsRerankerDataArguments, AbsRerankerModelArguments, AbsRerankerTrainingArguments
|
||||
from .AbsDataset import (
|
||||
AbsRerankerTrainDataset, AbsRerankerCollator,
|
||||
AbsLLMRerankerTrainDataset, AbsLLMRerankerCollator
|
||||
)
|
||||
from .AbsModeling import AbsRerankerModel, RerankerOutput
|
||||
from .AbsTrainer import AbsRerankerTrainer
|
||||
from .AbsRunner import AbsRerankerRunner
|
||||
|
||||
__all__ = [
|
||||
"AbsRerankerDataArguments",
|
||||
"AbsRerankerModelArguments",
|
||||
"AbsRerankerTrainingArguments",
|
||||
"AbsRerankerTrainDataset",
|
||||
"AbsRerankerCollator",
|
||||
"AbsLLMRerankerTrainDataset",
|
||||
"AbsLLMRerankerCollator",
|
||||
"AbsRerankerModel",
|
||||
"RerankerOutput",
|
||||
"AbsRerankerTrainer",
|
||||
"AbsRerankerRunner",
|
||||
]
|
||||
|
|
@ -0,0 +1,443 @@
|
|||
import logging
|
||||
from tqdm import tqdm, trange
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Union, List, Dict, Literal, Optional
|
||||
|
||||
import queue
|
||||
import multiprocessing as mp
|
||||
from multiprocessing import Queue
|
||||
|
||||
import math
|
||||
import gc
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import is_torch_npu_available
|
||||
|
||||
try:
|
||||
import torch_musa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbsEmbedder(ABC):
|
||||
"""
|
||||
Base class for embedder.
|
||||
Extend this class and implement :meth:`encode_queries`, :meth:`encode_corpus`, :meth:`encode` for custom embedders.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
|
||||
load a model from HuggingFace Hub with the name.
|
||||
normalize_embeddings (bool, optional): If True, normalize the embedding vector. Defaults to :data:`True`.
|
||||
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
|
||||
degradation. Defaults to :data:`True`.
|
||||
query_instruction_for_retrieval: (Optional[str], optional): Query instruction for retrieval tasks, which will be used with
|
||||
with :attr:`query_instruction_format`. Defaults to :data:`None`.
|
||||
query_instruction_format: (str, optional): The template for :attr:`query_instruction_for_retrieval`. Defaults to :data:`"{}{}"`.
|
||||
devices (Optional[Union[str, int, List[str], List[int]]], optional): Devices to use for model inference. Defaults to :data:`None`.
|
||||
batch_size (int, optional): Batch size for inference. Defaults to :data:`256`.
|
||||
query_max_length (int, optional): Maximum length for query. Defaults to :data:`512`.
|
||||
passage_max_length (int, optional): Maximum length for passage. Defaults to :data:`512`.
|
||||
convert_to_numpy (bool, optional): If True, the output embedding will be a Numpy array. Otherwise, it will be a Torch Tensor.
|
||||
Defaults to :data:`True`.
|
||||
kwargs (Dict[Any], optional): Additional parameters for HuggingFace Transformers config or children classes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
normalize_embeddings: bool = True,
|
||||
use_fp16: bool = True,
|
||||
query_instruction_for_retrieval: Optional[str] = None,
|
||||
query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval
|
||||
devices: Optional[Union[str, int, List[str], List[int]]] = None,
|
||||
# inference
|
||||
batch_size: int = 256,
|
||||
query_max_length: int = 512,
|
||||
passage_max_length: int = 512,
|
||||
convert_to_numpy: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.normalize_embeddings = normalize_embeddings
|
||||
self.use_fp16 = use_fp16
|
||||
self.query_instruction_for_retrieval = query_instruction_for_retrieval
|
||||
self.query_instruction_format = query_instruction_format
|
||||
self.target_devices = self.get_target_devices(devices)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.query_max_length = query_max_length
|
||||
self.passage_max_length = passage_max_length
|
||||
self.convert_to_numpy = convert_to_numpy
|
||||
|
||||
for k in kwargs:
|
||||
setattr(self, k, kwargs[k])
|
||||
|
||||
self.kwargs = kwargs
|
||||
|
||||
# tokenizer and model are initialized in the child class
|
||||
self.tokenizer = None
|
||||
self.model = None
|
||||
self.pool = None
|
||||
|
||||
def stop_self_pool(self):
|
||||
if self.pool is not None:
|
||||
self.stop_multi_process_pool(self.pool)
|
||||
self.pool = None
|
||||
try:
|
||||
self.model.to('cpu')
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
if gc is not None and callable(gc.collect):
|
||||
gc.collect()
|
||||
|
||||
@staticmethod
|
||||
def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[str]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
devices (Union[str, int, List[str], List[int]]): specified devices, can be `str`, `int`, list of `str`, or list of `int`.
|
||||
|
||||
Raises:
|
||||
ValueError: Devices should be a string or an integer or a list of strings or a list of integers.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of target devices in format.
|
||||
"""
|
||||
if devices is None:
|
||||
if torch.cuda.is_available():
|
||||
return [f"cuda:{i}" for i in range(torch.cuda.device_count())]
|
||||
elif is_torch_npu_available():
|
||||
return [f"npu:{i}" for i in range(torch.npu.device_count())]
|
||||
elif hasattr(torch, "musa") and torch.musa.is_available():
|
||||
return [f"musa:{i}" for i in range(torch.musa.device_count())]
|
||||
elif torch.backends.mps.is_available():
|
||||
try:
|
||||
return [f"mps:{i}" for i in range(torch.mps.device_count())]
|
||||
except:
|
||||
return ["mps"]
|
||||
else:
|
||||
return ["cpu"]
|
||||
elif isinstance(devices, str):
|
||||
return [devices]
|
||||
elif isinstance(devices, int):
|
||||
if hasattr(torch, "musa") and torch.musa.is_available():
|
||||
return [f"musa:{devices}"]
|
||||
else:
|
||||
return [f"cuda:{devices}"]
|
||||
elif isinstance(devices, list):
|
||||
if isinstance(devices[0], str):
|
||||
return devices
|
||||
elif isinstance(devices[0], int):
|
||||
if hasattr(torch, "musa") and torch.musa.is_available():
|
||||
return [f"musa:{device}" for device in devices]
|
||||
else:
|
||||
return [f"cuda:{device}" for device in devices]
|
||||
else:
|
||||
raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.")
|
||||
else:
|
||||
raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.")
|
||||
|
||||
@staticmethod
|
||||
def get_detailed_instruct(instruction_format: str, instruction: str, sentence: str):
|
||||
"""Combine the instruction and sentence along with the instruction format.
|
||||
|
||||
Args:
|
||||
instruction_format (str): Format for instruction.
|
||||
instruction (str): The text of instruction.
|
||||
sentence (str): The sentence to concatenate with.
|
||||
|
||||
Returns:
|
||||
str: The complete sentence with instruction
|
||||
"""
|
||||
if "\\n" in instruction_format:
|
||||
instruction_format = instruction_format.replace("\\n", "\n")
|
||||
return instruction_format.format(instruction, sentence)
|
||||
|
||||
def encode_queries(
|
||||
self,
|
||||
queries: Union[List[str], str],
|
||||
batch_size: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
convert_to_numpy: Optional[bool] = None,
|
||||
**kwargs: Any
|
||||
):
|
||||
"""encode the queries using the instruction if provided.
|
||||
|
||||
Args:
|
||||
queries (Union[List[str], str]): Input queries to encode.
|
||||
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
|
||||
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
|
||||
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
|
||||
be a Torch Tensor. Defaults to :data:`None`.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
|
||||
"""
|
||||
if batch_size is None: batch_size = self.batch_size
|
||||
if max_length is None: max_length = self.query_max_length
|
||||
if convert_to_numpy is None: convert_to_numpy = self.convert_to_numpy
|
||||
|
||||
return self.encode(
|
||||
queries,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
convert_to_numpy=convert_to_numpy,
|
||||
instruction=self.query_instruction_for_retrieval,
|
||||
instruction_format=self.query_instruction_format,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def encode_corpus(
|
||||
self,
|
||||
corpus: Union[List[str], str],
|
||||
batch_size: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
convert_to_numpy: Optional[bool] = None,
|
||||
**kwargs: Any
|
||||
):
|
||||
"""encode the corpus using the instruction if provided.
|
||||
|
||||
Args:
|
||||
corpus (Union[List[str], str]): Input corpus to encode.
|
||||
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
|
||||
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
|
||||
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
|
||||
be a Torch Tensor. Defaults to :data:`None`.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, np.ndarray]: Return the embedding vectors in a numpy array or tensor.
|
||||
"""
|
||||
passage_instruction_for_retrieval = self.kwargs.get("passage_instruction_for_retrieval", None)
|
||||
passage_instruction_format = self.kwargs.get("passage_instruction_format", "{}{}")
|
||||
|
||||
if batch_size is None: batch_size = self.batch_size
|
||||
if max_length is None: max_length = self.passage_max_length
|
||||
if convert_to_numpy is None: convert_to_numpy = self.convert_to_numpy
|
||||
|
||||
return self.encode(
|
||||
corpus,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
convert_to_numpy=convert_to_numpy,
|
||||
instruction=passage_instruction_for_retrieval,
|
||||
instruction_format=passage_instruction_format,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
sentences: Union[List[str], str],
|
||||
batch_size: Optional[int] = None,
|
||||
max_length: Optional[int] = None,
|
||||
convert_to_numpy: Optional[bool] = None,
|
||||
instruction: Optional[str] = None,
|
||||
instruction_format: Optional[str] = None,
|
||||
**kwargs: Any
|
||||
):
|
||||
"""encode the input sentences with the embedding model.
|
||||
|
||||
Args:
|
||||
sentences (Union[List[str], str]): Input sentences to encode.
|
||||
batch_size (Optional[int], optional): Number of sentences for each iter. Defaults to :data:`None`.
|
||||
max_length (Optional[int], optional): Maximum length of tokens. Defaults to :data:`None`.
|
||||
convert_to_numpy (Optional[bool], optional): If True, the output embedding will be a Numpy array. Otherwise, it will
|
||||
be a Torch Tensor. Defaults to :data:`None`.
|
||||
instruction (Optional[str], optional): The text of instruction. Defaults to :data:`None`.
|
||||
instruction_format (Optional[str], optional): Format for instruction. Defaults to :data:`None`.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
|
||||
"""
|
||||
if batch_size is None: batch_size = self.batch_size
|
||||
if max_length is None: max_length = self.passage_max_length
|
||||
if convert_to_numpy is None: convert_to_numpy = self.convert_to_numpy
|
||||
|
||||
if instruction is not None:
|
||||
if isinstance(sentences, str):
|
||||
sentences = self.get_detailed_instruct(instruction_format, instruction, sentences)
|
||||
else:
|
||||
sentences = [self.get_detailed_instruct(instruction_format, instruction, sentence) for sentence in
|
||||
sentences]
|
||||
|
||||
if isinstance(sentences, str) or len(self.target_devices) == 1:
|
||||
return self.encode_single_device(
|
||||
sentences,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
convert_to_numpy=convert_to_numpy,
|
||||
device=self.target_devices[0],
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if self.pool is None:
|
||||
self.pool = self.start_multi_process_pool(AbsEmbedder._encode_multi_process_worker)
|
||||
embeddings = self.encode_multi_process(
|
||||
sentences,
|
||||
self.pool,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
convert_to_numpy=convert_to_numpy,
|
||||
**kwargs
|
||||
)
|
||||
return embeddings
|
||||
|
||||
def __del__(self):
|
||||
self.stop_self_pool()
|
||||
|
||||
@abstractmethod
|
||||
def encode_single_device(
|
||||
self,
|
||||
sentences: Union[List[str], str],
|
||||
batch_size: int = 256,
|
||||
max_length: int = 512,
|
||||
convert_to_numpy: bool = True,
|
||||
device: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
This method should encode sentences and return embeddings on a single device.
|
||||
"""
|
||||
pass
|
||||
|
||||
# adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L807
|
||||
def start_multi_process_pool(
|
||||
self,
|
||||
process_target_func: Any,
|
||||
) -> Dict[Literal["input", "output", "processes"], Any]:
|
||||
"""
|
||||
Starts a multi-process pool to process the encoding with several independent processes
|
||||
via :meth:`SentenceTransformer.encode_multi_process <sentence_transformers.SentenceTransformer.encode_multi_process>`.
|
||||
|
||||
This method is recommended if you want to encode on multiple GPUs or CPUs. It is advised
|
||||
to start only one process per GPU. This method works together with encode_multi_process
|
||||
and stop_multi_process_pool.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with the target processes, an input queue, and an output queue.
|
||||
"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model is not initialized.")
|
||||
|
||||
logger.info("Start multi-process pool on devices: {}".format(", ".join(map(str, self.target_devices))))
|
||||
|
||||
self.model.to("cpu")
|
||||
self.model.share_memory()
|
||||
ctx = mp.get_context("spawn")
|
||||
input_queue = ctx.Queue()
|
||||
output_queue = ctx.Queue()
|
||||
processes = []
|
||||
|
||||
for device_id in tqdm(self.target_devices, desc='initial target device'):
|
||||
p = ctx.Process(
|
||||
target=process_target_func,
|
||||
args=(device_id, self, input_queue, output_queue),
|
||||
daemon=True,
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
return {"input": input_queue, "output": output_queue, "processes": processes}
|
||||
|
||||
# adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L976
|
||||
@staticmethod
|
||||
def _encode_multi_process_worker(
|
||||
target_device: str, model: 'AbsEmbedder', input_queue: Queue, results_queue: Queue
|
||||
) -> None:
|
||||
"""
|
||||
Internal working process to encode sentences in multi-process setup
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
chunk_id, sentences, kwargs = (
|
||||
input_queue.get()
|
||||
)
|
||||
embeddings = model.encode_single_device(
|
||||
sentences,
|
||||
device=target_device,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
results_queue.put([chunk_id, embeddings])
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# copied from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L857
|
||||
@staticmethod
|
||||
def stop_multi_process_pool(pool: Dict[Literal["input", "output", "processes"], Any]) -> None:
|
||||
"""
|
||||
Stops all processes started with start_multi_process_pool.
|
||||
|
||||
Args:
|
||||
pool (Dict[str, object]): A dictionary containing the input queue, output queue, and process list.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for p in pool["processes"]:
|
||||
p.terminate()
|
||||
|
||||
for p in pool["processes"]:
|
||||
p.join()
|
||||
p.close()
|
||||
|
||||
pool["input"].close()
|
||||
pool["output"].close()
|
||||
pool = None
|
||||
|
||||
# adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L877
|
||||
def encode_multi_process(
|
||||
self,
|
||||
sentences: List[str],
|
||||
pool: Dict[Literal["input", "output", "processes"], Any],
|
||||
**kwargs
|
||||
):
|
||||
chunk_size = math.ceil(len(sentences) / len(pool["processes"]))
|
||||
|
||||
input_queue = pool["input"]
|
||||
last_chunk_id = 0
|
||||
chunk = []
|
||||
|
||||
for sentence in sentences:
|
||||
chunk.append(sentence)
|
||||
if len(chunk) >= chunk_size:
|
||||
input_queue.put(
|
||||
[last_chunk_id, chunk, kwargs]
|
||||
)
|
||||
last_chunk_id += 1
|
||||
chunk = []
|
||||
|
||||
if len(chunk) > 0:
|
||||
input_queue.put([last_chunk_id, chunk, kwargs])
|
||||
last_chunk_id += 1
|
||||
|
||||
output_queue = pool["output"]
|
||||
results_list = sorted(
|
||||
[output_queue.get() for _ in trange(last_chunk_id, desc="Chunks")],
|
||||
key=lambda x: x[0],
|
||||
)
|
||||
embeddings = self._concatenate_results_from_multi_process([result[1] for result in results_list])
|
||||
return embeddings
|
||||
|
||||
def _concatenate_results_from_multi_process(self, results_list: List[Union[torch.Tensor, np.ndarray, Any]]):
|
||||
"""concatenate and return the results from all the processes
|
||||
|
||||
Args:
|
||||
results_list (List[Union[torch.Tensor, np.ndarray, Any]]): A list of results from all the processes.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Unsupported type for results_list
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
|
||||
"""
|
||||
if isinstance(results_list[0], torch.Tensor):
|
||||
# move all tensors to the same device
|
||||
results_list = [res.to(self.target_devices[0]) for res in results_list]
|
||||
return torch.cat(results_list, dim=0)
|
||||
elif isinstance(results_list[0], np.ndarray):
|
||||
return np.concatenate(results_list, axis=0)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported type for results_list")
|
||||
|
|
@ -0,0 +1,360 @@
|
|||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Union, List, Tuple, Dict, Literal, Optional
|
||||
|
||||
import multiprocessing as mp
|
||||
from multiprocessing import Queue
|
||||
|
||||
import math
|
||||
import gc
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm, trange
|
||||
from transformers import is_torch_npu_available
|
||||
|
||||
try:
|
||||
import torch_musa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbsReranker(ABC):
|
||||
"""
|
||||
Base class for Reranker.
|
||||
Extend this class and implement :meth:`compute_score_single_gpu` for custom rerankers.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): If it's a path to a local model, it loads the model from the path. Otherwise tries to download and
|
||||
load a model from HuggingFace Hub with the name.
|
||||
use_fp16 (bool, optional): If true, use half-precision floating-point to speed up computation with a slight performance
|
||||
degradation. Defaults to :data:`False`.
|
||||
query_instruction_for_rerank: (Optional[str], optional): Query instruction for reranking, which will be used with
|
||||
with :attr:`query_instruction_format`. Defaults to :data:`None`.
|
||||
query_instruction_format: (str, optional): The template for :attr:`query_instruction_for_rerank`. Defaults to :data:`"{}{}"`.
|
||||
passage_instruction_for_rerank (Optional[str], optional): Passage instruction for reranking. Defaults to :data:`None`.
|
||||
passage_instruction_format (str, optional): Passage instruction format when using :attr:`passage_instruction_for_rerank`.
|
||||
Defaults to :data:`"{}{}"`.
|
||||
devices (Optional[Union[str, int, List[str], List[int]]], optional): Devices to use for model inference. Defaults to :data:`None`.
|
||||
batch_size (int, optional): Batch size for inference. Defaults to :data:`128`.
|
||||
query_max_length (int, optional): Maximum length for query. Defaults to :data:`None`.
|
||||
max_length (int, optional): Maximum length. Defaults to :data:`512`.
|
||||
normalize (bool, optional): If true, normalize the result. Defaults to :data:`False`.
|
||||
kwargs (Dict[Any], optional): Additional parameters for HuggingFace Transformers config or children classes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
use_fp16: bool = False,
|
||||
query_instruction_for_rerank: Optional[str] = None,
|
||||
query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_rerank
|
||||
passage_instruction_for_rerank: Optional[str] = None,
|
||||
passage_instruction_format: str = "{}{}", # specify the format of passage_instruction_for_rerank
|
||||
devices: Optional[Union[str, int, List[str], List[int]]] = None,
|
||||
# inference
|
||||
batch_size: int = 128,
|
||||
query_max_length: Optional[int] = None,
|
||||
max_length: int = 512,
|
||||
normalize: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.model_name_or_path = model_name_or_path
|
||||
self.use_fp16 = use_fp16
|
||||
self.query_instruction_for_rerank = query_instruction_for_rerank
|
||||
self.query_instruction_format = query_instruction_format
|
||||
self.passage_instruction_for_rerank = passage_instruction_for_rerank
|
||||
self.passage_instruction_format = passage_instruction_format
|
||||
self.target_devices = self.get_target_devices(devices)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.query_max_length = query_max_length
|
||||
self.max_length = max_length
|
||||
self.normalize = normalize
|
||||
|
||||
for k in kwargs:
|
||||
setattr(self, k, kwargs[k])
|
||||
|
||||
self.kwargs = kwargs
|
||||
|
||||
# tokenizer and model are initialized in the child class
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.pool = None
|
||||
|
||||
def stop_self_pool(self):
|
||||
if self.pool is not None:
|
||||
self.stop_multi_process_pool(self.pool)
|
||||
self.pool = None
|
||||
try:
|
||||
self.model.to('cpu')
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
if gc is not None and callable(gc.collect):
|
||||
gc.collect()
|
||||
|
||||
@staticmethod
|
||||
def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[str]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
devices (Union[str, int, List[str], List[int]]): Specified devices, can be `str`, `int`, list of `str`, or list of `int`.
|
||||
|
||||
Raises:
|
||||
ValueError: Devices should be a string or an integer or a list of strings or a list of integers.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of target devices in format
|
||||
"""
|
||||
if devices is None:
|
||||
if torch.cuda.is_available():
|
||||
return [f"cuda:{i}" for i in range(torch.cuda.device_count())]
|
||||
elif is_torch_npu_available():
|
||||
return [f"npu:{i}" for i in range(torch.npu.device_count())]
|
||||
elif hasattr(torch, "musa") and torch.musa.is_available():
|
||||
return [f"musa:{i}" for i in range(torch.musa.device_count())]
|
||||
elif torch.backends.mps.is_available():
|
||||
return ["mps"]
|
||||
else:
|
||||
return ["cpu"]
|
||||
elif isinstance(devices, str):
|
||||
return [devices]
|
||||
elif isinstance(devices, int):
|
||||
if hasattr(torch, "musa") and torch.musa.is_available():
|
||||
return [f"musa:{devices}"]
|
||||
else:
|
||||
return [f"cuda:{devices}"]
|
||||
elif isinstance(devices, list):
|
||||
if isinstance(devices[0], str):
|
||||
return devices
|
||||
elif isinstance(devices[0], int):
|
||||
if hasattr(torch, "musa") and torch.musa.is_available():
|
||||
return [f"musa:{device}" for device in devices]
|
||||
else:
|
||||
return [f"cuda:{device}" for device in devices]
|
||||
else:
|
||||
raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.")
|
||||
else:
|
||||
raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.")
|
||||
|
||||
def get_detailed_instruct(self, instruction_format: str, instruction: str, sentence: str):
|
||||
"""Combine the instruction and sentence along with the instruction format.
|
||||
|
||||
Args:
|
||||
instruction_format (str): Format for instruction.
|
||||
instruction (str): The text of instruction.
|
||||
sentence (str): The sentence to concatenate with.
|
||||
|
||||
Returns:
|
||||
str: The complete sentence with instruction
|
||||
"""
|
||||
if "\\n" in instruction_format:
|
||||
instruction_format = instruction_format.replace("\\n", "\n")
|
||||
return instruction_format.format(instruction, sentence)
|
||||
|
||||
def get_detailed_inputs(self, sentence_pairs: Union[str, List[str]]):
|
||||
"""get detailed instruct for all the inputs
|
||||
|
||||
Args:
|
||||
sentence_pairs (Union[str, List[str]]): Input sentence pairs
|
||||
|
||||
Returns:
|
||||
list[list[str]]: The complete sentence pairs with instruction
|
||||
"""
|
||||
if isinstance(sentence_pairs, str):
|
||||
sentence_pairs = [sentence_pairs]
|
||||
|
||||
if self.query_instruction_for_rerank is not None:
|
||||
if self.passage_instruction_for_rerank is None:
|
||||
return [
|
||||
[
|
||||
self.get_detailed_instruct(self.query_instruction_format, self.query_instruction_for_rerank, sentence_pair[0]),
|
||||
sentence_pair[1]
|
||||
] for sentence_pair in sentence_pairs
|
||||
]
|
||||
else:
|
||||
return [
|
||||
[
|
||||
self.get_detailed_instruct(self.query_instruction_format, self.query_instruction_for_rerank, sentence_pair[0]),
|
||||
self.get_detailed_instruct(self.passage_instruction_format, self.passage_instruction_for_rerank, sentence_pair[1])
|
||||
] for sentence_pair in sentence_pairs
|
||||
]
|
||||
else:
|
||||
if self.passage_instruction_for_rerank is None:
|
||||
return [
|
||||
[
|
||||
sentence_pair[0],
|
||||
sentence_pair[1]
|
||||
] for sentence_pair in sentence_pairs
|
||||
]
|
||||
else:
|
||||
return [
|
||||
[
|
||||
sentence_pair[0],
|
||||
self.get_detailed_instruct(self.passage_instruction_format, self.passage_instruction_for_rerank, sentence_pair[1])
|
||||
] for sentence_pair in sentence_pairs
|
||||
]
|
||||
|
||||
def compute_score(
|
||||
self,
|
||||
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
|
||||
**kwargs
|
||||
):
|
||||
"""Compute score for each sentence pair
|
||||
|
||||
Args:
|
||||
sentence_pairs (Union[List[Tuple[str, str]], Tuple[str, str]]): Input sentence pairs to compute.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: scores of all the sentence pairs.
|
||||
"""
|
||||
if isinstance(sentence_pairs[0], str):
|
||||
sentence_pairs = [sentence_pairs]
|
||||
sentence_pairs = self.get_detailed_inputs(sentence_pairs)
|
||||
|
||||
if isinstance(sentence_pairs, str) or len(self.target_devices) == 1:
|
||||
return self.compute_score_single_gpu(
|
||||
sentence_pairs,
|
||||
device=self.target_devices[0],
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if self.pool is None:
|
||||
self.pool = self.start_multi_process_pool()
|
||||
scores = self.encode_multi_process(sentence_pairs,
|
||||
self.pool,
|
||||
**kwargs)
|
||||
return scores
|
||||
|
||||
def __del__(self):
|
||||
self.stop_self_pool()
|
||||
|
||||
@abstractmethod
|
||||
def compute_score_single_gpu(
|
||||
self,
|
||||
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
|
||||
batch_size: int = 256,
|
||||
query_max_length: Optional[int] = None,
|
||||
max_length: int = 512,
|
||||
normalize: bool = False,
|
||||
device: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
This method should compute the scores of sentence_pair and return scores.
|
||||
"""
|
||||
pass
|
||||
|
||||
# copied from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L857
|
||||
def start_multi_process_pool(self) -> Dict[Literal["input", "output", "processes"], Any]:
|
||||
"""
|
||||
Starts a multi-process pool to process the encoding with several independent processes
|
||||
via :meth:`SentenceTransformer.encode_multi_process <sentence_transformers.SentenceTransformer.encode_multi_process>`.
|
||||
|
||||
This method is recommended if you want to encode on multiple GPUs or CPUs. It is advised
|
||||
to start only one process per GPU. This method works together with encode_multi_process
|
||||
and stop_multi_process_pool.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with the target processes, an input queue, and an output queue.
|
||||
"""
|
||||
logger.info("Start multi-process pool on devices: {}".format(", ".join(map(str, self.target_devices))))
|
||||
|
||||
self.model.to("cpu")
|
||||
self.model.share_memory()
|
||||
ctx = mp.get_context("spawn")
|
||||
input_queue = ctx.Queue()
|
||||
output_queue = ctx.Queue()
|
||||
processes = []
|
||||
|
||||
for device_id in tqdm(self.target_devices, desc='initial target device'):
|
||||
p = ctx.Process(
|
||||
target=AbsReranker._encode_multi_process_worker,
|
||||
args=(device_id, self, input_queue, output_queue),
|
||||
daemon=True,
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
|
||||
return {"input": input_queue, "output": output_queue, "processes": processes}
|
||||
|
||||
# copied from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L857
|
||||
def encode_multi_process(
|
||||
self,
|
||||
sentence_pairs: List,
|
||||
pool: Dict[Literal["input", "output", "processes"], Any],
|
||||
**kwargs
|
||||
) -> np.ndarray:
|
||||
chunk_size = math.ceil(len(sentence_pairs) / len(pool["processes"]))
|
||||
|
||||
input_queue = pool["input"]
|
||||
last_chunk_id = 0
|
||||
chunk = []
|
||||
|
||||
for sentence_pair in sentence_pairs:
|
||||
chunk.append(sentence_pair)
|
||||
if len(chunk) >= chunk_size:
|
||||
input_queue.put(
|
||||
[last_chunk_id, chunk, kwargs]
|
||||
)
|
||||
last_chunk_id += 1
|
||||
chunk = []
|
||||
|
||||
if len(chunk) > 0:
|
||||
input_queue.put([last_chunk_id, chunk, kwargs])
|
||||
last_chunk_id += 1
|
||||
|
||||
output_queue = pool["output"]
|
||||
results_list = sorted(
|
||||
[output_queue.get() for _ in trange(last_chunk_id, desc="Chunks")],
|
||||
key=lambda x: x[0],
|
||||
)
|
||||
scores = np.concatenate([result[1] for result in results_list])
|
||||
return scores
|
||||
|
||||
# copied from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L857
|
||||
@staticmethod
|
||||
def _encode_multi_process_worker(
|
||||
target_device: str, model: 'AbsReranker', input_queue: Queue, results_queue: Queue
|
||||
) -> None:
|
||||
"""
|
||||
Internal working process to encode sentences in multi-process setup
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
chunk_id, sentences, kwargs = (
|
||||
input_queue.get()
|
||||
)
|
||||
embeddings = model.compute_score_single_gpu(
|
||||
sentences,
|
||||
device=target_device,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
results_queue.put([chunk_id, embeddings])
|
||||
except:
|
||||
break
|
||||
|
||||
# copied from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L857
|
||||
@staticmethod
|
||||
def stop_multi_process_pool(pool: Dict[Literal["input", "output", "processes"], Any]) -> None:
|
||||
"""
|
||||
Stops all processes started with start_multi_process_pool.
|
||||
|
||||
Args:
|
||||
pool (Dict[str, object]): A dictionary containing the input queue, output queue, and process list.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for p in pool["processes"]:
|
||||
p.terminate()
|
||||
|
||||
for p in pool["processes"]:
|
||||
p.join()
|
||||
p.close()
|
||||
|
||||
pool["input"].close()
|
||||
pool["output"].close()
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
from .AbsEmbedder import AbsEmbedder
|
||||
from .AbsReranker import AbsReranker
|
||||
|
||||
__all__ = [
|
||||
'AbsEmbedder',
|
||||
'AbsReranker'
|
||||
]
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
from .arguments import AIRBenchEvalModelArgs, AIRBenchEvalArgs
|
||||
from .runner import AIRBenchEvalRunner
|
||||
|
||||
__all__ = [
|
||||
"AIRBenchEvalModelArgs",
|
||||
"AIRBenchEvalArgs",
|
||||
"AIRBenchEvalRunner"
|
||||
]
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
from transformers import HfArgumentParser
|
||||
|
||||
from FlagEmbedding.evaluation.air_bench import (
|
||||
AIRBenchEvalArgs, AIRBenchEvalModelArgs,
|
||||
AIRBenchEvalRunner
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((
|
||||
AIRBenchEvalArgs,
|
||||
AIRBenchEvalModelArgs
|
||||
))
|
||||
|
||||
eval_args, model_args = parser.parse_args_into_dataclasses()
|
||||
eval_args: AIRBenchEvalArgs
|
||||
model_args: AIRBenchEvalModelArgs
|
||||
|
||||
runner = AIRBenchEvalRunner(
|
||||
eval_args=eval_args,
|
||||
model_args=model_args
|
||||
)
|
||||
|
||||
runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
print("==============================================")
|
||||
print("Search results have been generated.")
|
||||
print("For computing metrics, please refer to the official AIR-Bench docs:")
|
||||
print("- https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/submit_to_leaderboard.md")
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
from air_benchmark import EvalArgs as AIRBenchEvalArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIRBenchEvalModelArgs:
|
||||
"""
|
||||
Evaluation Model arguments for AIR Bench.
|
||||
"""
|
||||
embedder_name_or_path: str = field(
|
||||
metadata={"help": "The embedder name or path.", "required": True}
|
||||
)
|
||||
embedder_model_class: Optional[str] = field(
|
||||
default=None, metadata={"help": "The embedder model class. Available classes: ['encoder-only-base', 'encoder-only-m3', 'decoder-only-base', 'decoder-only-icl']. Default: None. For the custom model, you need to specifiy the model class.", "choices": ["encoder-only-base", "encoder-only-m3", "decoder-only-base", "decoder-only-icl"]}
|
||||
)
|
||||
normalize_embeddings: bool = field(
|
||||
default=True, metadata={"help": "whether to normalize the embeddings"}
|
||||
)
|
||||
pooling_method: str = field(
|
||||
default="cls", metadata={"help": "The pooling method fot the embedder."}
|
||||
)
|
||||
use_fp16: bool = field(
|
||||
default=True, metadata={"help": "whether to use fp16 for inference"}
|
||||
)
|
||||
devices: Optional[str] = field(
|
||||
default=None, metadata={"help": "Devices to use for inference.", "nargs": "+"}
|
||||
)
|
||||
query_instruction_for_retrieval: Optional[str] = field(
|
||||
default=None, metadata={"help": "Instruction for query"}
|
||||
)
|
||||
query_instruction_format_for_retrieval: str = field(
|
||||
default="{}{}", metadata={"help": "Format for query instruction"}
|
||||
)
|
||||
examples_for_task: Optional[str] = field(
|
||||
default=None, metadata={"help": "Examples for task"}
|
||||
)
|
||||
examples_instruction_format: str = field(
|
||||
default="{}{}", metadata={"help": "Format for examples instruction"}
|
||||
)
|
||||
trust_remote_code: bool = field(
|
||||
default=False, metadata={"help": "Trust remote code"}
|
||||
)
|
||||
reranker_name_or_path: Optional[str] = field(
|
||||
default=None, metadata={"help": "The reranker name or path."}
|
||||
)
|
||||
reranker_model_class: Optional[str] = field(
|
||||
default=None, metadata={"help": "The reranker model class. Available classes: ['encoder-only-base', 'decoder-only-base', 'decoder-only-layerwise', 'decoder-only-lightweight']. Default: None. For the custom model, you need to specify the model class.", "choices": ["encoder-only-base", "decoder-only-base", "decoder-only-layerwise", "decoder-only-lightweight"]}
|
||||
)
|
||||
reranker_peft_path: Optional[str] = field(
|
||||
default=None, metadata={"help": "The reranker peft path."}
|
||||
)
|
||||
use_bf16: bool = field(
|
||||
default=False, metadata={"help": "whether to use bf16 for inference"}
|
||||
)
|
||||
query_instruction_for_rerank: Optional[str] = field(
|
||||
default=None, metadata={"help": "Instruction for query"}
|
||||
)
|
||||
query_instruction_format_for_rerank: str = field(
|
||||
default="{}{}", metadata={"help": "Format for query instruction"}
|
||||
)
|
||||
passage_instruction_for_rerank: Optional[str] = field(
|
||||
default=None, metadata={"help": "Instruction for passage"}
|
||||
)
|
||||
passage_instruction_format_for_rerank: str = field(
|
||||
default="{}{}", metadata={"help": "Format for passage instruction"}
|
||||
)
|
||||
model_cache_dir: str = field(
|
||||
default=None, metadata={"help": "Cache directory for models."}
|
||||
)
|
||||
# ================ for inference ===============
|
||||
embedder_batch_size: int = field(
|
||||
default=3000, metadata={"help": "Batch size for inference."}
|
||||
)
|
||||
reranker_batch_size: int = field(
|
||||
default=3000, metadata={"help": "Batch size for inference."}
|
||||
)
|
||||
embedder_query_max_length: int = field(
|
||||
default=512, metadata={"help": "Max length for query."}
|
||||
)
|
||||
embedder_passage_max_length: int = field(
|
||||
default=512, metadata={"help": "Max length for passage."}
|
||||
)
|
||||
reranker_query_max_length: Optional[int] = field(
|
||||
default=None, metadata={"help": "Max length for reranking."}
|
||||
)
|
||||
reranker_max_length: int = field(
|
||||
default=512, metadata={"help": "Max length for reranking."}
|
||||
)
|
||||
normalize: bool = field(
|
||||
default=False, metadata={"help": "whether to normalize the reranking scores"}
|
||||
)
|
||||
prompt: Optional[str] = field(
|
||||
default=None, metadata={"help": "The prompt for the reranker."}
|
||||
)
|
||||
cutoff_layers: List[int] = field(
|
||||
default=None, metadata={"help": "The output layers of layerwise/lightweight reranker."}
|
||||
)
|
||||
compress_ratio: int = field(
|
||||
default=1, metadata={"help": "The compress ratio of lightweight reranker."}
|
||||
)
|
||||
compress_layers: Optional[int] = field(
|
||||
default=None, metadata={"help": "The compress layers of lightweight reranker.", "nargs": "+"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# replace "\\n" with "\n"
|
||||
if "\\n" in self.query_instruction_format_for_retrieval:
|
||||
self.query_instruction_format_for_retrieval = self.query_instruction_format_for_retrieval.replace("\\n", "\n")
|
||||
if "\\n" in self.examples_instruction_format:
|
||||
self.examples_instruction_format = self.examples_instruction_format.replace("\\n", "\n")
|
||||
if "\\n" in self.query_instruction_format_for_rerank:
|
||||
self.query_instruction_format_for_rerank = self.query_instruction_format_for_rerank.replace("\\n", "\n")
|
||||
if "\\n" in self.passage_instruction_format_for_rerank:
|
||||
self.passage_instruction_format_for_rerank = self.passage_instruction_format_for_rerank.replace("\\n", "\n")
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "So, which AI model did the best on the MMMU benchmark according to Yue and his team back in 2023?", "pos": "MMMU (val) Gemini Ultra (0-shot) GPT-4V (0-shot)\nMaj@32 pass@1 pass@1\nArt & Design 74.2 70.0 65.8\nBusiness 62.7 56.7 59.3\nScience 49.3 48.0 54.7\nHealth & Medicine 71.3 67.3 64.7\nHumanities & Social Science 78.3 78.3 72.5\nTechnology & Engineering 53.0 47.1 36.7\nOverall 62.4 59.4 56.8\nTable 8|Gemini Ultra performance on the MMMU benchmark (Yue et al., 2023) per discipline."}
|
||||
{"query": "The GSPMD partitioner, part of the XLA compiler, is responsible for dividing the training step calculation.", "pos": "The GSPMD partitioner (Xu et al., 2021) in the XLA compiler\npartitions the training step computation, and the MegaScale XLA compiler (XLA, 2019) pass statically\nschedules appropriate collectives so that they maximally overlap with the computation with very little\nvariation in step time.\nMaintaining a high goodput2at this scale would have been impossible using the conventional\napproach of periodic checkpointing of weights to persistent cluster storage. For Gemini models, we\ninstead made use of redundant in-memory copies of the model state, and on any unplanned hardware\nfailures, we rapidly recover directly from an intact model replica."}
|
||||
{"query": "What's the impact of where you live and your social status on how well AI image labeling tech works?", "pos": "Thoughwedo\nnot see large discrepancies across different groups, we note that this metric is imperfect as the human\nreference captions could be inherently biased. Additionally, we perform a zero-shot classification style\nevaluation with the Dollarstreet dataset (Rojas et al., 2022) to measure discrepancies in performance\nacross images which come from different geographic locations. As is seen in previous work, we find\nthat models work less effectively for images from lower socioeconomic regions and regions outside\nNorth America and Europe. This is an area where we need further research and work to improve in\nfuture iterations of our models.\nIn addition to comparing performance on tasks across groups, we also consider how people are\ndescribed in captions."}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "What gauges the effects of data contamination?", "pos": "We also undertake a systematic study of “data contamination” – a growing problem when training high capacity models\non datasets such as Common Crawl, which can potentially include content from test datasets simply because such\ncontent often exists on the web. In this paper we develop systematic tools to measure data contamination and quantify\nits distorting effects. Although we find that data contamination has a minimal effect on GPT-3’s performance on most\ndatasets, we do identify a few datasets where it could be inflating results, and we either do not report results on these\ndatasets or we note them with an asterisk, depending on the severity."}
|
||||
{"query": "What strategies did the United States employ to convince Pakistan to exercise its influence over the Taliban?", "pos": "Direct\npressure on the Taliban had proved unsuccessful. As one NSC staff note\nput it, \"Under the Taliban, Afghanistan is not so much a state sponsor\nof terrorism as it is a state sponsored by terrorists.\" In early 2000,\nthe United States began a high-level effort to persuade Pakistan to use\nits influence over the Taliban. In January 2000, Assistant Secretary\nof State Karl Inderfurth and the State Department’s counterterrorism\ncoordinator, Michael Sheehan, met with General Musharraf in Islamabad,\ndangling before him the possibility of a presidential visit in March as a\nreward for Pakistani cooperation. Such a visit was coveted by Musharraf,\npartly as a sign of his government’s legitimacy."}
|
||||
{"query": "What does carrying rotten potatoes symbolize?", "pos": "The children started complaining about the\ntrouble loudly.\nThen Mrs. Smith told them why she asked them to play the game. She\nsaid,\"This is exactly the situation when you carry your hatred for somebody\ninside your heart. The terrible smell of the hatred will pollute your\nheart and you will carry something unnecessary with you all the time. If\nyou cannot stand the smell of the rotten potatoes for just two weeks, can\nyou imagine how heavy it would be to have the hatred in your heart for your\nlifetime? So throw away any hatred from your heart, and you’ll be really\nhappy.\"\nQ: Which of the following is True according to the passage?\nA: If a kid hated four people,he or she had to carry four potatoes."}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "Could you elucidate on the values of temperature and top-p that are utilized for pass@1 scores?", "pos": "8 62.8\n13B 18.3 60.2 30.6 69.0\n34B 22.6 77.2 33.0 76.1\n70B29.9 89.0 45.0 81.4\nTable 21: Code generation results on Human-Eval and MBPP . We report 0-shot and 3-shot results for\nHuman-Eval and MBPP respectively. For pass@100 and pass@80 scores, we use a temperature of 0.8 and\ntop-p=0.95. For pass@1 scores, we use a temperature of 0.1 and top- p=0.95.\n49"}
|
||||
{"query": "What do high safety scores and low helpfulness ratings suggest?", "pos": "Here we show more evidence and\nqualitative results to manifest this tension. Figure32 are two scatter plots of helpfulness and safety reward\nmodel scores on the safety test set for safe and unsafe responses. The tension can be observed at the bottom\nright corner (i.e., high safety score but low helpfulness score) in the safe response plot (left) and the top left\ncorner (i.e., low safety score but high helpfulness score) in the unsafe response plot (right). We also list two\nqualitative examples where safety and helpfulness reward models don’t agree with each other in Table 35."}
|
||||
{"query": "The process of carefully adjusting precautions relies on using challenging stimuli together with protected displays to make its operation run more smoothly.", "pos": "4.2 Safety Fine-Tuning\nIn this section, we describe our approach to safety fine-tuning, including safety categories, annotation\nguidelines,and the techniques we use to mitigate safety risks. We employ a process similar to the general\nfine-tuning methods as described in Section 3, with some notable differences related to safety concerns.\nSpecifically, we use the following techniques in safety fine-tuning:\n1.Supervised Safety Fine-Tuning : We initialize by gathering adversarial prompts and safe demonstra-\ntions that are then included in the general supervised fine-tuning process (Section 3.1). This teaches\nthe model to align with our safety guidelines even before RLHF,and thus lays the foundation for\nhigh-quality human preference data annotation."}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "What are the pre-training challenges for large language models?", "pos": "To make this survey more self-contained, we present the\ndetailed formulations for these configurations in Table 6.\nNormalization Methods. Training instability is a challeng-\ning issue for pre-training LLMs. To alleviate this issue,\nnormalization is a widely adopted strategy to stabilize the\ntraining of neural networks. In the vanilla Transformer [22],\nLayerNorm [256] is employed. Recently, several advanced\nnormalization techniques have been proposed as alterna-\ntives to LayerNorm, e.g., RMSNorm, and DeepNorm.\n•LayerNorm. In the early research, BatchNorm [265] is\na commonly used normalization method. However, it is\ndifficult to deal with sequence data of variable lengths and\nsmall-batch data."}
|
||||
{"query": "Language learning models seriously struggle to grasp complex symbols when they're thrown in scenarios they don't know jack about.", "pos": "For an example of\nthe out-of-domain test, LLMs could only see the examples\nwith two words in context, but it requires LLMs to concate-\nnate the last letters of three or more words. Typically, the\naccuracy of the generated symbols is adopted to evaluate\nthe performance of LLMs on these tasks. Thus, LLMs need\nto understand the semantic relations among the symbolic\noperations and their composition in complex scenarios.\nHowever, under the out-of-domain setting, as LLMs have\nnot seen the complex compositions of symbolic operations\nand rules ( e.g., twice the number of operations in context\nexamples), it is hard for LLMs to capture their accurate\nmeanings."}
|
||||
{"query": "Could you shed some light on the two primary ways that LLMs employ demonstrations as discussed in document 493?", "pos": "How LLMs Perform ICL? At the inference stage, researchers\nfocus on analyzing how the ICL capability operates based\non given demonstrations since no explicit learning or updat-\ning is involved. According to the discussion in [493], there\nare two main ways for LLMs to utilize demonstrations: task\nrecognition and task learning.\n•Task recognition. In the first way, LLMs recognize the\ntask from demonstrations and utilize the prior knowledge\nobtained from pre-training to solve new test tasks. A Proba-\nbly Approximately Correct (PAC) framework [494] has been\nproposed to assess the learnability of ICL."}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "Why is it believed the universe began at a particular time?", "pos": "According to a number of earlycosmologies and the Jewish/Christian/Muslim tradition, the universe started at a finite, and not very distant,time in the past. One argument for such a beginning was the feeling that it was necessary to have “First Cause”to explain the existence of the universe. (Within the universe, you always explained one event as being causedby some earlier event, but the existence of the universe itself could be explained in this way only if it had somebeginning.) Another argument was put forward by St. Augustine in his book The City of God. He pointed out\nthat civilization is progressing and we remember who performed this deed or developed that technique."}
|
||||
{"query": "Could you elucidate on the intricate procedure of stellar constitution?", "pos": "Andeven then it was a long time before the implications of the theory for massive stars were understood.To understand how a black hole might be formed, we first need an understanding of the life cycle of a star. A star isformed when a large amount of gas (mostly hydrogen) starts to collapse in on itself due to its gravitational attraction. Asit contracts, the atoms of the gas collide with each other more and more frequently and at greater and greater speeds –the gas heats up. Eventually, the gas will be so hot that when the hydrogen atoms collide they no longer bounce offeach other, but instead coalesce to form helium. The heat released in this reaction, which is like a controlled hydrogenbomb explosion, is what makes the star shine."}
|
||||
{"query": "Black hole existence evidence?", "pos": "the body that has collapsed must be lost when a black hole is formed, because afterward all we can possibly measureabout the body is its mass and rate of rotation. The significance of this will be seen in the next chapter.Black holes are one of only a fairly small number of cases in the history of science in which a theory was developed ingreat detail as a mathematical model before there was any evidence from observations that it was correct. Indeed, thisused to be the main argument of opponents of black holes: how could one believe in objects for which the onlyevidence was calculations based on the dubious theory of general relativity?"}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "So, like, what's the big deal about the top species from the bigger groups talked about in Chapter 4?", "pos": "In our second and fourth chapters, on Variation and on Natural Selection, I have attempted\nto show that it is the widely ranging, the much diffused and common, that is the dominant species\nbelonging to the larger genera, which vary most. The varieties, or incipient species, thus produced\nultimately become converted, as I believe, into new and distinct species; and these, on the principle\nof inheritance, tend to produce other new and dominant species. Consequently the groups which\nare now large, and which generally include many dominant species, tend to go on increasing\nindefinitely in size."}
|
||||
{"query": "Identify the unique species in the Chthamalinae subfamily of sessile cirripedes and the location of its fossil discovery.", "pos": "I suspect that but few of\nthe very many animals which live on the beach between high and low watermark are preserved.\nFor instance, the several species of the Chthamalinae (a sub-family of sessile cirripedes) coat the\nrocks all over the world in infinite numbers: they are all strictly littoral, with the exception of a\nsingle Mediterranean species, which inhabits deep water and has been found fossil in Sicily,\nwhereas not one other species has hitherto been found in any tertiary formation: yet it is now\nknown that the genus Chthamalus existed during the chalk period. The molluscan genus Chiton\noffers a partially analogous case."}
|
||||
{"query": "Why are there flaws in the geological record?", "pos": "Nor is their rarity surprising, when we remember how large a proportion of the\nbones of tertiary mammals have been discovered either in caves or in lacustrine deposits; and that\nnot a cave or true lacustrine bed is known belonging to the age of our secondary or palaeozoic\nformations.\nBut the imperfection in the geological record mainly results from another and more important cause\nthan any of the foregoing; namely, from the several formations being separated from each other by\nwide intervals of time. When we see the formations tabulated in written works, or when we follow\nthem in nature, it is difficult to avoid believing that they are closely consecutive."}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "What does 'O' represent in peptides?", "pos": "by changing the peptides to be amphiphilic or completely polar, they systematically synthesized several derived peptides. each of them has a different polar uncharged group : p11-8 (423, based on glutamine q, sequence ac-qqrfowofeqq-nh2 ; o represents ornithine), p11-12 (424, based on serine s, sequence ac-ssrfowofess- nh2), p11-16 (427, based on asparagine n, sequence ac-nnrfowofenn- nh2), and p11-18 (428, based on threonine t, sequence ac-ttrfowofett- nh2)."}
|
||||
{"query": "Could you elucidate on the system that was demonstrated by Van Esch and his team utilizing 1,3,5-triamide cyclohexane-based hydrogelators 67 for the alignment of nanofibers?", "pos": "used an electrical field to assist the alignment of the nanofibers and demonstrated that the application of a voltage bias, indeed, helps the directional orientation of the fibrils. using the 1,3,5-triamide cyclohexane-based hydrogelators 67, van esch et al. demonstrated an elegant system that forms well-defined nanostructures by the orthogonal self-assembly of hydrogelators and surfactants."}
|
||||
{"query": "What environmental factors influence peptide self-assembly?", "pos": "as pointed out by the authors, the hydrophobic effect between 267 molecules favors axial assembly and their electrostatic forces modulate lateral assembly. at a concentration of 0.05 wt %, the peptide self-assembles to form a filament consisting of about 120 molecules of 267. the authors also reported that various environmental factors (e.g., ph, salt, molecular crowding reagents, and peptides) can regulate the self-assembled filaments in an assembly of predictable manner, which provides useful insights for developing coiled coils as peptide-based materials. it would be interesting to know the proteolytic stability of these self-assembled filaments. besides native peptides acting as hydrogelators, peptide derivatives can also self-assemble in water to form hydrogels."}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "Why do we get different parameter sets when looking at how ion and water-oxygen atoms interact?", "pos": "the problem here is which one to chose to obtain a consistent set of parameters. the multiple parameter sets arise because similar aij and bij terms can be obtained between the ion and water-oxygen atoms : for a certain rmin/2 and , there will be a corresponding bigger rmin/2 and smaller , or a smaller rmin/2 and bigger , which yield similar aij and bij terms (see eqs 54 and 55). when two different ion parameter sets, which give similar aij terms between an ion and oxygen in water, are applied to the same biomolecule, they may give quite different aij terms between the same atom type on the biomolecule and the metal ion after applying the combining rules."}
|
||||
{"query": "Chemical physicists managed to mock-up ions in common force fields using the 12-6 Lennard-Jones model without any direct bonding, which verified the structure traits of water-based potassium.", "pos": "the 12-6 lj nonbonded model remains a fast and practical way to simulate ions using classical force fields. the blyp functional was used for the system containing a k ion and 59 water molecules. in total, 0.168 ps of equilibration and 1.98 ps of sampling were performed in the nve ensemble. good agreement between the cpmd and classical md simulations was obtained for the structural properties of aqueous k, validating in part the classical representation of the k ion. moreover, it has also shown that it is possible to simultaneously simulate two or more experimental properties for some of the monovalent ions (e.g., na, k, rb, cs) using the 12-6 lj nonbonded model."}
|
||||
{"query": "Which concepts are encapsulated within classical models in AMOEBA?", "pos": "they also proposed that the ct effect may need to be included to improve the model. ponder, ren, and co-workers have created the atomic multipole optimized energetics for biomolecular simulation (amoeba) force field. it has bonded terms (bond, angle, dihedral, and improper torsion terms) represented using classical models. the bond and angle parameters are fit on the basis of qm-derived values (e.g., geometries and vibrational frequencies). the electrostatic interaction is represented by permanent monopoles (point charges), dipoles, and quadrupoles derived from the distributed multipole analysis (dma) procedure, along with the polarizable dipoles."}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "In what manner does the transference of electric charge impact the process of dimerization?", "pos": "the natural bond orbital analysis suggests that the n(py) *(r x) charge transfer plays a key role in the formation of these dimers, while the symmetry-adapted perturbation theory energy decomposition analysis indicates that the xb in r xpyridine complexes is predominantly inductive in nature. halogen-bonded systems containing one or two xbs were analyzed by using the natural orbitals for chemical valence (nocv) method combined with the extended-transition-state (ets) method."}
|
||||
{"query": "What affects XB's susceptibility to steric hindrance?", "pos": "for the reason stated above, xb is, in general, more sensitive to steric hindrance than hb. in the infinite chain formed by 1,4-diiodotetrafluorobenzene with 4,4- and 2,2-bipyridine, the c in distances are 2.864 and 3.158 , respectively ; when 2,4-bipyridine forms heteromeric crystals with the same xb donor, only the 4-pyridyl nitrogen is halogen-bonded, and trimers are formed wherein the c , we will see how, in the formation of dna base pairs wherein xb substitutes for hb, the most stable pairing was given by bromine as the advantage offered by the greater polarizability of iodine was overwhelmed by the disadvantage resulting from its greater size."}
|
||||
{"query": "Are there any haloheteroarenes with iodine atoms?", "pos": "color code : carbon, gray ; nitrogen, blue ; iodine, purple ; fluorine, yellow. the most commonly used classes of haloheteroarenes are those containing nitrogen atom(s) in the ring. both neutral and positively charged haloheteroarenes can function as scaffolds for an xb donor site (figure 55) ; the cationic form is typically obtained by reacting the neutral form with an alkyl halide or a hydrogen halide, and the released anion works as an xb acceptor for the activated xb donor site."}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "How do intrinsic and extrinsic factors impact A42 aggregation, and why is drug development challenging for this process?", "pos": "indeed, it has been shown that the dominant mechanism for catalyzing the formation of toxic a42 species is surface-catalyzed secondary nucleation. in other words, once a small but critical concentration of a42 aggregates has been generated through primary nucleation of monomers, surface-catalyzed secondary nucleation becomes the dominant process where the surface of the existing fibrils serve as catalytic sites for the generation of toxic oligomeric species [ 54, 57 ]. furthermore, the role of intrinsic and extrinsic factors on the aggregation process of a42 has been partly unveiled and a great effort has been focused on drug development against a42 aggregation, which has proven to be very difficult [ 100, 101 ]."}
|
||||
{"query": "Excitons transfer energy and can jump to a higher state, then lose energy quickly, causing them to vanish.", "pos": "this process occurs at high excitation densities when one exciton transfers its energy to another exciton and brings it to a higher-energy excited state. the higher-energy excited state relaxes rapidly, and overall an exciton is lost. in so far as quenching occurs throughout the volume of the sample, this measurement resembles volume quenching, but with excitons acting as their own quenchers. in these experiments, typically high light intensities are used, far higher than used under solar illumination conditions, and the consequences of this have to be taken into account when analyzing the data."}
|
||||
{"query": "Causes of artifacts?", "pos": "however, a limiting step in the measurements of the intrinsic young s modulus of amyloid fibrillar aggregates on a surface is the correct evaluation of the cross-sectional moment of inertia i. recently, it was presented a general approach based on theory of elasticity and an innovative calculation of the polymorphic fibrillar aggregates cross-sectional moment of inertia i in order to evaluate correctly the nanomechanical properties of amyloids. this method enables to calculate bending rigidities b and matching the measured experimental values of young s modulus of amyloid fibrils [ 149, 165 ]. however, fibril imaging by afm requires deposition on a surface and drying, which can potentially lead to artifacts in the evaluation of the persistence length and bending rigidity."}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "What is the primary product of photodimerization?", "pos": "sensitization is also the preferred way to promote coumarin and many of its derivatives into the excited state. in the absence of an external olefin, [ 2 + 2 ] photodimerization occurs with the hh cis-anti-cis product (rac-317, figure 15) being the major product. in benzene as the solvent and with benzophenone as triplet sensitizer, yields over 90% were achieved by ding and co-workers. compound rac-317 served as the starting material for the synthesis of new phosphane ligands."}
|
||||
{"query": "What's a basic challenge in the [2 + 2] photocycloaddition reactions?", "pos": "the requirement of an aryl enone was a fundamental obstacle in the [ 2 + 2 ] photocycloaddition reactions, which limited the application of this methodology. in order to overcome this problem, the yoon group described a visible-light-induced [ 2 + 2 ] photocycloaddition reaction of,-unsaturated 2-imidazolyl ketones such as 483 (scheme 163, dbu = 1,8-diazabicyclo[5.4.0]undec-7-ene)."}
|
||||
{"query": "Dry AMD causes photoreceptors to break down because the retinal pigment epithelium, which supports retinal neurons, isn't working properly.", "pos": "intravitreal anti-vegf therapies have emerged as a standard of care to treat wet amd ; however, there is currently no fda-approved treatment available for the dry form. thus, safe and effective treatment of dry amd remains a critical unmet need. atrophic (dry) form of amd represents a slowly progressing neurodegenerative disorder of the eye in which specialized retinal neurons (rod and cone photoreceptors) degenerate in the central part of the retina called macula. histopathological and clinical data suggest that photoreceptor degeneration in dry amd is triggered by abnormalities in the retinal pigment epithelium (rpe) that lies beneath photoreceptors and provides critical metabolic support to these light-sensing neuronal cells."}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "Where can I get the NIST Standard Reference Materials Catalog?", "pos": "The calibration services, standard reference materials and related measurement services along with changes and fees are published in two Special Publications (SP's) and their supplements. These are SP 250 “Calibration and Related Measurement Services of the National Institute of Standards & Technology” 1\n\n and SP 260 “NIST Standard Reference Materials Catalog.” 1 A complete catalog of all publications by NIST authors is issued annually as a supplement to SP 305 “Publications of the National Institute of Standards & Technology.” Announcements and listings of recent NIST publications and services are published in each issue of the bimonthly “NIST Journal of Research” 2\n\n and the NIST monthly magazine, “Dimensions/NIST” 2."}
|
||||
{"query": "What is the acceptable tolerance level for cranberries?", "pos": "(1) Having determined the errors on each dimension and given to each its proper sign (see § 241.5), add the errors on the effective diameter of head and the distance between heads algebraically and multiply the result by 1.67 (or 5/3). Then add this result to the error on the circumference of bulge algebraically. If the result obtained is not greater than the tolerance given in the following table for the proper subdivision, then the barrel is within the tolerance allowed; if the result is greater than this tolerance, then the barrel is not within the tolerance allowed.\n\n \n\n \n\n \n\nSize of subdivision\n\nTolerance\n\nFor fruits, vegetables, and other dry commodities (inches)\n\nFor cranberries (inches)"}
|
||||
{"query": "What tools and techniques might be listed in solicitation announcements for specific industry sectors?", "pos": "Specific industry sectors to be addressed and sub-categories of tools and techniques may be specified in solicitations. These sectors or sub-categories will be specified in the solicitation announcement. Examples of tools and techniques include, but are not limited to, manufacturing assessment tools, environmental benchmarking tools, training delivery programs, electronically accessible environmental information resources, environmental demonstration facilities, software tools, etc. Projects must be completed within the scope of the effort proposed and should not require on-going federal support.\n\n (c) Award period. Projects initiated under this category may be carried out over up to three years. Proposals selected for award will receive all funding from currently available funds. If an application is selected for funding, DOC has no obligation to provide any additional future funding in connection with that award."}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "When is the deadline for quarterly returns per § 53.153(a)?", "pos": "[T.D. ATF-308, 56 FR 303, Jan. 3, 1991, as amended by T.D. ATF-330, 57 FR 40325, Sept. 3, 1992. Redesignated in part by T.D. ATF-365, 60 FR 33670, June 28, 1995]\n\n\n\n\n\n§ 53.153\n\nTime for filing returns.\n\n(a) Quarterly returns. Each return required to be made under § 53.151(a) for a return period of one calendar quarter shall be filed on or before the last day of the first calendar month following the close of the period for which it is made."}
|
||||
{"query": "How are federal tax liens enforced?", "pos": "The satisfaction of the levy described in paragraph (b) of this section by an insuring organization shall be without prejudice to any civil action for the enforcement of any Federal tax lien with respect to a life insurance or endowment contract. Thus, this levy procedure is not the exclusive means of subjecting the life insurance and endowment contracts of the person against whom a tax is assessed to the collection of the person's unpaid assessment. The United States may choose to foreclose the tax lien in any case where it is appropriate, as, for example, to reach the cash surrender value (as distinguished from cash loan value) of a life insurance or endowment contract.\n\n(e) Cross references."}
|
||||
{"query": "Taxpayers must compile detailed lists for each tax jurisdiction, including names, addresses, and tax classifications.", "pos": "(b) Multiple locations and/or classes of tax. A taxpayer subject to special tax for the same period at more than one location or for more than one class of tax must—\n\n(1) File one special tax return, TTB Form 5630.5t, with payment of tax, to cover all such locations and classes of tax; and\n\n(2) Prepare, in duplicate, a list identified with the taxpayer's name, address (as shown on TTB Form 5630.5t), employer identification number, and period covered by the return. The list must show, by State, the name, address, and tax class of each location for which special tax is being paid."}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "Define \"project.\"", "pos": "Private, as applied to an agency, organization, or institution, means that it is not under Federal or public supervision or control.\n\n\n\nProject means the activity described in an application.\n\n\n\nProject component means an activity, strategy, intervention, process, product, practice, or policy included in a project. Evidence may pertain to an individual project component or to a combination of project components (e.g., training teachers on instructional practices for English learners and follow-on coaching for these teachers).\n\n\n\nProject period means the period established in the award document during which Federal sponsorship begins and ends (See, 2 CFR 200.77 Period of performance)."}
|
||||
{"query": "How do you file and respond to written motions in legal proceedings?", "pos": "The ALJ may require that oral motions be reduced to writing.\n\n(c) Within 15 days after a written motion is served, or such other time as may be fixed by the ALJ, any party may file a response to the motion.\n\n(d) The ALJ may not grant a written motion before the time for filing responses to the motion has expired, except upon consent of the parties or following a hearing on the motion, but may overrule or deny the motion without awaiting a response.\n\n(e) The ALJ shall make a reasonable effort to dispose of all outstanding motions prior to the beginning of the hearing.\n\n(Authority: 31 U.S.C. 3803(g)(3)(A))"}
|
||||
{"query": "What does the Credit Enhancement for Charter School Facilities Program do?", "pos": "(3) Assist charter schools with the predevelopment costs required to assess sites for the purpose of acquiring (by purchase, lease, donation, or otherwise) an interest (including an interest held by a third party for the benefit of a charter school) in improved or unimproved real property or constructing new facilities, or renovating, repairing, or altering existing facilities, and that are necessary to commence or continue the operation of a charter school.\n\n (c) Grantees may demonstrate innovative credit enhancement initiatives while meeting the program purposes under paragraph (b) of this section.\n\n (d) For the purposes of these regulations, the Credit Enhancement for Charter School Facilities Program includes grants made under the Charter School Facilities Financing Demonstration Grant Program.\n\n [70 FR 15003, Mar."}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "What happens to the loss?", "pos": "This loss shall be the measured loss less the net gain of any voice frequency repeaters in the circuit. Testing shall also be conducted to verify that the loss increases gradually as the frequency increases. The loss on H88 loaded loops should be down only slightly at 2.8 kHz but drop rapidly above 2.8 kHz. The loss on D66 loaded loops shall be fairly constant to about 3.4 kHz and there shall be good response at 4.0 kHz. When voice frequency repeaters are in the circuit there will be some frequency weighting in the build-out network and the loss at the higher frequencies will be greater than for nonrepeatered loops."}
|
||||
{"query": "We'll let all borrowers know by mail or email whenever there's a new Federal Register document about contract forms.", "pos": "The amendment may change the existing identification of a listed contract form; for example, changing the issuance date of a listed contract form or by identifying a new required contract form. The notice of rulemaking will describe the new standard contract form or the substantive change in the listed contract form, as the case may be, and the issues involved. The standard contract form or relevant portions thereof may be appended to the supplementary information section of the notice of rulemaking. As appropriate, the notice of rulemaking shall provide an opportunity for interested persons to provide comments. A copy of each such Federal Register document shall be sent by regular or electronic mail to all borrowers.\n\n[63 FR 58285, Oct. 30, 1998]"}
|
||||
{"query": "Which renewable energy projects can get funding through grant proposals?", "pos": "A grant project is eligible if it improves, or maintains energy services, or reduces the costs of providing energy services to eligible communities. Examples of eligible activities include, but are not limited to, the acquisition, construction, replacement, repair, or improvement of:\n\n(a) Electric generation, transmission, and distribution facilities, equipment, and services serving the eligible community;\n\n(b) Natural gas distribution or storage facilities and associated equipment and activities serving the eligible community;\n\n(c) Petroleum product storage and handling facilities serving residential or community use.\n\n(d) Renewable energy facilities used for on-grid or off-grid electric power generation, water or space heating, or process heating and power for the eligible community;\n\n(e) Backup up or emergency power generation or energy storage equipment, including distributed generation, to serve the eligible community; and"}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "How does the amount of energy affect how stuff scatters?", "pos": "it has been suggested that one may construct a lorentz - invariant noncommutative field theory by extending the coordinate algebra to additional, fictitious coordinates that transform nontrivially under the lorentz group. integration over these coordinates in the action produces a four - dimensional effective theory with lorentz invariance intact. previous applications of this approach, in particular to a specific construction of noncommutative qed, have been studied only in a low - momentum approximation. here we discuss lorentz - invariant field theories in which the relevant physics can be studied without requiring an expansion in the inverse scale of noncommutativity. qualitatively, we find that tree - level scattering cross sections are dramatically suppressed as the center - of - mass energy exceeds the scale of noncommutativity, that cross sections that are isotropic in the commutative limit can develop a pronounced angular dependence, and that nonrelativistic potentials (for example, the coloumb potential) become nonsingular at the origin. we consider a number of processes in noncommutative qed that may be studied at a future linear collider. we also give an example of scattering via a four - fermion operator in which the noncommutative modifications of the interaction can unitarize the tree - level amplitude, without requiring any other new physics in the ultraviolet."}
|
||||
{"query": "Why go for canonical instead of grand-canonical?", "pos": "the production of hadrons in relativistic heavy ion collisions is studied using a statistical ensemble with thermal and chemical equilibrium. special attention is given to exact conservation laws, i.e. certain charges are treated canonically instead of using the usual grand canonical approach. for small systems, the exact conservation of baryon number, strangeness and electric charge is to be taken into account. we have derived compact, analytical expressions for particle abundances in such ensemble. as an application, the change in @xmath0 ratios in ags experiments with different interaction system sizes is well reproduced. the canonical treatment of three charges becomes impractical very quickly with increasing system size. thus, we draw our attention to exact conservation of strangeness, and treat baryon number and electric charge grand canonically. we present expressions for particle abundances in such ensemble as well, and apply them to reproduce the large variety of particle ratios in gsi sis 2 a gev ni ni experiments. at the energies considered here, the exact strangeness conservation fully accounts for strange particle suppression, and no extra chemical factor is needed. [ on the exact conservation laws in thermal models ]"}
|
||||
{"query": "How good is the mean-field approximation at guessing the ground state features of molecular stuff?", "pos": "we present a model for molecular materials made up of polar and polarizable molecular units. a simple two state model is adopted for each molecular site and only classical intermolecular interactions are accounted for, neglecting any intermolecular overlap. the complex and interesting physics driven by interactions among polar and polarizable molecules becomes fairly transparent in the adopted model. collective effects are recognized in the large variation of the molecular polarity with supramolecular interactions, and cooperative behavior shows up with the appearance, in attractive lattices, of discontinuous charge crossovers. the mf approximation proves fairly accurate in the description of the gs properties of mm, including static linear and non - linear optical susceptibilities, apart from the region in the close proximity of the discontinuous charge crossover. sizeable deviations from the excitonic description are recognized both in the excitation spectrum and in linear and non - linear optical responses. new and interesting phenomena are recognized near the discontinuous charge crossover for non - centrosymmetric clusters, where the primary photoexcitation event corresponds to a multielectron transfer."}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "What is the effect of the LME Singapore Contract on trade dynamics?", "pos": "The London Metal Exchange's, LME,\ndecision to introduce a dollar-denominated aluminium contract,\nwith the Port of Singapore listed as a delivery point, is a\npositive move, physical traders and LME dealers said.\n Earlier this week the LME declared that a 99.70 pct minimum\npurity aluminium contract would commence trading on June 1,\n1987, alongside its long-established sterling-based 99.50 pct\ncontract.\n This is the LME's first dollar contract and non-European\ndelivery point, and the Board and Committee are looking at\nSingapore as a delivery point for other contracts.\n Trade sources said the LME's new contract will conform with\nexisting industry practice, where 99.70 standard re-melt\nmaterial, priced in dollars, is most commonly traded.\n The location of a warehouse in Singapore is also a positive\nmove by the LME, given its ideal location for Australian and\nJapanese traders, who would be able to place metal on to\nwarrant speedily and relatively inexpensively, they said.\n Hedging during the LME ring sessions becomes much simpler\nwith a dollar contract. At present pre-market trading is almost\nexclusively dollar-based, but currency conversions have to be\ndone during the sterling rings, they added.\n LME ring dealers said the new contract would match more\nclosely trade requirements and possibly alleviate some of the\nrecent wide backwardations.\n Very little physical business is now done in 99.50 pct\npurity metal, nearly all of which is produced in Eastern Bloc\ncountries, such as Romania.\n The Soviet Union also produces 99.50 pct, but has declined\nas an exporter recently, they said.\n Some dealers said the new 99.70 contract may suffer from\nliquidity problems initially, as business may continue to\ncentre on the present good ordinary brand (gob) contract, where\nthere are many holders of large short positions on the LME.\n But others said the new contract would soon attract trading\ninterest, given that much 99.70 metal has already been\nattracted to the LME's warehouses by backwardations.\n The LME also has a much more viable liquidity base for a\nnew contract, compared to the Comex market in New York, where\nhigh grade aluminium futures are not particularly active, they\nsaid.\n Thus, it seems likely that the sterling contract will\neventually lose trading interest and volumes will decline. Like\nstandard zinc, which was superseded by a high grade contract,\ngob aluminium will probably be replaced, although the process\nin this case may take longer, they added.\n Forming a new contract and establishing a Singapore\nwarehouse are constructive moves by the LME but backwardations,\nwhich make physical trading difficult, would not totally\ndisappear as a result, the trade sources said.\n These premiums for prompt metal have become a\nsemi-permanent feature over the last year, due to increased\nbusiness and volatility in traded options, and are presently\naround 50 stg.\n Increasingly large granting of option positions has been\ntaking place. When some of these are declared and exercised at\nthe end of the relevant month, physical tightness and squeezes\naround these dates are commonplace, they said.\n Listing Singapore as a delivery point allows Far Eastern\noperators to deliver aluminium into a LME warehouse instead of\nhaving to cover.\n But tightness and backwardations are seen continuing, even\nthough the LME's new option contracts widen the gap between the\ndeclaration and prompt dates.\n These will be due on the first and third Wednesday of the\nmonth, whereas at present most fall on the 20th and 25th.\n Backwardations will remain while operators continue to\ngrant options where potential tonnage to be delivered exceeds\naluminium stock levels, an LME option trader said.\n Reuter\n"}
|
||||
{"query": "Please provide the estimated quantity of the broad monetary aggregate designated as M-3, which encompasses the extensive range of financial assets held principally by households, as recorded in the month of February.", "pos": "South African year-on-year broadly\ndefined M-3 money supply growth slowed to 8.62 pct in January\nfrom 9.32 pct in December, Reserve Bank figures show.\n M-3 fell to 77.98 billion rand in January from 79.31\nbillion in December, while preliminary February figures show\nM-3 at 79.42 billion rand for a year-on-year rise of 10.63 pct.\n M-2 showed a rise of 5.09 pct for January at 55.68 billion\nrand after 4.30 pct in December, M-1 16.72 pct at 5.12 billion\nafter 12.80 pct and M-1A 22.79 pct at 14.30 billion rand after\n20.54 pct.\n REUTER\n"}
|
||||
{"query": "When did Reagan impose tariffs?", "pos": "The White House issued a\nlist of Japanese exports to covered by the 100 pct tariffs\nimposed by President Reagan.\n - Automatic data processing machines (1986 imports worth\n180 mln dlrs), including certain desk and lap models with\nmicroprocessor-based calculating mechanism capable of handling\nwords of at least 16-bits off the microprocessor;\n - Complete color television sets, with 18, 19 or 20 inch\nscreens (1986 imports 90 mln dlrs);\n - Power tools, including certain drills, percussion\nhammers, sanders, polishers, grinders.\n Reuter\n"}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "Which technique was employed to assess the blood pressure in Wistar rats subjected to various sodium intake regimens?", "pos": "Male Wistar rats were fed on normal- (0.5% Na(+); NS), high- (3.12% Na(+); HS),or low-sodium (0.06% Na(+); LS) diets for 3, 6, and 9 weeks after weaning. Blood pressure (BP) was measured using a computerized tail-cuff system. An intravenous insulin tolerance test (ivITT) was performed in fasted animals. At the end of each period, rats were killed and blood samples were collected for glucose and insulin determinations. The white adipose tissue (WAT) from abdominal and inguinal subcutaneous (SC) and periepididymal (PE) depots were weighed and processed for adipocyte isolation and measurement of in vitro rates of insulin-stimulated 2-deoxy-D-[(3)H]-glucose uptake (2DGU) and conversion of -[U-(14)C]-glucose into (14)CO(2)."}
|
||||
{"query": "How long were the kids treated with chemo for their stomach lymphoma?", "pos": "Only two patients, 5 and 12 years old, with primary gastric NHL were found. Upper gastroduodenal endoscopy detected an ulcer in the lesser curvature of the body of the stomach, in both cases. Endoscopy revealed a moderate chronic gastritis in the antrum of both patients that was H. pylori associated in one of them who also suffered from chronic gastritis. Biopsy specimens demonstrated infiltration by Burkitt lymphoma (BL). The two patients received chemotherapy for 6 months. Additionally, one of the two patients received a triple therapy regimen with bismuth, amoxicillin, and metronidazole for H. pylori. Fifteen and six years later they are in complete remission, free of symptoms."}
|
||||
{"query": "What are the correlations between the volume of tissue resected and the resulting clinical outcomes?", "pos": "Between May 2011 and April 2013, LSG was performed in 102 consecutive patients undergoing bariatric surgery. Two patients were excluded, and data from the remaining 100 patients were analyzed in this study. Patients were divided into three groups according to the following resected stomach volume: 700-1,200 mL (group A, n = 21), 1,200-1,700 mL (group B, n = 62), and>1,700 mL (group C, n = 17). Mean values were compared among the groups by analysis of variance."}
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "Corn on the cob boiling time?", "pos": "Corn on the Cob - Boiled In a large pot, enough to hold the corn, fill it with water to cover the corn (the corn should float). On a medium heat allow the pot of water to boil. Once the water is boiled, add in the corn into the pot and cover. Cook for 10-15 minutes depending on how soft you want your corn. Drain water and remove corn on the cob."}
|
||||
{"query": "Nitrous oxide is commonly used as an anesthetic or analgesic in medical and dental procedures.", "pos": "Nitrous oxide Nitrous oxide has significant medical uses, especially in surgery and dentistry, for its anaesthetic and analgesic effects. Its name laughing gas is due to the euphoric effects of inhaling it, a property that has led to its recreational use as a dissociative anaesthetic."}
|
||||
{"query": "At what temp do you start to roast?", "pos": "How long to cook 2.3 lb pork tenderloin in oven? Best Answer: For your seasoned pork loin, preheat your oven to 400 degrees F or (200C). Place the seasoned pork in the preheated oven and immediately turn the oven down to 350F (175C). Roast the pork loin or tenderloin for about 70-90 minutes or until it reaches an internal temperature of 145-150F (73-75C) degrees. If you prefer your pork cooked to medium well, cook it to an internal temperature of 155-160F (78-80C) degrees."}
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,3 @@
|
|||
{"query": "Identify the principal landmarks that are emblematic of the Baha'i religious tradition.", "pos": "House of Baha'u'llah, Baghdad\n\"Grieve not, O House of God, if the veil of thy sanctity be rent asunder by the infidels. God hath, in the world of creation, adorned thee with the jewel of His remembrance. Such an ornament no man can, at any time, profane. Towards thee the eyes of thy Lord shall, under all conditions, remain directed. He, verily, will incline His ear to the prayer of every one that visiteth thee, who will circle around thee, and calleth upon Him in thy name. He, in truth, is the Forgiving, the All-Merciful.\"\n(Gleanings from the Writings of Bahá’u’lláh, LVII, part 7)"}
|
||||
{"query": "Which type of healthcare professional should one consult regarding the sensation of tingling in the feet?", "pos": "“Tingly feet\" can be a sign of nerve loss. The nerves in the feet come from the lower back. Pressure or chemical change in the nerve can cause a tingling sensation in the feet. Any sensation that is out of the ordinary can be an early sign of neurologic or vascular problems. In addition to tingling, feet may feel numb or feel like they are \"falling asleep.\" There may also be a burning sensation in the feet.\nDiabetes is one of the most common medical conditions with which \"tingly feet\" can be associated. A thorough evaluation by a foot and ankle surgeon is advised to determine the cause of \"tingly feet.\"\nSee also Diabetic Peripheral Neuropathy."}
|
||||
{"query": "How big is the old Kaguru Basket?", "pos": "Home — Vintage Kaguru Basket from Tanzania - 15\" x 10.5\"\nVintage Kaguru Basket from Tanzania - 15\" x 10.5\"\nFrom Tanzania, East Africa, these baskets are used the same way all other similar baskets are used everywhere else in Africa. They are the primary vessel for storage and transportation of grain, fruit, vegetables and any other food item. Baskets of this type are often seen in markets containing food for sale there. The patina on the rims and side of this basket suggests it was handled often. There is residue of some sort on the interior surfaces which proves they were often in use in the way I have described.\nAfter 36 years of traveling in Africa and buying similar African items, we thought we had seen it all. These are truly amazing baskets, at least to us.\nThis basket measures 15\" x 10.5\" (38cm x 26.75cm)"}
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,65 @@
|
|||
from typing import Union, Tuple
|
||||
from air_benchmark import AIRBench
|
||||
|
||||
from FlagEmbedding.abc.evaluation import (
|
||||
AbsEvalRunner,
|
||||
EvalDenseRetriever, EvalReranker
|
||||
)
|
||||
|
||||
from .arguments import AIRBenchEvalArgs, AIRBenchEvalModelArgs
|
||||
|
||||
|
||||
class AIRBenchEvalRunner:
|
||||
"""
|
||||
Evaluation runner for AIR Bench.
|
||||
|
||||
Args:
|
||||
eval_args (AIRBenchEvalArgs): :class:AIRBenchEvalArgs object with the evaluation arguments.
|
||||
model_args (AIRBenchEvalModelArgs): :class:AIRBenchEvalModelArgs object with the model arguments.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
eval_args: AIRBenchEvalArgs,
|
||||
model_args: AIRBenchEvalModelArgs,
|
||||
):
|
||||
self.eval_args = eval_args
|
||||
self.model_args = model_args
|
||||
self.model_args.cache_dir = model_args.model_cache_dir
|
||||
|
||||
self.retriever, self.reranker = self.load_retriever_and_reranker()
|
||||
|
||||
def load_retriever_and_reranker(self) -> Tuple[EvalDenseRetriever, Union[EvalReranker, None]]:
|
||||
"""Load retriever and reranker for evaluation
|
||||
|
||||
Returns:
|
||||
Tuple[EvalDenseRetriever, Union[EvalReranker, None]]: A :class:EvalDenseRetriever object for retrieval, and a
|
||||
:class:EvalReranker object if reranker provided.
|
||||
"""
|
||||
embedder, reranker = AbsEvalRunner.get_models(self.model_args)
|
||||
retriever = EvalDenseRetriever(
|
||||
embedder,
|
||||
search_top_k=self.eval_args.search_top_k,
|
||||
overwrite=self.eval_args.overwrite
|
||||
)
|
||||
if reranker is not None:
|
||||
reranker = EvalReranker(reranker, rerank_top_k=self.eval_args.rerank_top_k)
|
||||
return retriever, reranker
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run the whole evaluation.
|
||||
"""
|
||||
evaluation = AIRBench(
|
||||
benchmark_version=self.eval_args.benchmark_version,
|
||||
task_types=self.eval_args.task_types,
|
||||
domains=self.eval_args.domains,
|
||||
languages=self.eval_args.languages,
|
||||
splits=self.eval_args.splits,
|
||||
cache_dir=self.eval_args.cache_dir,
|
||||
)
|
||||
evaluation.run(
|
||||
self.retriever,
|
||||
reranker=self.reranker,
|
||||
output_dir=self.eval_args.output_dir,
|
||||
overwrite=self.eval_args.overwrite,
|
||||
)
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
from FlagEmbedding.abc.evaluation import (
|
||||
AbsEvalModelArgs as BEIREvalModelArgs,
|
||||
)
|
||||
|
||||
from .data_loader import BEIREvalDataLoader
|
||||
from .arguments import BEIREvalArgs
|
||||
from .runner import BEIREvalRunner
|
||||
|
||||
__all__ = [
|
||||
"BEIREvalArgs",
|
||||
"BEIREvalModelArgs",
|
||||
"BEIREvalRunner",
|
||||
"BEIREvalDataLoader",
|
||||
]
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
from transformers import HfArgumentParser
|
||||
|
||||
from FlagEmbedding.evaluation.beir import (
|
||||
BEIREvalArgs, BEIREvalModelArgs,
|
||||
BEIREvalRunner
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((
|
||||
BEIREvalArgs,
|
||||
BEIREvalModelArgs
|
||||
))
|
||||
|
||||
eval_args, model_args = parser.parse_args_into_dataclasses()
|
||||
eval_args: BEIREvalArgs
|
||||
model_args: BEIREvalModelArgs
|
||||
|
||||
runner = BEIREvalRunner(
|
||||
eval_args=eval_args,
|
||||
model_args=model_args
|
||||
)
|
||||
|
||||
runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
from FlagEmbedding.abc.evaluation.arguments import AbsEvalArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class BEIREvalArgs(AbsEvalArgs):
|
||||
"""
|
||||
Argument class for BEIR evaluation.
|
||||
"""
|
||||
use_special_instructions: bool = field(
|
||||
default=False, metadata={"help": "Whether to use specific instructions in `prompts.py` for evaluation. Default: False"}
|
||||
)
|
||||
|
|
@ -0,0 +1,471 @@
|
|||
import os
|
||||
import json
|
||||
import logging
|
||||
import datasets
|
||||
from tqdm import tqdm
|
||||
from typing import List, Optional
|
||||
from beir import util
|
||||
from beir.datasets.data_loader import GenericDataLoader
|
||||
|
||||
from FlagEmbedding.abc.evaluation import AbsEvalDataLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BEIREvalDataLoader(AbsEvalDataLoader):
|
||||
"""
|
||||
Data loader class for BEIR.
|
||||
"""
|
||||
def available_dataset_names(self) -> List[str]:
|
||||
"""
|
||||
Get the available dataset names.
|
||||
|
||||
Returns:
|
||||
List[str]: All the available dataset names.
|
||||
"""
|
||||
return ['arguana', 'climate-fever', 'cqadupstack', 'dbpedia-entity', 'fever', 'fiqa', 'hotpotqa', 'msmarco', 'nfcorpus', 'nq', 'quora', 'scidocs', 'scifact', 'trec-covid', 'webis-touche2020']
|
||||
|
||||
def available_sub_dataset_names(self, dataset_name: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Get the available sub-dataset names.
|
||||
|
||||
Args:
|
||||
dataset_name (Optional[str], optional): All the available sub-dataset names. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
List[str]: All the available sub-dataset names.
|
||||
"""
|
||||
if dataset_name == 'cqadupstack':
|
||||
return ['android', 'english', 'gaming', 'gis', 'mathematica', 'physics', 'programmers', 'stats', 'tex', 'unix', 'webmasters', 'wordpress']
|
||||
return None
|
||||
|
||||
def available_splits(self, dataset_name: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Get the avaialble splits.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Dataset name.
|
||||
|
||||
Returns:
|
||||
List[str]: All the available splits for the dataset.
|
||||
"""
|
||||
if dataset_name == 'msmarco':
|
||||
return ['dev']
|
||||
return ['test']
|
||||
|
||||
def _load_remote_corpus(
|
||||
self,
|
||||
dataset_name: str,
|
||||
sub_dataset_name: Optional[str] = None,
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load the corpus dataset from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
sub_dataset_name (Optional[str]): Name of the sub-dataset. Defaults to ``None``.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of corpus.
|
||||
"""
|
||||
if dataset_name != 'cqadupstack':
|
||||
corpus = datasets.load_dataset(
|
||||
'BeIR/{d}'.format(d=dataset_name),
|
||||
'corpus',
|
||||
trust_remote_code=True,
|
||||
cache_dir=self.cache_dir,
|
||||
download_mode=self.hf_download_mode
|
||||
)['corpus']
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, "corpus.jsonl")
|
||||
corpus_dict = {}
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for data in tqdm(corpus, desc="Loading and Saving corpus"):
|
||||
_data = {
|
||||
"id": data["_id"],
|
||||
"title": data["title"],
|
||||
"text": data["text"]
|
||||
}
|
||||
corpus_dict[data["_id"]] = {
|
||||
"title": data["title"],
|
||||
"text": data["text"]
|
||||
}
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} corpus saved to {save_path}")
|
||||
else:
|
||||
corpus_dict = {data["docid"]: {"title": data["title"], "text": data["text"]} for data in tqdm(corpus, desc="Loading corpus")}
|
||||
else:
|
||||
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset_name)
|
||||
data_path = util.download_and_unzip(url, self.cache_dir)
|
||||
full_path = os.path.join(data_path, sub_dataset_name)
|
||||
corpus, _, _ = GenericDataLoader(data_folder=full_path).load(split="test")
|
||||
if save_dir is not None:
|
||||
new_save_dir = os.path.join(save_dir, sub_dataset_name)
|
||||
os.makedirs(new_save_dir, exist_ok=True)
|
||||
save_path = os.path.join(new_save_dir, "corpus.jsonl")
|
||||
corpus_dict = {}
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for _id in tqdm(corpus.keys(), desc="Loading corpus"):
|
||||
_data = {
|
||||
"id": _id,
|
||||
"title": corpus[_id]["title"],
|
||||
"text": corpus[_id]["text"]
|
||||
}
|
||||
corpus_dict[_id] = {
|
||||
"title": corpus[_id]["title"],
|
||||
"text": corpus[_id]["text"]
|
||||
}
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} corpus saved to {save_path}")
|
||||
else:
|
||||
corpus_dict = {_id: {"title": corpus[_id]["title"], "text": corpus[_id]["text"]} for _id in tqdm(corpus.keys(), desc="Loading corpus")}
|
||||
return datasets.DatasetDict(corpus_dict)
|
||||
|
||||
def _load_remote_qrels(
|
||||
self,
|
||||
dataset_name: Optional[str] = None,
|
||||
sub_dataset_name: Optional[str] = None,
|
||||
split: str = 'dev',
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load the qrels from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
sub_dataset_name (Optional[str]): Name of the sub-dataset. Defaults to ``None``.
|
||||
split (str, optional): Split of the dataset. Defaults to ``'dev'``.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of qrel.
|
||||
"""
|
||||
if dataset_name != 'cqadupstack':
|
||||
qrels = datasets.load_dataset(
|
||||
'BeIR/{d}-qrels'.format(d=dataset_name),
|
||||
split=split if split != 'dev' else 'validation',
|
||||
trust_remote_code=True,
|
||||
cache_dir=self.cache_dir,
|
||||
download_mode=self.hf_download_mode
|
||||
)
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, f"{split}_qrels.jsonl")
|
||||
qrels_dict = {}
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for data in tqdm(qrels, desc="Loading and Saving qrels"):
|
||||
qid, docid, rel = str(data['query-id']), str(data['corpus-id']), int(data['score'])
|
||||
_data = {
|
||||
"qid": qid,
|
||||
"docid": docid,
|
||||
"relevance": rel
|
||||
}
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
qrels_dict[qid][docid] = rel
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} qrels saved to {save_path}")
|
||||
else:
|
||||
qrels_dict = {}
|
||||
for data in tqdm(qrels, desc="Loading queries"):
|
||||
qid, docid, rel = str(data['query-id']), str(data['corpus-id']), int(data['score'])
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
qrels_dict[qid][docid] = rel
|
||||
else:
|
||||
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset_name)
|
||||
data_path = util.download_and_unzip(url, self.cache_dir)
|
||||
full_path = os.path.join(data_path, sub_dataset_name)
|
||||
_, _, qrels = GenericDataLoader(data_folder=full_path).load(split="test")
|
||||
if save_dir is not None:
|
||||
new_save_dir = os.path.join(save_dir, sub_dataset_name)
|
||||
os.makedirs(new_save_dir, exist_ok=True)
|
||||
save_path = os.path.join(new_save_dir, f"{split}_qrels.jsonl")
|
||||
qrels_dict = {}
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for qid in tqdm(qrels.keys(), desc="Loading and Saving qrels"):
|
||||
for docid in tqdm(qrels[qid].keys()):
|
||||
rel = int(qrels[qid][docid])
|
||||
_data = {
|
||||
"qid": qid,
|
||||
"docid": docid,
|
||||
"relevance": rel
|
||||
}
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
qrels_dict[qid][docid] = rel
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} qrels saved to {save_path}")
|
||||
else:
|
||||
qrels_dict = {}
|
||||
for qid in tqdm(qrels.keys(), desc="Loading qrels"):
|
||||
for docid in tqdm(qrels[qid].keys()):
|
||||
rel = int(qrels[qid][docid])
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
qrels_dict[qid][docid] = rel
|
||||
return datasets.DatasetDict(qrels_dict)
|
||||
|
||||
def _load_remote_queries(
|
||||
self,
|
||||
dataset_name: Optional[str] = None,
|
||||
sub_dataset_name: Optional[str] = None,
|
||||
split: str = 'test',
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load the queries from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
sub_dataset_name (Optional[str]): Name of the sub-dataset. Defaults to ``None``.
|
||||
split (str, optional): Split of the dataset. Defaults to ``'dev'``.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of queries.
|
||||
"""
|
||||
qrels = self.load_qrels(dataset_name=dataset_name, sub_dataset_name=sub_dataset_name, split=split)
|
||||
|
||||
if dataset_name != 'cqadupstack':
|
||||
queries = datasets.load_dataset(
|
||||
'BeIR/{d}'.format(d=dataset_name),
|
||||
'queries',
|
||||
trust_remote_code=True,
|
||||
cache_dir=self.cache_dir,
|
||||
download_mode=self.hf_download_mode
|
||||
)['queries']
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, f"{split}_queries.jsonl")
|
||||
queries_dict = {}
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for data in tqdm(queries, desc="Loading and Saving queries"):
|
||||
qid, query = data['_id'], data['text']
|
||||
if qid not in qrels.keys(): continue
|
||||
_data = {
|
||||
"id": qid,
|
||||
"text": query
|
||||
}
|
||||
queries_dict[qid] = query
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} queries saved to {save_path}")
|
||||
else:
|
||||
queries_dict = {}
|
||||
for data in tqdm(queries, desc="Loading queries"):
|
||||
qid, query = data['_id'], data['text']
|
||||
if qid not in qrels.keys(): continue
|
||||
queries_dict[qid] = query
|
||||
else:
|
||||
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset_name)
|
||||
data_path = util.download_and_unzip(url, self.cache_dir)
|
||||
full_path = os.path.join(data_path, sub_dataset_name)
|
||||
_, queries, _ = GenericDataLoader(data_folder=full_path).load(split="test")
|
||||
if save_dir is not None:
|
||||
new_save_dir = os.path.join(save_dir, sub_dataset_name)
|
||||
os.makedirs(new_save_dir, exist_ok=True)
|
||||
save_path = os.path.join(new_save_dir, f"{split}_queries.jsonl")
|
||||
queries_dict = {}
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for qid in tqdm(queries.keys(), desc="Loading and Saving queries"):
|
||||
query = queries[qid]
|
||||
if qid not in qrels.keys(): continue
|
||||
_data = {
|
||||
"id": qid,
|
||||
"text": query
|
||||
}
|
||||
queries_dict[qid] = query
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} queries saved to {save_path}")
|
||||
else:
|
||||
queries_dict = {}
|
||||
for qid in tqdm(queries.keys(), desc="Loading queries"):
|
||||
query = queries[qid]
|
||||
if qid not in qrels.keys(): continue
|
||||
queries_dict[qid] = query
|
||||
return datasets.DatasetDict(queries_dict)
|
||||
|
||||
def load_corpus(self, dataset_name: Optional[str] = None, sub_dataset_name: Optional[str] = None) -> datasets.DatasetDict:
|
||||
"""Load the corpus from the dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
sub_dataset_name (Optional[str], optional): Name of the sub-dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
|
||||
"""
|
||||
if self.dataset_dir is not None:
|
||||
if dataset_name is None:
|
||||
save_dir = self.dataset_dir
|
||||
else:
|
||||
save_dir = os.path.join(self.dataset_dir, dataset_name)
|
||||
return self._load_local_corpus(save_dir, dataset_name=dataset_name, sub_dataset_name=sub_dataset_name)
|
||||
else:
|
||||
return self._load_remote_corpus(dataset_name=dataset_name, sub_dataset_name=sub_dataset_name)
|
||||
|
||||
def load_qrels(self, dataset_name: Optional[str] = None, sub_dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
|
||||
"""Load the qrels from the dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
sub_dataset_name (Optional[str], optional): Name of the sub-dataset. Defaults to ``None``.
|
||||
split (str, optional): The split to load relevance from. Defaults to ``'test'``.
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: A dict of relevance of query and document.
|
||||
"""
|
||||
if self.dataset_dir is not None:
|
||||
if dataset_name is None:
|
||||
save_dir = self.dataset_dir
|
||||
else:
|
||||
checked_dataset_names = self.check_dataset_names(dataset_name)
|
||||
if len(checked_dataset_names) == 0:
|
||||
raise ValueError(f"Dataset name {dataset_name} not found in the dataset.")
|
||||
dataset_name = checked_dataset_names[0]
|
||||
|
||||
save_dir = os.path.join(self.dataset_dir, dataset_name)
|
||||
|
||||
return self._load_local_qrels(save_dir, dataset_name=dataset_name, sub_dataset_name=sub_dataset_name, split=split)
|
||||
else:
|
||||
return self._load_remote_qrels(dataset_name=dataset_name, sub_dataset_name=sub_dataset_name, split=split)
|
||||
|
||||
def load_queries(self, dataset_name: Optional[str] = None, sub_dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
|
||||
"""Load the queries from the dataset.
|
||||
|
||||
Args:
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
sub_dataset_name (Optional[str], optional): Name of the sub-dataset. Defaults to ``None``.
|
||||
split (str, optional): The split to load queries from. Defaults to ``'test'``.
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: A dict of queries with id as key, query text as value.
|
||||
"""
|
||||
if self.dataset_dir is not None:
|
||||
if dataset_name is None:
|
||||
save_dir = self.dataset_dir
|
||||
else:
|
||||
checked_dataset_names = self.check_dataset_names(dataset_name)
|
||||
if len(checked_dataset_names) == 0:
|
||||
raise ValueError(f"Dataset name {dataset_name} not found in the dataset.")
|
||||
dataset_name = checked_dataset_names[0]
|
||||
|
||||
save_dir = os.path.join(self.dataset_dir, dataset_name)
|
||||
|
||||
return self._load_local_queries(save_dir, dataset_name=dataset_name, sub_dataset_name=sub_dataset_name, split=split)
|
||||
else:
|
||||
return self._load_remote_queries(dataset_name=dataset_name, sub_dataset_name=sub_dataset_name, split=split)
|
||||
|
||||
def _load_local_corpus(self, save_dir: str, dataset_name: Optional[str] = None, sub_dataset_name: Optional[str] = None) -> datasets.DatasetDict:
|
||||
"""Load corpus from local dataset.
|
||||
|
||||
Args:
|
||||
save_dir (str): Path to save the loaded corpus.
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
sub_dataset_name (Optional[str], optional): Name of the sub-dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
|
||||
"""
|
||||
if sub_dataset_name is None:
|
||||
corpus_path = os.path.join(save_dir, 'corpus.jsonl')
|
||||
else:
|
||||
corpus_path = os.path.join(save_dir, sub_dataset_name, 'corpus.jsonl')
|
||||
if self.force_redownload or not os.path.exists(corpus_path):
|
||||
logger.warning(f"Corpus not found in {corpus_path}. Trying to download the corpus from the remote and save it to {save_dir}.")
|
||||
return self._load_remote_corpus(dataset_name=dataset_name, save_dir=save_dir, sub_dataset_name=sub_dataset_name)
|
||||
else:
|
||||
if sub_dataset_name is not None:
|
||||
save_dir = os.path.join(save_dir, sub_dataset_name)
|
||||
corpus_data = datasets.load_dataset('json', data_files=corpus_path, cache_dir=self.cache_dir)['train']
|
||||
|
||||
corpus = {}
|
||||
for e in corpus_data:
|
||||
corpus[e['id']] = {'title': e.get('title', ""), 'text': e['text']}
|
||||
|
||||
return datasets.DatasetDict(corpus)
|
||||
|
||||
def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, sub_dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
|
||||
"""Load relevance from local dataset.
|
||||
|
||||
Args:
|
||||
save_dir (str): Path to save the loaded relevance.
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
sub_dataset_name (Optional[str], optional): Name of the sub-dataset. Defaults to ``None``.
|
||||
split (str, optional): Split to load from the local dataset. Defaults to ``'test'``.
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: A dict of relevance of query and document.
|
||||
"""
|
||||
checked_split = self.check_splits(split, dataset_name=dataset_name)
|
||||
if len(checked_split) == 0:
|
||||
raise ValueError(f"Split {split} not found in the dataset.")
|
||||
split = checked_split[0]
|
||||
|
||||
if sub_dataset_name is None:
|
||||
qrels_path = os.path.join(save_dir, f"{split}_qrels.jsonl")
|
||||
else:
|
||||
qrels_path = os.path.join(save_dir, sub_dataset_name, f"{split}_qrels.jsonl")
|
||||
if self.force_redownload or not os.path.exists(qrels_path):
|
||||
logger.warning(f"Qrels not found in {qrels_path}. Trying to download the qrels from the remote and save it to {save_dir}.")
|
||||
return self._load_remote_qrels(dataset_name=dataset_name, split=split, sub_dataset_name=sub_dataset_name, save_dir=save_dir)
|
||||
else:
|
||||
if sub_dataset_name is not None:
|
||||
save_dir = os.path.join(save_dir, sub_dataset_name)
|
||||
qrels_data = datasets.load_dataset('json', data_files=qrels_path, cache_dir=self.cache_dir)['train']
|
||||
|
||||
qrels = {}
|
||||
for data in qrels_data:
|
||||
qid = data['qid']
|
||||
if qid not in qrels:
|
||||
qrels[qid] = {}
|
||||
qrels[qid][data['docid']] = data['relevance']
|
||||
|
||||
return datasets.DatasetDict(qrels)
|
||||
|
||||
def _load_local_queries(self, save_dir: str, dataset_name: Optional[str] = None, sub_dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
|
||||
"""Load queries from local dataset.
|
||||
|
||||
Args:
|
||||
save_dir (str): Path to save the loaded queries.
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
sub_dataset_name (Optional[str], optional): Name of the sub-dataset. Defaults to ``None``.
|
||||
split (str, optional): Split to load from the local dataset. Defaults to ``'test'``.
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: A dict of queries with id as key, query text as value.
|
||||
"""
|
||||
checked_split = self.check_splits(split, dataset_name=dataset_name)
|
||||
if len(checked_split) == 0:
|
||||
raise ValueError(f"Split {split} not found in the dataset.")
|
||||
split = checked_split[0]
|
||||
|
||||
if sub_dataset_name is None:
|
||||
queries_path = os.path.join(save_dir, f"{split}_queries.jsonl")
|
||||
else:
|
||||
queries_path = os.path.join(save_dir, sub_dataset_name, f"{split}_queries.jsonl")
|
||||
if self.force_redownload or not os.path.exists(queries_path):
|
||||
logger.warning(f"Queries not found in {queries_path}. Trying to download the queries from the remote and save it to {save_dir}.")
|
||||
return self._load_remote_queries(dataset_name=dataset_name, split=split, sub_dataset_name=sub_dataset_name, save_dir=save_dir)
|
||||
else:
|
||||
if sub_dataset_name is not None:
|
||||
save_dir = os.path.join(save_dir, sub_dataset_name)
|
||||
queries_data = datasets.load_dataset('json', data_files=queries_path, cache_dir=self.cache_dir)['train']
|
||||
|
||||
queries = {e['id']: e['text'] for e in queries_data}
|
||||
return datasets.DatasetDict(queries)
|
||||
|
|
@ -0,0 +1,454 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, Optional, List, Union
|
||||
|
||||
from FlagEmbedding.abc.evaluation import AbsEvaluator, EvalRetriever, EvalReranker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BEIREvaluator(AbsEvaluator):
|
||||
"""
|
||||
Evaluator class of BEIR
|
||||
"""
|
||||
def check_data_info(
|
||||
self,
|
||||
data_info: Dict[str, str],
|
||||
model_name: str,
|
||||
reranker_name: str,
|
||||
split: str,
|
||||
dataset_name: Optional[str] = None,
|
||||
sub_dataset_name: Optional[str] = None,
|
||||
):
|
||||
"""Check the validity of data info.
|
||||
|
||||
Args:
|
||||
data_info (Dict[str, str]): The loaded data info to be check.
|
||||
model_name (str): Name of model used.
|
||||
reranker_name (str): Name of reranker used.
|
||||
split (str): Split used in searching.
|
||||
dataset_name (Optional[str], optional): Name of dataset used. Defaults to None.
|
||||
sub_dataset_name (Optional[str], optional): Name of the sub-dataset. Defaults to ``None``.
|
||||
|
||||
Raises:
|
||||
ValueError: eval_name mismatch
|
||||
ValueError: model_name or reranker_name mismatch
|
||||
ValueError: split mismatch
|
||||
ValueError: dataset_name mismatch
|
||||
ValueError: sub_dataset_name mismatch
|
||||
"""
|
||||
if data_info["eval_name"] != self.eval_name:
|
||||
raise ValueError(
|
||||
f'eval_name mismatch: {data_info["eval_name"]} vs {self.eval_name}'
|
||||
)
|
||||
if (
|
||||
data_info["model_name"] != model_name
|
||||
or data_info["reranker_name"] != reranker_name
|
||||
):
|
||||
raise ValueError(
|
||||
f'model_name or reranker_name mismatch: {data_info["model_name"]} vs {model_name} or {data_info["reranker_name"]} vs {reranker_name}'
|
||||
)
|
||||
if (data_info["split"] != split):
|
||||
raise ValueError(
|
||||
f'split mismatch: {data_info["split"]} vs {split}'
|
||||
)
|
||||
if dataset_name is not None and data_info["dataset_name"] != dataset_name:
|
||||
raise ValueError(
|
||||
f'dataset_name mismatch: {data_info["dataset_name"]} vs {dataset_name}'
|
||||
)
|
||||
if sub_dataset_name is not None and data_info["sub_dataset_name"] != sub_dataset_name:
|
||||
raise ValueError(
|
||||
f'sub_dataset_name mismatch: {data_info["sub_dataset_name"]} vs {sub_dataset_name}'
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
splits: Union[str, List[str]],
|
||||
search_results_save_dir: str,
|
||||
retriever: EvalRetriever,
|
||||
reranker: Optional[EvalReranker] = None,
|
||||
corpus_embd_save_dir: Optional[str] = None,
|
||||
ignore_identical_ids: bool = False,
|
||||
k_values: List[int] = [1, 3, 5, 10, 100, 1000],
|
||||
dataset_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
sub_dataset_name = None
|
||||
sub_dataset_names = self.data_loader.available_sub_dataset_names(dataset_name=dataset_name)
|
||||
# Check Splits
|
||||
checked_splits = self.data_loader.check_splits(splits, dataset_name=dataset_name)
|
||||
if len(checked_splits) == 0:
|
||||
logger.warning(f"{splits} not found in the dataset. Skipping evaluation.")
|
||||
return
|
||||
splits = checked_splits
|
||||
|
||||
if sub_dataset_names is None:
|
||||
if dataset_name is not None:
|
||||
save_name = f"{dataset_name}-" + "{split}.json"
|
||||
if corpus_embd_save_dir is not None:
|
||||
corpus_embd_save_dir = os.path.join(corpus_embd_save_dir, str(retriever), dataset_name)
|
||||
else:
|
||||
save_name = "{split}.json"
|
||||
|
||||
# Retrieval Stage
|
||||
no_reranker_search_results_save_dir = os.path.join(
|
||||
search_results_save_dir, str(retriever), "NoReranker"
|
||||
)
|
||||
os.makedirs(no_reranker_search_results_save_dir, exist_ok=True)
|
||||
|
||||
flag = False
|
||||
for split in splits:
|
||||
split_no_reranker_search_results_save_path = os.path.join(
|
||||
no_reranker_search_results_save_dir, save_name.format(split=split)
|
||||
)
|
||||
if not os.path.exists(split_no_reranker_search_results_save_path) or self.overwrite:
|
||||
flag = True
|
||||
break
|
||||
|
||||
no_reranker_search_results_dict = {}
|
||||
if flag:
|
||||
corpus = self.data_loader.load_corpus(dataset_name=dataset_name)
|
||||
|
||||
queries_dict = {
|
||||
split: self.data_loader.load_queries(dataset_name=dataset_name, split=split)
|
||||
for split in splits
|
||||
}
|
||||
|
||||
all_queries = {}
|
||||
for _, split_queries in queries_dict.items():
|
||||
all_queries.update(split_queries)
|
||||
|
||||
all_no_reranker_search_results = retriever(
|
||||
corpus=corpus,
|
||||
queries=all_queries,
|
||||
corpus_embd_save_dir=corpus_embd_save_dir,
|
||||
ignore_identical_ids=ignore_identical_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
for split in splits:
|
||||
split_queries = queries_dict[split]
|
||||
no_reranker_search_results_dict[split] = {
|
||||
qid: all_no_reranker_search_results[qid] for qid in split_queries
|
||||
}
|
||||
split_no_reranker_search_results_save_path = os.path.join(
|
||||
no_reranker_search_results_save_dir, save_name.format(split=split)
|
||||
)
|
||||
self.save_search_results(
|
||||
eval_name=self.eval_name,
|
||||
model_name=str(retriever),
|
||||
reranker_name="NoReranker",
|
||||
search_results=no_reranker_search_results_dict[split],
|
||||
output_path=split_no_reranker_search_results_save_path,
|
||||
split=split,
|
||||
dataset_name=dataset_name,
|
||||
sub_dataset_name=sub_dataset_name,
|
||||
)
|
||||
else:
|
||||
for split in splits:
|
||||
split_no_reranker_search_results_save_path = os.path.join(
|
||||
no_reranker_search_results_save_dir, save_name.format(split=split)
|
||||
)
|
||||
data_info, search_results = self.load_search_results(split_no_reranker_search_results_save_path)
|
||||
|
||||
self.check_data_info(
|
||||
data_info=data_info,
|
||||
model_name=str(retriever),
|
||||
reranker_name="NoReranker",
|
||||
split=split,
|
||||
dataset_name=dataset_name,
|
||||
sub_dataset_name=sub_dataset_name,
|
||||
)
|
||||
no_reranker_search_results_dict[split] = search_results
|
||||
retriever.stop_multi_process_pool()
|
||||
eval_results_save_path = os.path.join(no_reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
|
||||
if not os.path.exists(eval_results_save_path) or self.overwrite or flag:
|
||||
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir, k_values=k_values)
|
||||
self.output_eval_results_to_json(retriever_eval_results, eval_results_save_path)
|
||||
|
||||
# Reranking Stage
|
||||
if reranker is not None:
|
||||
reranker_search_results_save_dir = os.path.join(
|
||||
search_results_save_dir, str(retriever), str(reranker)
|
||||
)
|
||||
os.makedirs(reranker_search_results_save_dir, exist_ok=True)
|
||||
|
||||
corpus = self.data_loader.load_corpus(dataset_name=dataset_name)
|
||||
|
||||
queries_dict = {
|
||||
split: self.data_loader.load_queries(dataset_name=dataset_name, split=split)
|
||||
for split in splits
|
||||
}
|
||||
|
||||
flag = False
|
||||
for split in splits:
|
||||
rerank_search_results_save_path = os.path.join(
|
||||
reranker_search_results_save_dir, save_name.format(split=split)
|
||||
)
|
||||
|
||||
if os.path.exists(rerank_search_results_save_path) and not self.overwrite:
|
||||
continue
|
||||
|
||||
flag = True
|
||||
rerank_search_results = reranker(
|
||||
corpus=corpus,
|
||||
queries=queries_dict[split],
|
||||
search_results=no_reranker_search_results_dict[split],
|
||||
ignore_identical_ids=ignore_identical_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.save_search_results(
|
||||
eval_name=self.eval_name,
|
||||
model_name=str(retriever),
|
||||
reranker_name=str(reranker),
|
||||
search_results=rerank_search_results,
|
||||
output_path=rerank_search_results_save_path,
|
||||
split=split,
|
||||
dataset_name=dataset_name,
|
||||
sub_dataset_name=sub_dataset_name,
|
||||
)
|
||||
eval_results_save_path = os.path.join(reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
|
||||
if not os.path.exists(eval_results_save_path) or self.overwrite or flag:
|
||||
reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir, k_values=k_values)
|
||||
self.output_eval_results_to_json(reranker_eval_results, eval_results_save_path)
|
||||
else:
|
||||
for sub_dataset_name in sub_dataset_names:
|
||||
if dataset_name is not None:
|
||||
save_name = f"{dataset_name}-{sub_dataset_name}-" + "{split}.json"
|
||||
if corpus_embd_save_dir is not None:
|
||||
corpus_embd_save_dir = os.path.join(corpus_embd_save_dir, str(retriever), dataset_name, sub_dataset_name)
|
||||
else:
|
||||
save_name = f"{sub_dataset_name}-" + "{split}.json"
|
||||
|
||||
# Retrieval Stage
|
||||
no_reranker_search_results_save_dir = os.path.join(
|
||||
search_results_save_dir, str(retriever), "NoReranker"
|
||||
)
|
||||
os.makedirs(no_reranker_search_results_save_dir, exist_ok=True)
|
||||
|
||||
flag = False
|
||||
for split in splits:
|
||||
split_no_reranker_search_results_save_path = os.path.join(
|
||||
no_reranker_search_results_save_dir, save_name.format(split=split)
|
||||
)
|
||||
if not os.path.exists(split_no_reranker_search_results_save_path) or self.overwrite:
|
||||
flag = True
|
||||
break
|
||||
|
||||
no_reranker_search_results_dict = {}
|
||||
if flag:
|
||||
corpus = self.data_loader.load_corpus(dataset_name=dataset_name, sub_dataset_name=sub_dataset_name)
|
||||
|
||||
queries_dict = {
|
||||
split: self.data_loader.load_queries(dataset_name=dataset_name, sub_dataset_name=sub_dataset_name, split=split)
|
||||
for split in splits
|
||||
}
|
||||
|
||||
all_queries = {}
|
||||
for _, split_queries in queries_dict.items():
|
||||
all_queries.update(split_queries)
|
||||
|
||||
all_no_reranker_search_results = retriever(
|
||||
corpus=corpus,
|
||||
queries=all_queries,
|
||||
corpus_embd_save_dir=corpus_embd_save_dir,
|
||||
ignore_identical_ids=ignore_identical_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
for split in splits:
|
||||
split_queries = queries_dict[split]
|
||||
no_reranker_search_results_dict[split] = {
|
||||
qid: all_no_reranker_search_results[qid] for qid in split_queries
|
||||
}
|
||||
split_no_reranker_search_results_save_path = os.path.join(
|
||||
no_reranker_search_results_save_dir, save_name.format(split=split)
|
||||
)
|
||||
|
||||
self.save_search_results(
|
||||
eval_name=self.eval_name,
|
||||
model_name=str(retriever),
|
||||
reranker_name="NoReranker",
|
||||
search_results=no_reranker_search_results_dict[split],
|
||||
output_path=split_no_reranker_search_results_save_path,
|
||||
split=split,
|
||||
dataset_name=dataset_name,
|
||||
sub_dataset_name=sub_dataset_name,
|
||||
)
|
||||
else:
|
||||
for split in splits:
|
||||
split_no_reranker_search_results_save_path = os.path.join(
|
||||
no_reranker_search_results_save_dir, save_name.format(split=split)
|
||||
)
|
||||
data_info, search_results = self.load_search_results(split_no_reranker_search_results_save_path)
|
||||
|
||||
self.check_data_info(
|
||||
data_info=data_info,
|
||||
model_name=str(retriever),
|
||||
reranker_name="NoReranker",
|
||||
split=split,
|
||||
dataset_name=dataset_name,
|
||||
sub_dataset_name=sub_dataset_name,
|
||||
)
|
||||
no_reranker_search_results_dict[split] = search_results
|
||||
eval_results_save_path = os.path.join(no_reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
|
||||
if not os.path.exists(eval_results_save_path) or self.overwrite or flag:
|
||||
retriever_eval_results = self.evaluate_results(no_reranker_search_results_save_dir, k_values=k_values)
|
||||
self.output_eval_results_to_json(retriever_eval_results, eval_results_save_path)
|
||||
|
||||
# Reranking Stage
|
||||
if reranker is not None:
|
||||
reranker_search_results_save_dir = os.path.join(
|
||||
search_results_save_dir, str(retriever), str(reranker)
|
||||
)
|
||||
os.makedirs(reranker_search_results_save_dir, exist_ok=True)
|
||||
|
||||
corpus = self.data_loader.load_corpus(dataset_name=dataset_name, sub_dataset_name=sub_dataset_name)
|
||||
|
||||
queries_dict = {
|
||||
split: self.data_loader.load_queries(dataset_name=dataset_name, sub_dataset_name=sub_dataset_name, split=split)
|
||||
for split in splits
|
||||
}
|
||||
|
||||
flag = False
|
||||
for split in splits:
|
||||
rerank_search_results_save_path = os.path.join(
|
||||
reranker_search_results_save_dir, save_name.format(split=split)
|
||||
)
|
||||
|
||||
if os.path.exists(rerank_search_results_save_path) and not self.overwrite:
|
||||
continue
|
||||
|
||||
flag = True
|
||||
rerank_search_results = reranker(
|
||||
corpus=corpus,
|
||||
queries=queries_dict[split],
|
||||
search_results=no_reranker_search_results_dict[split],
|
||||
ignore_identical_ids=ignore_identical_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.save_search_results(
|
||||
eval_name=self.eval_name,
|
||||
model_name=str(retriever),
|
||||
reranker_name=str(reranker),
|
||||
search_results=rerank_search_results,
|
||||
output_path=rerank_search_results_save_path,
|
||||
split=split,
|
||||
dataset_name=dataset_name,
|
||||
sub_dataset_name=sub_dataset_name,
|
||||
)
|
||||
eval_results_save_path = os.path.join(reranker_search_results_save_dir, 'EVAL', 'eval_results.json')
|
||||
if not os.path.exists(eval_results_save_path) or self.overwrite or flag:
|
||||
reranker_eval_results = self.evaluate_results(reranker_search_results_save_dir, k_values=k_values)
|
||||
self.output_eval_results_to_json(reranker_eval_results, eval_results_save_path)
|
||||
if reranker is not None:
|
||||
reranker.stop_multi_process_pool()
|
||||
|
||||
def evaluate_results(
|
||||
self,
|
||||
search_results_save_dir: str,
|
||||
k_values: List[int] = [1, 3, 5, 10, 100, 1000]
|
||||
):
|
||||
"""Compute metrics according to the results in the directory.
|
||||
|
||||
Args:
|
||||
search_results_save_dir (str): Path to the search results.
|
||||
k_values (List[int], optional): Cutoffs. Defaults to :data:`[1, 3, 5, 10, 100, 1000]`.
|
||||
|
||||
Returns:
|
||||
dict: Evaluation results.
|
||||
"""
|
||||
eval_results_dict = {}
|
||||
cqadupstack_results = None
|
||||
cqadupstack_num = 0
|
||||
|
||||
for file in os.listdir(search_results_save_dir):
|
||||
if not file.endswith('.json'):
|
||||
continue
|
||||
|
||||
file_path = os.path.join(search_results_save_dir, file)
|
||||
data_info, search_results = self.load_search_results(file_path)
|
||||
|
||||
_eval_name = data_info['eval_name']
|
||||
assert _eval_name == self.eval_name, f'Mismatch eval_name: {_eval_name} vs {self.eval_name} in {file_path}'
|
||||
|
||||
split = data_info['split']
|
||||
dataset_name = data_info.get('dataset_name', None)
|
||||
sub_dataset_name = data_info.get('sub_dataset_name', None)
|
||||
qrels = self.data_loader.load_qrels(dataset_name=dataset_name, sub_dataset_name=sub_dataset_name, split=split)
|
||||
|
||||
eval_results = self.compute_metrics(
|
||||
qrels=qrels,
|
||||
search_results=search_results,
|
||||
k_values=k_values
|
||||
)
|
||||
|
||||
if dataset_name is not None:
|
||||
if sub_dataset_name is None:
|
||||
key = f"{dataset_name}-{split}"
|
||||
else:
|
||||
key = f"{dataset_name}-{sub_dataset_name}-{split}"
|
||||
else:
|
||||
if sub_dataset_name is None:
|
||||
key = split
|
||||
else:
|
||||
key = f"{sub_dataset_name}-{split}"
|
||||
if sub_dataset_name is None:
|
||||
eval_results_dict[key] = eval_results
|
||||
else:
|
||||
if cqadupstack_results is None:
|
||||
cqadupstack_results = eval_results
|
||||
cqadupstack_num += 1
|
||||
else:
|
||||
for k, v in eval_results.items():
|
||||
cqadupstack_results[k] += v
|
||||
cqadupstack_num += 1
|
||||
|
||||
if cqadupstack_num > 0:
|
||||
for k in cqadupstack_results.keys():
|
||||
cqadupstack_results[k] /= cqadupstack_num
|
||||
eval_results_dict['cqadupstack-test'] = cqadupstack_results
|
||||
|
||||
return eval_results_dict
|
||||
|
||||
def save_search_results(
|
||||
self,
|
||||
eval_name: str,
|
||||
model_name: str,
|
||||
reranker_name: str,
|
||||
search_results: Dict[str, Dict[str, float]],
|
||||
output_path: str,
|
||||
split: str,
|
||||
dataset_name: Optional[str] = None,
|
||||
sub_dataset_name: Optional[str] = None,
|
||||
):
|
||||
"""Save the metadata and search results into a file.
|
||||
|
||||
Args:
|
||||
eval_name (str): The experiment name of current evaluation.
|
||||
model_name (str): Name of model used.
|
||||
reranker_name (str): Name of reranker used.
|
||||
search_results (Dict[str, Dict[str, float]]): Dictionary of search results.
|
||||
output_path (str): Output path to write the results.
|
||||
split (str): Split used in searching.
|
||||
dataset_name (Optional[str], optional): Name of dataset used. Defaults to ``None``.
|
||||
sub_dataset_name (Optional[str], optional): Name of the sub-dataset. Defaults to ``None``.
|
||||
"""
|
||||
data = {
|
||||
"eval_name": eval_name,
|
||||
"model_name": model_name,
|
||||
"reranker_name": reranker_name,
|
||||
"split": split,
|
||||
"dataset_name": dataset_name,
|
||||
"sub_dataset_name": sub_dataset_name,
|
||||
"search_results": search_results,
|
||||
}
|
||||
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
BEIRInstructions = {
|
||||
'dbpedia-entity': 'Given a query, retrieve relevant entity descriptions from DBPedia.',
|
||||
'arguana': 'Given a claim, find documents that refute the claim.',
|
||||
'climate-fever': 'Given a claim about climate change, retrieve documents that support or refute the claim.',
|
||||
'cqadupstack': 'Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question.',
|
||||
'fever': 'Given a claim, retrieve documents that support or refute the claim.',
|
||||
'fiqa': 'Given a financial question, retrieve user replies that best answer the question.',
|
||||
'hotpotqa': 'Given a multi-hop question, retrieve documents that can help answer the question.',
|
||||
'msmarco': 'Given a web search query, retrieve relevant passages that answer the query.',
|
||||
'nfcorpus': 'Given a question, retrieve relevant documents that best answer the question.',
|
||||
'nq': 'Given a question, retrieve Wikipedia passages that answer the question.',
|
||||
'quora': 'Given a question, retrieve questions that are semantically equivalent to the given question.',
|
||||
'scidocs': 'Given a scientific paper title, retrieve paper abstracts that are cited by the given paper.',
|
||||
'scifact': 'Given a scientific claim, retrieve documents that support or refute the claim.',
|
||||
'webis-touche2020': 'Given a question, retrieve detailed and persuasive arguments that answer the question.',
|
||||
'trec-covid': 'Given a query on COVID-19, retrieve documents that answer the query.',
|
||||
}
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
import logging
|
||||
from FlagEmbedding.abc.evaluation import AbsEvalRunner
|
||||
|
||||
from .data_loader import BEIREvalDataLoader
|
||||
from .prompts import BEIRInstructions
|
||||
from .evaluator import BEIREvaluator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BEIREvalRunner(AbsEvalRunner):
|
||||
"""
|
||||
Runner class of BEIR evaluation.
|
||||
"""
|
||||
def run(self):
|
||||
"""
|
||||
Run the whole evaluation.
|
||||
"""
|
||||
if self.eval_args.dataset_names is None:
|
||||
dataset_names = self.data_loader.available_dataset_names()
|
||||
else:
|
||||
dataset_names = self.data_loader.check_dataset_names(self.eval_args.dataset_names)
|
||||
|
||||
if len(dataset_names) == 0:
|
||||
logger.info(f"Running {self.eval_args.eval_name} evaluation on the default dataset.")
|
||||
self.evaluator(
|
||||
splits=self.eval_args.splits,
|
||||
search_results_save_dir=self.eval_args.output_dir,
|
||||
retriever=self.retriever,
|
||||
reranker=self.reranker,
|
||||
corpus_embd_save_dir=self.eval_args.corpus_embd_save_dir,
|
||||
ignore_identical_ids=self.eval_args.ignore_identical_ids,
|
||||
k_values=self.eval_args.k_values
|
||||
)
|
||||
logger.info(f"{self.eval_args.eval_name} evaluation completed.")
|
||||
else:
|
||||
logger.info(f"Running {self.eval_args.eval_name} evaluation on the following dataset names: {dataset_names}")
|
||||
for dataset_name in dataset_names:
|
||||
if self.eval_args.use_special_instructions:
|
||||
self.retriever.stop_multi_process_pool()
|
||||
self.retriever.embedder.query_instruction_for_retrieval = BEIRInstructions[dataset_name]
|
||||
logger.info(f"Running {self.eval_args.eval_name} evaluation on: {dataset_name}")
|
||||
self.evaluator(
|
||||
splits=self.eval_args.splits,
|
||||
search_results_save_dir=self.eval_args.output_dir,
|
||||
retriever=self.retriever,
|
||||
reranker=self.reranker,
|
||||
corpus_embd_save_dir=self.eval_args.corpus_embd_save_dir,
|
||||
ignore_identical_ids=self.eval_args.ignore_identical_ids,
|
||||
k_values=self.eval_args.k_values,
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
logger.info(f"{self.eval_args.eval_name} evaluation on {dataset_names} completed.")
|
||||
|
||||
logger.info("Start computing metrics.")
|
||||
self.evaluate_metrics(
|
||||
search_results_save_dir=self.eval_args.output_dir,
|
||||
output_method=self.eval_args.eval_output_method,
|
||||
output_path=self.eval_args.eval_output_path,
|
||||
metrics=self.eval_args.eval_metrics
|
||||
)
|
||||
|
||||
def load_data_loader(self) -> BEIREvalDataLoader:
|
||||
"""Load the data loader
|
||||
|
||||
Returns:
|
||||
BEIREvalDataLoader: BEIR data loader object.
|
||||
"""
|
||||
data_loader = BEIREvalDataLoader(
|
||||
eval_name=self.eval_args.eval_name,
|
||||
dataset_dir=self.eval_args.dataset_dir,
|
||||
cache_dir=self.eval_args.cache_path,
|
||||
token=self.eval_args.token,
|
||||
force_redownload=self.eval_args.force_redownload,
|
||||
)
|
||||
return data_loader
|
||||
|
||||
def load_evaluator(self) -> BEIREvaluator:
|
||||
"""Load the evaluator for evaluation
|
||||
|
||||
Returns:
|
||||
BEIREvaluator: The BEIR evaluator to run the evaluation.
|
||||
"""
|
||||
evaluator = BEIREvaluator(
|
||||
eval_name=self.eval_args.eval_name,
|
||||
data_loader=self.data_loader,
|
||||
overwrite=self.eval_args.overwrite,
|
||||
)
|
||||
return evaluator
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
from FlagEmbedding.abc.evaluation import (
|
||||
AbsEvalArgs as CustomEvalArgs,
|
||||
AbsEvalModelArgs as CustomEvalModelArgs,
|
||||
)
|
||||
|
||||
from .data_loader import CustomEvalDataLoader
|
||||
from .runner import CustomEvalRunner
|
||||
|
||||
__all__ = [
|
||||
"CustomEvalArgs",
|
||||
"CustomEvalModelArgs",
|
||||
"CustomEvalRunner",
|
||||
"CustomEvalDataLoader",
|
||||
]
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
from transformers import HfArgumentParser
|
||||
|
||||
from FlagEmbedding.evaluation.custom import (
|
||||
CustomEvalArgs, CustomEvalModelArgs,
|
||||
CustomEvalRunner
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((
|
||||
CustomEvalArgs,
|
||||
CustomEvalModelArgs
|
||||
))
|
||||
|
||||
eval_args, model_args = parser.parse_args_into_dataclasses()
|
||||
eval_args: CustomEvalArgs
|
||||
model_args: CustomEvalModelArgs
|
||||
|
||||
runner = CustomEvalRunner(
|
||||
eval_args=eval_args,
|
||||
model_args=model_args
|
||||
)
|
||||
|
||||
runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
import logging
|
||||
from tqdm import tqdm
|
||||
from typing import List, Optional
|
||||
|
||||
from FlagEmbedding.abc.evaluation import AbsEvalDataLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomEvalDataLoader(AbsEvalDataLoader):
|
||||
def available_dataset_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def available_splits(self, dataset_name: Optional[str] = None) -> List[str]:
|
||||
return ["test"]
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from FlagEmbedding.abc.evaluation import AbsEvalRunner
|
||||
|
||||
from .data_loader import CustomEvalDataLoader
|
||||
|
||||
|
||||
class CustomEvalRunner(AbsEvalRunner):
|
||||
def load_data_loader(self) -> CustomEvalDataLoader:
|
||||
data_loader = CustomEvalDataLoader(
|
||||
eval_name=self.eval_args.eval_name,
|
||||
dataset_dir=self.eval_args.dataset_dir,
|
||||
cache_dir=self.eval_args.cache_path,
|
||||
token=self.eval_args.token,
|
||||
force_redownload=self.eval_args.force_redownload,
|
||||
)
|
||||
return data_loader
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
from FlagEmbedding.abc.evaluation import (
|
||||
AbsEvalArgs as MIRACLEvalArgs,
|
||||
AbsEvalModelArgs as MIRACLEvalModelArgs,
|
||||
)
|
||||
|
||||
from .data_loader import MIRACLEvalDataLoader
|
||||
from .runner import MIRACLEvalRunner
|
||||
|
||||
__all__ = [
|
||||
"MIRACLEvalArgs",
|
||||
"MIRACLEvalModelArgs",
|
||||
"MIRACLEvalRunner",
|
||||
"MIRACLEvalDataLoader",
|
||||
]
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
from transformers import HfArgumentParser
|
||||
|
||||
from FlagEmbedding.evaluation.miracl import (
|
||||
MIRACLEvalArgs, MIRACLEvalModelArgs,
|
||||
MIRACLEvalRunner
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((
|
||||
MIRACLEvalArgs,
|
||||
MIRACLEvalModelArgs
|
||||
))
|
||||
|
||||
eval_args, model_args = parser.parse_args_into_dataclasses()
|
||||
eval_args: MIRACLEvalArgs
|
||||
model_args: MIRACLEvalModelArgs
|
||||
|
||||
runner = MIRACLEvalRunner(
|
||||
eval_args=eval_args,
|
||||
model_args=model_args
|
||||
)
|
||||
|
||||
runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,179 @@
|
|||
import os
|
||||
import json
|
||||
import logging
|
||||
import datasets
|
||||
from tqdm import tqdm
|
||||
from typing import List, Optional
|
||||
|
||||
from FlagEmbedding.abc.evaluation import AbsEvalDataLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MIRACLEvalDataLoader(AbsEvalDataLoader):
|
||||
"""
|
||||
Data loader class for MIRACL.
|
||||
"""
|
||||
def available_dataset_names(self) -> List[str]:
|
||||
"""
|
||||
Get the available dataset names.
|
||||
|
||||
Returns:
|
||||
List[str]: All the available dataset names.
|
||||
"""
|
||||
return ["ar", "bn", "en", "es", "fa", "fi", "fr", "hi", "id", "ja", "ko", "ru", "sw", "te", "th", "zh", "de", "yo"]
|
||||
|
||||
def available_splits(self, dataset_name: str) -> List[str]:
|
||||
"""
|
||||
Get the avaialble splits.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Dataset name.
|
||||
|
||||
Returns:
|
||||
List[str]: All the available splits for the dataset.
|
||||
"""
|
||||
if dataset_name in ["de", "yo"]:
|
||||
return ["dev"]
|
||||
else:
|
||||
return ["train", "dev"]
|
||||
|
||||
def _load_remote_corpus(
|
||||
self,
|
||||
dataset_name: str,
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load the corpus dataset from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of corpus.
|
||||
"""
|
||||
corpus = datasets.load_dataset(
|
||||
"miracl/miracl-corpus", dataset_name,
|
||||
cache_dir=self.cache_dir,
|
||||
trust_remote_code=True,
|
||||
download_mode=self.hf_download_mode
|
||||
)["train"]
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, "corpus.jsonl")
|
||||
corpus_dict = {}
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for data in tqdm(corpus, desc="Loading and Saving corpus"):
|
||||
docid, title, text = str(data["docid"]), data["title"], data["text"]
|
||||
_data = {
|
||||
"id": docid,
|
||||
"title": title,
|
||||
"text": text
|
||||
}
|
||||
corpus_dict[docid] = {
|
||||
"title": title,
|
||||
"text": text
|
||||
}
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} corpus saved to {save_path}")
|
||||
else:
|
||||
corpus_dict = {str(data["docid"]): {"title": data["title"], "text": data["text"]} for data in tqdm(corpus, desc="Loading corpus")}
|
||||
return datasets.DatasetDict(corpus_dict)
|
||||
|
||||
def _load_remote_qrels(
|
||||
self,
|
||||
dataset_name: str,
|
||||
split: str = 'dev',
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load the qrels from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
split (str, optional): Split of the dataset. Defaults to ``'dev'``.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of qrel.
|
||||
"""
|
||||
endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/miracl/miracl"
|
||||
qrels_download_url = f"{endpoint}/resolve/main/miracl-v1.0-{dataset_name}/qrels/qrels.miracl-v1.0-{dataset_name}-{split}.tsv"
|
||||
|
||||
qrels_save_path = self._download_file(qrels_download_url, self.cache_dir)
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, f"{split}_qrels.jsonl")
|
||||
qrels_dict = {}
|
||||
with open(save_path, "w", encoding="utf-8") as f1:
|
||||
with open(qrels_save_path, "r", encoding="utf-8") as f2:
|
||||
for line in tqdm(f2.readlines(), desc="Loading and Saving qrels"):
|
||||
qid, _, docid, rel = line.strip().split("\t")
|
||||
qid, docid, rel = str(qid), str(docid), int(rel)
|
||||
_data = {
|
||||
"qid": qid,
|
||||
"docid": docid,
|
||||
"relevance": rel
|
||||
}
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
qrels_dict[qid][docid] = rel
|
||||
f1.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} qrels saved to {save_path}")
|
||||
else:
|
||||
qrels_dict = {}
|
||||
with open(qrels_save_path, "r", encoding="utf-8") as f:
|
||||
for line in tqdm(f.readlines(), desc="Loading qrels"):
|
||||
qid, _, docid, rel = line.strip().split("\t")
|
||||
qid, docid, rel = str(qid), str(docid), int(rel)
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
qrels_dict[qid][docid] = rel
|
||||
return datasets.DatasetDict(qrels_dict)
|
||||
|
||||
def _load_remote_queries(
|
||||
self,
|
||||
dataset_name: str,
|
||||
split: str = 'dev',
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load the queries from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
split (str, optional): Split of the dataset. Defaults to ``'dev'``.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of queries.
|
||||
"""
|
||||
endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/miracl/miracl"
|
||||
queries_download_url = f"{endpoint}/resolve/main/miracl-v1.0-{dataset_name}/topics/topics.miracl-v1.0-{dataset_name}-{split}.tsv"
|
||||
|
||||
queries_save_path = self._download_file(queries_download_url, self.cache_dir)
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, f"{split}_queries.jsonl")
|
||||
queries_dict = {}
|
||||
with open(save_path, "w", encoding="utf-8") as f1:
|
||||
with open(queries_save_path, "r", encoding="utf-8") as f2:
|
||||
for line in tqdm(f2.readlines(), desc="Loading and Saving queries"):
|
||||
qid, query = line.strip().split("\t")
|
||||
qid = str(qid)
|
||||
_data = {
|
||||
"id": qid,
|
||||
"text": query
|
||||
}
|
||||
queries_dict[qid] = query
|
||||
f1.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} queries saved to {save_path}")
|
||||
else:
|
||||
queries_dict = {}
|
||||
with open(queries_save_path, "r", encoding="utf-8") as f:
|
||||
for line in tqdm(f.readlines(), desc="Loading queries"):
|
||||
qid, query = line.strip().split("\t")
|
||||
qid = str(qid)
|
||||
queries_dict[qid] = query
|
||||
return datasets.DatasetDict(queries_dict)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
from FlagEmbedding.abc.evaluation import AbsEvalRunner
|
||||
|
||||
from .data_loader import MIRACLEvalDataLoader
|
||||
|
||||
|
||||
class MIRACLEvalRunner(AbsEvalRunner):
|
||||
"""
|
||||
Evaluation runner of MIRACL.
|
||||
"""
|
||||
def load_data_loader(self) -> MIRACLEvalDataLoader:
|
||||
"""Load the data loader instance by args.
|
||||
|
||||
Returns:
|
||||
MIRACLEvalDataLoader: The MIRACL data loader instance.
|
||||
"""
|
||||
data_loader = MIRACLEvalDataLoader(
|
||||
eval_name=self.eval_args.eval_name,
|
||||
dataset_dir=self.eval_args.dataset_dir,
|
||||
cache_dir=self.eval_args.cache_path,
|
||||
token=self.eval_args.token,
|
||||
force_redownload=self.eval_args.force_redownload,
|
||||
)
|
||||
return data_loader
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
from FlagEmbedding.abc.evaluation import (
|
||||
AbsEvalArgs as MKQAEvalArgs,
|
||||
AbsEvalModelArgs as MKQAEvalModelArgs,
|
||||
)
|
||||
|
||||
from .data_loader import MKQAEvalDataLoader
|
||||
from .evaluator import MKQAEvaluator
|
||||
from .runner import MKQAEvalRunner
|
||||
|
||||
__all__ = [
|
||||
"MKQAEvalArgs",
|
||||
"MKQAEvalModelArgs",
|
||||
"MKQAEvalRunner",
|
||||
"MKQAEvalDataLoader",
|
||||
"MKQAEvaluator"
|
||||
]
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
from transformers import HfArgumentParser
|
||||
|
||||
from FlagEmbedding.evaluation.mkqa import (
|
||||
MKQAEvalArgs, MKQAEvalModelArgs,
|
||||
MKQAEvalRunner
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((
|
||||
MKQAEvalArgs,
|
||||
MKQAEvalModelArgs
|
||||
))
|
||||
|
||||
eval_args, model_args = parser.parse_args_into_dataclasses()
|
||||
eval_args: MKQAEvalArgs
|
||||
model_args: MKQAEvalModelArgs
|
||||
|
||||
runner = MKQAEvalRunner(
|
||||
eval_args=eval_args,
|
||||
model_args=model_args
|
||||
)
|
||||
|
||||
runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,228 @@
|
|||
import os
|
||||
import json
|
||||
import logging
|
||||
import datasets
|
||||
from tqdm import tqdm
|
||||
from typing import List, Optional
|
||||
|
||||
from FlagEmbedding.abc.evaluation import AbsEvalDataLoader
|
||||
|
||||
from .utils.normalize_text import normalize_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MKQAEvalDataLoader(AbsEvalDataLoader):
|
||||
"""
|
||||
Data loader class for MKQA.
|
||||
"""
|
||||
def available_dataset_names(self) -> List[str]:
|
||||
"""
|
||||
Get the available dataset names.
|
||||
|
||||
Returns:
|
||||
List[str]: All the available dataset names.
|
||||
"""
|
||||
return ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw']
|
||||
|
||||
def available_splits(self, dataset_name: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Get the avaialble splits.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Dataset name.
|
||||
|
||||
Returns:
|
||||
List[str]: All the available splits for the dataset.
|
||||
"""
|
||||
return ["test"]
|
||||
|
||||
def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDict:
|
||||
"""Load the corpus.
|
||||
|
||||
Args:
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of corpus.
|
||||
"""
|
||||
if self.dataset_dir is not None:
|
||||
# same corpus for all languages
|
||||
save_dir = self.dataset_dir
|
||||
return self._load_local_corpus(save_dir, dataset_name=dataset_name)
|
||||
else:
|
||||
return self._load_remote_corpus(dataset_name=dataset_name)
|
||||
|
||||
def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
|
||||
"""Try to load qrels from local datasets.
|
||||
|
||||
Args:
|
||||
save_dir (str): Directory that save the data files.
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
split (str, optional): Split of the dataset. Defaults to ``'test'``.
|
||||
|
||||
Raises:
|
||||
ValueError: No local qrels found, will try to download from remote.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of qrels.
|
||||
"""
|
||||
checked_split = self.check_splits(split)
|
||||
if len(checked_split) == 0:
|
||||
raise ValueError(f"Split {split} not found in the dataset.")
|
||||
split = checked_split[0]
|
||||
|
||||
qrels_path = os.path.join(save_dir, f"{split}_qrels.jsonl")
|
||||
if self.force_redownload or not os.path.exists(qrels_path):
|
||||
logger.warning(f"Qrels not found in {qrels_path}. Trying to download the qrels from the remote and save it to {save_dir}.")
|
||||
return self._load_remote_qrels(dataset_name=dataset_name, split=split, save_dir=save_dir)
|
||||
else:
|
||||
qrels_data = datasets.load_dataset('json', data_files=qrels_path, cache_dir=self.cache_dir)['train']
|
||||
|
||||
qrels = {}
|
||||
for data in qrels_data:
|
||||
qid = data['qid']
|
||||
qrels[qid] = data['answers']
|
||||
|
||||
return datasets.DatasetDict(qrels)
|
||||
|
||||
def _load_remote_corpus(
|
||||
self,
|
||||
dataset_name: Optional[str] = None,
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""
|
||||
Refer to: https://arxiv.org/pdf/2402.03216. We use the corpus from the BeIR dataset.
|
||||
"""
|
||||
corpus = datasets.load_dataset(
|
||||
"BeIR/nq", "corpus",
|
||||
cache_dir=self.cache_dir,
|
||||
trust_remote_code=True,
|
||||
download_mode=self.hf_download_mode
|
||||
)["corpus"]
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, "corpus.jsonl")
|
||||
corpus_dict = {}
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for data in tqdm(corpus, desc="Loading and Saving corpus"):
|
||||
docid, title, text = str(data["_id"]), normalize_text(data["title"]).lower(), normalize_text(data["text"]).lower()
|
||||
_data = {
|
||||
"id": docid,
|
||||
"title": title,
|
||||
"text": text
|
||||
}
|
||||
corpus_dict[docid] = {
|
||||
"title": title,
|
||||
"text": text
|
||||
}
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} corpus saved to {save_path}")
|
||||
else:
|
||||
corpus_dict = {}
|
||||
for data in tqdm(corpus, desc="Loading corpus"):
|
||||
docid, title, text = str(data["_id"]), normalize_text(data["title"]), normalize_text(data["text"])
|
||||
corpus_dict[docid] = {
|
||||
"title": title,
|
||||
"text": text
|
||||
}
|
||||
return datasets.DatasetDict(corpus_dict)
|
||||
|
||||
def _load_remote_qrels(
|
||||
self,
|
||||
dataset_name: str,
|
||||
split: str = 'test',
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load remote qrels from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
split (str, optional): Split of the dataset. Defaults to ``'test'``.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of qrel.
|
||||
"""
|
||||
endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/Shitao/bge-m3-data"
|
||||
queries_download_url = f"{endpoint}/resolve/main/MKQA_test-data.zip"
|
||||
|
||||
qrels_save_dir = self._download_zip_file(queries_download_url, self.cache_dir)
|
||||
qrels_save_path = os.path.join(qrels_save_dir, f"{dataset_name}.jsonl")
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, f"{split}_qrels.jsonl")
|
||||
qrels_dict = {}
|
||||
with open(save_path, "w", encoding="utf-8") as f1:
|
||||
with open(qrels_save_path, "r", encoding="utf-8") as f2:
|
||||
for line in tqdm(f2.readlines(), desc="Loading and Saving qrels"):
|
||||
data = json.loads(line)
|
||||
qid, answers = str(data["id"]), data["answers"]
|
||||
_data = {
|
||||
"qid": qid,
|
||||
"answers": answers
|
||||
}
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
qrels_dict[qid] = answers
|
||||
f1.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} qrels saved to {save_path}")
|
||||
else:
|
||||
qrels_dict = {}
|
||||
with open(qrels_save_path, "r", encoding="utf-8") as f:
|
||||
for line in tqdm(f.readlines(), desc="Loading qrels"):
|
||||
data = json.loads(line)
|
||||
qid, answers = str(data["id"]), data["answers"]
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
qrels_dict[qid] = answers
|
||||
return datasets.DatasetDict(qrels_dict)
|
||||
|
||||
def _load_remote_queries(
|
||||
self,
|
||||
dataset_name: str,
|
||||
split: str = 'test',
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load the queries from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
split (str, optional): Split of the dataset. Defaults to ``'test'``.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of queries.
|
||||
"""
|
||||
endpoint = f"{os.getenv('HF_ENDPOINT', 'https://huggingface.co')}/datasets/Shitao/bge-m3-data"
|
||||
queries_download_url = f"{endpoint}/resolve/main/MKQA_test-data.zip"
|
||||
|
||||
queries_save_dir = self._download_zip_file(queries_download_url, self.cache_dir)
|
||||
queries_save_path = os.path.join(queries_save_dir, f"{dataset_name}.jsonl")
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, f"{split}_queries.jsonl")
|
||||
queries_dict = {}
|
||||
with open(save_path, "w", encoding="utf-8") as f1:
|
||||
with open(queries_save_path, "r", encoding="utf-8") as f2:
|
||||
for line in tqdm(f2.readlines(), desc="Loading and Saving queries"):
|
||||
data = json.loads(line)
|
||||
qid, query = str(data["id"]), data["question"]
|
||||
_data = {
|
||||
"id": qid,
|
||||
"text": query
|
||||
}
|
||||
queries_dict[qid] = query
|
||||
f1.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} queries saved to {save_path}")
|
||||
else:
|
||||
queries_dict = {}
|
||||
with open(queries_save_path, "r", encoding="utf-8") as f:
|
||||
for line in tqdm(f.readlines(), desc="Loading queries"):
|
||||
data = json.loads(line)
|
||||
qid, query = str(data["id"]), data["question"]
|
||||
queries_dict[qid] = query
|
||||
return datasets.DatasetDict(queries_dict)
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
import os
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from FlagEmbedding.abc.evaluation import AbsEvaluator
|
||||
|
||||
from .utils.compute_metrics import evaluate_qa_recall
|
||||
|
||||
|
||||
class MKQAEvaluator(AbsEvaluator):
|
||||
"""
|
||||
The evaluator class of MKQA.
|
||||
"""
|
||||
def get_corpus_embd_save_dir(
|
||||
self,
|
||||
retriever_name: str,
|
||||
corpus_embd_save_dir: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None
|
||||
):
|
||||
"""Get the directory to save the corpus embedding.
|
||||
|
||||
Args:
|
||||
retriever_name (str): Name of the retriever.
|
||||
corpus_embd_save_dir (Optional[str], optional): Directory to save the corpus embedding. Defaults to ``None``.
|
||||
dataset_name (Optional[str], optional): Name of the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
str: The final directory to save the corpus embedding.
|
||||
"""
|
||||
if corpus_embd_save_dir is not None:
|
||||
# Save the corpus embeddings in the same directory for all dataset_name
|
||||
corpus_embd_save_dir = os.path.join(corpus_embd_save_dir, retriever_name)
|
||||
return corpus_embd_save_dir
|
||||
|
||||
def evaluate_results(
|
||||
self,
|
||||
search_results_save_dir: str,
|
||||
k_values: List[int] = [1, 3, 5, 10, 100, 1000]
|
||||
):
|
||||
"""Compute the metrics and get the eval results.
|
||||
|
||||
Args:
|
||||
search_results_save_dir (str): Directory that saves the search results.
|
||||
k_values (List[int], optional): Cutoffs. Defaults to ``[1, 3, 5, 10, 100, 1000]``.
|
||||
|
||||
Returns:
|
||||
dict: The evaluation results.
|
||||
"""
|
||||
eval_results_dict = {}
|
||||
|
||||
corpus = self.data_loader.load_corpus()
|
||||
corpus_dict = {}
|
||||
for docid, data in tqdm(corpus.items(), desc="Loading corpus for evaluation"):
|
||||
title, text = data["title"], data["text"]
|
||||
corpus_dict[docid] = f"{title} {text}".strip()
|
||||
|
||||
for file in os.listdir(search_results_save_dir):
|
||||
if not file.endswith('.json'):
|
||||
continue
|
||||
|
||||
file_path = os.path.join(search_results_save_dir, file)
|
||||
data_info, search_results = self.load_search_results(file_path)
|
||||
|
||||
_eval_name = data_info['eval_name']
|
||||
assert _eval_name == self.eval_name, f'Mismatch eval_name: {_eval_name} vs {self.eval_name} in {file_path}'
|
||||
|
||||
split = data_info['split']
|
||||
dataset_name = data_info.get('dataset_name', None)
|
||||
qrels = self.data_loader.load_qrels(dataset_name=dataset_name, split=split)
|
||||
|
||||
eval_results = self.compute_metrics(
|
||||
corpus_dict=corpus_dict,
|
||||
qrels=qrels,
|
||||
search_results=search_results,
|
||||
k_values=k_values
|
||||
)
|
||||
|
||||
if dataset_name is not None:
|
||||
key = f"{dataset_name}-{split}"
|
||||
else:
|
||||
key = split
|
||||
eval_results_dict[key] = eval_results
|
||||
|
||||
return eval_results_dict
|
||||
|
||||
@staticmethod
|
||||
def compute_metrics(
|
||||
corpus_dict: Dict[str, str],
|
||||
qrels: Dict[str, List[str]],
|
||||
search_results: Dict[str, Dict[str, float]],
|
||||
k_values: List[int],
|
||||
):
|
||||
"""
|
||||
Compute Recall@k for QA task. The definition of recall in QA task is different from the one in IR task. Please refer to the paper of RocketQA: https://aclanthology.org/2021.naacl-main.466.pdf.
|
||||
|
||||
Args:
|
||||
corpus_dict (Dict[str, str]): Dictionary of the corpus with doc id and contents.
|
||||
qrels (Dict[str, List[str]]): Relevances of queries and passage.
|
||||
search_results (Dict[str, Dict[str, float]]): Search results of the model to evaluate.
|
||||
|
||||
Returns:
|
||||
dict: The model's scores of the metrics.
|
||||
"""
|
||||
contexts = []
|
||||
answers = []
|
||||
top_k = max(k_values)
|
||||
for qid, doc_score_dict in search_results.items():
|
||||
doc_score_pair = sorted(doc_score_dict.items(), key=lambda x: x[1], reverse=True)
|
||||
_ctxs = [corpus_dict[docid] for docid, _ in doc_score_pair[:top_k]]
|
||||
contexts.append(_ctxs)
|
||||
answers.append(qrels[qid])
|
||||
|
||||
recall = evaluate_qa_recall(contexts, answers, k_values=k_values)
|
||||
scores = {f"qa_recall_at_{k}": v for k, v in zip(k_values, recall)}
|
||||
|
||||
return scores
|
||||
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
from FlagEmbedding.abc.evaluation import AbsEvalRunner
|
||||
|
||||
from .data_loader import MKQAEvalDataLoader
|
||||
from .evaluator import MKQAEvaluator
|
||||
|
||||
|
||||
class MKQAEvalRunner(AbsEvalRunner):
|
||||
"""
|
||||
Evaluation runner of MKQA.
|
||||
"""
|
||||
def load_data_loader(self) -> MKQAEvalDataLoader:
|
||||
"""Load the data loader instance by args.
|
||||
|
||||
Returns:
|
||||
MKQAEvalDataLoader: The MKQA data loader instance.
|
||||
"""
|
||||
data_loader = MKQAEvalDataLoader(
|
||||
eval_name=self.eval_args.eval_name,
|
||||
dataset_dir=self.eval_args.dataset_dir,
|
||||
cache_dir=self.eval_args.cache_path,
|
||||
token=self.eval_args.token,
|
||||
force_redownload=self.eval_args.force_redownload,
|
||||
)
|
||||
return data_loader
|
||||
|
||||
def load_evaluator(self) -> MKQAEvaluator:
|
||||
"""Load the evaluator instance by args.
|
||||
|
||||
Returns:
|
||||
MKQAEvaluator: The MKQA evaluator instance.
|
||||
"""
|
||||
evaluator = MKQAEvaluator(
|
||||
eval_name=self.eval_args.eval_name,
|
||||
data_loader=self.data_loader,
|
||||
overwrite=self.eval_args.overwrite,
|
||||
)
|
||||
return evaluator
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
"""
|
||||
Ref: https://github.com/facebookresearch/contriever
|
||||
"""
|
||||
import regex
|
||||
import unicodedata
|
||||
from functools import partial
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
class SimpleTokenizer:
|
||||
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
|
||||
NON_WS = r'[^\p{Z}\p{C}]'
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Args:
|
||||
annotators: None or empty set (only tokenizes).
|
||||
"""
|
||||
self._regexp = regex.compile(
|
||||
'(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
|
||||
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
|
||||
)
|
||||
|
||||
def tokenize(self, text, uncased=False):
|
||||
matches = [m for m in self._regexp.finditer(text)]
|
||||
if uncased:
|
||||
tokens = [m.group().lower() for m in matches]
|
||||
else:
|
||||
tokens = [m.group() for m in matches]
|
||||
return tokens
|
||||
|
||||
|
||||
def _normalize(text):
|
||||
return unicodedata.normalize('NFD', text)
|
||||
|
||||
|
||||
def has_answer(answers, text, tokenizer) -> bool:
|
||||
"""Check if a document contains an answer string."""
|
||||
text = _normalize(text)
|
||||
text = tokenizer.tokenize(text, uncased=True)
|
||||
|
||||
for answer in answers:
|
||||
answer = _normalize(answer)
|
||||
answer = tokenizer.tokenize(answer, uncased=True)
|
||||
for i in range(0, len(text) - len(answer) + 1):
|
||||
if answer == text[i: i + len(answer)]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_answer(example, tokenizer) -> List[bool]:
|
||||
"""Search through all the top docs to see if they have any of the answers."""
|
||||
answers = example['answers']
|
||||
ctxs = example['ctxs']
|
||||
|
||||
hits = []
|
||||
for i, text in enumerate(ctxs):
|
||||
if text is None: # cannot find the document for some reason
|
||||
hits.append(False)
|
||||
continue
|
||||
hits.append(has_answer(answers, text, tokenizer))
|
||||
return hits
|
||||
|
||||
|
||||
def evaluate_qa_recall(ctxs, answers, k_values: Union[int, List[int]]=100):
|
||||
# compute Recall@k for QA task
|
||||
data = []
|
||||
assert len(ctxs) == len(answers)
|
||||
for i in range(len(ctxs)):
|
||||
_ctxs, _answers = ctxs[i], answers[i]
|
||||
data.append({
|
||||
'answers': _answers,
|
||||
'ctxs': _ctxs,
|
||||
})
|
||||
tokenizer = SimpleTokenizer()
|
||||
get_score_partial = partial(check_answer, tokenizer=tokenizer)
|
||||
|
||||
scores = map(get_score_partial, data)
|
||||
|
||||
n_docs = len(data[0]['ctxs'])
|
||||
top_k_hits = [0] * n_docs
|
||||
for question_hits in scores:
|
||||
best_hit = next((i for i, x in enumerate(question_hits) if x), None)
|
||||
if best_hit is not None:
|
||||
top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
|
||||
|
||||
if isinstance(k_values, int):
|
||||
k = min(k_values, len(top_k_hits))
|
||||
return top_k_hits[k - 1] / len(data)
|
||||
else:
|
||||
scores = []
|
||||
for k in k_values:
|
||||
k = min(k, len(top_k_hits))
|
||||
scores.append(top_k_hits[k - 1] / len(data))
|
||||
return scores
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
"""
|
||||
adapted from chemdataextractor.text.normalize
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
Tools for normalizing text.
|
||||
https://github.com/mcs07/ChemDataExtractor
|
||||
:copyright: Copyright 2016 by Matt Swain.
|
||||
:license: MIT
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining
|
||||
a copy of this software and associated documentation files (the
|
||||
'Software'), to deal in the Software without restriction, including
|
||||
without limitation the rights to use, copy, modify, merge, publish,
|
||||
distribute, sublicense, and/or sell copies of the Software, and to
|
||||
permit persons to whom the Software is furnished to do so, subject to
|
||||
the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be
|
||||
included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
#: Control characters.
|
||||
CONTROLS = {
|
||||
'\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u000e', '\u000f', '\u0011',
|
||||
'\u0012', '\u0013', '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001a', '\u001b',
|
||||
}
|
||||
# There are further control characters, but they are instead replaced with a space by unicode normalization
|
||||
# '\u0009', '\u000a', '\u000b', '\u000c', '\u000d', '\u001c', '\u001d', '\u001e', '\u001f'
|
||||
|
||||
|
||||
#: Hyphen and dash characters.
|
||||
HYPHENS = {
|
||||
'-', # \u002d Hyphen-minus
|
||||
'‐', # \u2010 Hyphen
|
||||
'‑', # \u2011 Non-breaking hyphen
|
||||
'⁃', # \u2043 Hyphen bullet
|
||||
'‒', # \u2012 figure dash
|
||||
'–', # \u2013 en dash
|
||||
'—', # \u2014 em dash
|
||||
'―', # \u2015 horizontal bar
|
||||
}
|
||||
|
||||
#: Minus characters.
|
||||
MINUSES = {
|
||||
'-', # \u002d Hyphen-minus
|
||||
'−', # \u2212 Minus
|
||||
'-', # \uff0d Full-width Hyphen-minus
|
||||
'⁻', # \u207b Superscript minus
|
||||
}
|
||||
|
||||
#: Plus characters.
|
||||
PLUSES = {
|
||||
'+', # \u002b Plus
|
||||
'+', # \uff0b Full-width Plus
|
||||
'⁺', # \u207a Superscript plus
|
||||
}
|
||||
|
||||
#: Slash characters.
|
||||
SLASHES = {
|
||||
'/', # \u002f Solidus
|
||||
'⁄', # \u2044 Fraction slash
|
||||
'∕', # \u2215 Division slash
|
||||
}
|
||||
|
||||
#: Tilde characters.
|
||||
TILDES = {
|
||||
'~', # \u007e Tilde
|
||||
'˜', # \u02dc Small tilde
|
||||
'⁓', # \u2053 Swung dash
|
||||
'∼', # \u223c Tilde operator #in mbert vocab
|
||||
'∽', # \u223d Reversed tilde
|
||||
'∿', # \u223f Sine wave
|
||||
'〜', # \u301c Wave dash #in mbert vocab
|
||||
'~', # \uff5e Full-width tilde #in mbert vocab
|
||||
}
|
||||
|
||||
#: Apostrophe characters.
|
||||
APOSTROPHES = {
|
||||
"'", # \u0027
|
||||
'’', # \u2019
|
||||
'՚', # \u055a
|
||||
'Ꞌ', # \ua78b
|
||||
'ꞌ', # \ua78c
|
||||
''', # \uff07
|
||||
}
|
||||
|
||||
#: Single quote characters.
|
||||
SINGLE_QUOTES = {
|
||||
"'", # \u0027
|
||||
'‘', # \u2018
|
||||
'’', # \u2019
|
||||
'‚', # \u201a
|
||||
'‛', # \u201b
|
||||
|
||||
}
|
||||
|
||||
#: Double quote characters.
|
||||
DOUBLE_QUOTES = {
|
||||
'"', # \u0022
|
||||
'“', # \u201c
|
||||
'”', # \u201d
|
||||
'„', # \u201e
|
||||
'‟', # \u201f
|
||||
}
|
||||
|
||||
#: Accent characters.
|
||||
ACCENTS = {
|
||||
'`', # \u0060
|
||||
'´', # \u00b4
|
||||
}
|
||||
|
||||
#: Prime characters.
|
||||
PRIMES = {
|
||||
'′', # \u2032
|
||||
'″', # \u2033
|
||||
'‴', # \u2034
|
||||
'‵', # \u2035
|
||||
'‶', # \u2036
|
||||
'‷', # \u2037
|
||||
'⁗', # \u2057
|
||||
}
|
||||
|
||||
#: Quote characters, including apostrophes, single quotes, double quotes, accents and primes.
|
||||
QUOTES = APOSTROPHES | SINGLE_QUOTES | DOUBLE_QUOTES | ACCENTS | PRIMES
|
||||
|
||||
def normalize_text(text: str):
|
||||
for control in CONTROLS:
|
||||
text = text.replace(control, '')
|
||||
text = text.replace('\u000b', ' ').replace('\u000c', ' ').replace(u'\u0085', ' ')
|
||||
|
||||
for hyphen in HYPHENS | MINUSES:
|
||||
text = text.replace(hyphen, '-')
|
||||
text = text.replace('\u00ad', '')
|
||||
|
||||
for double_quote in DOUBLE_QUOTES:
|
||||
text = text.replace(double_quote, '"') # \u0022
|
||||
for single_quote in (SINGLE_QUOTES | APOSTROPHES | ACCENTS):
|
||||
text = text.replace(single_quote, "'") # \u0027
|
||||
text = text.replace('′', "'") # \u2032 prime
|
||||
text = text.replace('‵', "'") # \u2035 reversed prime
|
||||
text = text.replace('″', "''") # \u2033 double prime
|
||||
text = text.replace('‶', "''") # \u2036 reversed double prime
|
||||
text = text.replace('‴', "'''") # \u2034 triple prime
|
||||
text = text.replace('‷', "'''") # \u2037 reversed triple prime
|
||||
text = text.replace('⁗', "''''") # \u2057 quadruple prime
|
||||
|
||||
text = text.replace('…', '...').replace(' . . . ', ' ... ') # \u2026
|
||||
|
||||
for slash in SLASHES:
|
||||
text = text.replace(slash, '/')
|
||||
|
||||
#for tilde in TILDES:
|
||||
# text = text.replace(tilde, '~')
|
||||
|
||||
return text
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
from FlagEmbedding.abc.evaluation import (
|
||||
AbsEvalArgs as MLDREvalArgs,
|
||||
AbsEvalModelArgs as MLDREvalModelArgs,
|
||||
)
|
||||
|
||||
from .data_loader import MLDREvalDataLoader
|
||||
from .runner import MLDREvalRunner
|
||||
|
||||
__all__ = [
|
||||
"MLDREvalArgs",
|
||||
"MLDREvalModelArgs",
|
||||
"MLDREvalRunner",
|
||||
"MLDREvalDataLoader",
|
||||
]
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
from transformers import HfArgumentParser
|
||||
|
||||
from FlagEmbedding.evaluation.mldr import (
|
||||
MLDREvalArgs, MLDREvalModelArgs,
|
||||
MLDREvalRunner
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((
|
||||
MLDREvalArgs,
|
||||
MLDREvalModelArgs
|
||||
))
|
||||
|
||||
eval_args, model_args = parser.parse_args_into_dataclasses()
|
||||
eval_args: MLDREvalArgs
|
||||
model_args: MLDREvalModelArgs
|
||||
|
||||
runner = MLDREvalRunner(
|
||||
eval_args=eval_args,
|
||||
model_args=model_args
|
||||
)
|
||||
|
||||
runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,184 @@
|
|||
import os
|
||||
import json
|
||||
import logging
|
||||
import datasets
|
||||
from tqdm import tqdm
|
||||
from typing import List, Optional
|
||||
|
||||
from FlagEmbedding.abc.evaluation import AbsEvalDataLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MLDREvalDataLoader(AbsEvalDataLoader):
|
||||
"""
|
||||
Data loader class for MLDR.
|
||||
"""
|
||||
def available_dataset_names(self) -> List[str]:
|
||||
"""
|
||||
Get the available dataset names.
|
||||
|
||||
Returns:
|
||||
List[str]: All the available dataset names.
|
||||
"""
|
||||
return ["ar", "de", "en", "es", "fr", "hi", "it", "ja", "ko", "pt", "ru", "th", "zh"]
|
||||
|
||||
def available_splits(self, dataset_name: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Get the avaialble splits.
|
||||
|
||||
Args:
|
||||
dataset_name (Optional[str], optional): Dataset name. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
List[str]: All the available splits for the dataset.
|
||||
"""
|
||||
return ["train", "dev", "test"]
|
||||
|
||||
def _load_remote_corpus(
|
||||
self,
|
||||
dataset_name: str,
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load the corpus dataset from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of corpus.
|
||||
"""
|
||||
corpus = datasets.load_dataset(
|
||||
"Shitao/MLDR", f"corpus-{dataset_name}",
|
||||
cache_dir=self.cache_dir,
|
||||
trust_remote_code=True,
|
||||
download_mode=self.hf_download_mode
|
||||
)["corpus"]
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, "corpus.jsonl")
|
||||
corpus_dict = {}
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for data in tqdm(corpus, desc="Loading and Saving corpus"):
|
||||
docid, text = str(data["docid"]), data["text"]
|
||||
_data = {
|
||||
"id": docid,
|
||||
"text": text
|
||||
}
|
||||
corpus_dict[docid] = {"text": text}
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} corpus saved to {save_path}")
|
||||
else:
|
||||
corpus_dict = {str(data["docid"]): {"text": data["text"]} for data in tqdm(corpus, desc="Loading corpus")}
|
||||
return datasets.DatasetDict(corpus_dict)
|
||||
|
||||
def _load_remote_qrels(
|
||||
self,
|
||||
dataset_name: str,
|
||||
split: str = "test",
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load the qrels from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
split (str, optional): Split of the dataset. Defaults to ``'test'``.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of qrel.
|
||||
"""
|
||||
qrels_data = datasets.load_dataset(
|
||||
"Shitao/MLDR", dataset_name,
|
||||
cache_dir=self.cache_dir,
|
||||
trust_remote_code=True,
|
||||
download_mode=self.hf_download_mode
|
||||
)[split]
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, f"{split}_qrels.jsonl")
|
||||
qrels_dict = {}
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for data in tqdm(qrels_data, desc="Loading and Saving qrels"):
|
||||
qid = str(data["query_id"])
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
for doc in data["positive_passages"]:
|
||||
docid = str(doc["docid"])
|
||||
_data = {
|
||||
"qid": qid,
|
||||
"docid": docid,
|
||||
"relevance": 1
|
||||
}
|
||||
qrels_dict[qid][docid] = 1
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
for doc in data["negative_passages"]:
|
||||
docid = str(doc["docid"])
|
||||
_data = {
|
||||
"qid": qid,
|
||||
"docid": docid,
|
||||
"relevance": 0
|
||||
}
|
||||
qrels_dict[qid][docid] = 0
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} qrels saved to {save_path}")
|
||||
else:
|
||||
qrels_dict = {}
|
||||
for data in tqdm(qrels_data, desc="Loading qrels"):
|
||||
qid = str(data["query_id"])
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
for doc in data["positive_passages"]:
|
||||
docid = str(doc["docid"])
|
||||
qrels_dict[qid][docid] = 1
|
||||
for doc in data["negative_passages"]:
|
||||
docid = str(doc["docid"])
|
||||
qrels_dict[qid][docid] = 0
|
||||
return datasets.DatasetDict(qrels_dict)
|
||||
|
||||
def _load_remote_queries(
|
||||
self,
|
||||
dataset_name: str,
|
||||
split: str = "test",
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load the queries from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
split (str, optional): Split of the dataset. Defaults to ``'test'``.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of queries.
|
||||
"""
|
||||
queries_data = datasets.load_dataset(
|
||||
"Shitao/MLDR", dataset_name,
|
||||
cache_dir=self.cache_dir,
|
||||
trust_remote_code=True,
|
||||
download_mode=self.hf_download_mode
|
||||
)[split]
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, f"{split}_queries.jsonl")
|
||||
queries_dict = {}
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for data in tqdm(queries_data, desc="Loading and Saving queries"):
|
||||
qid, query = str(data["query_id"]), data["query"]
|
||||
_data = {
|
||||
"id": qid,
|
||||
"text": query
|
||||
}
|
||||
queries_dict[qid] = query
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} queries saved to {save_path}")
|
||||
else:
|
||||
queries_dict = {}
|
||||
for data in tqdm(queries_data, desc="Loading queries"):
|
||||
qid, query = str(data["query_id"]), data["query"]
|
||||
queries_dict[qid] = query
|
||||
return datasets.DatasetDict(queries_dict)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
from FlagEmbedding.abc.evaluation import AbsEvalRunner
|
||||
|
||||
from .data_loader import MLDREvalDataLoader
|
||||
|
||||
|
||||
class MLDREvalRunner(AbsEvalRunner):
|
||||
"""
|
||||
Evaluation runner of MIRACL.
|
||||
"""
|
||||
def load_data_loader(self) -> MLDREvalDataLoader:
|
||||
"""Load the data loader instance by args.
|
||||
|
||||
Returns:
|
||||
MLDREvalDataLoader: The MLDR data loader instance.
|
||||
"""
|
||||
data_loader = MLDREvalDataLoader(
|
||||
eval_name=self.eval_args.eval_name,
|
||||
dataset_dir=self.eval_args.dataset_dir,
|
||||
cache_dir=self.eval_args.cache_path,
|
||||
token=self.eval_args.token,
|
||||
force_redownload=self.eval_args.force_redownload,
|
||||
)
|
||||
return data_loader
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
from FlagEmbedding.abc.evaluation import (
|
||||
AbsEvalArgs as MSMARCOEvalArgs,
|
||||
AbsEvalModelArgs as MSMARCOEvalModelArgs,
|
||||
)
|
||||
|
||||
from .data_loader import MSMARCOEvalDataLoader
|
||||
from .runner import MSMARCOEvalRunner
|
||||
|
||||
__all__ = [
|
||||
"MSMARCOEvalArgs",
|
||||
"MSMARCOEvalModelArgs",
|
||||
"MSMARCOEvalRunner",
|
||||
"MSMARCOEvalDataLoader",
|
||||
]
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
from transformers import HfArgumentParser
|
||||
|
||||
from FlagEmbedding.evaluation.msmarco import (
|
||||
MSMARCOEvalArgs, MSMARCOEvalModelArgs,
|
||||
MSMARCOEvalRunner
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((
|
||||
MSMARCOEvalArgs,
|
||||
MSMARCOEvalModelArgs
|
||||
))
|
||||
|
||||
eval_args, model_args = parser.parse_args_into_dataclasses()
|
||||
eval_args: MSMARCOEvalArgs
|
||||
model_args: MSMARCOEvalModelArgs
|
||||
|
||||
runner = MSMARCOEvalRunner(
|
||||
eval_args=eval_args,
|
||||
model_args=model_args
|
||||
)
|
||||
|
||||
runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,277 @@
|
|||
import os
|
||||
import json
|
||||
import logging
|
||||
import datasets
|
||||
from tqdm import tqdm
|
||||
from typing import List, Optional
|
||||
|
||||
from FlagEmbedding.abc.evaluation import AbsEvalDataLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MSMARCOEvalDataLoader(AbsEvalDataLoader):
|
||||
"""
|
||||
Data loader class for MSMARCO.
|
||||
"""
|
||||
def available_dataset_names(self) -> List[str]:
|
||||
"""
|
||||
Get the available dataset names.
|
||||
|
||||
Returns:
|
||||
List[str]: All the available dataset names.
|
||||
"""
|
||||
return ["passage", "document"]
|
||||
|
||||
def available_splits(self, dataset_name: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Get the avaialble splits.
|
||||
|
||||
Args:
|
||||
dataset_name (Optional[str], optional): Dataset name. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
List[str]: All the available splits for the dataset.
|
||||
"""
|
||||
return ["dev", "dl19", "dl20"]
|
||||
|
||||
def _load_remote_corpus(
|
||||
self,
|
||||
dataset_name: str,
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load the corpus dataset from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of corpus.
|
||||
"""
|
||||
if dataset_name == 'passage':
|
||||
corpus = datasets.load_dataset(
|
||||
'Tevatron/msmarco-passage-corpus',
|
||||
'default',
|
||||
trust_remote_code=True,
|
||||
cache_dir=self.cache_dir,
|
||||
download_mode=self.hf_download_mode
|
||||
)['train']
|
||||
else:
|
||||
corpus = datasets.load_dataset(
|
||||
'irds/msmarco-document',
|
||||
'docs',
|
||||
trust_remote_code=True,
|
||||
cache_dir=self.cache_dir,
|
||||
download_mode=self.hf_download_mode
|
||||
)
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, "corpus.jsonl")
|
||||
corpus_dict = {}
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for data in tqdm(corpus, desc="Loading and Saving corpus"):
|
||||
if dataset_name == 'passage':
|
||||
_data = {
|
||||
"id": data["docid"],
|
||||
"title": data["title"],
|
||||
"text": data["text"]
|
||||
}
|
||||
corpus_dict[data["docid"]] = {
|
||||
"title": data["title"],
|
||||
"text": data["text"]
|
||||
}
|
||||
else:
|
||||
_data = {
|
||||
"id": data["doc_id"],
|
||||
"title": data["title"],
|
||||
"text": data["body"]
|
||||
}
|
||||
corpus_dict[data["doc_id"]] = {
|
||||
"title": data["title"],
|
||||
"text": data["body"]
|
||||
}
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} corpus saved to {save_path}")
|
||||
else:
|
||||
if dataset_name == 'passage':
|
||||
corpus_dict = {data["docid"]: {"title": data["title"], "text": data["text"]} for data in tqdm(corpus, desc="Loading corpus")}
|
||||
else:
|
||||
corpus_dict = {data["doc_id"]: {"title": data["title"], "text": data["body"]} for data in tqdm(corpus, desc="Loading corpus")}
|
||||
return datasets.DatasetDict(corpus_dict)
|
||||
|
||||
def _load_remote_qrels(
|
||||
self,
|
||||
dataset_name: Optional[str] = None,
|
||||
split: str = 'dev',
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load the qrels from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
split (str, optional): Split of the dataset. Defaults to ``'dev'``.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of qrel.
|
||||
"""
|
||||
if dataset_name == 'passage':
|
||||
if split == 'dev':
|
||||
qrels = datasets.load_dataset(
|
||||
'BeIR/msmarco-qrels',
|
||||
split='validation',
|
||||
trust_remote_code=True,
|
||||
cache_dir=self.cache_dir,
|
||||
download_mode=self.hf_download_mode
|
||||
)
|
||||
qrels_download_url = None
|
||||
elif split == 'dl19':
|
||||
qrels_download_url = "https://trec.nist.gov/data/deep/2019qrels-pass.txt"
|
||||
else:
|
||||
qrels_download_url = "https://trec.nist.gov/data/deep/2020qrels-pass.txt"
|
||||
else:
|
||||
if split == 'dev':
|
||||
qrels_download_url = "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz"
|
||||
elif split == 'dl19':
|
||||
qrels_download_url = "https://trec.nist.gov/data/deep/2019qrels-docs.txt"
|
||||
else:
|
||||
qrels_download_url = "https://trec.nist.gov/data/deep/2020qrels-docs.txt"
|
||||
|
||||
if qrels_download_url is not None:
|
||||
qrels_save_path = self._download_file(qrels_download_url, self.cache_dir)
|
||||
else:
|
||||
qrels_save_path = None
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, f"{split}_qrels.jsonl")
|
||||
qrels_dict = {}
|
||||
if qrels_save_path is not None:
|
||||
with open(save_path, "w", encoding="utf-8") as f1:
|
||||
with open(qrels_save_path, "r", encoding="utf-8") as f2:
|
||||
for line in tqdm(f2.readlines(), desc="Loading and Saving qrels"):
|
||||
qid, _, docid, rel = line.strip().split()
|
||||
qid, docid, rel = str(qid), str(docid), int(rel)
|
||||
_data = {
|
||||
"qid": qid,
|
||||
"docid": docid,
|
||||
"relevance": rel
|
||||
}
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
qrels_dict[qid][docid] = rel
|
||||
f1.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
else:
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for data in tqdm(qrels, desc="Loading and Saving qrels"):
|
||||
qid, docid, rel = str(data['query-id']), str(data['corpus-id']), int(data['score'])
|
||||
_data = {
|
||||
"qid": qid,
|
||||
"docid": docid,
|
||||
"relevance": rel
|
||||
}
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
qrels_dict[qid][docid] = rel
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} qrels saved to {save_path}")
|
||||
else:
|
||||
qrels_dict = {}
|
||||
if qrels_save_path is None:
|
||||
with open(qrels_save_path, "r", encoding="utf-8") as f:
|
||||
for line in tqdm(f.readlines(), desc="Loading qrels"):
|
||||
qid, _, docid, rel = line.strip().split()
|
||||
qid, docid, rel = str(qid), str(docid), int(rel)
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
qrels_dict[qid][docid] = rel
|
||||
else:
|
||||
for data in tqdm(qrels, desc="Loading queries"):
|
||||
qid, docid, rel = str(data['query-id']), str(data['corpus-id']), int(data['score'])
|
||||
if qid not in qrels_dict:
|
||||
qrels_dict[qid] = {}
|
||||
qrels_dict[qid][docid] = rel
|
||||
return datasets.DatasetDict(qrels_dict)
|
||||
|
||||
def _load_remote_queries(
|
||||
self,
|
||||
dataset_name: Optional[str] = None,
|
||||
split: str = 'test',
|
||||
save_dir: Optional[str] = None
|
||||
) -> datasets.DatasetDict:
|
||||
"""Load the queries from HF.
|
||||
|
||||
Args:
|
||||
dataset_name (str): Name of the dataset.
|
||||
split (str, optional): Split of the dataset. Defaults to ``'test'``.
|
||||
save_dir (Optional[str], optional): Directory to save the dataset. Defaults to ``None``.
|
||||
|
||||
Returns:
|
||||
datasets.DatasetDict: Loaded datasets instance of queries.
|
||||
"""
|
||||
if split == 'dev':
|
||||
if dataset_name == 'passage':
|
||||
queries = datasets.load_dataset(
|
||||
'BeIR/msmarco',
|
||||
'queries',
|
||||
trust_remote_code=True,
|
||||
cache_dir=self.cache_dir,
|
||||
download_mode=self.hf_download_mode
|
||||
)['queries']
|
||||
queries_save_path = None
|
||||
else:
|
||||
queries_download_url = "https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-docdev-qrels.tsv.gz"
|
||||
queries_save_path = self._download_gz_file(queries_download_url, self.cache_dir)
|
||||
else:
|
||||
year = split.replace("dl", "")
|
||||
queries_download_url = f"https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco-test20{year}-queries.tsv.gz"
|
||||
queries_save_path = self._download_gz_file(queries_download_url, self.cache_dir)
|
||||
|
||||
qrels = self.load_qrels(dataset_name=dataset_name, split=split)
|
||||
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = os.path.join(save_dir, f"{split}_queries.jsonl")
|
||||
queries_dict = {}
|
||||
if queries_save_path is not None:
|
||||
with open(save_path, "w", encoding="utf-8") as f1:
|
||||
with open(queries_save_path, "r", encoding="utf-8") as f2:
|
||||
for line in tqdm(f2.readlines(), desc="Loading and Saving queries"):
|
||||
qid, query = line.strip().split("\t")
|
||||
if qid not in qrels.keys(): continue
|
||||
qid = str(qid)
|
||||
_data = {
|
||||
"id": qid,
|
||||
"text": query
|
||||
}
|
||||
queries_dict[qid] = query
|
||||
f1.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
else:
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
for data in tqdm(queries, desc="Loading and Saving queries"):
|
||||
qid, query = data['_id'], data['text']
|
||||
if qid not in qrels.keys(): continue
|
||||
_data = {
|
||||
"id": qid,
|
||||
"text": query
|
||||
}
|
||||
queries_dict[qid] = query
|
||||
f.write(json.dumps(_data, ensure_ascii=False) + "\n")
|
||||
logging.info(f"{self.eval_name} {dataset_name} queries saved to {save_path}")
|
||||
else:
|
||||
queries_dict = {}
|
||||
if queries_save_path is not None:
|
||||
with open(queries_save_path, "r", encoding="utf-8") as f:
|
||||
for line in tqdm(f.readlines(), desc="Loading queries"):
|
||||
qid, query = line.strip().split("\t")
|
||||
qid = str(qid)
|
||||
if qid not in qrels.keys(): continue
|
||||
queries_dict[qid] = query
|
||||
else:
|
||||
for data in tqdm(queries, desc="Loading queries"):
|
||||
qid, query = data['_id'], data['text']
|
||||
if qid not in qrels.keys(): continue
|
||||
queries_dict[qid] = query
|
||||
return datasets.DatasetDict(queries_dict)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
from FlagEmbedding.abc.evaluation import AbsEvalRunner
|
||||
|
||||
from .data_loader import MSMARCOEvalDataLoader
|
||||
|
||||
|
||||
class MSMARCOEvalRunner(AbsEvalRunner):
|
||||
"""
|
||||
Evaluation runner of MSMARCO.
|
||||
"""
|
||||
def load_data_loader(self) -> MSMARCOEvalDataLoader:
|
||||
"""Load the data loader instance by args.
|
||||
|
||||
Returns:
|
||||
MSMARCOEvalDataLoader: The MSMARCO data loader instance.
|
||||
"""
|
||||
data_loader = MSMARCOEvalDataLoader(
|
||||
eval_name=self.eval_args.eval_name,
|
||||
dataset_dir=self.eval_args.dataset_dir,
|
||||
cache_dir=self.eval_args.cache_path,
|
||||
token=self.eval_args.token,
|
||||
force_redownload=self.eval_args.force_redownload,
|
||||
)
|
||||
return data_loader
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
from FlagEmbedding.abc.evaluation import (
|
||||
AbsEvalModelArgs as MTEBEvalModelArgs,
|
||||
)
|
||||
|
||||
from .arguments import MTEBEvalArgs
|
||||
from .runner import MTEBEvalRunner
|
||||
|
||||
__all__ = [
|
||||
"MTEBEvalArgs",
|
||||
"MTEBEvalModelArgs",
|
||||
"MTEBEvalRunner",
|
||||
]
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
from transformers import HfArgumentParser
|
||||
|
||||
from FlagEmbedding.evaluation.mteb import (
|
||||
MTEBEvalArgs, MTEBEvalModelArgs,
|
||||
MTEBEvalRunner
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((
|
||||
MTEBEvalArgs,
|
||||
MTEBEvalModelArgs
|
||||
))
|
||||
|
||||
eval_args, model_args = parser.parse_args_into_dataclasses()
|
||||
eval_args: MTEBEvalArgs
|
||||
model_args: MTEBEvalModelArgs
|
||||
|
||||
runner = MTEBEvalRunner(
|
||||
eval_args=eval_args,
|
||||
model_args=model_args
|
||||
)
|
||||
|
||||
runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from FlagEmbedding.abc.evaluation.arguments import AbsEvalArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class MTEBEvalArgs(AbsEvalArgs):
|
||||
"""
|
||||
Argument class for MTEB evaluation.
|
||||
"""
|
||||
languages: List[str] = field(
|
||||
default=None, metadata={"help": "Languages to evaluate. Default: eng"}
|
||||
)
|
||||
tasks: List[str] = field(
|
||||
default=None, metadata={"help": "Tasks to evaluate. Default: None"}
|
||||
)
|
||||
task_types: List[str] = field(
|
||||
default=None, metadata={"help": "The task types to evaluate. Default: None"}
|
||||
)
|
||||
use_special_instructions: bool = field(
|
||||
default=False, metadata={"help": "Whether to use specific instructions in `prompts.py` for evaluation. Default: False"}
|
||||
)
|
||||
examples_path: str = field(
|
||||
default=None, metadata={"help": "Use specific examples in the path. Default: None"}
|
||||
)
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
text,label
|
||||
"I wish I could have used this head set but the day I received it it wouldn't even turn on and I really wanted this product to work I'm very disappointed.","counterfactual"
|
||||
"I would advise that instead of trying to follow these poor instructions, Google it.","not-counterfactual"
|
||||
"I wrote to Monster customer service before ordering and they told me it would be fine to use without a converter and it was absolutely true.","not-counterfactual"
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
text,label
|
||||
"Hunting the Hard Way Thia was a gift for my Husband, who loved the book. It arrived on the date we were told it would.",positive
|
||||
"Poor DVD Has too many interviews with people at the Live THomas day in Penn. My kids were annoyed and hated this DVD.",negative
|
||||
"Ludicrous and silly I remember getting this book so faintly that that says alot about my opinion of it. Basically, while I will entertain lots of odd ideas and theories, this book was basically silly.",negative
|
||||
"Artistry I think that the Deodato concerts are very rich, as he used real strings and band musicians, as well as you can appreciate the John Tropea excelent renditions on guitar.",positive
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
text,label
|
||||
"DO NOT ORDER THIS\n\nThis isn't what's described at all. Taking it out of the package lace was cut upon arrival, wig was cut to like 14 inch, not curly, and smelled like cigarettes. I obviously was sent what someone returned, disgusting.Not what I ordered at all, not pleased at all. I want my money back DO NOT ORDER","1 star"
|
||||
"And I can’t return it\n\nThis product seemed like good quality but it does not stay stuck to the soles at all. You walk a few steps and then you find the black shoe grip somewhere on the floor.","2 star"
|
||||
"Three Stars\n\nnew yearly subscription plan is horrible, but the product still works as it did in the past","3 star"
|
||||
"I like how it has lots of pockets to put stuff ...\n\nI like how it has lots of pockets to put stuff in. I would have liked to have a shorter securing strap so it would not slide around so much. Good product.","4 star"
|
||||
"Great\n\nIt is really good. That's my favorite. THANK YOU","5 star"
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
query,pos
|
||||
"People will die if we don’t do animal testing Every year, 23 new drugs are introduced in the UK alone.[13] Almost all will be tested on animals. A new drug will be used for a long time. Think of all the people saved by the use of penicillin. If drugs cost more to test, that means drug companies will develop less. This means more people suffering and dying","animals science science general ban animal testing junior Many of these drugs are “me too” drugs – ones with a slight change that doesn’t make much difference to an existing drug. [14] So often the benefits from animal testing are marginal, and even if there was a slight increase in human suffering, it would be worth it based on the animal suffering saved."
|
||||
"Survival of the fittest It is natural for human beings to farm, kill, and eat other species. In the wild there is a brutal struggle for existence as is shown by Darwin’s On the Origin of the Species. The fact that we humans have succeeded in that struggle by exploiting our natural environment means that we have a natural right over lower species. The concept of survival of the fittest may seem outdated but it is still the defining order of nature. In fact farming animals is much less brutal than the pain and hardship that animals inflict on each other naturally in the wild.","The claim of human entitlement over other species based on 'survival of the fittest' is flawed. While Darwin's theory highlights competition, it doesn't justify exploitation. Our capacity for empathy and moral reasoning surpasses mere survival instincts. Farming still inflicts suffering, contradicting the notion of human superiority. Ethical considerations should guide our treatment of animals, not outdated notions of natural selection."
|
||||
"Underground Nuclear Storage is Expensive. Underground nuclear storage is expensive. This is because the deep geological repositories needed to deal with such waste are difficult to construct. This is because said repositories need to be 300m underground and also need failsafe systems so that they can be sealed off should there be a leak. For smaller countries, implementing this idea is almost completely impossible. Further, the maintenance of the facilities also requires a lot of long-term investment as the structural integrity of the facilities must consistently be monitored and maintained so that if there is a leak, the relevant authorities can be informed quickly and efficiently. This is seen with the Yucca mountain waste repository site which has cost billions of dollars since the 1990s and was eventually halted due to public fears about nuclear safety.","While initial construction and maintenance entail significant costs, advancements in technology offer more cost-effective solutions. Modular storage designs and improved monitoring systems mitigate expenses. Collaborative international efforts can also distribute costs. Additionally, public concerns can be addressed through transparent safety protocols and community engagement, ensuring responsible nuclear waste management without exorbitant expenditure. Underground nuclear storage isn't inherently prohibitive."
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
text,label
|
||||
"A Novel Approach to Enhancing Cybersecurity in Smart Grids through Deep Reinforcement Learning The integration of renewable energy sources and advanced metering infrastructure in smart grids introduces complex cybersecurity challenges. In this paper, we propose a novel approach utilizing deep reinforcement learning (DRL) to enhance the resilience of smart grids against cyber attacks. Our method leverages DRL agents to dynamically optimize intrusion detection and response strategies based on real-time grid conditions and attack patterns. We demonstrate through simulations on a realistic smart grid testbed that our approach effectively reduces the impact of cyber threats while maintaining grid operational efficiency and reliability. The results highlight significant improvements in security posture compared to traditional rule-based and anomaly detection approaches.",cs
|
||||
"Dynamics of Frobenius Endomorphisms in Characteristic p This paper investigates the dynamics of Frobenius endomorphisms in characteristic 𝑝, focusing on their algebraic and arithmetic properties. We explore the behavior of Frobenius endomorphisms on varieties over finite fields and delve into their applications in number theory and algebraic geometry. Specifically, we analyze the distribution of fixed points, the growth rates of orbits under iteration, and connections to zeta functions and L-functions. Theoretical results are complemented by computational experiments that illustrate the interplay between Frobenius endomorphisms and geometric structures. Our findings contribute to a deeper understanding of the arithmetic nature of varieties and their representations in characteristic 𝑝, offering insights into fundamental questions in modern algebraic and arithmetic geometry.",math
|
||||
"Probing Exoplanetary Atmospheres Using Transmission Spectroscopy with the James Webb Space Telescope Transmission spectroscopy has revolutionized our understanding of exoplanetary atmospheres, revealing key insights into their chemical compositions and physical properties. With the upcoming launch of the James Webb Space Telescope (JWST), we explore the potential of this technique to characterize exoplanetary atmospheres across a wide range of wavelengths and planetary types. We present a comprehensive analysis framework that incorporates high-resolution spectroscopic data and advanced atmospheric models to interpret transmission spectra obtained by JWST. Our simulations predict detectability thresholds for key molecular species and atmospheric features, offering critical guidance for future observational campaigns aimed at unraveling the diversity and origins of exoplanetary atmospheres.",astro-ph
|
||||
"Quantum Coherence and Information Transfer in Photosynthetic Complexes: Insights from Coherent Spectroscopy Photosynthetic complexes are renowned for their efficient energy transfer mechanisms, driven by quantum coherence phenomena over femtosecond timescales. This paper explores the role of coherent spectroscopy techniques in elucidating the quantum dynamics underlying energy transfer processes in natural photosynthetic systems. We review recent experimental findings and theoretical models that highlight the significance of quantum coherence in optimizing energy capture and transport efficiency in photosynthetic complexes. Our analysis integrates insights from ultrafast spectroscopy experiments with advanced quantum mechanical simulations, providing a comprehensive framework for understanding the interplay between coherence, environmental influences, and biological functionality in photosynthesis.",quant-ph
|
||||
"Quantum Hall Effect in Moiré Superlattices of Twisted Bilayer Graphene The discovery of the quantum Hall effect in moiré superlattices formed by twisted bilayer graphene has opened new avenues in the study of correlated electron systems. This paper investigates the emergence of fractional quantum Hall states and their robustness against disorder and varying twist angles in twisted bilayer graphene. We analyze experimental observations of Landau level spectra and magnetotransport measurements, revealing distinctive features such as enhanced localization and unconventional symmetry breaking effects. Our theoretical framework integrates effective model descriptions and numerical simulations to elucidate the underlying mechanisms driving the quantum Hall phenomena in moiré superlattices, paving the way for future applications in quantum devices and topological materials.",cond-mat
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
text,label
|
||||
"A Survey on Graph Neural Networks: Algorithms and Applications",cs
|
||||
"Hamiltonian Dynamics and KAM Theory for Infinite-Dimensional Systems",math
|
||||
"Dark Matter Distribution in Dwarf Spheroidal Galaxies: Constraints from Stellar Kinematics",astro-ph
|
||||
"Decoherence and Quantum Error Correction in Topological Quantum Computers",quant-ph
|
||||
"Spin-Orbit Coupling Effects in Low-Dimensional Quantum Materials",cond-mat
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
query,positive
|
||||
angularjs infinite scroll in a container,AngularJS ng-infinite-scroll not working on a specific container/div
|
||||
Java: Efficiently converting an array of longs to an array of bytes,Most Compact way to Serialize an Array of Longs in Java
|
||||
PyVISA missing methods,NI VISA + pyVisa on Mac OS X (Snow Leopard)
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
sent1,sent2
|
||||
"Recent studies have highlighted the crucial role of p53 in regulating cell cycle progression.","Recent research underscores p53's pivotal function in controlling cellular division."
|
||||
"Neuroscience has revealed intricate pathways linking dopamine to reward and motivation.","Recent neuroscientific findings have illuminated complex dopamine pathways associated with motivation and reward."
|
||||
"Stem cell research holds promise for treating a variety of degenerative diseases.","The potential of stem cell research in combating degenerative illnesses is widely recognized."
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
text,label
|
||||
"What is my money worth in other countries?",exchange_rate
|
||||
"What can I do if my card still hasn't arrived after 2 weeks?",card_arrival
|
||||
"Would I be able to open an account for my daughter?",age_limit
|
||||
"My address details have changed and I want to update them",edit_personal_details
|
||||
"If my cash withdrawal is still not showing, is something wrong?",pending_cash_withdrawal
|
||||
"How long do transfers typically take? Is there a way of speeding the process up? My friend needs the money I sent her desperately.",transfer_not_received_by_recipient
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
text,label
|
||||
"Neural Mechanisms of Social Cognition: A Study on Mirror Neurons and EmpathySocial cognition is the mental process involved in understanding, recognizing, and predicting others' behavior and emotions. In this study, we investigate the role of mirror neurons in the process of empathy by using a combination of functional magnetic resonance imaging (fMRI) and electroencephalography (EEG). Our experiments involve observing the neural activation of participants as they watch videos of individuals experiencing various emotional states. We demonstrate that specific mirror neuron systems in the premotor cortex and the inferior parietal lobule are significantly activated when participants empathize with others. This suggests that mirror neurons might be fundamental to the neural basis of empathy, facilitating an understanding of others' emotions by simulating them internally. These findings provide insights into the neural mechanisms underlying social cognition and offer potential pathways for therapeutic interventions in conditions like autism and psychopathy, where social cognition is often impaired.",neuroscience
|
||||
"Methicillin-resistant Staphylococcus aureus (MRSA) is a major health threat due to its resistance to multiple antibiotics. This study analyzed 50 clinical MRSA isolates using whole-genome sequencing and phenotypic assays. We identified mecA and mecC genes encoding beta-lactam-resistant penicillin-binding proteins. Mutations in rpoB conferred rifampicin resistance, while changes in gyrA and grlA were linked to fluoroquinolone resistance. Biofilm formation was also found to enhance antibiotic resistance. These findings highlight genetic mechanisms and suggest potential targets for developing new treatments against MRSA infections.",microbiology
|
||||
"Deep Learning Approaches for Predicting Protein-Protein Interactions from Sequence Data\nProtein-protein interactions (PPIs) are fundamental to numerous biological processes, and understanding these interactions is critical for uncovering cellular mechanisms and developing therapeutic strategies. Traditional experimental methods for identifying PPIs are labor-intensive and time-consuming, highlighting the need for computational approaches. In this study, we present DeepPPI, a deep learning-based framework designed to predict PPIs directly from protein sequence data. DeepPPI employs a combination of convolutional neural networks (CNNs) and recurrent neural networks (RNNs) to capture both local and global sequence features. We trained DeepPPI on a comprehensive dataset of known PPIs and benchmarked its performance against existing methods, demonstrating superior accuracy and generalizability. Additionally, we applied DeepPPI to predict novel interactions in the human proteome and validated a subset of these predictions experimentally. Our results indicate that DeepPPI not only achieves high prediction accuracy but also provides insights into the structural and functional basis of protein interactions, making it a valuable tool for the bioinformatics community.",bioinformatics
|
||||
"Cell migration, pivotal in wound healing, immune responses, and cancer metastasis, relies on the actin cytoskeleton for membrane protrusions and movement. We explore phosphoinositides' role—key membrane phospholipids—in this process. Using live-cell imaging and FRET-based biosensors, we track phosphoinositide dynamics during migration. Our findings reveal distinct distributions: phosphatidylinositol 4,5-bisphosphate (PIP2) enriches actin polymerization sites, while phosphatidylinositol 3,4,5-trisphosphate (PIP3) predominates in membrane ruffles and lamellipodia. Modulating these phosphoinositides via kinases and phosphatases alters actin filament organization and migration speed, suggesting therapeutic targets for diseases involving abnormal cell migration.",cell biology
|
||||
"Cell membranes, comprising lipids and proteins, regulate molecular transport and signaling. Lipid rafts, enriched in cholesterol and sphingolipids, organize membrane proteins and influence cellular functions. Using AFM and fluorescence microscopy, we studied how lipid rafts and cholesterol impact membrane mechanics. Manipulating cholesterol levels and disrupting rafts with MβCD revealed changes in stiffness and lipid density. Rafts enhance rigidity and resistance to deformation, while cholesterol depletion increases fluidity and reduces stability. Lipid-protein interactions in rafts maintain membrane integrity. These insights into membrane organization offer strategies for manipulating cellular responses through lipid raft modulation.",biophysics
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
text,label
|
||||
"Neural Circuit Dynamics in Decision-Making: A Computational Model of Prefrontal-Striatal Interactions",neuroscience
|
||||
"Metagenomic Insights into Extreme Environments: Microbial Diversity and Functional Adaptations in Antarctic Lakes",microbiology
|
||||
"Machine Learning Approaches for Predicting Protein Structure and Function from Sequence Data",bioinformatics
|
||||
"Regulation of Stem Cell Fate Decisions by the Hippo Signaling Pathway: Implications for Tissue Regeneration and Cancer Therapy",cell biology
|
||||
"Optical Tweezers and Single-Molecule Force Spectroscopy: Probing Protein Folding Dynamics and Mechanical Properties of Biomolecules",biophysics
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
query,positive
|
||||
angularjs infinite scroll in a container,AngularJS ng-infinite-scroll not working on a specific container/div
|
||||
Java: Efficiently converting an array of longs to an array of bytes,Most Compact way to Serialize an Array of Longs in Java
|
||||
PyVISA missing methods,NI VISA + pyVisa on Mac OS X (Snow Leopard)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue