Skip to content

Commit 9d1d9d6

Browse files
alex-jw-brooksDarkLight1337
authored andcommitted
[CI/Build] Add Model Tests for Qwen2-VL (vllm-project#9846)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Loc Huynh <jc1da.3011@gmail.com>
1 parent 4511fcb commit 9d1d9d6

File tree

9 files changed

+106
-52
lines changed

9 files changed

+106
-52
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# label(str): the name of the test. emoji allowed.
1010
# fast_check(bool): whether to run this on each commit on fastcheck pipeline.
1111
# fast_check_only(bool): run this test on fastcheck pipeline only
12+
# nightly(bool): run this test in nightly pipeline only
1213
# optional(bool): never run this test by default (i.e. need to unblock manually)
1314
# command(str): the single command to run for tests. incompatible with commands.
1415
# commands(list): the list of commands to run for test. incompatbile with command.
@@ -330,18 +331,28 @@ steps:
330331
commands:
331332
- pytest -v -s models/decoder_only/language --ignore=models/decoder_only/language/test_models.py --ignore=models/decoder_only/language/test_big_models.py
332333

333-
- label: Decoder-only Multi-Modal Models Test # 1h31min
334+
- label: Decoder-only Multi-Modal Models Test (Standard)
334335
#mirror_hardwares: [amd]
335336
source_file_dependencies:
336337
- vllm/
337338
- tests/models/decoder_only/audio_language
338339
- tests/models/decoder_only/vision_language
339340
commands:
340-
- pytest -v -s models/decoder_only/audio_language
341+
- pytest -v -s models/decoder_only/audio_language -m core_model
342+
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m core_model
343+
344+
- label: Decoder-only Multi-Modal Models Test (Extended)
345+
nightly: true
346+
source_file_dependencies:
347+
- vllm/
348+
- tests/models/decoder_only/audio_language
349+
- tests/models/decoder_only/vision_language
350+
commands:
351+
- pytest -v -s models/decoder_only/audio_language -m 'not core_model'
341352
# HACK - run phi3v tests separately to sidestep this transformers bug
342353
# https://github.com/huggingface/transformers/issues/34307
343354
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py
344-
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language
355+
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model'
345356

346357
- label: Other Models Test # 6min
347358
#mirror_hardwares: [amd]

examples/offline_inference_vision_language.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,9 @@ def run_qwen2_vl(question: str, modality: str):
262262

263263
model_name = "Qwen/Qwen2-VL-7B-Instruct"
264264

265-
# Tested on L40
266265
llm = LLM(
267266
model=model_name,
268-
max_model_len=8192,
267+
max_model_len=4096,
269268
max_num_seqs=5,
270269
# Note - mm_processor_kwargs can also be passed to generate/chat calls
271270
mm_processor_kwargs={

tests/models/decoder_only/audio_language/test_ultravox.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def run_multi_audio_test(
158158
assert all(tokens for tokens, *_ in vllm_outputs)
159159

160160

161+
@pytest.mark.core_model
161162
@pytest.mark.parametrize("dtype", ["half"])
162163
@pytest.mark.parametrize("max_tokens", [128])
163164
@pytest.mark.parametrize("num_logprobs", [5])
@@ -178,6 +179,7 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
178179
)
179180

180181

182+
@pytest.mark.core_model
181183
@pytest.mark.parametrize("dtype", ["half"])
182184
@pytest.mark.parametrize("max_tokens", [128])
183185
@pytest.mark.parametrize("num_logprobs", [5])

tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
# Fixtures lazy import to avoid initializing CUDA during test collection
20-
# NOTE: Qwen2vl supports multiple input modalities, so it registers multiple
20+
# NOTE: Qwen2VL supports multiple input modalities, so it registers multiple
2121
# input mappers.
2222
@pytest.fixture()
2323
def image_input_mapper_for_qwen2_vl():

tests/models/decoder_only/vision_language/test_models.py

Lines changed: 62 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,63 @@
7575
# this is a good idea for checking your command first, since tests are slow.
7676

7777
VLM_TEST_SETTINGS = {
78+
#### Core tests to always run in the CI
79+
"llava": VLMTestInfo(
80+
models=["llava-hf/llava-1.5-7b-hf"],
81+
test_type=(
82+
VLMTestType.EMBEDDING,
83+
VLMTestType.IMAGE,
84+
VLMTestType.CUSTOM_INPUTS
85+
),
86+
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
87+
convert_assets_to_embeddings=model_utils.get_llava_embeddings,
88+
max_model_len=4096,
89+
auto_cls=AutoModelForVision2Seq,
90+
vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output,
91+
custom_test_opts=[CustomTestOptions(
92+
inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs(
93+
formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:"
94+
),
95+
limit_mm_per_prompt={"image": 4},
96+
)],
97+
marks=[pytest.mark.core_model],
98+
),
99+
"paligemma": VLMTestInfo(
100+
models=["google/paligemma-3b-mix-224"],
101+
test_type=VLMTestType.IMAGE,
102+
prompt_formatter=identity,
103+
img_idx_to_prompt = lambda idx: "",
104+
# Paligemma uses its own sample prompts because the default one fails
105+
single_image_prompts=IMAGE_ASSETS.prompts({
106+
"stop_sign": "caption es",
107+
"cherry_blossom": "What is in the picture?",
108+
}),
109+
auto_cls=AutoModelForVision2Seq,
110+
postprocess_inputs=model_utils.get_key_type_post_processor(
111+
"pixel_values"
112+
),
113+
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
114+
dtype="half" if current_platform.is_rocm() else ("half", "float"),
115+
marks=[pytest.mark.core_model],
116+
),
117+
"qwen2_vl": VLMTestInfo(
118+
models=["Qwen/Qwen2-VL-2B-Instruct"],
119+
test_type=(
120+
VLMTestType.IMAGE,
121+
VLMTestType.MULTI_IMAGE,
122+
VLMTestType.VIDEO
123+
),
124+
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
125+
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
126+
video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501
127+
max_model_len=4096,
128+
max_num_seqs=2,
129+
auto_cls=AutoModelForVision2Seq,
130+
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
131+
marks=[pytest.mark.core_model],
132+
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
133+
),
134+
#### Extended model tests
78135
"blip2": VLMTestInfo(
79136
models=["Salesforce/blip2-opt-2.7b"],
80137
test_type=VLMTestType.IMAGE,
@@ -151,25 +208,6 @@
151208
use_tokenizer_eos=True,
152209
patch_hf_runner=model_utils.internvl_patch_hf_runner,
153210
),
154-
"llava": VLMTestInfo(
155-
models=["llava-hf/llava-1.5-7b-hf"],
156-
test_type=(
157-
VLMTestType.EMBEDDING,
158-
VLMTestType.IMAGE,
159-
VLMTestType.CUSTOM_INPUTS
160-
),
161-
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
162-
convert_assets_to_embeddings=model_utils.get_llava_embeddings,
163-
max_model_len=4096,
164-
auto_cls=AutoModelForVision2Seq,
165-
vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output,
166-
custom_test_opts=[CustomTestOptions(
167-
inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs(
168-
formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:"
169-
),
170-
limit_mm_per_prompt={"image": 4},
171-
)],
172-
),
173211
"llava_next": VLMTestInfo(
174212
models=["llava-hf/llava-v1.6-mistral-7b-hf"],
175213
test_type=(VLMTestType.IMAGE, VLMTestType.CUSTOM_INPUTS),
@@ -200,12 +238,12 @@
200238
vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output,
201239
# Llava-one-vision tests fixed sizes & the default size factors
202240
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
203-
runner_mm_key="videos",
204241
custom_test_opts=[CustomTestOptions(
205242
inputs=custom_inputs.multi_video_multi_aspect_ratio_inputs(
206243
formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
207244
),
208245
limit_mm_per_prompt={"video": 4},
246+
runner_mm_key="videos",
209247
)],
210248
),
211249
# FIXME
@@ -218,9 +256,11 @@
218256
auto_cls=AutoModelForVision2Seq,
219257
vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output,
220258
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
221-
runner_mm_key="videos",
222259
marks=[
223-
pytest.mark.skip(reason="LLava next video tests currently fail.")
260+
pytest.mark.skipif(
261+
transformers.__version__.startswith("4.46"),
262+
reason="Model broken with changes in transformers 4.46"
263+
)
224264
],
225265
),
226266
"minicpmv": VLMTestInfo(
@@ -234,23 +274,6 @@
234274
postprocess_inputs=model_utils.wrap_inputs_post_processor,
235275
hf_output_post_proc=model_utils.minicmpv_trunc_hf_output,
236276
),
237-
"paligemma": VLMTestInfo(
238-
models=["google/paligemma-3b-mix-224"],
239-
test_type=VLMTestType.IMAGE,
240-
prompt_formatter=identity,
241-
img_idx_to_prompt = lambda idx: "",
242-
# Paligemma uses its own sample prompts because the default one fails
243-
single_image_prompts=IMAGE_ASSETS.prompts({
244-
"stop_sign": "caption es",
245-
"cherry_blossom": "What is in the picture?",
246-
}),
247-
auto_cls=AutoModelForVision2Seq,
248-
postprocess_inputs=model_utils.get_key_type_post_processor(
249-
"pixel_values"
250-
),
251-
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
252-
dtype="half" if current_platform.is_rocm() else ("half", "float"),
253-
),
254277
# Tests for phi3v currently live in another file because of a bug in
255278
# transformers. Once this issue is fixed, we can enable them here instead.
256279
# https://github.com/huggingface/transformers/issues/34307

tests/models/decoder_only/vision_language/vlm_utils/model_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,17 @@ def qwen_vllm_to_hf_output(
5656
return output_ids, hf_output_str, out_logprobs
5757

5858

59+
def qwen2_vllm_to_hf_output(
60+
vllm_output: RunnerOutput,
61+
model: str) -> Tuple[List[int], str, Optional[SampleLogprobs]]:
62+
"""Sanitize vllm output [qwen2 models] to be comparable with hf output."""
63+
output_ids, output_str, out_logprobs = vllm_output
64+
65+
hf_output_str = output_str + "<|im_end|>"
66+
67+
return output_ids, hf_output_str, out_logprobs
68+
69+
5970
def llava_image_vllm_to_hf_output(vllm_output: RunnerOutput,
6071
model: str) -> RunnerOutput:
6172
config = AutoConfig.from_pretrained(model)

tests/models/decoder_only/vision_language/vlm_utils/runners.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def run_single_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo,
2929
num_logprobs=test_case.num_logprobs,
3030
limit_mm_per_prompt={"image": 1},
3131
distributed_executor_backend=test_case.distributed_executor_backend,
32+
runner_mm_key="images",
3233
**model_test_info.get_non_parametrized_runner_kwargs())
3334

3435

@@ -51,6 +52,7 @@ def run_multi_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo,
5152
num_logprobs=test_case.num_logprobs,
5253
limit_mm_per_prompt={"image": len(image_assets)},
5354
distributed_executor_backend=test_case.distributed_executor_backend,
55+
runner_mm_key="images",
5456
**model_test_info.get_non_parametrized_runner_kwargs())
5557

5658

@@ -74,6 +76,7 @@ def run_embedding_test(*, model_test_info: VLMTestInfo,
7476
limit_mm_per_prompt={"image": 1},
7577
vllm_embeddings=vllm_embeddings,
7678
distributed_executor_backend=test_case.distributed_executor_backend,
79+
runner_mm_key="images",
7780
**model_test_info.get_non_parametrized_runner_kwargs())
7881

7982

@@ -101,6 +104,7 @@ def run_video_test(
101104
num_logprobs=test_case.num_logprobs,
102105
limit_mm_per_prompt={"video": len(video_assets)},
103106
distributed_executor_backend=test_case.distributed_executor_backend,
107+
runner_mm_key="videos",
104108
**model_test_info.get_non_parametrized_runner_kwargs())
105109

106110

@@ -115,7 +119,11 @@ def run_custom_inputs_test(*, model_test_info: VLMTestInfo,
115119

116120
inputs = test_case.custom_test_opts.inputs
117121
limit_mm_per_prompt = test_case.custom_test_opts.limit_mm_per_prompt
118-
assert inputs is not None and limit_mm_per_prompt is not None
122+
runner_mm_key = test_case.custom_test_opts.runner_mm_key
123+
# Inputs, limit_mm_per_prompt, and runner_mm_key should all be set
124+
assert inputs is not None
125+
assert limit_mm_per_prompt is not None
126+
assert runner_mm_key is not None
119127

120128
core.run_test(
121129
hf_runner=hf_runner,
@@ -127,4 +135,5 @@ def run_custom_inputs_test(*, model_test_info: VLMTestInfo,
127135
num_logprobs=test_case.num_logprobs,
128136
limit_mm_per_prompt=limit_mm_per_prompt,
129137
distributed_executor_backend=test_case.distributed_executor_backend,
138+
runner_mm_key=runner_mm_key,
130139
**model_test_info.get_non_parametrized_runner_kwargs())

tests/models/decoder_only/vision_language/vlm_utils/types.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ class SizeType(Enum):
5252
class CustomTestOptions(NamedTuple):
5353
inputs: List[Tuple[List[str], List[Union[List[Image], Image]]]]
5454
limit_mm_per_prompt: Dict[str, int]
55+
# kwarg to pass multimodal data in as to vllm/hf runner instances.
56+
runner_mm_key: str = "images"
5557

5658

5759
class ImageSizeWrapper(NamedTuple):
@@ -141,9 +143,6 @@ class VLMTestInfo(NamedTuple):
141143
Callable[[PosixPath, str, Union[List[ImageAsset], _ImageAssets]],
142144
str]] = None # noqa: E501
143145

144-
# kwarg to pass multimodal data in as to vllm/hf runner instances
145-
runner_mm_key: str = "images"
146-
147146
# Allows configuring a test to run with custom inputs
148147
custom_test_opts: Optional[List[CustomTestOptions]] = None
149148

@@ -168,7 +167,6 @@ def get_non_parametrized_runner_kwargs(self):
168167
"get_stop_token_ids": self.get_stop_token_ids,
169168
"model_kwargs": self.model_kwargs,
170169
"patch_hf_runner": self.patch_hf_runner,
171-
"runner_mm_key": self.runner_mm_key,
172170
}
173171

174172

tests/models/embedding/vision_language/test_llava_next.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44
import torch.nn.functional as F
5+
import transformers
56
from transformers import AutoModelForVision2Seq
67

78
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
@@ -85,8 +86,8 @@ def _run_test(
8586
)
8687

8788

88-
# FIXME
89-
@pytest.mark.skip(reason="LLava next embedding tests currently fail")
89+
@pytest.mark.skipif(transformers.__version__.startswith("4.46"),
90+
reason="Model broken with changes in transformers 4.46")
9091
@pytest.mark.parametrize("model", MODELS)
9192
@pytest.mark.parametrize("dtype", ["half"])
9293
def test_models_text(

0 commit comments

Comments
 (0)