"""Response Parsing and Evaluation for various models""" import argparse import dataclasses import json import os import pprint import random import re from typing import Dict, Optional import numpy as np import torch from data_utils import ( CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT, construct_prompt, load_yaml, process_single_sample, ) from datasets import concatenate_datasets, load_dataset from tqdm import tqdm @dataclasses.dataclass class EvalArgs: seed: int = 42 split: str = "validation" # Default setting to make the benchmark available on A100 for most 7B models image_pixels_limit: int = 4300000 result_filename: str = "" prompt_format_file: str = "prompt_format.yaml" dataset_path: str = "MMMU/MMMU" extra_request_body: Optional[str] = None @staticmethod def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--result-filename", type=str, default=EvalArgs.result_filename ) parser.add_argument( "--image-pixels-limit", type=int, default=EvalArgs.image_pixels_limit ) parser.add_argument( "--dataset-path", type=str, default=EvalArgs.dataset_path, help="path to the dataset", ) parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( "--prompt-format-file", type=str, help="The path to the prompt format of mmmu. If not, a default format llava_config.yaml will be used", ) parser.add_argument("--split", type=str, default=EvalArgs.split) parser.add_argument( "--extra-request-body", metavar='{"key1": "value1", "key2": "value2"}', type=str, default=EvalArgs.extra_request_body, help="Append given JSON object to the request payload. You can use this to specify" "additional generate params like sampling params.", ) @classmethod def from_cli_args(cls, args: argparse.Namespace): attrs = [attr.name for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs}) def set_seed(seed_value): """ Set the seed for PyTorch (both CPU and CUDA), Python, and NumPy for reproducible results. :param seed_value: An integer value to be used as the seed. """ torch.manual_seed(seed_value) if torch.cuda.is_available(): torch.cuda.manual_seed(seed_value) torch.cuda.manual_seed_all(seed_value) # For multi-GPU setups random.seed(seed_value) np.random.seed(seed_value) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def prepare_samples(eval_args: EvalArgs): print("preparing samples...") # Build prompts set_seed(eval_args.seed) prompt_format_file = ( eval_args.prompt_format_file if eval_args.prompt_format_file is not None else os.path.join(os.path.dirname(__file__), "prompt_format.yaml") ) # load config and process to one value eval_args.config = load_yaml(prompt_format_file) for key, value in eval_args.config.items(): if key != "eval_params" and type(value) == list: assert len(value) == 1, "key {} has more than one value".format(key) eval_args.config[key] = value[0] # run for each subject sub_dataset_list = [] for subject in tqdm(CAT_SHORT2LONG.values()): sub_dataset = load_dataset( eval_args.dataset_path, subject, split=eval_args.split ) sub_dataset_list.append(sub_dataset) # break # merge all dataset dataset = concatenate_datasets(sub_dataset_list) ## prepare images samples = [] skip_count = 0 # use image file as input to ensure the consistency between sglang and hf images_path = os.path.expanduser("~/.cache/mmmu/images") os.makedirs(images_path, exist_ok=True) print(f"Saving images to: {images_path}") for i, sample in enumerate(tqdm(dataset)): sample = process_single_sample(sample) sample = construct_prompt(sample, eval_args.config) image = sample["image"] width, height = image.size if width * height >= eval_args.image_pixels_limit: skip_count += 1 continue image_path = f"{images_path}/image_{i}.png" if not os.path.exists(image_path): image.save(image_path) sample["image_path"] = image_path samples.append(sample) print( f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset" ) print("samples have been prepared") return samples def get_sampling_params(eval_args): max_new_tokens = 30 temperature = 0.001 extra_request_body = {} if eval_args.extra_request_body: extra_request_body = json.loads(eval_args.extra_request_body) return { "temperature": temperature, "max_new_tokens": max_new_tokens, **extra_request_body, } # ----------- Process Multi-choice ------------- def parse_multi_choice_response(response, all_choices, index2ans): """ Parse the prediction from the generated response. Return the predicted index e.g., A, B, C, D. """ for char in [",", ".", "!", "?", ";", ":", "'"]: response = response.strip(char) response = " " + response + " " # add space to avoid partial match index_ans = True ans_with_brack = False candidates = [] for choice in all_choices: # e.g., (A) (B) (C) (D) if f"({choice})" in response: candidates.append(choice) ans_with_brack = True if len(candidates) == 0: for choice in all_choices: # e.g., A B C D if f" {choice} " in response: candidates.append(choice) # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example if len(candidates) == 0 and len(response.split()) > 5: for index, ans in index2ans.items(): if ans.lower() in response.lower(): candidates.append(index) index_ans = False # it's content ans. if len(candidates) == 0: # still not get answer, randomly choose one. pred_index = random.choice(all_choices) elif len(candidates) > 1: start_indexes = [] if index_ans: if ans_with_brack: for can in candidates: index = response.rfind(f"({can})") start_indexes.append(index) # -1 will be ignored anyway # start_indexes = [generated_response.index(f'({can})') for can in candidates] else: for can in candidates: index = response.rfind(f" {can} ") start_indexes.append(index) else: for can in candidates: index = response.lower().rfind(index2ans[can].lower()) start_indexes.append(index) # get the last one pred_index = candidates[np.argmax(start_indexes)] else: # if only one candidate, use it. pred_index = candidates[0] return pred_index # ----------- Process Open ------------- def check_is_number(string): """ Check if the given string a number. """ try: float(string.replace(",", "")) return True except ValueError: # check if there's comma inside return False def normalize_str(string): """ Normalize the str to lower case and make them float numbers if possible. """ # check if characters in the string # if number, numerize it. string = string.strip() is_number = check_is_number(string) if is_number: string = string.replace(",", "") string = float(string) # leave 2 decimal string = round(string, 2) return [string] else: # it's likely to be a string # lower it string = string.lower() if len(string) == 1: return [" " + string, string + " "] # avoid trivial matches return [string] def extract_numbers(string): """ Exact all forms of numbers from a string with regex. """ # Pattern for numbers with commas pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b" # Pattern for scientific notation pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+" # Pattern for simple numbers without commas pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])" # Extract numbers with commas numbers_with_commas = re.findall(pattern_commas, string) # Extract numbers in scientific notation numbers_scientific = re.findall(pattern_scientific, string) # Extract simple numbers without commas numbers_simple = re.findall(pattern_simple, string) # Combine all extracted numbers all_numbers = numbers_with_commas + numbers_scientific + numbers_simple return all_numbers def parse_open_response(response): """ Parse the prediction from the generated response. Return a list of predicted strings or numbers. """ # content = content.strip("\n").strip(".").strip(" ") def get_key_subresponses(response): key_responses = [] response = response.strip().strip(".").lower() sub_responses = re.split(r"\.\s(?=[A-Z])|\n", response) indicators_of_keys = [ "could be ", "so ", "is ", "thus ", "therefore ", "final ", "answer ", "result ", ] key_responses = [] for index, resp in enumerate(sub_responses): # if last one, accept it's an equation (the entire response can be just one sentence with equation) if index == len(sub_responses) - 1: indicators_of_keys.extend(["="]) shortest_key_response = None # the shortest response that may contain the answer (tail part of the response) for indicator in indicators_of_keys: if indicator in resp: if not shortest_key_response: shortest_key_response = resp.split(indicator)[-1].strip() else: if len(resp.split(indicator)[-1].strip()) < len( shortest_key_response ): shortest_key_response = resp.split(indicator)[-1].strip() # key_responses.append(resp.split(indicator)[1].strip()) if shortest_key_response: # and it's not trivial if shortest_key_response.strip() not in [ ":", ",", ".", "!", "?", ";", ":", "'", ]: key_responses.append(shortest_key_response) if len(key_responses) == 0: # did not found any return [response] return key_responses # pdb.set_trace() key_responses = get_key_subresponses(response) pred_list = key_responses.copy() # keep the original string response for resp in key_responses: pred_list.extend(extract_numbers(resp)) tmp_pred_list = [] for i in range(len(pred_list)): tmp_pred_list.extend(normalize_str(pred_list[i])) pred_list = tmp_pred_list # remove duplicates pred_list = list(set(pred_list)) return pred_list # ----------- Evaluation ------------- def eval_multi_choice(gold_i, pred_i): """ Evaluate a multiple choice instance. """ correct = False # only they are exactly the same, we consider it as correct if isinstance(gold_i, list): for answer in gold_i: if answer == pred_i: correct = True break else: # gold_i is a string if gold_i == pred_i: correct = True return correct def eval_open(gold_i, pred_i): """ Evaluate an open question instance """ correct = False if isinstance(gold_i, list): # use float to avoid trivial matches norm_answers = [] for answer in gold_i: norm_answers.extend(normalize_str(answer)) else: norm_answers = normalize_str(gold_i) for pred in pred_i: # pred is already normalized in parse response phase if isinstance(pred, str): # if it's a string, then find if ans in the pred_i for norm_ans in norm_answers: # only see if the string answer in the string pred if isinstance(norm_ans, str) and norm_ans in pred: if not correct: correct = True break else: # it's a float number if pred in norm_answers: if not correct: correct = True break return correct # ----------- Batch Evaluation ------------- def evaluate(samples): """ Batch evaluation for multiple choice and open questions. """ pred_correct = 0 judge_dict = dict() for sample in samples: gold_i = sample["answer"] pred_i = sample["parsed_pred"] if sample["question_type"] == "multiple-choice": correct = eval_multi_choice(gold_i, pred_i) else: # open question correct = eval_open(gold_i, pred_i) if correct: judge_dict[sample["id"]] = "Correct" pred_correct += 1 else: # print(f"Wrong! expected {pred_i}, answered with {gold_i}") judge_dict[sample["id"]] = "Wrong" if len(samples) == 0: return {"acc": 0} return judge_dict, {"acc": pred_correct / len(samples)} # ----------- Calculate Accuracy ------------- def calculate_ins_level_acc(results: Dict): """Calculate the instruction level accuracy for given Subject results""" acc = 0 ins_num = 0 for cat_results in results.values(): acc += cat_results["acc"] * cat_results["num_example"] ins_num += cat_results["num_example"] if ins_num == 0: return 0 return acc / ins_num def process_result(response, sample, answer_dict, out_samples): if sample["question_type"] == "multiple-choice": pred_ans = parse_multi_choice_response( response, sample["all_choices"], sample["index2ans"] ) else: # open question pred_ans = response out_samples[sample["id"]] = pred_ans # set ground truth answer answer_dict[sample["id"]] = { "question_type": sample["question_type"], "ground_truth": sample["answer"], } def eval_result(model_answer_path, answer_dict): print("Evaluating...") output_dict = json.load(open(model_answer_path)) # answer_dict = json.load(open(answer_path)) # group by category output_dict_w_cat = {} for data_id, parsed_pred in output_dict.items(): category = "_".join(data_id.split("_")[1:-1]) if category not in output_dict_w_cat: output_dict_w_cat.update({category: {}}) output_dict_w_cat[category].update({data_id: parsed_pred}) # group by category answer_dict_w_cat = {} for data_id, parsed_pred in answer_dict.items(): category = "_".join(data_id.split("_")[1:-1]) if category not in answer_dict_w_cat: answer_dict_w_cat.update({category: {}}) answer_dict_w_cat[category].update({data_id: parsed_pred}) evaluation_result = {} for category in CAT_SHORT2LONG.values(): # print("Evaluating: {}".format(category)) # get cat_outputs and cat_answers try: cat_outputs = output_dict_w_cat[category] cat_answers = answer_dict_w_cat[category] except KeyError: # print("Skipping {} for not found".format(category)) continue exampels_to_eval = [] for data_id, parsed_pred in cat_outputs.items(): question_type = cat_answers[data_id]["question_type"] if question_type != "multiple-choice": parsed_pred = parse_open_response( parsed_pred ) # mainly for type consistency (make it number, etc.) else: parsed_pred = parsed_pred exampels_to_eval.append( { "id": data_id, "question_type": question_type, "answer": cat_answers[data_id]["ground_truth"], "parsed_pred": parsed_pred, } ) judge_dict, metric_dict = evaluate(exampels_to_eval) metric_dict.update({"num_example": len(exampels_to_eval)}) evaluation_result[category] = metric_dict printable_results = {} # pdb.set_trace() # add domain Subject for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items(): in_domain_cat_results = {} for cat_name in in_domain_cats: # use the order in DOMAIN_CAT2SUB_CAT if cat_name in evaluation_result.keys(): in_domain_cat_results[cat_name] = evaluation_result[cat_name] else: pass in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results) in_domain_data_num = sum( [ cat_results["num_example"] for cat_results in in_domain_cat_results.values() ] ) printable_results["Overall-" + domain] = { "num": int(in_domain_data_num), "acc": round(in_domain_ins_acc, 3), } # add sub category for cat_name, cat_results in in_domain_cat_results.items(): printable_results[cat_name] = { "num": int(cat_results["num_example"]), "acc": round(cat_results["acc"], 3), } # table.append(["-----------------------------", "-----", "----"]) all_ins_acc = calculate_ins_level_acc(evaluation_result) overall_acc = round(all_ins_acc, 3) printable_results["Overall"] = { "num": sum( [cat_results["num_example"] for cat_results in evaluation_result.values()] ), "acc": overall_acc, } pprint.pprint(printable_results) out = model_answer_path with open(out, "w", encoding="utf-8") as outfile: json.dump(printable_results, outfile) print(f"eval out saved to {out}") print(f"Overall accuracy: {overall_acc}")