@@ -202,6 +202,66 @@ def cal_cos_sin(self, rotary_pos_emb):
202
202
self .hidden_size_per_attention_head )
203
203
return cos_new , sin_new
204
204
205
+ def rot_pos_emb (self , grid_thw : torch .Tensor ) -> torch .Tensor :
206
+ pos_ids = []
207
+ for t , h , w in grid_thw :
208
+ hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
209
+ wpos_ids = torch .arange (w ).unsqueeze (0 ).expand (h , - 1 )
210
+ hpos_ids = hpos_ids .reshape (
211
+ h // self .spatial_merge_size ,
212
+ self .spatial_merge_size ,
213
+ w // self .spatial_merge_size ,
214
+ self .spatial_merge_size ,
215
+ ).permute (0 , 2 , 1 , 3 ).flatten ()
216
+ wpos_ids = wpos_ids .reshape (
217
+ h // self .spatial_merge_size ,
218
+ self .spatial_merge_size ,
219
+ w // self .spatial_merge_size ,
220
+ self .spatial_merge_size ,
221
+ ).permute (0 , 2 , 1 , 3 ).flatten ()
222
+ pos_ids .append (
223
+ torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ).repeat (t , 1 ))
224
+ pos_ids = torch .cat (pos_ids , dim = 0 )
225
+ max_grid_size = grid_thw [:, 1 :].max ()
226
+ rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size )
227
+ rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
228
+ return rotary_pos_emb
229
+
230
+ def get_window_index (self , grid_thw ):
231
+ window_index : list = []
232
+ cu_window_seqlens : list = [0 ]
233
+ window_index_id = 0
234
+ vit_merger_window_size = (self .window_size //
235
+ self .spatial_merge_size // self .patch_size )
236
+
237
+ for grid_t , grid_h , grid_w in grid_thw :
238
+ llm_grid_h = grid_h // self .spatial_merge_size
239
+ llm_grid_w = grid_w // self .spatial_merge_size
240
+ index = torch .arange (grid_t * llm_grid_h * llm_grid_w ).reshape (
241
+ grid_t , llm_grid_h , llm_grid_w )
242
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
243
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
244
+ num_windows_h = (llm_grid_h + pad_h ) // vit_merger_window_size
245
+ num_windows_w = (llm_grid_w + pad_w ) // vit_merger_window_size
246
+ index_padded = F .pad (index , (0 , pad_w , 0 , pad_h ), 'constant' , - 100 )
247
+ index_padded = index_padded .reshape (grid_t , num_windows_h ,
248
+ vit_merger_window_size ,
249
+ num_windows_w ,
250
+ vit_merger_window_size )
251
+ index_padded = index_padded .permute (0 , 1 , 3 , 2 , 4 ).reshape (
252
+ grid_t , num_windows_h * num_windows_w , vit_merger_window_size ,
253
+ vit_merger_window_size )
254
+ seqlens = (index_padded != - 100 ).sum ([2 , 3 ]).reshape (- 1 )
255
+ index_padded = index_padded .reshape (- 1 )
256
+ index_new = index_padded [index_padded != - 100 ]
257
+ window_index .append (index_new + window_index_id )
258
+ cu_seqlens_tmp = seqlens .cumsum (
259
+ 0 ) * self .spatial_merge_unit + cu_window_seqlens [- 1 ]
260
+ cu_window_seqlens .extend (cu_seqlens_tmp .tolist ())
261
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w ).item ()
262
+ window_index = torch .cat (window_index , dim = 0 )
263
+ return window_index , cu_window_seqlens
264
+
205
265
def forward (
206
266
self ,
207
267
x : torch .Tensor ,
@@ -271,3 +331,36 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
271
331
quant_config = self ._maybe_ignore_quant_config (quant_config ),
272
332
prefix = maybe_prefix (prefix , "visual" ),
273
333
)
334
+
335
+ def _process_image_input (self , image_input ) -> tuple [torch .Tensor , ...]:
336
+
337
+ grid_thw = image_input ["image_grid_thw" ]
338
+ assert grid_thw .ndim == 2
339
+
340
+ if image_input ["type" ] == "image_embeds" :
341
+ image_embeds = image_input ["image_embeds" ].type (self .visual .dtype )
342
+ else :
343
+ pixel_values = image_input ["pixel_values" ].type (self .visual .dtype )
344
+ image_embeds = self .visual (pixel_values , grid_thw = grid_thw )
345
+
346
+ # Split concatenated embeddings for each image item.
347
+ merge_size = self .visual .spatial_merge_size
348
+ sizes = grid_thw .prod (- 1 ) // merge_size // merge_size
349
+ return image_embeds .split (sizes .tolist ())
350
+
351
+ def _process_video_input (self , video_input ) -> tuple [torch .Tensor , ...]:
352
+
353
+ grid_thw = video_input ["video_grid_thw" ]
354
+ assert grid_thw .ndim == 2
355
+
356
+ if video_input ["type" ] == "video_embeds" :
357
+ video_embeds = video_input ["video_embeds" ].type (self .visual .dtype )
358
+ else :
359
+ pixel_values_videos = video_input ["pixel_values_videos" ].type (
360
+ self .visual .dtype )
361
+ video_embeds = self .visual (pixel_values_videos , grid_thw = grid_thw )
362
+
363
+ # Split concatenated embeddings for each video item.
364
+ merge_size = self .visual .spatial_merge_size
365
+ sizes = grid_thw .prod (- 1 ) // merge_size // merge_size
366
+ return video_embeds .split (sizes .tolist ())
0 commit comments