168 lines
6.1 KiB
Python
168 lines
6.1 KiB
Python
import os
|
|
import json
|
|
import coir
|
|
from transformers import HfArgumentParser
|
|
|
|
from arguments import COIREvalArgs, COIREvalModelArgs
|
|
from prompts import get_task_def_by_task_name
|
|
from FlagEmbedding import FlagLLMModel, FlagModel
|
|
|
|
|
|
def get_model(model_args: COIREvalModelArgs):
|
|
embedder_name_or_path = model_args.embedder_name_or_path
|
|
|
|
if model_args.embedder_model_class == "encoder-only-base":
|
|
embedder = FlagModel(
|
|
model_name_or_path=embedder_name_or_path,
|
|
normalize_embeddings=model_args.normalize_embeddings,
|
|
pooling_method=model_args.pooling_method,
|
|
use_fp16=model_args.use_fp16,
|
|
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
|
|
query_instruction_format=model_args.query_instruction_format_for_retrieval,
|
|
devices=model_args.devices,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
cache_dir=model_args.cache_dir,
|
|
batch_size=model_args.embedder_batch_size,
|
|
query_max_length=model_args.embedder_query_max_length,
|
|
passage_max_length=model_args.embedder_passage_max_length,
|
|
)
|
|
elif model_args.embedder_model_class == "decoder-only-base":
|
|
embedder = FlagLLMModel(
|
|
model_name_or_path=embedder_name_or_path,
|
|
normalize_embeddings=model_args.normalize_embeddings,
|
|
pooling_method=model_args.pooling_method,
|
|
use_fp16=model_args.use_fp16,
|
|
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
|
|
query_instruction_format=model_args.query_instruction_format_for_retrieval,
|
|
devices=model_args.devices,
|
|
examples_for_task=model_args.examples_for_task,
|
|
examples_instruction_format=model_args.examples_instruction_format,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
cache_dir=model_args.cache_dir,
|
|
batch_size=model_args.embedder_batch_size,
|
|
query_max_length=model_args.embedder_query_max_length,
|
|
passage_max_length=model_args.embedder_passage_max_length,
|
|
)
|
|
else:
|
|
raise ValueError(f"Invalid model class: {model_args.embedder_model_class}")
|
|
embedder.model.config._name_or_path = model_args.embedder_name_or_path
|
|
|
|
class CustomFlagModel:
|
|
def __init__(self, model):
|
|
self.model = model
|
|
|
|
def encode_queries(self, queries, show_progress_bar, convert_to_tensor, **kwargs):
|
|
if isinstance(queries, str):
|
|
queries = [queries]
|
|
|
|
if isinstance(queries[0], dict):
|
|
queries = [(e.get('title') + ' ' + e['text']).strip() for e in queries]
|
|
|
|
return self.model.encode_queries(queries, **kwargs)
|
|
|
|
def encode_corpus(self, corpus, show_progress_bar, convert_to_tensor, **kwargs):
|
|
if isinstance(corpus, str):
|
|
corpus = [corpus]
|
|
|
|
if isinstance(corpus[0], dict):
|
|
corpus = [(e.get('title') + ' ' + e['text']).strip() for e in corpus]
|
|
|
|
return self.model.encode_corpus(corpus, **kwargs)
|
|
|
|
def encode(self, corpus, show_progress_bar, convert_to_tensor, **kwargs):
|
|
if isinstance(corpus, str):
|
|
corpus = [corpus]
|
|
|
|
if isinstance(corpus[0], dict):
|
|
corpus = [(e.get('title') + ' ' + e['text']).strip() for e in corpus]
|
|
|
|
return self.model.encode(corpus, **kwargs)
|
|
|
|
return CustomFlagModel(embedder)
|
|
|
|
|
|
def main(
|
|
eval_args: COIREvalArgs,
|
|
model_args: COIREvalModelArgs
|
|
):
|
|
model = get_model(model_args)
|
|
|
|
output_folder = os.path.join(eval_args.output_dir, os.path.basename(model.model.model.config._name_or_path))
|
|
|
|
all_task = eval_args.tasks
|
|
if not isinstance(all_task, list):
|
|
all_task = [all_task]
|
|
|
|
all_results = {}
|
|
for task_name in all_task:
|
|
save_path = os.path.join(output_folder, f"{task_name}.json")
|
|
if os.path.exists(save_path):
|
|
with open(save_path, "r", encoding="utf-8") as f:
|
|
results = json.load(f)
|
|
all_results[task_name] = results['metrics']
|
|
continue
|
|
|
|
tmp_task = coir.get_tasks(tasks=[task_name])
|
|
evaluation = coir.COIR(tasks=tmp_task,
|
|
batch_size=model_args.embedder_batch_size)
|
|
|
|
model.model.stop_self_pool()
|
|
|
|
if eval_args.use_special_instructions:
|
|
model.model.query_instruction_for_retrieval = get_task_def_by_task_name(task_name)
|
|
|
|
results = evaluation.run(model, output_folder=output_folder)
|
|
all_results[task_name] = results[task_name]
|
|
|
|
csn_result = 0
|
|
csn_num = 0
|
|
csn_ccr_result = 0
|
|
csn_ccr_num = 0
|
|
pop_keys = []
|
|
all_result = 0
|
|
all_num = 0
|
|
for k in all_results.keys():
|
|
if 'CodeSearchNet-ccr' in k:
|
|
csn_ccr_result += all_results[k]['NDCG']['NDCG@10']
|
|
csn_ccr_num += 1
|
|
pop_keys.append(k)
|
|
elif 'CodeSearchNet' in k:
|
|
csn_result += all_results[k]['NDCG']['NDCG@10']
|
|
csn_num += 1
|
|
pop_keys.append(k)
|
|
else:
|
|
all_result += all_results[k]['NDCG']['NDCG@10']
|
|
all_num += 1
|
|
if csn_num > 0:
|
|
print('Using CodeSearchNet')
|
|
all_result += csn_result / csn_num
|
|
all_num += 1
|
|
if csn_ccr_num > 0:
|
|
print('Using CodeSearchNet-ccr')
|
|
all_result += csn_ccr_result / csn_ccr_num
|
|
all_num += 1
|
|
new_results = {}
|
|
for k in all_results:
|
|
if k in pop_keys:
|
|
continue
|
|
new_results[k] = all_results[k]['NDCG']['NDCG@10']
|
|
if csn_num > 0:
|
|
new_results['CodeSearchNet'] = csn_result / csn_num
|
|
if csn_ccr_num > 0:
|
|
new_results['CodeSearchNet_CCR'] = csn_ccr_result / csn_ccr_num
|
|
new_results['all'] = all_result / all_num
|
|
|
|
print(new_results)
|
|
|
|
with open(os.path.join(output_folder, 'OVERALL-results.json'), 'w') as f:
|
|
json.dump(new_results, f)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = HfArgumentParser((
|
|
COIREvalArgs,
|
|
COIREvalModelArgs
|
|
))
|
|
eval_args, model_args = parser.parse_args_into_dataclasses()
|
|
main(eval_args, model_args)
|