527 lines
17 KiB
Python
527 lines
17 KiB
Python
"""
|
|
The logic in this file largely borrows from Qwen2.5-Math codebase at https://github.com/QwenLM/Qwen2.5-Math:
|
|
"""
|
|
# flake8: noqa
|
|
import re
|
|
import regex
|
|
from latex2sympy2_extended import latex2sympy
|
|
from math import isclose
|
|
from sympy import N, simplify
|
|
from sympy.parsing.latex import parse_latex
|
|
from sympy.parsing.sympy_parser import parse_expr
|
|
from word2number import w2n
|
|
|
|
|
|
def convert_word_number(text: str) -> str:
|
|
try:
|
|
text = str(w2n.word_to_num(text))
|
|
except Exception:
|
|
pass
|
|
return text
|
|
|
|
|
|
def _fix_fracs(string):
|
|
substrs = string.split('\\frac')
|
|
new_str = substrs[0]
|
|
if len(substrs) > 1:
|
|
substrs = substrs[1:]
|
|
for substr in substrs:
|
|
new_str += '\\frac'
|
|
if len(substr) > 0 and substr[0] == '{':
|
|
new_str += substr
|
|
else:
|
|
try:
|
|
assert len(substr) >= 2
|
|
except Exception:
|
|
return string
|
|
a = substr[0]
|
|
b = substr[1]
|
|
if b != '{':
|
|
if len(substr) > 2:
|
|
post_substr = substr[2:]
|
|
new_str += '{' + a + '}{' + b + '}' + post_substr
|
|
else:
|
|
new_str += '{' + a + '}{' + b + '}'
|
|
else:
|
|
if len(substr) > 2:
|
|
post_substr = substr[2:]
|
|
new_str += '{' + a + '}' + b + post_substr
|
|
else:
|
|
new_str += '{' + a + '}' + b
|
|
string = new_str
|
|
return string
|
|
|
|
|
|
def _fix_a_slash_b(string):
|
|
if len(string.split('/')) != 2:
|
|
return string
|
|
a = string.split('/')[0]
|
|
b = string.split('/')[1]
|
|
try:
|
|
if 'sqrt' not in a:
|
|
a = int(a)
|
|
if 'sqrt' not in b:
|
|
b = int(b)
|
|
assert string == '{}/{}'.format(a, b)
|
|
new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
|
|
return new_string
|
|
except Exception:
|
|
return string
|
|
|
|
|
|
def _fix_sqrt(string):
|
|
_string = re.sub(r'\\sqrt(\w+)', r'\\sqrt{\1}', string)
|
|
return _string
|
|
|
|
|
|
def strip_answer_string(string):
|
|
string = str(string).strip()
|
|
# linebreaks
|
|
string = string.replace('\n', '')
|
|
|
|
# right "."
|
|
string = string.rstrip('.')
|
|
|
|
# remove inverse spaces
|
|
# replace \\ with \
|
|
string = string.replace('\\!', '')
|
|
# string = string.replace("\\ ", "")
|
|
# string = string.replace("\\\\", "\\")
|
|
|
|
# matrix
|
|
string = re.sub(r'\\begin\{array\}\{.*?\}', r'\\begin{pmatrix}', string)
|
|
string = re.sub(r'\\end\{array\}', r'\\end{pmatrix}', string)
|
|
string = string.replace('bmatrix', 'pmatrix')
|
|
|
|
# replace tfrac and dfrac with frac
|
|
string = string.replace('tfrac', 'frac')
|
|
string = string.replace('dfrac', 'frac')
|
|
string = (string.replace('\\neq', '\\ne').replace('\\leq', '\\le').replace('\\geq', '\\ge'))
|
|
|
|
# remove \left and \right
|
|
string = string.replace('\\left', '')
|
|
string = string.replace('\\right', '')
|
|
string = string.replace('\\{', '{')
|
|
string = string.replace('\\}', '}')
|
|
|
|
# Function to replace number words with corresponding digits
|
|
def replace_match(match):
|
|
word = match.group(1).lower()
|
|
if convert_word_number(word) == word:
|
|
return match.group(0)
|
|
else:
|
|
return convert_word_number(word)
|
|
|
|
string = re.sub(r'\\text\{([a-zA-Z]+)\}', replace_match, string)
|
|
|
|
# Before removing unit, check if the unit is squared (for surface area)
|
|
string = re.sub(r'(cm|inches)\}\^2', r'\1}', string)
|
|
|
|
# Remove unit: miles, dollars if after is not none
|
|
_string = re.sub(r'\\text{.*?}$', '', string).strip()
|
|
if _string != '' and _string != string:
|
|
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
|
string = _string
|
|
|
|
# Remove circ (degrees)
|
|
string = string.replace('^{\\circ}', '')
|
|
string = string.replace('^\\circ', '')
|
|
|
|
# remove dollar signs
|
|
string = string.replace('\\$', '')
|
|
string = string.replace('$', '')
|
|
string = string.replace('\\(', '').replace('\\)', '')
|
|
|
|
# convert word number to digit
|
|
string = convert_word_number(string)
|
|
|
|
# replace "\\text{...}" to "..."
|
|
string = re.sub(r'\\text\{(.*?)\}', r'\1', string)
|
|
for key in ['x=', 'y=', 'z=', 'x\\in', 'y\\in', 'z\\in', 'x\\to', 'y\\to', 'z\\to']:
|
|
string = string.replace(key, '')
|
|
string = string.replace('\\emptyset', r'{}')
|
|
string = string.replace('(-\\infty,\\infty)', '\\mathbb{R}')
|
|
|
|
# remove percentage
|
|
string = string.replace('\\%', '')
|
|
string = string.replace('\%', '')
|
|
string = string.replace('%', '')
|
|
|
|
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
|
string = string.replace(' .', ' 0.')
|
|
string = string.replace('{.', '{0.')
|
|
|
|
# cdot
|
|
# string = string.replace("\\cdot", "")
|
|
if (string.startswith('{') and string.endswith('}') and string.isalnum()
|
|
or string.startswith('(') and string.endswith(')') and string.isalnum()
|
|
or string.startswith('[') and string.endswith(']') and string.isalnum()):
|
|
string = string[1:-1]
|
|
|
|
# inf
|
|
string = string.replace('infinity', '\\infty')
|
|
if '\\infty' not in string:
|
|
string = string.replace('inf', '\\infty')
|
|
string = string.replace('+\\inity', '\\infty')
|
|
|
|
# and
|
|
string = string.replace('and', '')
|
|
string = string.replace('\\mathbf', '')
|
|
|
|
# use regex to remove \mbox{...}
|
|
string = re.sub(r'\\mbox{.*?}', '', string)
|
|
|
|
# quote
|
|
string.replace("'", '')
|
|
string.replace('"', '')
|
|
|
|
# i, j
|
|
if 'j' in string and 'i' not in string:
|
|
string = string.replace('j', 'i')
|
|
|
|
# replace a.000b where b is not number or b is end, with ab, use regex
|
|
string = re.sub(r'(\d+)\.0*([^\d])', r'\1\2', string)
|
|
string = re.sub(r'(\d+)\.0*$', r'\1', string)
|
|
|
|
# if empty, return empty string
|
|
if len(string) == 0:
|
|
return string
|
|
if string[0] == '.':
|
|
string = '0' + string
|
|
|
|
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
|
if len(string.split('=')) == 2:
|
|
if len(string.split('=')[0]) <= 2:
|
|
string = string.split('=')[1]
|
|
|
|
string = _fix_sqrt(string)
|
|
string = string.replace(' ', '')
|
|
|
|
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
|
string = _fix_fracs(string)
|
|
|
|
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
|
string = _fix_a_slash_b(string)
|
|
|
|
# Remove unnecessary '\' before integers
|
|
string = re.sub(r'\\(?=\-?\d+(\\|\)|,|\]|$))', '', string)
|
|
|
|
# Remove grade level (e.g., 12th grade) and just maintain the integer
|
|
string = re.sub(r'thgrade$', '', string)
|
|
|
|
# If the answer is a list of integers (without parenthesis), sort them
|
|
if re.fullmatch(r'(\s*-?\d+\s*,)*\s*-?\d+\s*', string):
|
|
# Split the string into a list of integers
|
|
try:
|
|
integer_list = list(map(int, string.split(',')))
|
|
except Exception:
|
|
integer_list = list(map(int, '-1,-1'.split(',')))
|
|
|
|
# Sort the list in ascending order
|
|
sorted_list = sorted(integer_list)
|
|
|
|
# Join the sorted list back into a comma-separated string
|
|
string = ','.join(map(str, sorted_list))
|
|
|
|
return string
|
|
|
|
|
|
def extract_answer(pred_str, use_last_number=True):
|
|
pred_str = pred_str.replace('\u043a\u0438', '')
|
|
if 'final answer is $' in pred_str and '$. I hope' in pred_str:
|
|
# minerva_math
|
|
tmp = pred_str.split('final answer is $', 1)[1]
|
|
pred = tmp.split('$. I hope', 1)[0].strip()
|
|
elif 'boxed' in pred_str:
|
|
ans = pred_str.split('boxed')[-1]
|
|
if len(ans) == 0:
|
|
return ''
|
|
elif ans[0] == '{':
|
|
stack = 1
|
|
a = ''
|
|
for c in ans[1:]:
|
|
if c == '{':
|
|
stack += 1
|
|
a += c
|
|
elif c == '}':
|
|
stack -= 1
|
|
if stack == 0:
|
|
break
|
|
a += c
|
|
else:
|
|
a += c
|
|
else:
|
|
a = ans.split('$')[0].strip()
|
|
pred = a
|
|
elif 'he answer is' in pred_str:
|
|
pred = pred_str.split('he answer is')[-1].strip()
|
|
elif 'final answer is' in pred_str:
|
|
pred = pred_str.split('final answer is')[-1].strip()
|
|
elif '答案是' in pred_str:
|
|
# Handle Chinese few-shot multiple choice problem answer extraction
|
|
pred = pred_str.split('答案是')[1].strip().split('\n\n')[0].strip()
|
|
else: # use the last number
|
|
if use_last_number:
|
|
pattern = '-?\d*\.?\d+'
|
|
pred = re.findall(pattern, pred_str.replace(',', ''))
|
|
if len(pred) >= 1:
|
|
pred = pred[-1]
|
|
else:
|
|
pred = ''
|
|
else:
|
|
pred = ''
|
|
|
|
# multiple line
|
|
# pred = pred.split("\n")[0]
|
|
pred = re.sub(r'\n\s*', '', pred)
|
|
if pred != '' and pred[0] == ':':
|
|
pred = pred[1:]
|
|
if pred != '' and pred[-1] == '.':
|
|
pred = pred[:-1]
|
|
if pred != '' and pred[-1] == '/':
|
|
pred = pred[:-1]
|
|
pred = strip_answer_string(pred)
|
|
return pred
|
|
|
|
|
|
def choice_answer_clean(pred: str):
|
|
pred = pred.strip('\n').rstrip('.').rstrip('/').strip(' ').lstrip(':')
|
|
# Clean the answer based on the dataset
|
|
tmp = re.findall(r'\b(A|B|C|D|E)\b', pred.upper())
|
|
if tmp:
|
|
pred = tmp
|
|
else:
|
|
pred = [pred.strip().strip('.')]
|
|
pred = pred[-1]
|
|
# Remove the period at the end, again!
|
|
pred = pred.rstrip('.').rstrip('/')
|
|
return pred
|
|
|
|
|
|
def parse_digits(num):
|
|
num = regex.sub(',', '', str(num))
|
|
try:
|
|
return float(num)
|
|
except Exception:
|
|
if num.endswith('%'):
|
|
num = num[:-1]
|
|
if num.endswith('\\'):
|
|
num = num[:-1]
|
|
try:
|
|
return float(num) / 100
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
def is_digit(num):
|
|
# paired with parse_digits
|
|
return parse_digits(num) is not None
|
|
|
|
|
|
def str_to_pmatrix(input_str):
|
|
input_str = input_str.strip()
|
|
matrix_str = re.findall(r'\{.*,.*\}', input_str)
|
|
pmatrix_list = []
|
|
|
|
for m in matrix_str:
|
|
m = m.strip('{}')
|
|
pmatrix = r'\begin{pmatrix}' + m.replace(',', '\\') + r'\end{pmatrix}'
|
|
pmatrix_list.append(pmatrix)
|
|
|
|
return ', '.join(pmatrix_list)
|
|
|
|
|
|
def math_equal(
|
|
prediction,
|
|
reference,
|
|
include_percentage: bool = True,
|
|
is_close: bool = True,
|
|
timeout: bool = False,
|
|
) -> bool:
|
|
"""
|
|
Exact match of math if and only if:
|
|
1. numerical equal: both can convert to float and are equal
|
|
2. symbolic equal: both can convert to sympy expression and are equal
|
|
"""
|
|
if prediction is None or reference is None:
|
|
return False
|
|
if str(prediction.strip().lower()) == str(reference.strip().lower()):
|
|
return True
|
|
if (reference in ['A', 'B', 'C', 'D', 'E'] and choice_answer_clean(prediction) == reference):
|
|
return True
|
|
|
|
try: # 1. numerical equal
|
|
if is_digit(prediction) and is_digit(reference):
|
|
prediction = parse_digits(prediction)
|
|
reference = parse_digits(reference)
|
|
# number questions
|
|
if include_percentage:
|
|
gt_result = [reference / 100, reference, reference * 100]
|
|
else:
|
|
gt_result = [reference]
|
|
for item in gt_result:
|
|
try:
|
|
if is_close:
|
|
if numeric_equal(prediction, item):
|
|
return True
|
|
else:
|
|
if item == prediction:
|
|
return True
|
|
except Exception:
|
|
continue
|
|
return False
|
|
except Exception:
|
|
pass
|
|
|
|
if not prediction and prediction not in [0, False]:
|
|
return False
|
|
|
|
# 2. symbolic equal
|
|
reference = str(reference).strip()
|
|
prediction = str(prediction).strip()
|
|
|
|
## pmatrix (amps)
|
|
if 'pmatrix' in prediction and 'pmatrix' not in reference:
|
|
reference = str_to_pmatrix(reference)
|
|
|
|
## deal with [], (), {}
|
|
pred_str, ref_str = prediction, reference
|
|
if (prediction.startswith('[') and prediction.endswith(']')
|
|
and not reference.startswith('(')) or (prediction.startswith('(') and prediction.endswith(')')
|
|
and not reference.startswith('[')):
|
|
pred_str = pred_str.strip('[]()')
|
|
ref_str = ref_str.strip('[]()')
|
|
for s in ['{', '}', '(', ')']:
|
|
ref_str = ref_str.replace(s, '')
|
|
pred_str = pred_str.replace(s, '')
|
|
if pred_str.lower() == ref_str.lower():
|
|
return True
|
|
|
|
## [a, b] vs. [c, d], return a==c and b==d
|
|
if (regex.match(r'(\(|\[).+(\)|\])', prediction) is not None
|
|
and regex.match(r'(\(|\[).+(\)|\])', reference) is not None):
|
|
pred_parts = prediction[1:-1].split(',')
|
|
ref_parts = reference[1:-1].split(',')
|
|
if len(pred_parts) == len(ref_parts):
|
|
if all(
|
|
[math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close)
|
|
for i in range(len(pred_parts))]):
|
|
return True
|
|
if ((prediction.startswith('\\begin{pmatrix}') or prediction.startswith('\\begin{bmatrix}'))
|
|
and (prediction.endswith('\\end{pmatrix}') or prediction.endswith('\\end{bmatrix}'))
|
|
and (reference.startswith('\\begin{pmatrix}') or reference.startswith('\\begin{bmatrix}'))
|
|
and (reference.endswith('\\end{pmatrix}') or reference.endswith('\\end{bmatrix}'))):
|
|
pred_lines = [
|
|
line.strip() for line in prediction[len('\\begin{pmatrix}'):-len('\\end{pmatrix}')].split('\\\\')
|
|
if line.strip()
|
|
]
|
|
ref_lines = [
|
|
line.strip() for line in reference[len('\\begin{pmatrix}'):-len('\\end{pmatrix}')].split('\\\\')
|
|
if line.strip()
|
|
]
|
|
matched = True
|
|
if len(pred_lines) == len(ref_lines):
|
|
for pred_line, ref_line in zip(pred_lines, ref_lines):
|
|
pred_parts = pred_line.split('&')
|
|
ref_parts = ref_line.split('&')
|
|
if len(pred_parts) == len(ref_parts):
|
|
if not all([
|
|
math_equal(
|
|
pred_parts[i],
|
|
ref_parts[i],
|
|
include_percentage,
|
|
is_close,
|
|
) for i in range(len(pred_parts))
|
|
]):
|
|
matched = False
|
|
break
|
|
else:
|
|
matched = False
|
|
if not matched:
|
|
break
|
|
else:
|
|
matched = False
|
|
if matched:
|
|
return True
|
|
|
|
if prediction.count('=') == 1 and reference.count('=') == 1:
|
|
pred = prediction.split('=')
|
|
pred = f'{pred[0].strip()} - ({pred[1].strip()})'
|
|
ref = reference.split('=')
|
|
ref = f'{ref[0].strip()} - ({ref[1].strip()})'
|
|
if symbolic_equal(pred, ref) or symbolic_equal(f'-({pred})', ref):
|
|
return True
|
|
elif (prediction.count('=') == 1 and len(prediction.split('=')[0].strip()) <= 2 and '=' not in reference):
|
|
if math_equal(prediction.split('=')[1], reference, include_percentage, is_close):
|
|
return True
|
|
elif (reference.count('=') == 1 and len(reference.split('=')[0].strip()) <= 2 and '=' not in prediction):
|
|
if math_equal(prediction, reference.split('=')[1], include_percentage, is_close):
|
|
return True
|
|
|
|
if symbolic_equal(prediction, reference):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def numeric_equal(prediction: float, reference: float):
|
|
return isclose(reference, prediction, rel_tol=1e-4)
|
|
|
|
|
|
def symbolic_equal(a, b):
|
|
|
|
def _parse(s):
|
|
for f in [parse_latex, parse_expr, latex2sympy]:
|
|
try:
|
|
return f(s.replace('\\\\', '\\'))
|
|
except Exception:
|
|
try:
|
|
return f(s)
|
|
except Exception:
|
|
pass
|
|
return s
|
|
|
|
a = _parse(a)
|
|
b = _parse(b)
|
|
|
|
# direct equal
|
|
try:
|
|
if str(a) == str(b) or a == b:
|
|
return True
|
|
except Exception:
|
|
pass
|
|
|
|
# simplify equal
|
|
try:
|
|
if a.equals(b) or simplify(a - b) == 0:
|
|
return True
|
|
except Exception:
|
|
pass
|
|
|
|
# equation equal
|
|
try:
|
|
if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
|
|
return True
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
if numeric_equal(float(N(a)), float(N(b))):
|
|
return True
|
|
except Exception:
|
|
pass
|
|
|
|
# matrix
|
|
try:
|
|
# if a and b are matrix
|
|
if a.shape == b.shape:
|
|
_a = a.applyfunc(lambda x: round(x, 3))
|
|
_b = b.applyfunc(lambda x: round(x, 3))
|
|
if _a.equals(_b):
|
|
return True
|
|
except Exception:
|
|
pass
|
|
|
|
return False
|