Skip to content

Commit 6dca835

Browse files
committed
Refactor scattered w8a8 dynamic quantization operations
AscendW8A8DynamicLinearMethod is integrated into CustomDeepseekV2MLP in a very awkward way, causing scattered quantization operations all over the model scripts. Refactor to solve this problem. Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent 5cd5d64 commit 6dca835

File tree

2 files changed

+73
-75
lines changed

2 files changed

+73
-75
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
2626
# """Inference-only DeepseekV2/DeepseekV3 model."""
2727

28-
from typing import Any, Dict, List, Optional, Union
28+
from typing import Any, Dict, List, Optional, Tuple, Union
2929

3030
import torch
3131
import torch.distributed as dist
@@ -75,6 +75,29 @@
7575
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
7676

7777

78+
class CustomDeepseekV2SiluAndMul(SiluAndMul):
79+
80+
def __init__(self, *, weight_scale: Optional[torch.Tensor] = None):
81+
super().__init__()
82+
self.weight_scale = weight_scale
83+
84+
def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
85+
torch.Tensor]]):
86+
if isinstance(x, tuple):
87+
assert self.weight_scale is not None
88+
# For AscendW8A8DynamicLinearMethod:
89+
# a dynamic scale is passed along with the quantized value.
90+
quantized_x, dynamic_scale = x
91+
return torch_npu.npu_dequant_swiglu_quant(
92+
x=quantized_x,
93+
weight_scale=self.weight_scale,
94+
activation_scale=dynamic_scale,
95+
activate_left=True,
96+
quant_mode=1)
97+
else:
98+
return super().forward_oot(x)
99+
100+
78101
class CustomDeepseekV2MLP(nn.Module):
79102

