Skip to content

Commit af2a886

Browse files
realliujiaxuweijinqian0
andauthored
refactor linear (#2867)
### What this PR does / why we need it? The current linear.py has the following issues: - There is redundant conditional logic in the `comm_group` and `forward` selection for classes such as `AscendMergedColumnParallelLinear`. - Inconsistent comm_group selection logic exists among `AscendMergedColumnParallelLinear`, `AscendColumnParallelLinear`, and `AscendQKVParallelLinear`. To address these two issues, this PR encapsulates `comm_group` and `forward` into classes and extracts the classes selection logic into common functions. For future additions of custom communication groups or forward methods, it will only be necessary to extend `CustomColumnParallelOp` or `CustomRowParallelOp` and add new selection logic. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@dd39baf --------- Signed-off-by: realliujiaxu <realliujiaxu@163.com> Co-authored-by: weijinqian0 <weijinqian@huawei.com>
1 parent a7f8ed3 commit af2a886

File tree

4 files changed

+688
-503
lines changed

4 files changed

+688
-503
lines changed

tests/ut/models/test_qwen2_5_vl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def init_vision_transformer(
295295
mock_group.rank_in_group = 0
296296
mock_group.world_size = 2
297297
mocker.patch(
298-
"vllm_ascend.ops.linear.get_tp_group",
298+
"vllm_ascend.ops.linear_op.get_tp_group",
299299
return_value=mock_group,
300300
)
301301

tests/ut/ops/test_linear.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77

88
from vllm_ascend import ascend_config
99
from vllm_ascend.distributed import parallel_state
10-
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
11-
AscendMergedColumnParallelLinear,
10+
from vllm_ascend.ops.linear import (AscendMergedColumnParallelLinear,
1211
AscendRowParallelLinear)
1312

1413

@@ -32,7 +31,7 @@ def setUp(self):
3231
return_value=self.mock_group),
3332
patch("vllm_ascend.distributed.parallel_state.get_mlp_tp_group",
3433
return_value=self.mock_group),
35-
patch("vllm_ascend.ops.linear.get_tp_group",
34+
patch("vllm_ascend.ops.linear_op.get_tp_group",
3635
return_value=self.mock_group),
3736
patch("vllm_ascend.utils.mlp_tp_enable", return_value=True),
3837
patch("vllm_ascend.utils.oproj_tp_enable", return_value=True)
@@ -56,8 +55,7 @@ def test_mlp_optimize(self):
5655
output_size=8,
5756
prefix="down_proj",
5857
)
59-
self.assertEqual(linear.comm_group, parallel_state._MLP_TP)
60-
self.assertEqual(linear.forward_type, "mlp_tp")
58+
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
6159

6260
input_tensor = torch.randn(16, 8)
6361
linear(input_tensor)
@@ -71,34 +69,23 @@ def test_oproj_tp(self):
7169
output_size=8,
7270
prefix="o_proj",
7371
)
74-
self.assertEqual(linear.comm_group, parallel_state._OTP)
75-
self.assertEqual(linear.forward_type, "oproj_tp")
72+
self.assertEqual(linear.custom_op.comm_group, parallel_state._OTP)
7673

7774
input_tensor = torch.randn(16, 8)
7875
linear(input_tensor)
7976

8077

81-
class TestAscendColumnParallelLinear(BaseLinearTest):
82-
83-
def test_mlp_tp_init(self):
84-
linear = AscendColumnParallelLinear(
85-
input_size=16,
86-
output_size=8,
87-
prefix="down_proj",
88-
)
89-
self.assertEqual(linear.comm_group, parallel_state._MLP_TP)
90-
91-
9278
class TestAscendMergedColumnParallelLinear(BaseLinearTest):
9379

9480
def test_merged_mlp_tp_init(self):
81+
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
82+
9583
linear = AscendMergedColumnParallelLinear(
9684
input_size=16,
9785
output_sizes=[8, 8],
9886
prefix="gate_up_proj",
9987
)
100-
self.assertEqual(linear.comm_group, parallel_state._MLP_TP)
101-
self.assertEqual(linear.forward_type, "mlp_tp")
88+
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
10289

10390

10491
if __name__ == '__main__':

0 commit comments

Comments
 (0)