Skip to content

Commit 541645c

Browse files
committed
refactor linear
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
1 parent 93e28e6 commit 541645c

File tree

1 file changed

+19
-27
lines changed

1 file changed

+19
-27
lines changed

vllm_ascend/ops/linear.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import torch
2121
import torch.distributed as dist
22-
import torch.nn as nn
2322
import torch_npu
2423
from torch.distributed import ProcessGroup
2524
from torch.nn.parameter import Parameter
@@ -63,14 +62,11 @@ def __init__(
6362
*,
6463
return_bias: bool = True,
6564
):
66-
self.comm_group = None
67-
if prefix.find("gate_up_proj") != -1 and mlp_tp_enable():
68-
self.comm_group = get_mlp_tp_group()
69-
else:
65+
# if self has attr `tp_size`, this means it has been customized by subclass
66+
if not hasattr(self, "tp_size"):
7067
self.comm_group = get_tp_group()
71-
72-
self.tp_size = self.comm_group.world_size
73-
self.tp_rank = self.comm_group.rank_in_group
68+
self.tp_size = self.comm_group.world_size
69+
self.tp_rank = self.comm_group.rank_in_group
7470

7571
self.input_size_per_partition = input_size
7672
self.output_size_per_partition = divide(output_size, self.tp_size)
@@ -81,6 +77,8 @@ def __init__(
8177
divide(output_size, self.tp_size)
8278
for output_size in self.output_sizes
8379
]
80+
# skip ColumnParallelLinear.__init__, as it will create weight_loader with default tp group
81+
# we will create weight_loader by customized comm group
8482
AscendLinearBase.__init__(self,
8583
input_size,
8684
output_size,
@@ -164,6 +162,7 @@ def __init__(
164162
self.output_size_per_partition = output_size
165163
self.output_partition_sizes = [output_size]
166164

165+
# skip RowParallelLinear.__init__, as it will create weight_loader with default tp group
167166
AscendLinearBase.__init__(self,
168167
input_size,
169168
output_size,
@@ -526,7 +525,7 @@ def __init__(
526525
self.output_sizes = [
527526
self.num_heads * self.head_size * tp_size, # q_proj
528527
self.num_kv_heads * self.head_size * tp_size, # k_proj
529-
self.num_kv_heads * self.head_size * tp_size, # v_proj
528+
self.num_kv_heads * self.head_size * tp_size, # v_proj
530529
]
531530
AscendColumnParallelLinear.__init__(self,
532531
input_size=input_size,
@@ -593,22 +592,15 @@ def __init__(
593592
return_bias: bool = True,
594593
disable_tp: bool = False,
595594
):
596-
nn.Module.__init__(self)
597-
598-
# Keep input parameters
599-
self.input_size = input_size
600-
self.output_size = output_size
601-
self.skip_bias_add = skip_bias_add
602-
if params_dtype is None:
603-
params_dtype = torch.get_default_dtype()
604-
self.params_dtype = params_dtype
605-
self.quant_config = quant_config
606-
self.prefix = prefix
607-
if quant_config is None:
608-
self.quant_method: Optional[
609-
QuantizeMethodBase] = UnquantizedLinearMethod()
595+
if hasattr(self, "tp_rank") and hasattr(self, "tp_size"):
596+
tp_rank = self.tp_rank
597+
tp_size = self.tp_size
598+
super().__init__(input_size, output_size, skip_bias_add,
599+
params_dtype, quant_config, prefix, return_bias,
600+
disable_tp)
601+
self.tp_rank = tp_rank
602+
self.tp_size = tp_size
610603
else:
611-
self.quant_method = quant_config.get_quant_method(self,
612-
prefix=prefix)
613-
self.return_bias = return_bias
614-
self.disable_tp = disable_tp
604+
super().__init__(input_size, output_size, skip_bias_add,
605+
params_dtype, quant_config, prefix, return_bias,
606+
disable_tp)

0 commit comments

Comments
 (0)