80103
def __init__(
@@ -101,44 +124,33 @@ def __init__(
101124
if hidden_act != "silu":
102125
raise ValueError(f"Unsupported activation: {hidden_act}. "
103126
"Only silu is supported for now.")
104-
self.act_fn = SiluAndMul()
105127

106-
# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant
107-
self.is_dynamic_quant = not isinstance(
108-
self.gate_up_proj.quant_method,
109-
UnquantizedLinearMethod) and isinstance(
110-
self.gate_up_proj.quant_method.quant_method,
111-
AscendW8A8DynamicLinearMethod)
128+
quant_method = self.gate_up_proj.quant_method
129+
if isinstance(quant_method, UnquantizedLinearMethod):
130+
self.act_fn = CustomDeepseekV2SiluAndMul()
131+
elif isinstance(quant_method, AscendW8A8DynamicLinearMethod):
132+
# TODO(sdmyzlp): Currently preserved as before:
133+
# 1. The only quantization supported for silu is W8A8Dynamic
134+
# 2. Output dtype of gate_up/down is fixed to be int32/bfloat16
135+
#
136+
# Maybe one can implement a better and more general configuration
137+
# scheme, e.g. by somehow passing around the tweaked `quant_config`
138+
self.act_fn = CustomDeepseekV2SiluAndMul(
139+
weight_scale=self.gate_up_proj.weight_scale_fp32)
140+
# To be consumed by AscendW8A8DynamicLinearMethod.apply()
141+
self.gate_up_proj._dynamic_quant_config = {
142+
"output_dtype": torch.int32,
143+
"return_scale": True,
144+
}
145+
self.down_proj._dynamic_quant_config = {
146+
"output_dtype": torch.bfloat16,
147+
"return_scale": False,
148+
}
149+
else:
150+
raise NotImplementedError(
151+
f"Quantization with [{type(quant_method)}] is NOT supported")
112152

113153
def forward(self, x):
114-
if self.is_dynamic_quant:
115-
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
116-
x = torch_npu.npu_quant_matmul(
117-
x,
118-
self.gate_up_proj.weight,
119-
self.gate_up_proj.weight_scale,
120-
output_dtype=torch.int32,
121-
)
122-
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
123-
x=x,
124-
weight_scale=self.gate_up_proj.weight_scale_fp32,
125-
activation_scale=dynamic_scale,
126-
bias=None,
127-
quant_scale=None,
128-
quant_offset=None,
129-
group_index=None,
130-
activate_left=True,
131-
quant_mode=1)
132-
x = torch_npu.npu_quant_matmul(
133-
x,
134-
self.down_proj.weight,
135-
self.down_proj.weight_scale,
136-
pertoken_scale=dynamic_scale,
137-
output_dtype=torch.bfloat16,
138-
)
139-
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
140-
x = tensor_model_parallel_all_reduce(x)
141-
return x
142154
gate_up, _ = self.gate_up_proj(x)
143155
x = self.act_fn(gate_up)
144156
x, _ = self.down_proj(x)

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import Any, Callable, Dict, Optional
18+
from typing import Any, Callable, Dict, Optional, Tuple, Union
1919

2020
import torch
2121
import torch.distributed as dist
2222
import torch_npu
2323
import torchair as tng # type: ignore
24-
from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce
24+
from vllm.distributed import GroupCoordinator
2525

2626
import vllm_ascend.envs as envs_ascend
2727
from vllm_ascend.ascend_config import get_ascend_config
@@ -77,19 +77,9 @@ def apply_mlp(hidden_states: torch.Tensor,
7777
shared_experts = kwargs.get('shared_experts', None)
7878
if shared_experts:
7979
shared_gate_up = kwargs.get('shared_gate_up', None)
80-
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
8180
with tng.scope.npu_stream_switch('cv'):
82-
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
83-
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
84-
x=shared_gate_up,
85-
weight_scale=shared_experts.gate_up_proj.weight_scale_fp32,
86-
activation_scale=shared_dynamic_scale,
87-
bias=None,
88-
quant_scale=None,
89-
quant_offset=None,
90-
group_index=None,
91-
activate_left=True,
92-
quant_mode=1)
81+
tng.scope.npu_wait_tensor(shared_gate_up[0], hidden_states)
82+
shared_act = shared_experts.act_fn(shared_gate_up)
9383

9484
# gmm1: gate_up_proj
9585
hidden_states = torch_npu.npu_grouped_matmul(
@@ -122,16 +112,9 @@ def apply_mlp(hidden_states: torch.Tensor,
122112

123113
if shared_experts:
124114
with tng.scope.npu_stream_switch('cv'):
125-
tng.scope.npu_wait_tensor(shared_x, hidden_states)
126-
shared_output = torch_npu.npu_quant_matmul(
127-
shared_x,
128-
shared_experts.down_proj.weight,
129-
shared_experts.down_proj.weight_scale,
130-
pertoken_scale=shared_dynamic_scale,
131-
output_dtype=torch.bfloat16,
132-
)
133-
if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
134-
shared_output = tensor_model_parallel_all_reduce(shared_output)
115+
tng.scope.npu_wait_tensor(shared_act[0], hidden_states)
116+
shared_output, _ = shared_experts.down_proj(shared_act)
117+
135118
if shared_experts:
136119
return hidden_states, shared_output
137120
return hidden_states
@@ -193,17 +176,10 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
193176
shared_hidden_states = kwargs.get('shared_hidden_states', None)
194177
with tng.scope.npu_stream_switch('cv'):
195178
tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states)
196-
shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant(
179+
shared_gate_up, _ = shared_experts.gate_up_proj(
197180
shared_hidden_states)
198-
shared_gate_up = torch_npu.npu_quant_matmul(
199-
shared_x,
200-
shared_experts.gate_up_proj.weight,
201-
shared_experts.gate_up_proj.weight_scale,
202-
output_dtype=torch.int32,
203-
)
204181
kwargs.update({
205182
"shared_gate_up": shared_gate_up,
206-
"shared_dynamic_scale": shared_dynamic_scale,
207183
})
208184

209185
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
@@ -541,21 +517,31 @@ def get_perchannel_param(
541517
@staticmethod
542518
def apply(
543519
layer: torch.nn.Module,
544-
x: torch.Tensor,
520+
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
545521
bias: Optional[torch.Tensor] = None,
546522
tp_rank: Optional[int] = 0,
547523
) -> torch.Tensor:
548-
original_dtype = x.dtype
549-
# use ATB quantize
550-
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
551-
return torch_npu.npu_quant_matmul(
552-
quant_out,
524+
config = getattr(layer, "_dynamic_quant_config", {})
525+
if not isinstance(x, tuple):
526+
output_dtype = config.get("output_dtype", x.dtype)
527+
quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
528+
else:
529+
assert "output_dtype" in config.keys(), (
530+
f"DynamicLinearMethod needs explicitly specified `output_dtype`"
531+
f"for pre-quantized input, got config [{config}]")
532+
output_dtype = config["output_dtype"]
533+
quantized_x, dynamic_scale = x
534+
535+
output = torch_npu.npu_quant_matmul(
536+
quantized_x,
553537
layer.weight,
554538
layer.weight_scale,
555539
pertoken_scale=dynamic_scale,
556540
bias=bias,
557-
output_dtype=original_dtype,
541+
output_dtype=output_dtype,
558542
)
543+
return ((output, dynamic_scale)
544+
if config.get("return_scale", False) else output)
559545

560546
def process_weights_after_loading(self, layer):
561547
if self.transpose_weight:

0 commit comments

Comments
 (0)