Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions vlmeval/vlm/granite_vision/granite_vision.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
flash_attn_flag = False
try:
import flash_attn

flash_attn_flag = True
except ImportError:
pass
Expand Down Expand Up @@ -58,6 +57,8 @@ def output_process(self, answer, dataset):

if "<|end_of_text|>" in answer:
answer = answer.split("<|end_of_text|>")[0].strip("\n ")
if "answer" in answer.lower():
answer = answer.lower().split("answer")[-1].strip(" :.-\n")
if dataset in [
"ChartQA_TEST",
"DocVQA_VAL",
Expand All @@ -69,14 +70,16 @@ def output_process(self, answer, dataset):
"TextVQA_VAL"
]:
answer = answer.strip(".")

return answer
if "ChartMuseum" in dataset:
answer = f"<answer>{answer}</answer>"
return answer.strip("\n")

def use_custom_prompt(self, dataset):
assert dataset is not None
if DATASET_TYPE(dataset) == "MCQ":
return True
if dataset in ["OCRBench", "COCO_VAL"]:
if dataset in ["OCRBench", "COCO_VAL", "ChartQA_TEST", "CharXiv_descriptive_val", "ChartMimic_v1_direct",
"ChartMimic_v2_direct", "ChartMimic_v2_customized",]:
return True
return False

Expand All @@ -87,6 +90,11 @@ def get_pre_post_prompt(self, dataset, chineese=False):
"\nReply with only one word or a short phrase or a full address.",
),
"COCO_VAL": ("", "\nReply with one short sentence."),
"ChartQA_TEST": ("", "\nAnswer the question with a single word."),
"CharXiv_descriptive_val": ("", "\nAnswer the question with a single word or short phrase."),
"ChartMimic_v1_direct": ("", "\nAnswer using code only. strating with ```python and ending with ```"),
"ChartMimic_v2_direct": ("", "\nAnswer using code only. strating with ```python and ending with ```"),
"ChartMimic_v2_customized": ("", "\nAnswer using code only. strating with ```python and ending with ```"),
}
pre_post_prompt_cn = {}

Expand Down Expand Up @@ -172,13 +180,10 @@ def generate_inner(self, message, dataset=None):
"content": content,
}
]
prompt = self.processor.apply_chat_template(
conversation, add_generation_prompt=True
)
inputs = self.processor(prompt, images, return_tensors="pt").to(
"cuda", torch.float16
)
output = self.model.generate(**inputs, **self.kwargs)
prompt = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = self.processor(images=images, text=prompt, return_tensors="pt").to(self.model.device, self.model.dtype)
with torch.no_grad():
output = self.model.generate(**inputs, **self.kwargs)
answer = self.processor.decode(output[0], skip_special_token=True)
answer = self.output_process(answer, dataset)
return answer