diff --git a/run.py b/run.py index 6152aa88e..64edcf6fc 100644 --- a/run.py +++ b/run.py @@ -378,7 +378,7 @@ def main(): judge_kwargs['model'] = 'gpt-4-turbo' elif listinstr(['VGRPBench'], dataset_name): judge_kwargs['model'] = 'gpt-4o' - elif listinstr(['MathVista', 'MathVerse', 'MathVision', 'DynaMath', 'VL-RewardBench', 'LogicVista', 'MOAT', 'OCR_Reasoning'], dataset_name): # noqa: E501 + elif listinstr(['MathVista', 'MathVerse', 'MathVision', 'LENS','DynaMath', 'VL-RewardBench', 'LogicVista', 'MOAT', 'OCR_Reasoning'], dataset_name): # noqa: E501 judge_kwargs['model'] = 'gpt-4o-mini' elif listinstr(['OlympiadBench'], dataset_name): use_api_judger = judge_kwargs.get("olympiad_use_api_judger", False) diff --git a/vlmeval/dataset/__init__.py b/vlmeval/dataset/__init__.py index 15be72455..7e682ea45 100644 --- a/vlmeval/dataset/__init__.py +++ b/vlmeval/dataset/__init__.py @@ -11,7 +11,7 @@ ) from .image_mt import MMDUDataset from .image_vqa import ( - ImageVQADataset, MathVision, OCRBench, MathVista, LLaVABench, LLaVABench_KO, VGRPBench, MMVet, MTVQADataset, + ImageVQADataset, MathVision, LENS, OCRBench, MathVista, LLaVABench, LLaVABench_KO, VGRPBench, MMVet, MTVQADataset, TableVQABench, CustomVQADataset, CRPE, MathVerse, OlympiadBench, SeePhys, QSpatial, VizWiz, MMNIAH, LogicVista, MME_CoT, MMSci_Captioning, Physics_yale, TDBenchGrounding, WildDocBenchmark, OCR_Reasoning, PhyX, CountBenchQA, ZEROBench, Omni3DBench, TallyQA, MMEReasoning, MMVMBench, BMMR, OCRBench_v2, AyaVisionBench, MathCanvas, MMReason @@ -213,7 +213,7 @@ def evaluate(self, eval_file, **judge_kwargs): # Add new supported dataset class here IMAGE_DATASET = [ ImageCaptionDataset, ImageYORNDataset, ImageMCQDataset, ImageVQADataset, - MathVision, MMMUDataset, OCRBench, MathVista, LLaVABench, LLaVABench_KO, VGRPBench, MMVet, + MathVision, LENS, MMMUDataset, OCRBench, MathVista, LLaVABench, LLaVABench_KO, VGRPBench, MMVet, MTVQADataset, TableVQABench, MMLongBench, VCRDataset, MMDUDataset, DUDE, SlideVQA, MUIRDataset, CCOCRDataset, GMAIMMBenchDataset, MMERealWorld, HRBenchDataset, CRPE, MathVerse, NaturalBenchDataset, MIABench, diff --git a/vlmeval/dataset/image_vqa.py b/vlmeval/dataset/image_vqa.py index 23aefb946..cedb3573c 100644 --- a/vlmeval/dataset/image_vqa.py +++ b/vlmeval/dataset/image_vqa.py @@ -695,6 +695,137 @@ def MathVision_acc_verifier(result_file): return score +class LENS(ImageBaseDataset): + TYPE = 'VQA' + DATASET_URL = { + 'LENS-CN-QA': + 'https://huggingface.co/datasets/songlier/LENS/resolve/main/LENS-CN-QA.tsv', + 'LENS-CN-QA_MINI': + 'https://huggingface.co/datasets/songlier/LENS/resolve/main/LENS-CN-QA_MINI.tsv' + } + DATASET_MD5 = { + 'LENS-CN-QA': 'D382365A2C977543BEB890BAC240E731', + 'LENS-CN-QA_MINI':'4CEA1BDE46537DE2428C1D05A0B36094' + } + + def evaluate(self, eval_file, **judge_kwargs): + if judge_kwargs.get('use_verifier', False): + return self.evaluate_verifier(eval_file, **judge_kwargs) + else: + return self.evaluate_heuristic(eval_file, **judge_kwargs) + + def evaluate_heuristic(self, eval_file, **judge_kwargs): + from .utils.lens import LENS_auxeval, LENS_acc + + if 'model' in judge_kwargs: + model = judge_kwargs['model'] + else: + model = os.path.basename(os.environ.get('LOCAL_LLM')) + storage = get_intermediate_file_path(eval_file, f'_{model}') + tmp_file = get_intermediate_file_path(eval_file, f'_{model}', 'pkl') + nproc = judge_kwargs.pop('nproc', 4) + + if not osp.exists(storage): + data = load(eval_file) + model = build_judge(max_tokens=128, **judge_kwargs) + assert model.working(), 'LENS evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE + lt = len(data) + lines = [data.iloc[i] for i in range(lt)] + tups = [(model, line) for line in lines] + indices = [line['index'] for line in lines] + + ans = {} + if osp.exists(tmp_file): + ans = load(tmp_file) + tups = [x for x, i in zip(tups, indices) if i not in ans] + indices = [i for i in indices if i not in ans] + + if len(indices): + new_results = track_progress_rich( + LENS_auxeval, + tups, + nproc=nproc, + chunksize=nproc, + keys=indices, + save=tmp_file, + ) + ans = load(tmp_file) + for k, v in zip(indices, new_results): + assert k in ans + assert ans[k]['log'] == v['log'] and ans[k]['res'] == v['res'] + + data['res'] = [ans[idx]['res'] for idx in data['index']] + data['log'] = [ans[idx]['log'] for idx in data['index']] + dump(data, storage) + + score = LENS_acc(storage) + score_pth = get_intermediate_file_path(storage, '_score', 'csv') + dump(score, score_pth) + return score + + # It returns a DataFrame + @classmethod + def evaluate_verifier(self, eval_file, **judge_kwargs): + # Add verifier evaluation for LENS + data = load(eval_file) + if 'verifier_score' not in data.columns: + from .utils.verifier import Verifier + verifier = Verifier(use_vllm=judge_kwargs.get('use_vllm', False)) + + verifier_scores = [] + verifier_matches = [] + for idx, row in tqdm(data.iterrows(), total=len(data), desc="Verifier Evaluation Progress"): + question_text = row['question'] if 'question' in row else "" + prediction_text = row['prediction'] if 'prediction' in row else "" + answer_text = row['answer'] if 'answer' in row else "" + + score = verifier.evaluate(question_text, prediction_text, answer_text) + verifier_scores.append(score) + verifier_matches.append(1.0 if score else 0.0) + + data['verifier_score'] = verifier_scores + data['verifier_match'] = verifier_matches + + detailed_result_file = get_intermediate_file_path(eval_file, '_detailed_results') + dump(data, detailed_result_file) + + else: + detailed_result_file = get_intermediate_file_path(eval_file, '_detailed_results') + if not osp.exists(detailed_result_file): + dump(data, detailed_result_file) + + def LENS_acc_verifier(result_file): + from collections import defaultdict + data = load(result_file) + tot = defaultdict(lambda: 0) + hit = defaultdict(lambda: 0) + lt = len(data) + + for i in range(lt): + item = data.iloc[i] + cate = item['category'] if 'category' in item else 'Overall' + tot['Overall'] += 1 + tot[cate] += 1 + + if item['verifier_score'] is True: + hit['Overall'] += 1 + hit[cate] += 1 + + res = defaultdict(list) + for k in tot.keys(): + res['Subject'].append(k) + res['tot'].append(tot[k]) + res['hit'].append(hit[k]) + res['acc'].append(hit[k] / tot[k] * 100) + res = pd.DataFrame(res).sort_values('Subject', ignore_index=True) + return res + + score = LENS_acc_verifier(detailed_result_file) + score_pth = get_intermediate_file_path(eval_file, '_score', 'csv') + dump(score, score_pth) + return score + + class Physics_yale(ImageBaseDataset): TYPE = 'VQA' DATASET_URL = { diff --git a/vlmeval/dataset/utils/lens.py b/vlmeval/dataset/utils/lens.py new file mode 100644 index 000000000..debed4ff9 --- /dev/null +++ b/vlmeval/dataset/utils/lens.py @@ -0,0 +1,214 @@ +from ...smp import * +from ...utils import can_infer +import timeout_decorator +try: + try: + from latex2sympy2_extended import latex2sympy + except ImportError: + from latex2sympy2 import latex2sympy + +except Exception as e: + logging.critical(f'{type(e)}: {e}') + logging.critical('Please install latex2sympy2-extended by running "pip install latex2sympy2-extended"') + raise e + + +FAIL_MSG = 'Failed to obtain answer via API.' + + +# @timeout_decorator.timeout(30) +def is_equal(asw: str, gt_asw: str) -> bool: + if not isinstance(asw, str) != str or not isinstance(gt_asw, str): + print('Warning: input is not string') + print(asw, gt_asw) + asw = str(asw).lower().strip() + gt_asw = str(gt_asw).lower().strip() + if gt_asw == asw: + return True + try: + a = eval(gt_asw) + b = eval(asw) + if abs(a - b) < 1e-6: + return True + except: + pass + try: + a = latex2sympy(gt_asw) + b = latex2sympy(asw) + if abs(eval(str(a)) - eval(str(b))) < 1e-6: + return True + if abs(a - b) < 1e-6: + return True + except: + pass + return False + + +def get_gpt4_ICE(): + example_1 = """ +Hint: Please answer the question and provide the final answer at the end.\n +Question: Which number is missing?\n +Model response: The number missing in the sequence is 14.\n +Extracted answer: 14 +""" + + example_2 = """ +Hint: Please answer the question and provide the final answer at the end.\n +Question: What is the fraction of females facing the camera?\n +Model response: The fraction of females facing the camera is 0.6, +which means that six out of ten females in the group are facing the camera.\n +Extracted answer: 0.6 +""" + + example_3 = """ +Hint: Please answer the question and provide the final answer at the end.\n +Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $)\n +Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.\n +Extracted answer: 1.45 +""" + + example_4 = """ +Hint: Please answer the question and provide the final answer at the end.\n +Question: Between which two years does the line graph saw its maximum peak?\n +Model response: The line graph saw its maximum peak between 2007 and 2008.\n +Extracted answer: [2007, 2008] +""" + + example_5 = """ +Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\n +Question: What fraction of the shape is blue?\n +Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5\n +Model response: The correct answer is (B) 8/11.\n +Extracted answer: B +""" + + return [example_1, example_2, example_3, example_4, example_5] + + +def build_lens_gpt4_prompt(line): + task_description = """ +Please read the following example. +Then extract the answer from the model response and type it at the end of the prompt.\n +""" + question = line['question'] + prediction = str(line['prediction']) + prompt = task_description + examples = get_gpt4_ICE() + for example in examples: + prompt += example + '\n' + prompt += question + '\n' + prompt += 'Model respone: ' + prediction + prompt += 'Extracted answer:' + return prompt + + +def list_to_dict(lst): + return {chr(65 + i): val for i, val in enumerate(lst)} + + +def post_check(line, prefetch=False): + res = None + ans = line['answer'] + response = line['prediction'] if prefetch else line['res'] + try: + res = str(response) + ans = str(ans) + except ValueError: + pass + + try: + if is_equal(res, ans): + return res if prefetch else True + else: + return False + except Exception as err: + logging.warning(f'{type(err)}: {err}') + return False + +def LENS_auxeval(model, line): + try: + prompt = build_lens_gpt4_prompt(line) + log = '' + retry = 5 + try: + if post_check(line, prefetch=True): + res = post_check(line, prefetch=True) + return dict(log='Prefetch succeed', res=res) + except Exception as e: + logging.warning(f"Prefetch failed for index {line.get('index', 'unknown')}: {e}") + + for i in range(retry): + try: + prediction = line['prediction'] + res = model.generate(prompt, temperature=i * 0.5) + + if res is None: + raise ValueError("Model returned None") + + if FAIL_MSG in res: + log += f'Try {i}: output is {prediction}, failed to parse.\n' + else: + log += 'Succeed' + return dict(log=log, res=res) + + except Exception as api_err: + log += f'Try {i} Exception: {type(api_err)} - {api_err}\n' + continue + + log += 'All 5 retries failed.\n' + return dict(log=log, res='') + + except Exception as critical_err: + logging.critical(f"Critical Error in LENS_auxeval: {critical_err}") + return dict(log=f"Critical Error: {critical_err}", res='') + + +def LENS_acc(result_file): + data = load(result_file) + tot = defaultdict(lambda: 0) + fetch = defaultdict(lambda: 0) + hit = defaultdict(lambda: 0) + lt = len(data) + from tqdm import tqdm + for i in tqdm(range(lt)): + item = data.iloc[i] + cate = item['category'] + tot['Overall'] += 1 + tot[cate] += 1 + if item['log'] == 'Prefetch succeed': + fetch['Overall'] += 1 + fetch[cate] += 1 + is_correct = post_check(item, prefetch=False) + if is_correct: + hit['Overall'] += 1 + hit[cate] += 1 + + # Print details + # idx = item.get('index', 'N/A') + # q = item.get('question', 'N/A') + # gt = item.get('answer', 'N/A') + # raw_pred = str(item.get('prediction', 'N/A')).replace('\n', ' ') + # processed_res = item.get('res', 'N/A') + # status = "Yes" if is_correct else "No" + # msg = ( + # f"\n--------------------------------------------------\n" + # f"Index: {idx}\n" + # f"Question: {q}\n" + # f"Correct answer: {gt}\n" + # f"Model original: {raw_pred[:100]}\n" + # f"Answer after processing: {processed_res}\n" + # f"Judgment result: {status}" + # ) + # tqdm.write(msg) + + res = defaultdict(list) + for k in tot.keys(): + res['Subject'].append(k) + res['tot'].append(tot[k]) + res['prefetch'].append(fetch[k]) + res['hit'].append(hit[k]) + res['prefetch_rate'].append(fetch[k] / tot[k] * 100) + res['acc'].append(hit[k] / tot[k] * 100) + res = pd.DataFrame(res).sort_values('Subject', ignore_index=True) + # res.columns = ['Subject', 'Total', 'Prefetch', 'Hit', 'Prefetch rate (%)', 'Accuracy rate (Acc %)'] + return res