Skip to content

Commit 2277552

Browse files
committed
add data preprocess functions to qwen2.5_vl_without_padding
Signed-off-by: zheliuyu <15750543867@163.com>
1 parent 57664f0 commit 2277552

File tree

1 file changed

+93
-0
lines changed

1 file changed

+93
-0
lines changed

vllm_ascend/models/qwen2_5_vl_without_padding.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,66 @@ def cal_cos_sin(self, rotary_pos_emb):
202202
self.hidden_size_per_attention_head)
203203
return cos_new, sin_new
204204

205+
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
206+
pos_ids = []
207+
for t, h, w in grid_thw:
208+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
209+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
210+
hpos_ids = hpos_ids.reshape(
211+
h // self.spatial_merge_size,
212+
self.spatial_merge_size,
213+
w // self.spatial_merge_size,
214+
self.spatial_merge_size,
215+
).permute(0, 2, 1, 3).flatten()
216+
wpos_ids = wpos_ids.reshape(
217+
h // self.spatial_merge_size,
218+
self.spatial_merge_size,
219+
w // self.spatial_merge_size,
220+
self.spatial_merge_size,
221+
).permute(0, 2, 1, 3).flatten()
222+
pos_ids.append(
223+
torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
224+
pos_ids = torch.cat(pos_ids, dim=0)
225+
max_grid_size = grid_thw[:, 1:].max()
226+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
227+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
228+
return rotary_pos_emb
229+
230+
def get_window_index(self, grid_thw):
231+
window_index: list = []
232+
cu_window_seqlens: list = [0]
233+
window_index_id = 0
234+
vit_merger_window_size = (self.window_size //
235+
self.spatial_merge_size // self.patch_size)
236+
237+
for grid_t, grid_h, grid_w in grid_thw:
238+
llm_grid_h = grid_h // self.spatial_merge_size
239+
llm_grid_w = grid_w // self.spatial_merge_size
240+
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
241+
grid_t, llm_grid_h, llm_grid_w)
242+
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
243+
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
244+
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
245+
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
246+
index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
247+
index_padded = index_padded.reshape(grid_t, num_windows_h,
248+
vit_merger_window_size,
249+
num_windows_w,
250+
vit_merger_window_size)
251+
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
252+
grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
253+
vit_merger_window_size)
254+
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
255+
index_padded = index_padded.reshape(-1)
256+
index_new = index_padded[index_padded != -100]
257+
window_index.append(index_new + window_index_id)
258+
cu_seqlens_tmp = seqlens.cumsum(
259+
0) * self.spatial_merge_unit + cu_window_seqlens[-1]
260+
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
261+
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
262+
window_index = torch.cat(window_index, dim=0)
263+
return window_index, cu_window_seqlens
264+
205265
def forward(
206266
self,
207267
x: torch.Tensor,
@@ -271,3 +331,36 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
271331
quant_config=self._maybe_ignore_quant_config(quant_config),
272332
prefix=maybe_prefix(prefix, "visual"),
273333
)
334+
335+
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:
336+
337+
grid_thw = image_input["image_grid_thw"]
338+
assert grid_thw.ndim == 2
339+
340+
if image_input["type"] == "image_embeds":
341+
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
342+
else:
343+
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
344+
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
345+
346+
# Split concatenated embeddings for each image item.
347+
merge_size = self.visual.spatial_merge_size
348+
sizes = grid_thw.prod(-1) // merge_size // merge_size
349+
return image_embeds.split(sizes.tolist())
350+
351+
def _process_video_input(self, video_input) -> tuple[torch.Tensor, ...]:
352+
353+
grid_thw = video_input["video_grid_thw"]
354+
assert grid_thw.ndim == 2
355+
356+
if video_input["type"] == "video_embeds":
357+
video_embeds = video_input["video_embeds"].type(self.visual.dtype)
358+
else:
359+
pixel_values_videos = video_input["pixel_values_videos"].type(
360+
self.visual.dtype)
361+
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
362+
363+
# Split concatenated embeddings for each video item.
364+
merge_size = self.visual.spatial_merge_size
365+
sizes = grid_thw.prod(-1) // merge_size // merge_size
366+
return video_embeds.split(sizes.tolist())

0 commit comments

Comments
 (0)