Skip to content

Conversation

kunpengW-code
Copy link
Contributor

@kunpengW-code kunpengW-code commented Aug 15, 2025

What this PR does / why we need it?

The deepseek w4a8 weights we supported before were in mindie-format format. It uses int8 to represent int4, so the weight size is similar to w8a8, and we need to do a few extra steps to make vllm-ascend load it normally.

Now we can directly use the new weight format, which uses two int4 packs to save the weight, the weight size is reduced, and there is no need to do many extra operations to directly use it on vllm-ascend, but we are also compatible with the weights of the previous mindie format.

The weight changes in the new version:

  1. The weight is packed (2 int4 pack to int8)
  2. The bias required in the apply method is directly generated by modelslim

Does this PR introduce any user-facing change?

no

How was this patch tested?

Adding ut case in tests/ut/quantization/test_w4a8_dynamic.py

1.How to get weights using Modelslim

Installation steps

we can use the branch br_release_MindStudio_8.1.RC2_TR5_20260624
git clone -b br_release_MindStudio_8.1.RC2_TR5_20260624 https://gitee.com/ascend/msit.git
cd msit/msmodelslim
bash install.sh

Generate w4a8 weights

cd /example/DeepSeek
Command reference: msmodelslim/example/DeepSeek/README.md Execute the pre-check and DeepSeek-R1 w4a8 mix quantization chapter
Reference command:python3 quant_deepseek_w4a8.py --model_path {Original weight path} --save_path {Generate weight path}

Adapt to vllm-ascend

Modification in config.json"model_type":deepseekv2 is changed to "model_type":deepseek_v3;

2.How to run w4a8

a.How to run eager mode

export VLLM_ASCEND_MLA_PA=1

python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 -dp $3 --enable_expert_parallel --quantization ascend --port $4 --max-model-len $5 --max-num-seqs $6 --enforce-eager
eg: python -m vllm.entrypoints.openai.api_server --model=/weightpath/w4a8_4_layer --trust-remote-code -tp 4 -dp 4 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 5120 --max-num-seqs 128 --enforce-eager

b.How to run graph mode

export HCCL_BUFFSIZE=1024

python -m vllm.entrypoints.openai.api_server --model=$1 --trust-remote-code -tp $2 -dp $3 --enable_expert_parallel --quantization ascend --port $4 --max-model-len $5 --additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}'
eg: python -m vllm.entrypoints.openai.api_server --model=/weight/dsr1_w4a8_vllm --trust-remote-code -tp 4 -dp 4 --enable_expert_parallel --quantization ascend --port 8002 --max-model-len 5120 --additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}'

…weights

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@kunpengW-code kunpengW-code changed the title [main][quantization] Adapt to the new format of ds w4a8 quantization … [main][quantization] Adapt to the new format of ds w4a8 weight Aug 15, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adapts the w4a8 dynamic quantization method to support a new weight format, introducing version-dependent logic for weight and parameter creation and processing. The changes are mostly in vllm_ascend/quantization/w4a8_dynamic.py and are accompanied by updates to the unit tests. While the changes are generally well-structured, I found a critical bug related to a shape inconsistency in the new format's w13_scale_bias parameter, which would likely lead to runtime errors. I've provided suggestions to fix this in both the implementation and the tests.

Comment on lines +231 to +235
param_dict["w13_scale_bias"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There seems to be a shape inconsistency for w13_scale_bias in the new quantization format. For the new version, w13_weight is created with w13_output_size = intermediate_size_per_partition. However, w13_scale_bias is created with a dimension of 2 * intermediate_size_per_partition. This is inconsistent with the corresponding weight and will likely cause runtime errors or incorrect results. The dimension should probably be intermediate_size_per_partition to match w13_weight.

Suggested change
param_dict["w13_scale_bias"] = torch.empty(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32)
param_dict["w13_scale_bias"] = torch.empty(
num_experts,
intermediate_size_per_partition,
1,
dtype=torch.float32)

Comment on lines +157 to +160
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
dtype=torch.float32)
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The setup for w13_scale_bias seems to be based on an incorrect shape definition in get_dynamic_quant_param. The dimension 2 * self.input_size is inconsistent with the corresponding w13_weight's dimension for the new quantization version. This should be self.input_size to match the weight.

Suggested change
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
dtype=torch.float32)
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)
w13_scale_bias = torch.zeros((self.experts, self.input_size, 1),
dtype=torch.float32)
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)

