# Adapted from https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py from math import isclose import regex from sympy import N, simplify from sympy.parsing.latex import parse_latex from sympy.parsing.sympy_parser import parse_expr def parse_digits(num): # format: 234.23 || 23% num = regex.sub(",", "", str(num)) try: return float(num) except: if num.endswith("%"): num = num[:-1] if num.endswith("\\"): num = num[:-1] try: return float(num) / 100 except: pass return None def is_digit(num): # paired with parse_digits return parse_digits(num) is not None def symbolic_equal(a, b): def _parse(s): for f in [parse_latex, parse_expr]: try: return f(s) except: pass return s a = _parse(a) b = _parse(b) try: if simplify(a - b) == 0: return True except: pass try: if isclose(N(a), N(b), abs_tol=1e-3): return True except: pass return False def math_equal(prediction, reference, include_percentage=True, is_close=True): """ Exact match of math if and only if: 1. numerical equal: both can convert to float and are equal 2. symbolic equal: both can convert to sympy expression and are equal """ if str(prediction) == str(reference): return True try: # 1. numerical equal if is_digit(prediction) and is_digit(reference): prediction = parse_digits(prediction) reference = parse_digits(reference) # number questions if include_percentage: gt_result = [reference / 100, reference, reference * 100] else: gt_result = [reference] for item in gt_result: try: if is_close: if isclose(item, prediction, abs_tol=1e-3): return True else: if item == prediction: return True except Exception: continue return False except: pass if not prediction and prediction not in [0, False]: return False # 2. symbolic equal reference = str(reference).strip() prediction = str(prediction).strip() if ( regex.match(r"(\(|\[).+(\)|\])", prediction) is not None and regex.match(r"(\(|\[).+(\)|\])", reference) is not None ): pred_parts = prediction[1:-1].split(",") ref_parts = reference[1:-1].split(",") if len(pred_parts) == len(ref_parts): if all( [ math_equal( pred_parts[i], ref_parts[i], include_percentage, is_close ) for i in range(len(pred_parts)) ] ): return True # Add back matrix comparison if ( ( prediction.startswith("\\begin{pmatrix}") or prediction.startswith("\\begin{bmatrix}") ) and ( prediction.endswith("\\end{pmatrix}") or prediction.endswith("\\end{bmatrix}") ) and ( reference.startswith("\\begin{pmatrix}") or reference.startswith("\\begin{bmatrix}") ) and ( reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}") ) ): pred_lines = [ line.strip() for line in prediction[ len("\\begin{pmatrix}") : -len("\\end{pmatrix}") ].split("\\\\") if line.strip() ] ref_lines = [ line.strip() for line in reference[ len("\\begin{pmatrix}") : -len("\\end{pmatrix}") ].split("\\\\") if line.strip() ] matched = True if len(pred_lines) == len(ref_lines): for pred_line, ref_line in zip(pred_lines, ref_lines): pred_parts = pred_line.split("&") ref_parts = ref_line.split("&") if len(pred_parts) == len(ref_parts): if not all( [ math_equal( pred_parts[i], ref_parts[i], include_percentage, is_close, ) for i in range(len(pred_parts)) ] ): matched = False break else: matched = False if not matched: break else: matched = False if matched: return True # Add back equation comparison if prediction.count("=") == 1 and reference.count("=") == 1: pred = prediction.split("=") pred = f"{pred[0].strip()} - ({pred[1].strip()})" ref = reference.split("=") ref = f"{ref[0].strip()} - ({ref[1].strip()})" if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): return True elif ( prediction.count("=") == 1 and len(prediction.split("=")[0].strip()) <= 2 and "=" not in reference ): if math_equal( prediction.split("=")[1], reference, include_percentage, is_close ): return True elif ( reference.count("=") == 1 and len(reference.split("=")[0].strip()) <= 2 and "=" not in prediction ): if math_equal( prediction, reference.split("=")[1], include_percentage, is_close ): return True # symbolic equal with sympy if symbolic_equal(prediction, reference): return True return False