Skip to content

Commit 3124d1b

Browse files
authored
[qwen-vl] fix beam search with videos (#39726)
* fix * fix copies
1 parent 1372a5b commit 3124d1b

File tree

3 files changed

+9
-24
lines changed

3 files changed

+9
-24
lines changed

src/transformers/models/glm4v/modeling_glm4v.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,14 +1614,9 @@ def _repeat_interleave_samples(x, lengths, repeat_times):
16141614
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
16151615
)
16161616
elif key == "second_per_grid_ts":
1617-
if not isinstance(dict_to_expand[key], list):
1618-
raise TypeError(
1619-
f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead."
1620-
)
1621-
tensor = torch.tensor(dict_to_expand[key])
1622-
lengths = list(video_nums)
1623-
tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size)
1624-
dict_to_expand[key] = tensor.tolist()
1617+
dict_to_expand[key] = _repeat_interleave_samples(
1618+
dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
1619+
)
16251620
return dict_to_expand
16261621

16271622
def _expand_dict_for_generation(dict_to_expand):

src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,14 +1713,9 @@ def _repeat_interleave_samples(x, lengths, repeat_times):
17131713
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
17141714
)
17151715
elif key == "second_per_grid_ts":
1716-
if not isinstance(dict_to_expand[key], list):
1717-
raise TypeError(
1718-
f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead."
1719-
)
1720-
tensor = torch.tensor(dict_to_expand[key])
1721-
lengths = list(video_nums)
1722-
tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size)
1723-
dict_to_expand[key] = tensor.tolist()
1716+
dict_to_expand[key] = _repeat_interleave_samples(
1717+
dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
1718+
)
17241719
return dict_to_expand
17251720

17261721
def _expand_dict_for_generation(dict_to_expand):

src/transformers/models/qwen2_vl/modeling_qwen2_vl.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,14 +1599,9 @@ def _repeat_interleave_samples(x, lengths, repeat_times):
15991599
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
16001600
)
16011601
elif key == "second_per_grid_ts":
1602-
if not isinstance(dict_to_expand[key], list):
1603-
raise TypeError(
1604-
f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead."
1605-
)
1606-
tensor = torch.tensor(dict_to_expand[key])
1607-
lengths = list(video_nums)
1608-
tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size)
1609-
dict_to_expand[key] = tensor.tolist()
1602+
dict_to_expand[key] = _repeat_interleave_samples(
1603+
dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
1604+
)
16101605
return dict_to_expand
16111606

16121607
def _expand_dict_for_generation(dict_to_expand):

0 commit comments

Comments
 (0)