18
18
from typing import Optional , Union
19
19
20
20
import torch
21
+ import torch .nn .functional as F
21
22
from torch .nn .parameter import Parameter
23
+ import torch_npu
22
24
from vllm .distributed import (divide , get_tensor_model_parallel_rank ,
23
25
get_tensor_model_parallel_world_size ,
24
26
split_tensor_along_last_dim ,
28
30
ColumnParallelLinear ,
29
31
LinearBase ,
30
32
MergedColumnParallelLinear ,
31
- RowParallelLinear )
33
+ QKVParallelLinear ,
34
+ RowParallelLinear ,
35
+ UnquantizedLinearMethod )
32
36
from vllm .model_executor .layers .quantization .base_config import \
33
37
QuantizationConfig
34
38
from vllm .model_executor .utils import set_weight_attrs
39
+ from vllm .forward_context import get_forward_context
35
40
36
41
from vllm_ascend .distributed .parallel_state import (
37
42
get_mlp_tensor_model_parallel_rank ,
38
43
get_mlp_tensor_model_parallel_world_size , get_mlp_tp_group )
44
+ from vllm_ascend .quantization .w8a8 import AscendW8A8LinearMethod , quant_per_tensor
45
+ from vllm_ascend .utils import all_gather_and_maybe_unpad , maybe_pad_and_reduce_scatter
39
46
40
47
41
48
class AscendMlpColumnParallelLinear (ColumnParallelLinear ):
@@ -307,3 +314,148 @@ def forward(
307
314
if not self .return_bias :
308
315
return output
309
316
return output , output_bias
317
+
318
+
319
+ class AscendDenseMergedColumnParallelLinear (MergedColumnParallelLinear ):
320
+ """Linear layer with column parallelism.
321
+
322
+ Implemented multiple optimization projects for dense models, such as FlashComm and
323
+ communication-computation fusion.
324
+ """
325
+
326
+ def forward (
327
+ self ,
328
+ input_ : torch .Tensor
329
+ ) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [Parameter ]]]:
330
+ bias = self .bias if not self .skip_bias_add else None
331
+
332
+ # Matrix multiply.
333
+ assert self .quant_method is not None
334
+ forward_context = get_forward_context ()
335
+ flashcomm_v1_enabled = forward_context .flashcomm_v1_enabled
336
+ ag_matmal_enabled = forward_context .ag_matmal_enabled
337
+ pad_size = forward_context .pad_size
338
+ if not flashcomm_v1_enabled :
339
+ output_parallel = self .quant_method .apply (self , input_ , bias )
340
+ # fp or bf
341
+ elif ag_matmal_enabled and isinstance (self .quant_method , UnquantizedLinearMethod ):
342
+ raise NotImplementedError ("AllGather_MatMul with UnquantizedLinearMethod is not implemented yet." )
343
+ # w8a8 quant
344
+ elif ag_matmal_enabled and isinstance (self .quant_method .quant_method , AscendW8A8LinearMethod ):
345
+ raise NotImplementedError ("AllGather_MatMul with AscendW8A8LinearMethod is not implemented yet." )
346
+ else :
347
+ input_ = all_gather_and_maybe_unpad (input_ , pad_size , 0 )
348
+ output_parallel = self .quant_method .apply (self , input_ , bias )
349
+
350
+ if self .gather_output :
351
+ # All-gather across the partitions.
352
+ output = tensor_model_parallel_all_gather (output_parallel )
353
+ else :
354
+ output = output_parallel
355
+ output_bias = self .bias if self .skip_bias_add else None
356
+ if not self .return_bias :
357
+ return output
358
+ return output , output_bias
359
+
360
+
361
+ class AscendDenseQKVParallelLinear (QKVParallelLinear ):
362
+ """Linear layer with column parallelism.
363
+
364
+ Implemented multiple optimization projects for dense models, such as FlashComm and
365
+ communication-computation fusion.
366
+ """
367
+
368
+ def forward (
369
+ self ,
370
+ input_ : torch .Tensor
371
+ ) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [Parameter ]]]:
372
+ bias = self .bias if not self .skip_bias_add else None
373
+
374
+ # Matrix multiply.
375
+ assert self .quant_method is not None
376
+ forward_context = get_forward_context ()
377
+ layer_num = self .prefix .split ('.' )[2 ]
378
+ if layer_num == '0' :
379
+ flashcomm_v1_enabled = False
380
+ else :
381
+ flashcomm_v1_enabled = forward_context .flashcomm_v1_enabled
382
+ ag_matmal_enabled = forward_context .ag_matmal_enabled
383
+ pad_size = forward_context .pad_size
384
+ if not flashcomm_v1_enabled :
385
+ output_parallel = self .quant_method .apply (self , input_ , bias )
386
+ # fp or bf
387
+ elif ag_matmal_enabled and isinstance (self .quant_method , UnquantizedLinearMethod ):
388
+ raise NotImplementedError ("AllGather_MatMul with UnquantizedLinearMethod is not implemented yet." )
389
+ # w8a8 quant
390
+ elif ag_matmal_enabled and isinstance (self .quant_method .quant_method , AscendW8A8LinearMethod ):
391
+ raise NotImplementedError ("AllGather_MatMul with AscendW8A8LinearMethod is not implemented yet." )
392
+ else :
393
+ input_ = all_gather_and_maybe_unpad (input_ , pad_size , 0 )
394
+ output_parallel = self .quant_method .apply (self , input_ , bias )
395
+
396
+ if self .gather_output :
397
+ # All-gather across the partitions.
398
+ output = tensor_model_parallel_all_gather (output_parallel )
399
+ else :
400
+ output = output_parallel
401
+ output_bias = self .bias if self .skip_bias_add else None
402
+ if not self .return_bias :
403
+ return output
404
+ return output , output_bias
405
+
406
+
407
+ class AscendDenseRowParallelLinear (RowParallelLinear ):
408
+ """Linear layer with row parallelism.
409
+
410
+ Implemented multiple optimization projects for dense models, such as FlashComm and
411
+ communication-computation fusion.
412
+ """
413
+
414
+ def forward (
415
+ self ,
416
+ input_ : torch .Tensor
417
+ ) -> Union [torch .Tensor , tuple [torch .Tensor , Optional [Parameter ]]]:
418
+ tp_rank = get_tensor_model_parallel_rank ()
419
+ forward_context = get_forward_context ()
420
+ flashcomm_v1_enabled = forward_context .flashcomm_v1_enabled
421
+ matmul_rs_enabled = forward_context .matmul_rs_enabled
422
+ pad_size = forward_context .pad_size
423
+ if self .input_is_parallel :
424
+ input_parallel = input_
425
+ else :
426
+ tp_rank = get_tensor_model_parallel_rank ()
427
+ splitted_input = split_tensor_along_last_dim (
428
+ input_ , num_partitions = self .tp_size )
429
+ input_parallel = splitted_input [tp_rank ].contiguous ()
430
+
431
+ # Matrix multiply.
432
+ assert self .quant_method is not None
433
+ # Only fuse bias add into GEMM for rank 0 (this ensures that
434
+ # bias will not get added more than once in TP>1 case)
435
+ bias_ = None if (self .tp_rank > 0 or self .skip_bias_add ) else self .bias
436
+ if self .tp_size == 1 or not self .reduce_results :
437
+ output = self .quant_method .apply (self ,
438
+ input_parallel ,
439
+ bias = bias_ )
440
+ elif not flashcomm_v1_enabled :
441
+ output_parallel = self .quant_method .apply (self ,
442
+ input_parallel ,
443
+ bias = bias_ )
444
+ output = tensor_model_parallel_all_reduce (output_parallel )
445
+ # fp or bf
446
+ elif matmul_rs_enabled and isinstance (self .quant_method , UnquantizedLinearMethod ):
447
+ raise NotImplementedError ("Matmul_ReduceScatter with UnquantizedLinearMethod is not implemented yet." )
448
+ # w8a8 quant
449
+ elif matmul_rs_enabled and isinstance (self .quant_method .quant_method , AscendW8A8LinearMethod ):
450
+ raise NotImplementedError ("Matmul_ReduceScatter with AscendW8A8LinearMethod is not implemented yet." )
451
+ else :
452
+ output_parallel = self .quant_method .apply (self ,
453
+ input_parallel ,
454
+ bias = bias_ )
455
+ output = maybe_pad_and_reduce_scatter (output_parallel , pad_size , 0 )
456
+
457
+ output_bias = self .bias if self .skip_bias_add else None
458
+
459
+ if not self .return_bias :
460
+ return output
461
+ return output , output_bias
0 commit comments