diff --git a/vlmeval/vlm/granite_vision/granite_vision.py b/vlmeval/vlm/granite_vision/granite_vision.py old mode 100644 new mode 100755 index 1f46f86cb..eec096e0c --- a/vlmeval/vlm/granite_vision/granite_vision.py +++ b/vlmeval/vlm/granite_vision/granite_vision.py @@ -13,7 +13,6 @@ flash_attn_flag = False try: import flash_attn - flash_attn_flag = True except ImportError: pass @@ -58,6 +57,8 @@ def output_process(self, answer, dataset): if "<|end_of_text|>" in answer: answer = answer.split("<|end_of_text|>")[0].strip("\n ") + if "answer" in answer.lower(): + answer = answer.lower().split("answer")[-1].strip(" :.-\n") if dataset in [ "ChartQA_TEST", "DocVQA_VAL", @@ -69,14 +70,16 @@ def output_process(self, answer, dataset): "TextVQA_VAL" ]: answer = answer.strip(".") - - return answer + if "ChartMuseum" in dataset: + answer = f"{answer}" + return answer.strip("\n") def use_custom_prompt(self, dataset): assert dataset is not None if DATASET_TYPE(dataset) == "MCQ": return True - if dataset in ["OCRBench", "COCO_VAL"]: + if dataset in ["OCRBench", "COCO_VAL", "ChartQA_TEST", "CharXiv_descriptive_val", "ChartMimic_v1_direct", + "ChartMimic_v2_direct", "ChartMimic_v2_customized",]: return True return False @@ -87,6 +90,11 @@ def get_pre_post_prompt(self, dataset, chineese=False): "\nReply with only one word or a short phrase or a full address.", ), "COCO_VAL": ("", "\nReply with one short sentence."), + "ChartQA_TEST": ("", "\nAnswer the question with a single word."), + "CharXiv_descriptive_val": ("", "\nAnswer the question with a single word or short phrase."), + "ChartMimic_v1_direct": ("", "\nAnswer using code only. strating with ```python and ending with ```"), + "ChartMimic_v2_direct": ("", "\nAnswer using code only. strating with ```python and ending with ```"), + "ChartMimic_v2_customized": ("", "\nAnswer using code only. strating with ```python and ending with ```"), } pre_post_prompt_cn = {} @@ -172,13 +180,10 @@ def generate_inner(self, message, dataset=None): "content": content, } ] - prompt = self.processor.apply_chat_template( - conversation, add_generation_prompt=True - ) - inputs = self.processor(prompt, images, return_tensors="pt").to( - "cuda", torch.float16 - ) - output = self.model.generate(**inputs, **self.kwargs) + prompt = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) + inputs = self.processor(images=images, text=prompt, return_tensors="pt").to(self.model.device, self.model.dtype) + with torch.no_grad(): + output = self.model.generate(**inputs, **self.kwargs) answer = self.processor.decode(output[0], skip_special_token=True) answer = self.output_process(answer, dataset) return answer