@@ -195,24 +195,25 @@ def forward(
195
195
input_ ,
196
196
is_prefill : bool = True ,
197
197
) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [Parameter ]]]:
198
+ # Choose different forward function according to the type of TP group
199
+ if self .forward_type == "oproj_tp" :
200
+ return self ._forward_oproj_tp (input_ )
201
+ elif self .forward_type == "mlp_tp" :
202
+ return self ._forward_mlp_tp (input_ )
203
+ else :
204
+ return super ().forward (input_ )
198
205
206
+ # enable custom MLP tensor parallel
207
+ def _forward_mlp_tp (self , input_ : torch .Tensor ) -> torch .Tensor :
208
+
199
209
if self .input_is_parallel :
200
210
input_parallel = input_
201
211
else :
202
212
splitted_input = split_tensor_along_last_dim (
203
213
input_ , num_partitions = self .tp_size )
204
214
input_parallel = splitted_input [self .tp_rank ].contiguous ()
215
+
205
216
assert self .quant_method is not None
206
- # Choose different forward function according to the type of TP group
207
- if self .forward_type == "oproj_tp" :
208
- return self ._forward_oproj_tp (input_parallel )
209
- elif self .forward_type == "mlp_tp" :
210
- return self ._forward_mlp_tp (input_parallel )
211
- else :
212
- return RowParallelLinear .forward (self , input_ )
213
-
214
- #
215
- def _forward_mlp_tp (self , input_parallel : torch .Tensor ) -> torch .Tensor :
216
217
bias_ = None if (self .tp_rank > 0 or self .skip_bias_add ) else self .bias
217
218
output_parallel = self .quant_method .apply (self ,
218
219
input_parallel ,
@@ -224,12 +225,19 @@ def _forward_mlp_tp(self, input_parallel: torch.Tensor) -> torch.Tensor:
224
225
return output
225
226
return output , output_bias
226
227
227
- # enable oproj tp forward function
228
+ # enable custom Oproj tensor parallel
228
229
def _forward_oproj_tp (
229
230
self ,
230
- input_parallel : torch .Tensor ,
231
+ input_ : torch .Tensor ,
231
232
) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [Parameter ]]]:
232
-
233
+
234
+ if self .input_is_parallel :
235
+ input_parallel = input_
236
+ else :
237
+ splitted_input = split_tensor_along_last_dim (
238
+ input_ , num_partitions = self .tp_size )
239
+ input_parallel = splitted_input [self .tp_rank ].contiguous ()
240
+
233
241
# Prepare tensors for all-to-all communication
234
242
local_batch_size = input_parallel .size (0 )
235
243
chunk_size = self .input_size_per_partition
@@ -254,6 +262,7 @@ def _forward_oproj_tp(
254
262
255
263
# Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
256
264
bias_ = None if (self .tp_rank > 0 or self .skip_bias_add ) else self .bias
265
+ assert self .quant_method is not None
257
266
output_parallel = self .quant_method .apply (self ,
258
267
input_parallel ,
259
268
bias = bias_ )
@@ -267,30 +276,6 @@ def _forward_oproj_tp(
267
276
return output
268
277
return output , output_bias
269
278
270
- # original forward function of RowParallelLinear
271
- # def _forward_normal(
272
- # self,
273
- # input_parallel: torch.Tensor,
274
- # ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
275
-
276
- # # Matrix multiply with quantized method
277
- # assert self.quant_method is not None
278
- # # Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
279
- # bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
280
- # output_parallel = self.quant_method.apply(self,
281
- # input_parallel,
282
- # bias=bias_)
283
-
284
- # if self.reduce_results and self.tp_size > 1:
285
- # output = tensor_model_parallel_all_reduce(output_parallel)
286
- # else:
287
- # output = output_parallel
288
- # # Handle bias return based on configuration
289
- # output_bias = self.bias if self.skip_bias_add else None
290
- # if not self.return_bias:
291
- # return output
292
- # return output, output_bias
293
-
294
279
295
280
class AscendMergedColumnParallelLinear (MergedColumnParallelLinear ):
296
281
"""Packed linear layers with column parallelism.
@@ -349,8 +334,7 @@ def forward(
349
334
if self .forward_type == "mlp_tp" :
350
335
return self ._forward_mlp_tp (input_ )
351
336
else :
352
- # same as origin ColumnParallelLinear forward
353
- return MergedColumnParallelLinear .forward (self , input_ )
337
+ return super ().forward (input_ )
354
338
355
339
def _forward_mlp_tp (
356
340
self ,
0 commit comments