evalscope_v0.17.0/evalscope.0.17.0/evalscope/metrics/math_parser.py

527 lines
17 KiB
Python

"""
The logic in this file largely borrows from Qwen2.5-Math codebase at https://github.com/QwenLM/Qwen2.5-Math:
"""
# flake8: noqa
import re
import regex
from latex2sympy2_extended import latex2sympy
from math import isclose
from sympy import N, simplify
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
from word2number import w2n
def convert_word_number(text: str) -> str:
try:
text = str(w2n.word_to_num(text))
except Exception:
pass
return text
def _fix_fracs(string):
substrs = string.split('\\frac')
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += '\\frac'
if len(substr) > 0 and substr[0] == '{':
new_str += substr
else:
try:
assert len(substr) >= 2
except Exception:
return string
a = substr[0]
b = substr[1]
if b != '{':
if len(substr) > 2:
post_substr = substr[2:]
new_str += '{' + a + '}{' + b + '}' + post_substr
else:
new_str += '{' + a + '}{' + b + '}'
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += '{' + a + '}' + b + post_substr
else:
new_str += '{' + a + '}' + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split('/')) != 2:
return string
a = string.split('/')[0]
b = string.split('/')[1]
try:
if 'sqrt' not in a:
a = int(a)
if 'sqrt' not in b:
b = int(b)
assert string == '{}/{}'.format(a, b)
new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
return new_string
except Exception:
return string
def _fix_sqrt(string):
_string = re.sub(r'\\sqrt(\w+)', r'\\sqrt{\1}', string)
return _string
def strip_answer_string(string):
string = str(string).strip()
# linebreaks
string = string.replace('\n', '')
# right "."
string = string.rstrip('.')
# remove inverse spaces
# replace \\ with \
string = string.replace('\\!', '')
# string = string.replace("\\ ", "")
# string = string.replace("\\\\", "\\")
# matrix
string = re.sub(r'\\begin\{array\}\{.*?\}', r'\\begin{pmatrix}', string)
string = re.sub(r'\\end\{array\}', r'\\end{pmatrix}', string)
string = string.replace('bmatrix', 'pmatrix')
# replace tfrac and dfrac with frac
string = string.replace('tfrac', 'frac')
string = string.replace('dfrac', 'frac')
string = (string.replace('\\neq', '\\ne').replace('\\leq', '\\le').replace('\\geq', '\\ge'))
# remove \left and \right
string = string.replace('\\left', '')
string = string.replace('\\right', '')
string = string.replace('\\{', '{')
string = string.replace('\\}', '}')
# Function to replace number words with corresponding digits
def replace_match(match):
word = match.group(1).lower()
if convert_word_number(word) == word:
return match.group(0)
else:
return convert_word_number(word)
string = re.sub(r'\\text\{([a-zA-Z]+)\}', replace_match, string)
# Before removing unit, check if the unit is squared (for surface area)
string = re.sub(r'(cm|inches)\}\^2', r'\1}', string)
# Remove unit: miles, dollars if after is not none
_string = re.sub(r'\\text{.*?}$', '', string).strip()
if _string != '' and _string != string:
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
string = _string
# Remove circ (degrees)
string = string.replace('^{\\circ}', '')
string = string.replace('^\\circ', '')
# remove dollar signs
string = string.replace('\\$', '')
string = string.replace('$', '')
string = string.replace('\\(', '').replace('\\)', '')
# convert word number to digit
string = convert_word_number(string)
# replace "\\text{...}" to "..."
string = re.sub(r'\\text\{(.*?)\}', r'\1', string)
for key in ['x=', 'y=', 'z=', 'x\\in', 'y\\in', 'z\\in', 'x\\to', 'y\\to', 'z\\to']:
string = string.replace(key, '')
string = string.replace('\\emptyset', r'{}')
string = string.replace('(-\\infty,\\infty)', '\\mathbb{R}')
# remove percentage
string = string.replace('\\%', '')
string = string.replace('\%', '')
string = string.replace('%', '')
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(' .', ' 0.')
string = string.replace('{.', '{0.')
# cdot
# string = string.replace("\\cdot", "")
if (string.startswith('{') and string.endswith('}') and string.isalnum()
or string.startswith('(') and string.endswith(')') and string.isalnum()
or string.startswith('[') and string.endswith(']') and string.isalnum()):
string = string[1:-1]
# inf
string = string.replace('infinity', '\\infty')
if '\\infty' not in string:
string = string.replace('inf', '\\infty')
string = string.replace('+\\inity', '\\infty')
# and
string = string.replace('and', '')
string = string.replace('\\mathbf', '')
# use regex to remove \mbox{...}
string = re.sub(r'\\mbox{.*?}', '', string)
# quote
string.replace("'", '')
string.replace('"', '')
# i, j
if 'j' in string and 'i' not in string:
string = string.replace('j', 'i')
# replace a.000b where b is not number or b is end, with ab, use regex
string = re.sub(r'(\d+)\.0*([^\d])', r'\1\2', string)
string = re.sub(r'(\d+)\.0*$', r'\1', string)
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == '.':
string = '0' + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split('=')) == 2:
if len(string.split('=')[0]) <= 2:
string = string.split('=')[1]
string = _fix_sqrt(string)
string = string.replace(' ', '')
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
# Remove unnecessary '\' before integers
string = re.sub(r'\\(?=\-?\d+(\\|\)|,|\]|$))', '', string)
# Remove grade level (e.g., 12th grade) and just maintain the integer
string = re.sub(r'thgrade$', '', string)
# If the answer is a list of integers (without parenthesis), sort them
if re.fullmatch(r'(\s*-?\d+\s*,)*\s*-?\d+\s*', string):
# Split the string into a list of integers
try:
integer_list = list(map(int, string.split(',')))
except Exception:
integer_list = list(map(int, '-1,-1'.split(',')))
# Sort the list in ascending order
sorted_list = sorted(integer_list)
# Join the sorted list back into a comma-separated string
string = ','.join(map(str, sorted_list))
return string
def extract_answer(pred_str, use_last_number=True):
pred_str = pred_str.replace('\u043a\u0438', '')
if 'final answer is $' in pred_str and '$. I hope' in pred_str:
# minerva_math
tmp = pred_str.split('final answer is $', 1)[1]
pred = tmp.split('$. I hope', 1)[0].strip()
elif 'boxed' in pred_str:
ans = pred_str.split('boxed')[-1]
if len(ans) == 0:
return ''
elif ans[0] == '{':
stack = 1
a = ''
for c in ans[1:]:
if c == '{':
stack += 1
a += c
elif c == '}':
stack -= 1
if stack == 0:
break
a += c
else:
a += c
else:
a = ans.split('$')[0].strip()
pred = a
elif 'he answer is' in pred_str:
pred = pred_str.split('he answer is')[-1].strip()
elif 'final answer is' in pred_str:
pred = pred_str.split('final answer is')[-1].strip()
elif '答案是' in pred_str:
# Handle Chinese few-shot multiple choice problem answer extraction
pred = pred_str.split('答案是')[1].strip().split('\n\n')[0].strip()
else: # use the last number
if use_last_number:
pattern = '-?\d*\.?\d+'
pred = re.findall(pattern, pred_str.replace(',', ''))
if len(pred) >= 1:
pred = pred[-1]
else:
pred = ''
else:
pred = ''
# multiple line
# pred = pred.split("\n")[0]
pred = re.sub(r'\n\s*', '', pred)
if pred != '' and pred[0] == ':':
pred = pred[1:]
if pred != '' and pred[-1] == '.':
pred = pred[:-1]
if pred != '' and pred[-1] == '/':
pred = pred[:-1]
pred = strip_answer_string(pred)
return pred
def choice_answer_clean(pred: str):
pred = pred.strip('\n').rstrip('.').rstrip('/').strip(' ').lstrip(':')
# Clean the answer based on the dataset
tmp = re.findall(r'\b(A|B|C|D|E)\b', pred.upper())
if tmp:
pred = tmp
else:
pred = [pred.strip().strip('.')]
pred = pred[-1]
# Remove the period at the end, again!
pred = pred.rstrip('.').rstrip('/')
return pred
def parse_digits(num):
num = regex.sub(',', '', str(num))
try:
return float(num)
except Exception:
if num.endswith('%'):
num = num[:-1]
if num.endswith('\\'):
num = num[:-1]
try:
return float(num) / 100
except Exception:
pass
return None
def is_digit(num):
# paired with parse_digits
return parse_digits(num) is not None
def str_to_pmatrix(input_str):
input_str = input_str.strip()
matrix_str = re.findall(r'\{.*,.*\}', input_str)
pmatrix_list = []
for m in matrix_str:
m = m.strip('{}')
pmatrix = r'\begin{pmatrix}' + m.replace(',', '\\') + r'\end{pmatrix}'
pmatrix_list.append(pmatrix)
return ', '.join(pmatrix_list)
def math_equal(
prediction,
reference,
include_percentage: bool = True,
is_close: bool = True,
timeout: bool = False,
) -> bool:
"""
Exact match of math if and only if:
1. numerical equal: both can convert to float and are equal
2. symbolic equal: both can convert to sympy expression and are equal
"""
if prediction is None or reference is None:
return False
if str(prediction.strip().lower()) == str(reference.strip().lower()):
return True
if (reference in ['A', 'B', 'C', 'D', 'E'] and choice_answer_clean(prediction) == reference):
return True
try: # 1. numerical equal
if is_digit(prediction) and is_digit(reference):
prediction = parse_digits(prediction)
reference = parse_digits(reference)
# number questions
if include_percentage:
gt_result = [reference / 100, reference, reference * 100]
else:
gt_result = [reference]
for item in gt_result:
try:
if is_close:
if numeric_equal(prediction, item):
return True
else:
if item == prediction:
return True
except Exception:
continue
return False
except Exception:
pass
if not prediction and prediction not in [0, False]:
return False
# 2. symbolic equal
reference = str(reference).strip()
prediction = str(prediction).strip()
## pmatrix (amps)
if 'pmatrix' in prediction and 'pmatrix' not in reference:
reference = str_to_pmatrix(reference)
## deal with [], (), {}
pred_str, ref_str = prediction, reference
if (prediction.startswith('[') and prediction.endswith(']')
and not reference.startswith('(')) or (prediction.startswith('(') and prediction.endswith(')')
and not reference.startswith('[')):
pred_str = pred_str.strip('[]()')
ref_str = ref_str.strip('[]()')
for s in ['{', '}', '(', ')']:
ref_str = ref_str.replace(s, '')
pred_str = pred_str.replace(s, '')
if pred_str.lower() == ref_str.lower():
return True
## [a, b] vs. [c, d], return a==c and b==d
if (regex.match(r'(\(|\[).+(\)|\])', prediction) is not None
and regex.match(r'(\(|\[).+(\)|\])', reference) is not None):
pred_parts = prediction[1:-1].split(',')
ref_parts = reference[1:-1].split(',')
if len(pred_parts) == len(ref_parts):
if all(
[math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close)
for i in range(len(pred_parts))]):
return True
if ((prediction.startswith('\\begin{pmatrix}') or prediction.startswith('\\begin{bmatrix}'))
and (prediction.endswith('\\end{pmatrix}') or prediction.endswith('\\end{bmatrix}'))
and (reference.startswith('\\begin{pmatrix}') or reference.startswith('\\begin{bmatrix}'))
and (reference.endswith('\\end{pmatrix}') or reference.endswith('\\end{bmatrix}'))):
pred_lines = [
line.strip() for line in prediction[len('\\begin{pmatrix}'):-len('\\end{pmatrix}')].split('\\\\')
if line.strip()
]
ref_lines = [
line.strip() for line in reference[len('\\begin{pmatrix}'):-len('\\end{pmatrix}')].split('\\\\')
if line.strip()
]
matched = True
if len(pred_lines) == len(ref_lines):
for pred_line, ref_line in zip(pred_lines, ref_lines):
pred_parts = pred_line.split('&')
ref_parts = ref_line.split('&')
if len(pred_parts) == len(ref_parts):
if not all([
math_equal(
pred_parts[i],
ref_parts[i],
include_percentage,
is_close,
) for i in range(len(pred_parts))
]):
matched = False
break
else:
matched = False
if not matched:
break
else:
matched = False
if matched:
return True
if prediction.count('=') == 1 and reference.count('=') == 1:
pred = prediction.split('=')
pred = f'{pred[0].strip()} - ({pred[1].strip()})'
ref = reference.split('=')
ref = f'{ref[0].strip()} - ({ref[1].strip()})'
if symbolic_equal(pred, ref) or symbolic_equal(f'-({pred})', ref):
return True
elif (prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference):
if math_equal(prediction.split('=')[1], reference, include_percentage, is_close):
return True
elif (reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction):
if math_equal(prediction, reference.split('=')[1], include_percentage, is_close):
return True
if symbolic_equal(prediction, reference):
return True
return False
def numeric_equal(prediction: float, reference: float):
return isclose(reference, prediction, rel_tol=1e-4)
def symbolic_equal(a, b):
def _parse(s):
for f in [parse_latex, parse_expr, latex2sympy]:
try:
return f(s.replace('\\\\', '\\'))
except Exception:
try:
return f(s)
except Exception:
pass
return s
a = _parse(a)
b = _parse(b)
# direct equal
try:
if str(a) == str(b) or a == b:
return True
except Exception:
pass
# simplify equal
try:
if a.equals(b) or simplify(a - b) == 0:
return True
except Exception:
pass
# equation equal
try:
if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
return True
except Exception:
pass
try:
if numeric_equal(float(N(a)), float(N(b))):
return True
except Exception:
pass
# matrix
try:
# if a and b are matrix
if a.shape == b.shape:
_a = a.applyfunc(lambda x: round(x, 3))
_b = b.applyfunc(lambda x: round(x, 3))
if _a.equals(_b):
return True
except Exception:
pass
return False