embed-bge-m3/FlagEmbedding/research/BGE_Coder/evaluation/coir_eval/main.py

168 lines
6.1 KiB
Python

import os
import json
import coir
from transformers import HfArgumentParser
from arguments import COIREvalArgs, COIREvalModelArgs
from prompts import get_task_def_by_task_name
from FlagEmbedding import FlagLLMModel, FlagModel
def get_model(model_args: COIREvalModelArgs):
embedder_name_or_path = model_args.embedder_name_or_path
if model_args.embedder_model_class == "encoder-only-base":
embedder = FlagModel(
model_name_or_path=embedder_name_or_path,
normalize_embeddings=model_args.normalize_embeddings,
pooling_method=model_args.pooling_method,
use_fp16=model_args.use_fp16,
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
query_instruction_format=model_args.query_instruction_format_for_retrieval,
devices=model_args.devices,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir,
batch_size=model_args.embedder_batch_size,
query_max_length=model_args.embedder_query_max_length,
passage_max_length=model_args.embedder_passage_max_length,
)
elif model_args.embedder_model_class == "decoder-only-base":
embedder = FlagLLMModel(
model_name_or_path=embedder_name_or_path,
normalize_embeddings=model_args.normalize_embeddings,
pooling_method=model_args.pooling_method,
use_fp16=model_args.use_fp16,
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
query_instruction_format=model_args.query_instruction_format_for_retrieval,
devices=model_args.devices,
examples_for_task=model_args.examples_for_task,
examples_instruction_format=model_args.examples_instruction_format,
trust_remote_code=model_args.trust_remote_code,
cache_dir=model_args.cache_dir,
batch_size=model_args.embedder_batch_size,
query_max_length=model_args.embedder_query_max_length,
passage_max_length=model_args.embedder_passage_max_length,
)
else:
raise ValueError(f"Invalid model class: {model_args.embedder_model_class}")
embedder.model.config._name_or_path = model_args.embedder_name_or_path
class CustomFlagModel:
def __init__(self, model):
self.model = model
def encode_queries(self, queries, show_progress_bar, convert_to_tensor, **kwargs):
if isinstance(queries, str):
queries = [queries]
if isinstance(queries[0], dict):
queries = [(e.get('title') + ' ' + e['text']).strip() for e in queries]
return self.model.encode_queries(queries, **kwargs)
def encode_corpus(self, corpus, show_progress_bar, convert_to_tensor, **kwargs):
if isinstance(corpus, str):
corpus = [corpus]
if isinstance(corpus[0], dict):
corpus = [(e.get('title') + ' ' + e['text']).strip() for e in corpus]
return self.model.encode_corpus(corpus, **kwargs)
def encode(self, corpus, show_progress_bar, convert_to_tensor, **kwargs):
if isinstance(corpus, str):
corpus = [corpus]
if isinstance(corpus[0], dict):
corpus = [(e.get('title') + ' ' + e['text']).strip() for e in corpus]
return self.model.encode(corpus, **kwargs)
return CustomFlagModel(embedder)
def main(
eval_args: COIREvalArgs,
model_args: COIREvalModelArgs
):
model = get_model(model_args)
output_folder = os.path.join(eval_args.output_dir, os.path.basename(model.model.model.config._name_or_path))
all_task = eval_args.tasks
if not isinstance(all_task, list):
all_task = [all_task]
all_results = {}
for task_name in all_task:
save_path = os.path.join(output_folder, f"{task_name}.json")
if os.path.exists(save_path):
with open(save_path, "r", encoding="utf-8") as f:
results = json.load(f)
all_results[task_name] = results['metrics']
continue
tmp_task = coir.get_tasks(tasks=[task_name])
evaluation = coir.COIR(tasks=tmp_task,
batch_size=model_args.embedder_batch_size)
model.model.stop_self_pool()
if eval_args.use_special_instructions:
model.model.query_instruction_for_retrieval = get_task_def_by_task_name(task_name)
results = evaluation.run(model, output_folder=output_folder)
all_results[task_name] = results[task_name]
csn_result = 0
csn_num = 0
csn_ccr_result = 0
csn_ccr_num = 0
pop_keys = []
all_result = 0
all_num = 0
for k in all_results.keys():
if 'CodeSearchNet-ccr' in k:
csn_ccr_result += all_results[k]['NDCG']['NDCG@10']
csn_ccr_num += 1
pop_keys.append(k)
elif 'CodeSearchNet' in k:
csn_result += all_results[k]['NDCG']['NDCG@10']
csn_num += 1
pop_keys.append(k)
else:
all_result += all_results[k]['NDCG']['NDCG@10']
all_num += 1
if csn_num > 0:
print('Using CodeSearchNet')
all_result += csn_result / csn_num
all_num += 1
if csn_ccr_num > 0:
print('Using CodeSearchNet-ccr')
all_result += csn_ccr_result / csn_ccr_num
all_num += 1
new_results = {}
for k in all_results:
if k in pop_keys:
continue
new_results[k] = all_results[k]['NDCG']['NDCG@10']
if csn_num > 0:
new_results['CodeSearchNet'] = csn_result / csn_num
if csn_ccr_num > 0:
new_results['CodeSearchNet_CCR'] = csn_ccr_result / csn_ccr_num
new_results['all'] = all_result / all_num
print(new_results)
with open(os.path.join(output_folder, 'OVERALL-results.json'), 'w') as f:
json.dump(new_results, f)
if __name__ == "__main__":
parser = HfArgumentParser((
COIREvalArgs,
COIREvalModelArgs
))
eval_args, model_args = parser.parse_args_into_dataclasses()
main(eval_args, model_args)