Comment on lines +167 to +168
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This assertion checks for a shape that is inconsistent with the corresponding weight's shape. Following the correction in the w13_scale_bias setup, this assertion should be updated to check for the correct shape, which should use self.input_size instead of 2 * self.input_size.

Suggested change
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, self.input_size))

Copy link

codecov bot commented Aug 15, 2025

Codecov Report

❌ Patch coverage is 96.47059% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 77.37%. Comparing base (eccfb71) to head (547075a).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
vllm_ascend/quantization/quant_config.py 0.00% 2 Missing ⚠️
vllm_ascend/quantization/w4a8_dynamic.py 97.77% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2392      +/-   ##
==========================================
+ Coverage   77.31%   77.37%   +0.05%     
==========================================
  Files         128      128              
  Lines       16405    16455      +50     
==========================================
+ Hits        12684    12732      +48     
- Misses       3721     3723       +2     
Flag Coverage Δ
unittests 77.37% <96.47%> (+0.05%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@wangxiyuan
Copy link
Collaborator

e2e passed here:https://github.com/vllm-project/vllm-ascend/actions/runs/17088776648
the new commit is just a rebase, let'merge this one ut passed

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
@wangxiyuan wangxiyuan merged commit c40d417 into vllm-project:main Aug 20, 2025
21 of 22 checks passed
wangxiaoteng888 pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Sep 25, 2025
…project#2392)

### What this PR does / why we need it?

The deepseek w4a8 weights we supported before were in mindie-format
format. It uses int8 to represent int4, so the weight size is similar to
w8a8, and we need to do a few extra steps to make vllm-ascend load it
normally.

Now we can directly use the new weight format, which uses two int4 packs
to save the weight, the weight size is reduced, and there is no need to
do many extra operations to directly use it on vllm-ascend, but we are
also compatible with the weights of the previous mindie format.

The weight changes in the new version: 
1. The weight is packed (2 int4 pack to int8)
2. The bias required in the apply method is directly generated by
modelslim

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?

Adding ut case in `tests/ut/quantization/test_w4a8_dynamic.py`

#### 1.How to get weights using Modelslim

##### Installation steps

we can use the branch br_release_MindStudio_8.1.RC2_TR5_20260624
git clone -b br_release_MindStudio_8.1.RC2_TR5_20260624
https://gitee.com/ascend/msit.git
cd msit/msmodelslim
bash install.sh

##### Generate w4a8 weights

cd /example/DeepSeek
Command reference: msmodelslim/example/DeepSeek/README.md Execute the
[pre-check](https://gitee.com/ascend/msit/blob/br_release_MindStudio_8.1.RC2_TR5_20260624/msmodelslim/example/DeepSeek/README.md#%E8%BF%90%E8%A1%8C%E5%89%8D%E5%BF%85%E6%A3%80)
and [DeepSeek-R1 w4a8 mix
quantization](https://gitee.com/ascend/msit/blob/br_release_MindStudio_8.1.RC2_TR5_20260624/msmodelslim/example/DeepSeek/README.md#deepseek-r1-w4a8-%E6%B7%B7%E5%90%88%E9%87%8F%E5%8C%96%E5%89%8D%E4%B8%89%E5%B1%82-mlpw8a8-dynamic-%E9%87%8F%E5%8C%96mla%E5%85%B1%E4%BA%AB%E4%B8%93%E5%AE%B6w8a8%E9%87%8F%E5%8C%96%E8%B7%AF%E7%94%B1%E4%B8%93%E5%AE%B6w4a8-dynamic%E9%87%8F%E5%8C%96)
chapter
Reference command:python3 quant_deepseek_w4a8.py --model_path {Original
weight path} --save_path {Generate weight path}

##### Adapt to vllm-ascend

Modification in `config.json`:`"model_type":deepseekv2` is changed to
`"model_type":deepseek_v3`;

#### 2.How to run w4a8

##### a.How to run eager mode

export VLLM_ASCEND_MLA_PA=1

python -m vllm.entrypoints.openai.api_server --model=$1
--trust-remote-code -tp $2 -dp $3 --enable_expert_parallel
--quantization ascend --port $4 --max-model-len $5 --max-num-seqs $6
--enforce-eager
eg: python -m vllm.entrypoints.openai.api_server
--model=/weightpath/w4a8_4_layer --trust-remote-code -tp 4 -dp 4
--enable_expert_parallel --quantization ascend --port 8002
--max-model-len 5120 --max-num-seqs 128 --enforce-eager

##### b.How to run graph mode

export HCCL_BUFFSIZE=1024

python -m vllm.entrypoints.openai.api_server --model=$1
--trust-remote-code -tp $2 -dp $3 --enable_expert_parallel
--quantization ascend --port $4 --max-model-len $5
--additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}'
eg: python -m vllm.entrypoints.openai.api_server
--model=/weight/dsr1_w4a8_vllm --trust-remote-code -tp 4 -dp 4
--enable_expert_parallel --quantization ascend --port 8002
--max-model-len 5120
--additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}'


- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@103f1ec

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Sep 26, 2025
…project#2392)

### What this PR does / why we need it?

The deepseek w4a8 weights we supported before were in mindie-format
format. It uses int8 to represent int4, so the weight size is similar to
w8a8, and we need to do a few extra steps to make vllm-ascend load it
normally.

Now we can directly use the new weight format, which uses two int4 packs
to save the weight, the weight size is reduced, and there is no need to
do many extra operations to directly use it on vllm-ascend, but we are
also compatible with the weights of the previous mindie format.

The weight changes in the new version: 
1. The weight is packed (2 int4 pack to int8)
2. The bias required in the apply method is directly generated by
modelslim

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?

Adding ut case in `tests/ut/quantization/test_w4a8_dynamic.py`

#### 1.How to get weights using Modelslim

##### Installation steps

we can use the branch br_release_MindStudio_8.1.RC2_TR5_20260624
git clone -b br_release_MindStudio_8.1.RC2_TR5_20260624
https://gitee.com/ascend/msit.git
cd msit/msmodelslim
bash install.sh

##### Generate w4a8 weights

cd /example/DeepSeek
Command reference: msmodelslim/example/DeepSeek/README.md Execute the
[pre-check](https://gitee.com/ascend/msit/blob/br_release_MindStudio_8.1.RC2_TR5_20260624/msmodelslim/example/DeepSeek/README.md#%E8%BF%90%E8%A1%8C%E5%89%8D%E5%BF%85%E6%A3%80)
and [DeepSeek-R1 w4a8 mix
quantization](https://gitee.com/ascend/msit/blob/br_release_MindStudio_8.1.RC2_TR5_20260624/msmodelslim/example/DeepSeek/README.md#deepseek-r1-w4a8-%E6%B7%B7%E5%90%88%E9%87%8F%E5%8C%96%E5%89%8D%E4%B8%89%E5%B1%82-mlpw8a8-dynamic-%E9%87%8F%E5%8C%96mla%E5%85%B1%E4%BA%AB%E4%B8%93%E5%AE%B6w8a8%E9%87%8F%E5%8C%96%E8%B7%AF%E7%94%B1%E4%B8%93%E5%AE%B6w4a8-dynamic%E9%87%8F%E5%8C%96)
chapter
Reference command:python3 quant_deepseek_w4a8.py --model_path {Original
weight path} --save_path {Generate weight path}

##### Adapt to vllm-ascend

Modification in `config.json`:`"model_type":deepseekv2` is changed to
`"model_type":deepseek_v3`;

#### 2.How to run w4a8

##### a.How to run eager mode

export VLLM_ASCEND_MLA_PA=1

python -m vllm.entrypoints.openai.api_server --model=$1
--trust-remote-code -tp $2 -dp $3 --enable_expert_parallel
--quantization ascend --port $4 --max-model-len $5 --max-num-seqs $6
--enforce-eager
eg: python -m vllm.entrypoints.openai.api_server
--model=/weightpath/w4a8_4_layer --trust-remote-code -tp 4 -dp 4
--enable_expert_parallel --quantization ascend --port 8002
--max-model-len 5120 --max-num-seqs 128 --enforce-eager

##### b.How to run graph mode

export HCCL_BUFFSIZE=1024

python -m vllm.entrypoints.openai.api_server --model=$1
--trust-remote-code -tp $2 -dp $3 --enable_expert_parallel
--quantization ascend --port $4 --max-model-len $5
--additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}'
eg: python -m vllm.entrypoints.openai.api_server
--model=/weight/dsr1_w4a8_vllm --trust-remote-code -tp 4 -dp 4
--enable_expert_parallel --quantization ascend --port 8002
--max-model-len 5120
--additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}'


- vLLM version: v0.10.0
- vLLM main:
vllm-project/vllm@103f1ec

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants