@@ -207,6 +207,66 @@ def cal_cos_sin(self, rotary_pos_emb):
207
207
self .hidden_size_per_attention_head )
208
208
return cos_new , sin_new
209
209
210
+ def rot_pos_emb (self , grid_thw : torch .Tensor ) -> torch .Tensor :
211
+ pos_ids = []
212
+ for t , h , w in grid_thw :
213
+ hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
214
+ wpos_ids = torch .arange (w ).unsqueeze (0 ).expand (h , - 1 )
215
+ hpos_ids = hpos_ids .reshape (
216
+ h // self .spatial_merge_size ,
217
+ self .spatial_merge_size ,
218
+ w // self .spatial_merge_size ,
219
+ self .spatial_merge_size ,
220
+ ).permute (0 , 2 , 1 , 3 ).flatten ()
221
+ wpos_ids = wpos_ids .reshape (
222
+ h // self .spatial_merge_size ,
223
+ self .spatial_merge_size ,
224
+ w // self .spatial_merge_size ,
225
+ self .spatial_merge_size ,
226
+ ).permute (0 , 2 , 1 , 3 ).flatten ()
227
+ pos_ids .append (
228
+ torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ).repeat (t , 1 ))
229
+ pos_ids = torch .cat (pos_ids , dim = 0 )
230
+ max_grid_size = grid_thw [:, 1 :].max ()
231
+ rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size )
232
+ rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
233
+ return rotary_pos_emb
234
+
235
+ def get_window_index (self , grid_thw ):
236
+ window_index : list = []
237
+ cu_window_seqlens : list = [0 ]
238
+ window_index_id = 0
239
+ vit_merger_window_size = (self .window_size //
240
+ self .spatial_merge_size // self .patch_size )
241
+
242
+ for grid_t , grid_h , grid_w in grid_thw :
243
+ llm_grid_h = grid_h // self .spatial_merge_size
244
+ llm_grid_w = grid_w // self .spatial_merge_size
245
+ index = torch .arange (grid_t * llm_grid_h * llm_grid_w ).reshape (
246
+ grid_t , llm_grid_h , llm_grid_w )
247
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
248
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
249
+ num_windows_h = (llm_grid_h + pad_h ) // vit_merger_window_size
250
+ num_windows_w = (llm_grid_w + pad_w ) // vit_merger_window_size
251
+ index_padded = F .pad (index , (0 , pad_w , 0 , pad_h ), 'constant' , - 100 )
252
+ index_padded = index_padded .reshape (grid_t , num_windows_h ,
253
+ vit_merger_window_size ,
254
+ num_windows_w ,
255
+ vit_merger_window_size )
256
+ index_padded = index_padded .permute (0 , 1 , 3 , 2 , 4 ).reshape (
257
+ grid_t , num_windows_h * num_windows_w , vit_merger_window_size ,
258
+ vit_merger_window_size )
259
+ seqlens = (index_padded != - 100 ).sum ([2 , 3 ]).reshape (- 1 )
260
+ index_padded = index_padded .reshape (- 1 )
261
+ index_new = index_padded [index_padded != - 100 ]
262
+ window_index .append (index_new + window_index_id )
263
+ cu_seqlens_tmp = seqlens .cumsum (
264
+ 0 ) * self .spatial_merge_unit + cu_window_seqlens [- 1 ]
265
+ cu_window_seqlens .extend (cu_seqlens_tmp .tolist ())
266
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w ).item ()
267
+ window_index = torch .cat (window_index , dim = 0 )
268
+ return window_index , cu_window_seqlens
269
+
210
270
def forward (
211
271
self ,
212
272
x : torch .Tensor ,
@@ -258,6 +318,39 @@ def forward(
258
318
x = x [reverse_indices , :]
259
319
return x
260
320
321
+ def _process_image_input (self , image_input ) -> tuple [torch .Tensor , ...]:
322
+
323
+ grid_thw = image_input ["image_grid_thw" ]
324
+ assert grid_thw .ndim == 2
325
+
326
+ if image_input ["type" ] == "image_embeds" :
327
+ image_embeds = image_input ["image_embeds" ].type (self .visual .dtype )
328
+ else :
329
+ pixel_values = image_input ["pixel_values" ].type (self .visual .dtype )
330
+ image_embeds = self .visual (pixel_values , grid_thw = grid_thw )
331
+
332
+ # Split concatenated embeddings for each image item.
333
+ merge_size = self .visual .spatial_merge_size
334
+ sizes = grid_thw .prod (- 1 ) // merge_size // merge_size
335
+ return image_embeds .split (sizes .tolist ())
336
+
337
+ def _process_video_input (self , video_input ) -> tuple [torch .Tensor , ...]:
338
+
339
+ grid_thw = video_input ["video_grid_thw" ]
340
+ assert grid_thw .ndim == 2
341
+
342
+ if video_input ["type" ] == "video_embeds" :
343
+ video_embeds = video_input ["video_embeds" ].type (self .visual .dtype )
344
+ else :
345
+ pixel_values_videos = video_input ["pixel_values_videos" ].type (
346
+ self .visual .dtype )
347
+ video_embeds = self .visual (pixel_values_videos , grid_thw = grid_thw )
348
+
349
+ # Split concatenated embeddings for each video item.
350
+ merge_size = self .visual .spatial_merge_size
351
+ sizes = grid_thw .prod (- 1 ) // merge_size // merge_size
352
+ return video_embeds .split (sizes .tolist ())
353
+
261
354
262
355
@MULTIMODAL_REGISTRY .register_processor (
263
356
Qwen2_5_VLMultiModalProcessor ,
0 commit comments