28
28
29
29
from vllm_ascend import envs
30
30
from vllm_ascend .ops .layernorm import AddRMSNormW8A8Quant
31
+ from vllm_ascend .utils import npu_stream_switch
31
32
32
33
33
34
def pad (tensor , x ):
@@ -130,6 +131,7 @@ def __init__(self,
130
131
self .tp_size = get_tensor_model_parallel_world_size ()
131
132
self .tp_rank = get_tensor_model_parallel_rank ()
132
133
self .enable_fc = envs .VLLM_ASCEND_ENABLE_FLASHCOMM
134
+ self .alt_stream = torch .npu .Stream ()
133
135
if self .enable_fc == 2 :
134
136
self .o_proj = ReplicatedLinear (
135
137
self .total_num_heads * self .head_dim ,
@@ -156,15 +158,17 @@ def forward(
156
158
) -> torch .Tensor :
157
159
qkv , _ = self .qkv_proj (hidden_states )
158
160
q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
159
- # Add qk-norm
160
- q_by_head = q .view (* q .shape [:- 1 ], q .shape [- 1 ] // self .head_dim ,
161
- self .head_dim )
162
- q_by_head = self .q_norm (q_by_head )
163
- q = q_by_head .view (q .shape )
161
+ with npu_stream_switch (self .alt_stream ):
162
+ # Add qk-norm
163
+ q_by_head = q .view (* q .shape [:- 1 ], q .shape [- 1 ] // self .head_dim ,
164
+ self .head_dim )
165
+ q_by_head = self .q_norm (q_by_head )
166
+ q = q_by_head .view (q .shape )
164
167
k_by_head = k .view (* k .shape [:- 1 ], k .shape [- 1 ] // self .head_dim ,
165
168
self .head_dim )
166
169
k_by_head = self .k_norm (k_by_head )
167
170
k = k_by_head .view (k .shape )
171
+ torch .npu .current_stream ().wait_stream (self .alt_stream )
168
172
q , k = self .rotary_emb (positions ,
169
173
q ,
170
174
k ,
0 commit comments