Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def main():
judge_kwargs['model'] = 'gpt-4-turbo'
elif listinstr(['VGRPBench'], dataset_name):
judge_kwargs['model'] = 'gpt-4o'
elif listinstr(['MathVista', 'MathVerse', 'MathVision', 'DynaMath', 'VL-RewardBench', 'LogicVista', 'MOAT', 'OCR_Reasoning'], dataset_name): # noqa: E501
elif listinstr(['MathVista', 'MathVerse', 'MathVision', 'LENS','DynaMath', 'VL-RewardBench', 'LogicVista', 'MOAT', 'OCR_Reasoning'], dataset_name): # noqa: E501
judge_kwargs['model'] = 'gpt-4o-mini'
elif listinstr(['OlympiadBench'], dataset_name):
use_api_judger = judge_kwargs.get("olympiad_use_api_judger", False)
Expand Down
4 changes: 2 additions & 2 deletions vlmeval/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from .image_mt import MMDUDataset
from .image_vqa import (
ImageVQADataset, MathVision, OCRBench, MathVista, LLaVABench, LLaVABench_KO, VGRPBench, MMVet, MTVQADataset,
ImageVQADataset, MathVision, LENS, OCRBench, MathVista, LLaVABench, LLaVABench_KO, VGRPBench, MMVet, MTVQADataset,
TableVQABench, CustomVQADataset, CRPE, MathVerse, OlympiadBench, SeePhys, QSpatial, VizWiz, MMNIAH, LogicVista,
MME_CoT, MMSci_Captioning, Physics_yale, TDBenchGrounding, WildDocBenchmark, OCR_Reasoning, PhyX, CountBenchQA,
ZEROBench, Omni3DBench, TallyQA, MMEReasoning, MMVMBench, BMMR, OCRBench_v2, AyaVisionBench, MathCanvas, MMReason
Expand Down Expand Up @@ -213,7 +213,7 @@ def evaluate(self, eval_file, **judge_kwargs):
# Add new supported dataset class here
IMAGE_DATASET = [
ImageCaptionDataset, ImageYORNDataset, ImageMCQDataset, ImageVQADataset,
MathVision, MMMUDataset, OCRBench, MathVista, LLaVABench, LLaVABench_KO, VGRPBench, MMVet,
MathVision, LENS, MMMUDataset, OCRBench, MathVista, LLaVABench, LLaVABench_KO, VGRPBench, MMVet,
MTVQADataset, TableVQABench, MMLongBench, VCRDataset, MMDUDataset, DUDE,
SlideVQA, MUIRDataset, CCOCRDataset, GMAIMMBenchDataset, MMERealWorld,
HRBenchDataset, CRPE, MathVerse, NaturalBenchDataset, MIABench,
Expand Down
131 changes: 131 additions & 0 deletions vlmeval/dataset/image_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,137 @@ def MathVision_acc_verifier(result_file):
return score


class LENS(ImageBaseDataset):
TYPE = 'VQA'
DATASET_URL = {
'LENS-CN-QA':
'https://huggingface.co/datasets/songlier/LENS/resolve/main/LENS-CN-QA.tsv',
'LENS-CN-QA_MINI':
'https://huggingface.co/datasets/songlier/LENS/resolve/main/LENS-CN-QA_MINI.tsv'
}
DATASET_MD5 = {
'LENS-CN-QA': 'D382365A2C977543BEB890BAC240E731',
'LENS-CN-QA_MINI':'4CEA1BDE46537DE2428C1D05A0B36094'
}

def evaluate(self, eval_file, **judge_kwargs):
if judge_kwargs.get('use_verifier', False):
return self.evaluate_verifier(eval_file, **judge_kwargs)
else:
return self.evaluate_heuristic(eval_file, **judge_kwargs)

def evaluate_heuristic(self, eval_file, **judge_kwargs):
from .utils.lens import LENS_auxeval, LENS_acc

if 'model' in judge_kwargs:
model = judge_kwargs['model']
else:
model = os.path.basename(os.environ.get('LOCAL_LLM'))
storage = get_intermediate_file_path(eval_file, f'_{model}')
tmp_file = get_intermediate_file_path(eval_file, f'_{model}', 'pkl')
nproc = judge_kwargs.pop('nproc', 4)

if not osp.exists(storage):
data = load(eval_file)
model = build_judge(max_tokens=128, **judge_kwargs)
assert model.working(), 'LENS evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE
lt = len(data)
lines = [data.iloc[i] for i in range(lt)]
tups = [(model, line) for line in lines]
indices = [line['index'] for line in lines]

ans = {}
if osp.exists(tmp_file):
ans = load(tmp_file)
tups = [x for x, i in zip(tups, indices) if i not in ans]
indices = [i for i in indices if i not in ans]

if len(indices):
new_results = track_progress_rich(
LENS_auxeval,
tups,
nproc=nproc,
chunksize=nproc,
keys=indices,
save=tmp_file,
)
ans = load(tmp_file)
for k, v in zip(indices, new_results):
assert k in ans
assert ans[k]['log'] == v['log'] and ans[k]['res'] == v['res']

data['res'] = [ans[idx]['res'] for idx in data['index']]
data['log'] = [ans[idx]['log'] for idx in data['index']]
dump(data, storage)

score = LENS_acc(storage)
score_pth = get_intermediate_file_path(storage, '_score', 'csv')
dump(score, score_pth)
return score

# It returns a DataFrame
@classmethod
def evaluate_verifier(self, eval_file, **judge_kwargs):
# Add verifier evaluation for LENS
data = load(eval_file)
if 'verifier_score' not in data.columns:
from .utils.verifier import Verifier
verifier = Verifier(use_vllm=judge_kwargs.get('use_vllm', False))

verifier_scores = []
verifier_matches = []
for idx, row in tqdm(data.iterrows(), total=len(data), desc="Verifier Evaluation Progress"):
question_text = row['question'] if 'question' in row else ""
prediction_text = row['prediction'] if 'prediction' in row else ""
answer_text = row['answer'] if 'answer' in row else ""

score = verifier.evaluate(question_text, prediction_text, answer_text)
verifier_scores.append(score)
verifier_matches.append(1.0 if score else 0.0)

data['verifier_score'] = verifier_scores
data['verifier_match'] = verifier_matches

detailed_result_file = get_intermediate_file_path(eval_file, '_detailed_results')
dump(data, detailed_result_file)

else:
detailed_result_file = get_intermediate_file_path(eval_file, '_detailed_results')
if not osp.exists(detailed_result_file):
dump(data, detailed_result_file)

def LENS_acc_verifier(result_file):
from collections import defaultdict
data = load(result_file)
tot = defaultdict(lambda: 0)
hit = defaultdict(lambda: 0)
lt = len(data)

for i in range(lt):
item = data.iloc[i]
cate = item['category'] if 'category' in item else 'Overall'
tot['Overall'] += 1
tot[cate] += 1

if item['verifier_score'] is True:
hit['Overall'] += 1
hit[cate] += 1

res = defaultdict(list)
for k in tot.keys():
res['Subject'].append(k)
res['tot'].append(tot[k])
res['hit'].append(hit[k])
res['acc'].append(hit[k] / tot[k] * 100)
res = pd.DataFrame(res).sort_values('Subject', ignore_index=True)
return res

score = LENS_acc_verifier(detailed_result_file)
score_pth = get_intermediate_file_path(eval_file, '_score', 'csv')
dump(score, score_pth)
return score


class Physics_yale(ImageBaseDataset):
TYPE = 'VQA'
DATASET_URL = {
Expand Down
214 changes: 214 additions & 0 deletions vlmeval/dataset/utils/lens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
from ...smp import *
from ...utils import can_infer
import timeout_decorator
try:
try:
from latex2sympy2_extended import latex2sympy
except ImportError:
from latex2sympy2 import latex2sympy

except Exception as e:
logging.critical(f'{type(e)}: {e}')
logging.critical('Please install latex2sympy2-extended by running "pip install latex2sympy2-extended"')
raise e


FAIL_MSG = 'Failed to obtain answer via API.'


# @timeout_decorator.timeout(30)
def is_equal(asw: str, gt_asw: str) -> bool:
if not isinstance(asw, str) != str or not isinstance(gt_asw, str):
print('Warning: input is not string')
print(asw, gt_asw)
asw = str(asw).lower().strip()
gt_asw = str(gt_asw).lower().strip()
if gt_asw == asw:
return True
try:
a = eval(gt_asw)
b = eval(asw)
if abs(a - b) < 1e-6:
return True
except:
pass
try:
a = latex2sympy(gt_asw)
b = latex2sympy(asw)
if abs(eval(str(a)) - eval(str(b))) < 1e-6:
return True
if abs(a - b) < 1e-6:
return True
except:
pass
return False


def get_gpt4_ICE():
example_1 = """
Hint: Please answer the question and provide the final answer at the end.\n
Question: Which number is missing?\n
Model response: The number missing in the sequence is 14.\n
Extracted answer: 14
"""

example_2 = """
Hint: Please answer the question and provide the final answer at the end.\n
Question: What is the fraction of females facing the camera?\n
Model response: The fraction of females facing the camera is 0.6,
which means that six out of ten females in the group are facing the camera.\n
Extracted answer: 0.6
"""

example_3 = """
Hint: Please answer the question and provide the final answer at the end.\n
Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $)\n
Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.\n
Extracted answer: 1.45
"""

example_4 = """
Hint: Please answer the question and provide the final answer at the end.\n
Question: Between which two years does the line graph saw its maximum peak?\n
Model response: The line graph saw its maximum peak between 2007 and 2008.\n
Extracted answer: [2007, 2008]
"""

example_5 = """
Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\n
Question: What fraction of the shape is blue?\n
Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5\n
Model response: The correct answer is (B) 8/11.\n
Extracted answer: B
"""

return [example_1, example_2, example_3, example_4, example_5]


def build_lens_gpt4_prompt(line):
task_description = """
Please read the following example.
Then extract the answer from the model response and type it at the end of the prompt.\n
"""
question = line['question']
prediction = str(line['prediction'])
prompt = task_description
examples = get_gpt4_ICE()
for example in examples:
prompt += example + '\n'
prompt += question + '\n'
prompt += 'Model respone: ' + prediction
prompt += 'Extracted answer:'
return prompt


def list_to_dict(lst):
return {chr(65 + i): val for i, val in enumerate(lst)}


def post_check(line, prefetch=False):
res = None
ans = line['answer']
response = line['prediction'] if prefetch else line['res']
try:
res = str(response)
ans = str(ans)
except ValueError:
pass

try:
if is_equal(res, ans):
return res if prefetch else True
else:
return False
except Exception as err:
logging.warning(f'{type(err)}: {err}')
return False

def LENS_auxeval(model, line):
try:
prompt = build_lens_gpt4_prompt(line)
log = ''
retry = 5
try:
if post_check(line, prefetch=True):
res = post_check(line, prefetch=True)
return dict(log='Prefetch succeed', res=res)
except Exception as e:
logging.warning(f"Prefetch failed for index {line.get('index', 'unknown')}: {e}")

for i in range(retry):
try:
prediction = line['prediction']
res = model.generate(prompt, temperature=i * 0.5)

if res is None:
raise ValueError("Model returned None")

if FAIL_MSG in res:
log += f'Try {i}: output is {prediction}, failed to parse.\n'
else:
log += 'Succeed'
return dict(log=log, res=res)

except Exception as api_err:
log += f'Try {i} Exception: {type(api_err)} - {api_err}\n'
continue

log += 'All 5 retries failed.\n'
return dict(log=log, res='')

except Exception as critical_err:
logging.critical(f"Critical Error in LENS_auxeval: {critical_err}")
return dict(log=f"Critical Error: {critical_err}", res='')


def LENS_acc(result_file):
data = load(result_file)
tot = defaultdict(lambda: 0)
fetch = defaultdict(lambda: 0)
hit = defaultdict(lambda: 0)
lt = len(data)
from tqdm import tqdm
for i in tqdm(range(lt)):
item = data.iloc[i]
cate = item['category']
tot['Overall'] += 1
tot[cate] += 1
if item['log'] == 'Prefetch succeed':
fetch['Overall'] += 1
fetch[cate] += 1
is_correct = post_check(item, prefetch=False)
if is_correct:
hit['Overall'] += 1
hit[cate] += 1

# Print details
# idx = item.get('index', 'N/A')
# q = item.get('question', 'N/A')
# gt = item.get('answer', 'N/A')
# raw_pred = str(item.get('prediction', 'N/A')).replace('\n', ' ')
# processed_res = item.get('res', 'N/A')
# status = "Yes" if is_correct else "No"
# msg = (
# f"\n--------------------------------------------------\n"
# f"Index: {idx}\n"
# f"Question: {q}\n"
# f"Correct answer: {gt}\n"
# f"Model original: {raw_pred[:100]}\n"
# f"Answer after processing: {processed_res}\n"
# f"Judgment result: {status}"
# )
# tqdm.write(msg)

res = defaultdict(list)
for k in tot.keys():
res['Subject'].append(k)
res['tot'].append(tot[k])
res['prefetch'].append(fetch[k])
res['hit'].append(hit[k])
res['prefetch_rate'].append(fetch[k] / tot[k] * 100)
res['acc'].append(hit[k] / tot[k] * 100)
res = pd.DataFrame(res).sort_values('Subject', ignore_index=True)
# res.columns = ['Subject', 'Total', 'Prefetch', 'Hit', 'Prefetch rate (%)', 'Accuracy rate (Acc %)']
return res