19
19
20
20
import torch
21
21
import torch .distributed as dist
22
- import torch .nn as nn
23
22
import torch_npu
24
23
from torch .distributed import ProcessGroup
25
24
from torch .nn .parameter import Parameter
@@ -63,14 +62,11 @@ def __init__(
63
62
* ,
64
63
return_bias : bool = True ,
65
64
):
66
- self .comm_group = None
67
- if prefix .find ("gate_up_proj" ) != - 1 and mlp_tp_enable ():
68
- self .comm_group = get_mlp_tp_group ()
69
- else :
65
+ # if self has attr `tp_size`, this means it has been customized by subclass
66
+ if not hasattr (self , "tp_size" ):
70
67
self .comm_group = get_tp_group ()
71
-
72
- self .tp_size = self .comm_group .world_size
73
- self .tp_rank = self .comm_group .rank_in_group
68
+ self .tp_size = self .comm_group .world_size
69
+ self .tp_rank = self .comm_group .rank_in_group
74
70
75
71
self .input_size_per_partition = input_size
76
72
self .output_size_per_partition = divide (output_size , self .tp_size )
@@ -81,6 +77,8 @@ def __init__(
81
77
divide (output_size , self .tp_size )
82
78
for output_size in self .output_sizes
83
79
]
80
+ # skip ColumnParallelLinear.__init__, as it will create weight_loader with default tp group
81
+ # we will create weight_loader by customized comm group
84
82
AscendLinearBase .__init__ (self ,
85
83
input_size ,
86
84
output_size ,
@@ -164,6 +162,7 @@ def __init__(
164
162
self .output_size_per_partition = output_size
165
163
self .output_partition_sizes = [output_size ]
166
164
165
+ # skip RowParallelLinear.__init__, as it will create weight_loader with default tp group
167
166
AscendLinearBase .__init__ (self ,
168
167
input_size ,
169
168
output_size ,
@@ -526,7 +525,7 @@ def __init__(
526
525
self .output_sizes = [
527
526
self .num_heads * self .head_size * tp_size , # q_proj
528
527
self .num_kv_heads * self .head_size * tp_size , # k_proj
529
- self .num_kv_heads * self .head_size * tp_size , # v_proj
528
+ self .num_kv_heads * self .head_size * tp_size , # v_proj
530
529
]
531
530
AscendColumnParallelLinear .__init__ (self ,
532
531
input_size = input_size ,
@@ -593,22 +592,15 @@ def __init__(
593
592
return_bias : bool = True ,
594
593
disable_tp : bool = False ,
595
594
):
596
- nn .Module .__init__ (self )
597
-
598
- # Keep input parameters
599
- self .input_size = input_size
600
- self .output_size = output_size
601
- self .skip_bias_add = skip_bias_add
602
- if params_dtype is None :
603
- params_dtype = torch .get_default_dtype ()
604
- self .params_dtype = params_dtype
605
- self .quant_config = quant_config
606
- self .prefix = prefix
607
- if quant_config is None :
608
- self .quant_method : Optional [
609
- QuantizeMethodBase ] = UnquantizedLinearMethod ()
595
+ if hasattr (self , "tp_rank" ) and hasattr (self , "tp_size" ):
596
+ tp_rank = self .tp_rank
597
+ tp_size = self .tp_size
598
+ super ().__init__ (input_size , output_size , skip_bias_add ,
599
+ params_dtype , quant_config , prefix , return_bias ,
600
+ disable_tp )
601
+ self .tp_rank = tp_rank
602
+ self .tp_size = tp_size
610
603
else :
611
- self .quant_method = quant_config .get_quant_method (self ,
612
- prefix = prefix )
613
- self .return_bias = return_bias
614
- self .disable_tp = disable_tp
604
+ super ().__init__ (input_size , output_size , skip_bias_add ,
605
+ params_dtype , quant_config , prefix , return_bias ,
606
+ disable_tp )
0 commit comments