diff --git a/vlmeval/config.py b/vlmeval/config.py index fa4d9524..4252e4fb 100644 --- a/vlmeval/config.py +++ b/vlmeval/config.py @@ -168,7 +168,8 @@ 'InternVL2-2B': partial(InternVLChat, model_path='OpenGVLab/InternVL2-2B', version='V2.0'), 'InternVL2-4B': partial(InternVLChat, model_path='OpenGVLab/InternVL2-4B', version='V2.0'), 'InternVL2-8B': partial(InternVLChat, model_path='OpenGVLab/InternVL2-8B', version='V2.0'), - 'InternVL2-8B-MPO': partial(InternVLChat, model_path='OpenGVLab/InternVL2-8B-MPO', version='V2.0', cot_prompt=True), + 'InternVL2-8B-MPO': partial(InternVLChat, model_path='OpenGVLab/InternVL2-8B-MPO', version='V2.0'), + 'InternVL2-8B-MPO-CoT': partial(InternVLChat, model_path='OpenGVLab/InternVL2-8B-MPO', version='V2.0', cot_prompt=True), 'InternVL2-26B': partial(InternVLChat, model_path='OpenGVLab/InternVL2-26B', version='V2.0'), 'InternVL2-40B': partial(InternVLChat, model_path='OpenGVLab/InternVL2-40B', version='V2.0', load_in_8bit=True), 'InternVL2-76B': partial(InternVLChat, model_path='OpenGVLab/InternVL2-Llama3-76B', version='V2.0'), diff --git a/vlmeval/vlm/internvl_chat.py b/vlmeval/vlm/internvl_chat.py index 6ae0333f..69ee922b 100644 --- a/vlmeval/vlm/internvl_chat.py +++ b/vlmeval/vlm/internvl_chat.py @@ -322,6 +322,18 @@ def build_prompt(self, line, dataset=None): question_orig = question_orig.split('Question:', 1)[-1].strip() question_orig = question_orig.replace('Choices:\n', '').strip() + options = { + cand: line[cand] + for cand in string.ascii_uppercase + if cand in line and not pd.isna(line[cand]) + } + options_prompt = '' + for key, item in options.items(): + options_prompt += f'{key}. {item}\n' + + if options_prompt.strip(): + question_orig = f'{question_orig}\n{options_prompt}' + prompt = cot_prompt.format(question=question_orig) message = [dict(type='text', value=prompt)]