Skip to content

Commit 0fe30f8

Browse files
sigridjinethSigrid Jin (Sionic AI)
authored andcommitted
refactor: review
Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>
1 parent efa8b04 commit 0fe30f8

File tree

2 files changed

+47
-47
lines changed

2 files changed

+47
-47
lines changed

vllm/model_executor/layers/pooler.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
from vllm.utils import resolve_obj_by_qualname
1919
from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata
2020

21+
from vllm.triton_utils import tl, triton
22+
HAS_TRITON = triton is not None
23+
24+
2125
PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata]
2226

2327

@@ -658,3 +662,44 @@ def forward(
658662
])
659663

660664
return build_output(scores)
665+
666+
667+
if HAS_TRITON:
668+
@triton.jit
669+
def extract_vision_tokens_kernel(
670+
hidden_states_ptr,
671+
token_ids_ptr,
672+
output_ptr,
673+
seq_start,
674+
seq_len,
675+
hidden_size,
676+
vision_start_id: tl.constexpr,
677+
vision_end_id: tl.constexpr,
678+
BLOCK_SIZE: tl.constexpr,
679+
):
680+
"""Triton kernel to extract and pool vision tokens efficiently."""
681+
pid = tl.program_id(0)
682+
683+
if pid >= hidden_size:
684+
return
685+
686+
# Find vision token range
687+
vision_count = 0
688+
accumulator = 0.0
689+
690+
for i in range(seq_len):
691+
token_id = tl.load(token_ids_ptr + seq_start + i)
692+
if token_id >= vision_start_id and token_id <= vision_end_id:
693+
hidden_val = tl.load(
694+
hidden_states_ptr + (seq_start + i) * hidden_size + pid
695+
)
696+
accumulator += hidden_val
697+
vision_count += 1
698+
699+
# Store mean pooled result
700+
if vision_count > 0:
701+
result = accumulator / vision_count
702+
else:
703+
result = 0.0
704+
705+
tl.store(output_ptr + pid, result)

vllm/model_executor/models/jina_embeddings_v4.py

Lines changed: 2 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,7 @@
99
import torch.nn.functional as F
1010
from torch import nn
1111

12-
try:
13-
import triton
14-
import triton.language as tl
15-
HAS_TRITON = True
16-
except ImportError:
17-
HAS_TRITON = False
18-
triton = None
19-
tl = None
12+
from vllm.model_executor.layers.pooler import HAS_TRITON, extract_vision_tokens_kernel
2013

2114
from vllm.config import VllmConfig
2215
from vllm.logger import init_logger
@@ -44,45 +37,7 @@
4437

4538

4639
# Triton kernel for optimized vision token extraction
47-
if HAS_TRITON:
48-
@triton.jit
49-
def extract_vision_tokens_kernel(
50-
hidden_states_ptr,
51-
token_ids_ptr,
52-
output_ptr,
53-
seq_start,
54-
seq_len,
55-
hidden_size,
56-
vision_start_id: tl.constexpr,
57-
vision_end_id: tl.constexpr,
58-
BLOCK_SIZE: tl.constexpr,
59-
):
60-
"""Triton kernel to extract and pool vision tokens efficiently."""
61-
pid = tl.program_id(0)
62-
63-
if pid >= hidden_size:
64-
return
65-
66-
# Find vision token range
67-
vision_count = 0
68-
accumulator = 0.0
69-
70-
for i in range(seq_len):
71-
token_id = tl.load(token_ids_ptr + seq_start + i)
72-
if token_id >= vision_start_id and token_id <= vision_end_id:
73-
hidden_val = tl.load(
74-
hidden_states_ptr + (seq_start + i) * hidden_size + pid
75-
)
76-
accumulator += hidden_val
77-
vision_count += 1
78-
79-
# Store mean pooled result
80-
if vision_count > 0:
81-
result = accumulator / vision_count
82-
else:
83-
result = 0.0
84-
85-
tl.store(output_ptr + pid, result)
40+
8641

8742

8843
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,

0 commit comments

Comments
 (0)