21
21
from torch import nn
22
22
from torch .nn .parameter import Parameter
23
23
from vllm .distributed import divide , tensor_model_parallel_all_reduce
24
- from vllm .distributed .parallel_state import get_tp_group
25
- import torch . distributed as dist
24
+ from vllm .distributed .parallel_state import get_dp_group , get_tp_group
25
+ from vllm . forward_context import get_forward_context
26
26
from vllm .model_executor .layers .logits_processor import LogitsProcessor
27
27
from vllm .model_executor .layers .quantization .base_config import (
28
28
QuantizationConfig , QuantizeMethodBase , method_has_implemented_embedding )
29
29
from vllm .model_executor .layers .vocab_parallel_embedding import (
30
30
DEFAULT_VOCAB_PADDING_SIZE , ParallelLMHead , UnquantizedEmbeddingMethod ,
31
31
VocabParallelEmbedding , pad_vocab_size )
32
32
from vllm .model_executor .utils import set_weight_attrs
33
- from vllm .distributed .parallel_state import get_dp_group
34
- from vllm .forward_context import get_forward_context
35
33
from vllm .utils import logger
36
34
37
- from vllm_ascend .distributed .parallel_state import get_lmhead_tp_group , get_emtp_group
38
- from vllm_ascend .utils import lmhead_tp_enable , embedding_tp_enable
35
+ from vllm_ascend .distributed .parallel_state import (get_emtp_group ,
36
+ get_lmhead_tp_group )
37
+ from vllm_ascend .utils import embedding_tp_enable , lmhead_tp_enable
39
38
40
39
41
40
class AscendVocabParallelEmbedding (VocabParallelEmbedding ):
@@ -150,30 +149,41 @@ def _get_masked_input_and_mask(
150
149
input_ = vocab_mask * (input_ - valid_offset )
151
150
return input_ , ~ vocab_mask
152
151
153
- def _get_local_batch_slice (self , tensor : torch .Tensor ,
154
- batch_sizes : list ,
155
- local_batch_size : int ,
156
- rank : int ) -> torch .Tensor :
152
+ def _get_local_batch_slice (self , tensor : torch .Tensor , batch_sizes : list ,
153
+ local_batch_size : int ,
154
+ rank : int ) -> torch .Tensor :
157
155
end_idx = batch_sizes [rank ]
158
156
start_idx = end_idx - local_batch_size
159
157
return tensor [start_idx :end_idx ]
160
-
158
+
161
159
def forward (self , input_ ):
162
160
if embedding_tp_enable ():
163
- logger .info (f"rank:{ get_dp_group ().rank_in_group } embedding_tp_enable" )
161
+ logger .info (
162
+ f"rank:{ get_dp_group ().rank_in_group } embedding_tp_enable" )
164
163
return self ._forward_embed_tp (input_ )
165
164
else :
166
165
return self ._forward_normal (input_ )
167
-
166
+
168
167
def _forward_embed_tp (self , input_ ):
169
- cu_tokens_across_dp_cpu = get_forward_context ().dp_metadata .cu_tokens_across_dp_cpu
170
- global_dp_batch_size = torch .diff (cu_tokens_across_dp_cpu , prepend = cu_tokens_across_dp_cpu .new_zeros (1 ))
171
- logger .info (f"debug input_: { input_ .shape } \n global_dp_batch_size: { global_dp_batch_size } \n " )
172
- lmhead_group_batch_size = [global_dp_batch_size [x ] for x in get_lmhead_tp_group ().ranks ]
168
+ cu_tokens_across_dp_cpu = get_forward_context (
169
+ ).dp_metadata .cu_tokens_across_dp_cpu
170
+ global_dp_batch_size = torch .diff (
171
+ cu_tokens_across_dp_cpu ,
172
+ prepend = cu_tokens_across_dp_cpu .new_zeros (1 ))
173
+ logger .info (
174
+ f"debug input_: { input_ .shape } \n global_dp_batch_size: { global_dp_batch_size } \n "
175
+ )
176
+ lmhead_group_batch_size = [
177
+ global_dp_batch_size [x ] for x in get_lmhead_tp_group ().ranks
178
+ ]
173
179
local_batch_size = input_ .size (0 )
174
- gathered_input = [torch .empty (batch_size , dtype = input_ .dtype , device = 'npu' ) for batch_size in lmhead_group_batch_size ]
175
- torch .distributed .all_gather (
176
- gathered_input , input_ , group = get_lmhead_tp_group ().device_group )
180
+ gathered_input = [
181
+ torch .empty (batch_size , dtype = input_ .dtype , device = 'npu' )
182
+ for batch_size in lmhead_group_batch_size
183
+ ]
184
+ torch .distributed .all_gather (gathered_input ,
185
+ input_ ,
186
+ group = get_lmhead_tp_group ().device_group )
177
187
complete_input = torch .cat (gathered_input , dim = 0 )
178
188
masked_input , input_mask = self ._get_masked_input_and_mask (
179
189
complete_input , self .shard_indices .org_vocab_start_index ,
@@ -182,20 +192,18 @@ def _forward_embed_tp(self, input_):
182
192
self .shard_indices .added_vocab_start_index ,
183
193
self .shard_indices .added_vocab_end_index )
184
194
logger .info (f"all_gather_down complete_input: { complete_input .shape } " )
185
-
195
+
186
196
output = self .quant_method .embedding (self , masked_input .long ())
187
197
output .masked_fill_ (input_mask .unsqueeze (- 1 ), 0 )
188
198
output = tensor_model_parallel_all_reduce (output )
189
199
# output = output[lmhead_group_batch_size[get_lmhead_tp_group().rank_in_group]-local_batch_size :lmhead_group_batch_size[get_lmhead_tp_group().rank_in_group]]
190
200
# Extract the local batch portion from the gathered output
191
201
lmhead_tp_group = get_lmhead_tp_group ()
192
- output = self ._get_local_batch_slice (
193
- output ,
194
- lmhead_group_batch_size ,
195
- local_batch_size ,
196
- lmhead_tp_group .rank_in_group
197
- )
198
- logger .info (f"rank:{ get_dp_group ().rank_in_group } output: { output .shape } " )
202
+ output = self ._get_local_batch_slice (output , lmhead_group_batch_size ,
203
+ local_batch_size ,
204
+ lmhead_tp_group .rank_in_group )
205
+ logger .info (
206
+ f"rank:{ get_dp_group ().rank_in_group } output: { output .shape } " )
199
207
return output
200
208
201
209
def _forward_normal (self , input_ ):
@@ -209,16 +217,23 @@ def _forward_normal(self, input_):
209
217
self .shard_indices .added_vocab_end_index )
210
218
else :
211
219
masked_input = input_
212
- logger .info (f"rank:{ get_dp_group ().rank_in_group } masked_input:{ masked_input .shape } " )
220
+ logger .info (
221
+ f"rank:{ get_dp_group ().rank_in_group } masked_input:{ masked_input .shape } "
222
+ )
213
223
# Get the embeddings.
214
- output_parallel = self .quant_method .embedding (self , masked_input .long ())
215
- logger .info (f"rank:{ get_dp_group ().rank_in_group } output_parallel:{ output_parallel .shape } " )
224
+ output_parallel = self .quant_method .embedding (self ,
225
+ masked_input .long ())
226
+ logger .info (
227
+ f"rank:{ get_dp_group ().rank_in_group } output_parallel:{ output_parallel .shape } "
228
+ )
216
229
# Mask the output embedding.
217
230
if self .tp_size > 1 :
218
231
output_parallel .masked_fill_ (input_mask .unsqueeze (- 1 ), 0 )
219
232
# Reduce across all the model parallel GPUs.
220
233
output = tensor_model_parallel_all_reduce (output_parallel )
221
- logger .info (f"rank:{ get_dp_group ().rank_in_group } forward_normal output:{ output .shape } " )
234
+ logger .info (
235
+ f"rank:{ get_dp_group ().rank_in_group } forward_normal output:{ output .shape } "
236
+ )
222
237
return output
223
238
224
239
0 commit comments