File tree Expand file tree Collapse file tree 2 files changed +47
-47
lines changed Expand file tree Collapse file tree 2 files changed +47
-47
lines changed Original file line number Diff line number Diff line change 18
18
from vllm .utils import resolve_obj_by_qualname
19
19
from vllm .v1 .pool .metadata import PoolingMetadata as V1PoolingMetadata
20
20
21
+ from vllm .triton_utils import tl , triton
22
+ HAS_TRITON = triton is not None
23
+
24
+
21
25
PoolingMetadata = Union [V0PoolingMetadata , V1PoolingMetadata ]
22
26
23
27
@@ -658,3 +662,44 @@ def forward(
658
662
])
659
663
660
664
return build_output (scores )
665
+
666
+
667
+ if HAS_TRITON :
668
+ @triton .jit
669
+ def extract_vision_tokens_kernel (
670
+ hidden_states_ptr ,
671
+ token_ids_ptr ,
672
+ output_ptr ,
673
+ seq_start ,
674
+ seq_len ,
675
+ hidden_size ,
676
+ vision_start_id : tl .constexpr ,
677
+ vision_end_id : tl .constexpr ,
678
+ BLOCK_SIZE : tl .constexpr ,
679
+ ):
680
+ """Triton kernel to extract and pool vision tokens efficiently."""
681
+ pid = tl .program_id (0 )
682
+
683
+ if pid >= hidden_size :
684
+ return
685
+
686
+ # Find vision token range
687
+ vision_count = 0
688
+ accumulator = 0.0
689
+
690
+ for i in range (seq_len ):
691
+ token_id = tl .load (token_ids_ptr + seq_start + i )
692
+ if token_id >= vision_start_id and token_id <= vision_end_id :
693
+ hidden_val = tl .load (
694
+ hidden_states_ptr + (seq_start + i ) * hidden_size + pid
695
+ )
696
+ accumulator += hidden_val
697
+ vision_count += 1
698
+
699
+ # Store mean pooled result
700
+ if vision_count > 0 :
701
+ result = accumulator / vision_count
702
+ else :
703
+ result = 0.0
704
+
705
+ tl .store (output_ptr + pid , result )
Original file line number Diff line number Diff line change 9
9
import torch .nn .functional as F
10
10
from torch import nn
11
11
12
- try :
13
- import triton
14
- import triton .language as tl
15
- HAS_TRITON = True
16
- except ImportError :
17
- HAS_TRITON = False
18
- triton = None
19
- tl = None
12
+ from vllm .model_executor .layers .pooler import HAS_TRITON , extract_vision_tokens_kernel
20
13
21
14
from vllm .config import VllmConfig
22
15
from vllm .logger import init_logger
44
37
45
38
46
39
# Triton kernel for optimized vision token extraction
47
- if HAS_TRITON :
48
- @triton .jit
49
- def extract_vision_tokens_kernel (
50
- hidden_states_ptr ,
51
- token_ids_ptr ,
52
- output_ptr ,
53
- seq_start ,
54
- seq_len ,
55
- hidden_size ,
56
- vision_start_id : tl .constexpr ,
57
- vision_end_id : tl .constexpr ,
58
- BLOCK_SIZE : tl .constexpr ,
59
- ):
60
- """Triton kernel to extract and pool vision tokens efficiently."""
61
- pid = tl .program_id (0 )
62
-
63
- if pid >= hidden_size :
64
- return
65
-
66
- # Find vision token range
67
- vision_count = 0
68
- accumulator = 0.0
69
-
70
- for i in range (seq_len ):
71
- token_id = tl .load (token_ids_ptr + seq_start + i )
72
- if token_id >= vision_start_id and token_id <= vision_end_id :
73
- hidden_val = tl .load (
74
- hidden_states_ptr + (seq_start + i ) * hidden_size + pid
75
- )
76
- accumulator += hidden_val
77
- vision_count += 1
78
-
79
- # Store mean pooled result
80
- if vision_count > 0 :
81
- result = accumulator / vision_count
82
- else :
83
- result = 0.0
84
-
85
- tl .store (output_ptr + pid , result )
40
+
86
41
87
42
88
43
@MULTIMODAL_REGISTRY .register_processor (Qwen2VLMultiModalProcessor ,
You can’t perform that action at this time.
0 commit comments