Skip to content

num_beams > 1 leads to exception for Qwen2.5VL (Qwen family or all VLM models?) #39723

@iglaweb

Description

@iglaweb

System Info

  • transformers version: 4.53.2
  • Platform: Windows-10-10.0.26100-SP0
  • Python version: 3.10.18
  • Huggingface_hub version: 0.33.4
  • Safetensors version: 0.5.3
  • Accelerate version: 1.9.0
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.5.1+cu118 (CUDA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: yes
  • GPU type: NVIDIA GeForce RTX 4080

Who can help?

@amyeroberts @qubvel @zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor

num_beams = 2
do_sample = False
max_token_length = 8192
video_path = 'some video file path'

# load model
model_id = 'Qwen/Qwen2.5-VL-7B-Instruct'
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map='auto',
)
processor = AutoProcessor.from_pretrained(model_id)

messages = []
messages.append({
        "role": "user",
        "content": [
            {"type": "text", "text": 'Describe a video in detail.'},
            {"type": "video", "path": video_path},
        ],
})

inputs = processor.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,
    return_tensors="pt"
).to(model.device, dtype=model.dtype)

if num_beams > 1:
    inp_model_kwargs = {'num_beams': num_beams}
else:
    inp_model_kwargs = {}

outputs = model.generate(
    **inputs,
    do_sample=do_sample,
    max_new_tokens=max_token_length,
    **inp_model_kwargs
    # num_return_sequences=2,
)
# ...

Whenever I make num_beams greater than 1, I get the following exception:

Traceback (most recent call last):
  File "hf_vlm_run_exps_eval.py", line 750, in <module>
    run_main()
  File "hf_vlm_run_exps_eval.py", line 734, in run_main
    answers_dict = exec_llm_on_segment_videos(
  File "hf_vlm_run_exps_eval.py", line 473, in exec_llm_on_segment_videos
    output_text = extract_answer_from_llm(
  File "hf_vlm_run_exps_eval.py", line 517, in extract_answer_from_llm
    output_text = hf_base_model_wrapper.run_model_single_inference(
  File "hf_base_model_wrapper.py", line 94, in run_model_single_inference
    outputs = model.generate(
  File "C:\Users\User\.conda\envs\env_grounded_sam2\lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "C:\Users\User\.conda\envs\env_grounded_sam2\lib\site-packages\transformers\generation\utils.py", line 2637, in generate
    input_ids, model_kwargs = self._expand_inputs_for_generation(
  File "C:\Users\User\.conda\envs\env_grounded_sam2\lib\site-packages\transformers\models\qwen2_5_vl\modeling_qwen2_5_vl.py", line 1695, in _expand_inputs_for_generation
    model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
  File "C:\Users\User\.conda\envs\env_grounded_sam2\lib\site-packages\transformers\models\qwen2_5_vl\modeling_qwen2_5_vl.py", line 1672, in _expand_dict_for_generation_visual
    raise TypeError(
TypeError: Expected value for key 'second_per_grid_ts' to be a list, but got <class 'torch.Tensor'> instead.

Same issue if I add num_return_sequences=2 to .generate.

Expected behavior

I'd expect the function to run successfully and return multiple (different) sequences.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions