Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,8 @@
llava_series = {
"llava_v1.5_7b": partial(LLaVA, model_path="liuhaotian/llava-v1.5-7b"),
"llava_v1.5_13b": partial(LLaVA, model_path="liuhaotian/llava-v1.5-13b"),
"llava_v1.5_7b_hf": partial(LLaVA, model_path="llava-hf/llava-1.5-7b-hf"),
"llava_v1.5_13b_hf": partial(LLaVA, model_path="llava-hf/llava-1.5-13b-hf"),
Comment on lines +847 to +848
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the LLaVA_HF class not used in the config.py?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing out the typo. It should indeed be LLaVA_HF here.

"llava_v1_7b": partial(LLaVA, model_path=LLAVA_V1_7B_MODEL_PTH),
"sharegpt4v_7b": partial(LLaVA, model_path="Lin-Chen/ShareGPT4V-7B"),
"sharegpt4v_13b": partial(LLaVA, model_path="Lin-Chen/ShareGPT4V-13B"),
Expand Down
1 change: 1 addition & 0 deletions vlmeval/vlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .kosmos import Kosmos2
from .llava import (
LLaVA,
LLaVA_HF,
LLaVA_Next,
LLaVA_XTuner,
LLaVA_Next2,
Expand Down
3 changes: 2 additions & 1 deletion vlmeval/vlm/llava/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .llava import LLaVA, LLaVA_Next, LLaVA_Next2, LLaVA_OneVision, LLaVA_OneVision_HF
from .llava_xtuner import LLaVA_XTuner
from .llava_hf import LLaVA_HF

__all__ = ['LLaVA', 'LLaVA_Next', 'LLaVA_XTuner', 'LLaVA_Next2', 'LLaVA_OneVision', 'LLaVA_OneVision_HF']
__all__ = ['LLaVA', 'LLaVA_Next', 'LLaVA_XTuner','LLaVA_HF', 'LLaVA_Next2', 'LLaVA_OneVision', 'LLaVA_OneVision_HF']
178 changes: 178 additions & 0 deletions vlmeval/vlm/llava/llava_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import torch
from PIL import Image
from abc import abstractproperty
import sys
import os.path as osp
from ..base import BaseModel
from ...smp import *
from ...dataset import DATASET_TYPE, DATASET_MODALITY
import copy
import requests
from transformers import AutoProcessor, LlavaForConditionalGeneration
import logging

class LLaVA_HF(BaseModel):
INSTALL_REQ = False
INTERLEAVE = True

def __init__(self, model_path="llava-hf/llava-1.5-7b-hf", **kwargs):

self.model_path = model_path

try:
self.model = LlavaForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="cuda"
)
self.processor = AutoProcessor.from_pretrained(model_path)
except Exception as err:
logging.critical(f"Failed to load Hugging Face LLaVA model from {model_path}.")
raise err

kwargs_default = dict(
do_sample=False,
temperature=0,
max_new_tokens=2048,
top_p=None,
num_beams=1,
use_cache=True,
)
kwargs_default.update(kwargs)

# Hugging Face's generation config doesn't accept temperature=0 with do_sample=False
if not kwargs_default["do_sample"] and kwargs_default["temperature"] == 0:
kwargs_default.pop("temperature", None)
kwargs_default.pop("top_p", None)

self.kwargs = kwargs_default
warnings.warn(
f"Following kwargs received: {self.kwargs}, will use as generation config. "
)

def use_custom_prompt(self, dataset):
assert dataset is not None
if DATASET_TYPE(dataset) == "MCQ":
return True
return False

def build_prompt(self, line, dataset=None):
assert self.use_custom_prompt(dataset)
assert dataset is None or isinstance(dataset, str)
tgt_path = self.dump_image(line, dataset)

question = line["question"]
hint = line["hint"] if ("hint" in line and not pd.isna(line["hint"])) else None
if hint is not None:
question = hint + "\n" + question

options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
for key, item in options.items():
question += f"\n{key}. {item}"
prompt = question

if len(options):
prompt += (
"\n请直接回答选项字母。"
if cn_string(prompt)
else "\nAnswer with the option's letter from the given choices directly."
)
else:
prompt += (
"\n请直接回答问题。"
if cn_string(prompt)
else "\nAnswer the question directly."
)

message = [dict(type="image", value=s) for s in tgt_path]
message.append(dict(type="text", value=prompt))
return message

def chat_inner(self, message, dataset=None):


conversation = []
images = []

# Convert framework messages to HF Chat Template format
for utter in message:
content_list = []
for item in utter["content"]:
if item["type"] == "text":
content_list.append({"type": "text", "text": item["value"]})
elif item["type"] == "image":
content_list.append({"type": "image"})
images.append(Image.open(item["value"]).convert("RGB"))

conversation.append({
"role": utter["role"],
"content": content_list
})

prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)

inputs = self.processor(
images=images if images else None,
text=prompt,
return_tensors="pt"
).to(self.model.device, torch.float16)

with torch.inference_mode():
output_ids = self.model.generate(
**inputs,
**self.kwargs
)

# Slice the output to remove the input prompt tokens
input_len = inputs["input_ids"].shape[1]
generated_ids = output_ids[0][input_len:]

output = self.processor.decode(generated_ids, skip_special_tokens=True).strip()
return output

def generate_inner(self, message, dataset=None):
import torch

content_list = []
images = []

# Convert single-turn framework message to HF Chat Template format
for item in message:
if item["type"] == "text":
content_list.append({"type": "text", "text": item["value"]})
elif item["type"] == "image":
content_list.append({"type": "image"})
images.append(Image.open(item["value"]).convert("RGB"))

conversation = [
{
"role": "user",
"content": content_list
}
]

prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)

inputs = self.processor(
images=images if images else None,
text=prompt,
return_tensors="pt"
).to(self.model.device, torch.float16)

with torch.inference_mode():
output_ids = self.model.generate(
**inputs,
**self.kwargs
)

# Slice the output to remove the input prompt tokens
input_len = inputs["input_ids"].shape[1]
generated_ids = output_ids[0][input_len:]

output = self.processor.decode(generated_ids, skip_special_tokens=True).strip()
return output