Skip to content

Commit 968e679

Browse files
authored
[Misc] Add data preprocess functions to qwen2.5_vl_without_padding (#2148)
### What this PR does / why we need it? Cherry pick #1705 from v0.9.1-dev Compared qwen2_5_vl.py, qwen2_5_vl_without_padding.py missing some funtions. The purpose of this PR is to supplement these. add: - rot_pos_emb(self, grid_thw: torch.Tensor) - get_window_index(self, grid_thw) - _process_image_input(self, image_input) - _process_video_input(self, video_input) Co-authored-by: zheliuyu [15750543867@163.com](mailto:15750543867@163.com) Co-authored-by: wangli [wangli858794774@gmail.com](mailto:wangli858794774@gmail.com) ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@207b750 Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent e3b3ffb commit 968e679

File tree

1 file changed

+93
-0
lines changed

1 file changed

+93
-0
lines changed

vllm_ascend/models/qwen2_5_vl_without_padding.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,66 @@ def cal_cos_sin(self, rotary_pos_emb):
207207
self.hidden_size_per_attention_head)
208208
return cos_new, sin_new
209209

210+
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
211+
pos_ids = []
212+
for t, h, w in grid_thw:
213+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
214+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
215+
hpos_ids = hpos_ids.reshape(
216+
h // self.spatial_merge_size,
217+
self.spatial_merge_size,
218+
w // self.spatial_merge_size,
219+
self.spatial_merge_size,
220+
).permute(0, 2, 1, 3).flatten()
221+
wpos_ids = wpos_ids.reshape(
222+
h // self.spatial_merge_size,
223+
self.spatial_merge_size,
224+
w // self.spatial_merge_size,
225+
self.spatial_merge_size,
226+
).permute(0, 2, 1, 3).flatten()
227+
pos_ids.append(
228+
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
229+
pos_ids = torch.cat(pos_ids, dim=0)
230+
max_grid_size = grid_thw[:, 1:].max()
231+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
232+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
233+
return rotary_pos_emb
234+
235+
def get_window_index(self, grid_thw):
236+
window_index: list = []
237+
cu_window_seqlens: list = [0]
238+
window_index_id = 0
239+
vit_merger_window_size = (self.window_size //
240+
self.spatial_merge_size // self.patch_size)
241+
242+
for grid_t, grid_h, grid_w in grid_thw:
243+
llm_grid_h = grid_h // self.spatial_merge_size
244+
llm_grid_w = grid_w // self.spatial_merge_size
245+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
246+
grid_t, llm_grid_h, llm_grid_w)
247+
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
248+
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
249+
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
250+
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
251+
index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
252+
index_padded = index_padded.reshape(grid_t, num_windows_h,
253+
vit_merger_window_size,
254+
num_windows_w,
255+
vit_merger_window_size)
256+
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
257+
grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
258+
vit_merger_window_size)
259+
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
260+
index_padded = index_padded.reshape(-1)
261+
index_new = index_padded[index_padded != -100]
262+
window_index.append(index_new + window_index_id)
263+
cu_seqlens_tmp = seqlens.cumsum(
264+
0) * self.spatial_merge_unit + cu_window_seqlens[-1]
265+
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
266+
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
267+
window_index = torch.cat(window_index, dim=0)
268+
return window_index, cu_window_seqlens
269+
210270
def forward(
211271
self,
212272
x: torch.Tensor,
@@ -258,6 +318,39 @@ def forward(
258318
x = x[reverse_indices, :]
259319
return x
260320

321+
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:
322+
323+
grid_thw = image_input["image_grid_thw"]
324+
assert grid_thw.ndim == 2
325+
326+
if image_input["type"] == "image_embeds":
327+
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
328+
else:
329+
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
330+
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
331+
332+
# Split concatenated embeddings for each image item.
333+
merge_size = self.visual.spatial_merge_size
334+
sizes = grid_thw.prod(-1) // merge_size // merge_size
335+
return image_embeds.split(sizes.tolist())
336+
337+
def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
338+
339+
grid_thw = video_input["video_grid_thw"]
340+
assert grid_thw.ndim == 2
341+
342+
if video_input["type"] == "video_embeds":
343+
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
344+
else:
345+
pixel_values_videos = video_input["pixel_values_videos"].type(
346+
self.visual.dtype)
347+
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
348+
349+
# Split concatenated embeddings for each video item.
350+
merge_size = self.visual.spatial_merge_size
351+
sizes = grid_thw.prod(-1) // merge_size // merge_size
352+
return video_embeds.split(sizes.tolist())
353+
261354

262355
@MULTIMODAL_REGISTRY.register_processor(
263356
Qwen2_5_VLMultiModalProcessor,

0 commit comments

Comments
 (0)