Skip to content

Conversation

lidenghui1110
Copy link
Contributor

@lidenghui1110 lidenghui1110 commented Aug 1, 2025

What this PR does / why we need it?

This PR introduces Oproj matrix tensor model parallel to achieve decreasing of memory consumption. It only support graph mode in pure DP scenario.

In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8 GB NPU memory per RANK. We got best performance when oproj_tensor_parallel_size=4 without TPOT increasing.

performance data:
image

Does this PR introduce any user-facing change?

This PR introduces one new config in additional_config.

Name Effect Required Type Constraints
oproj_tensor_parallel_size Split the o_proj matrix along the row dimension (head num * head dim) into oproj_tensor_parallel_size pieces. No int default value is None, once this value is set, the feature will be enabled, head num * head dim must be divisible by this value.

example

--additional_config={"oproj_tensor_parallel_size": 8}

How was this patch tested?

Copy link

github-actions bot commented Aug 1, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

else:
self.register_parameter("bias", None)

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function seems to be identical with that of RowParallelLinear, why do we need to rewrite it here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in origin weight_load,

tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()

we need replace it into custom comm group

tp_rank = self.tp_rank
tp_size = self.tp_size

It seems that the latest vllm does not have this problem.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, thanks

Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

else:
tp_rank = get_tensor_model_parallel_rank()
else:
tp_rank = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What means tp_rank = 0?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This origin code here

if isinstance(layer, RowParallelLinear):
            tp_rank = get_tensor_model_parallel_rank()
            return self.quant_method.apply(layer, x, bias, tp_rank)
        return self.quant_method.apply(layer, x, bias)

The default situation is not passing tp, which is tp=0

@zzhx1 zzhx1 force-pushed the oproj branch 8 times, most recently from 51d1def to 1e76d68 Compare September 4, 2025 12:39
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
@zzhx1 zzhx1 force-pushed the oproj branch 2 times, most recently from 065be1d to 6baf96b Compare September 5, 2025 06:46
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
@zzhx1 zzhx1 force-pushed the oproj branch 3 times, most recently from dcb84e0 to e7aaaf8 Compare September 6, 2025 02:43
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
@zzhx1
Copy link
Contributor

zzhx1 commented Sep 6, 2025

@wangxiyuan This PR is ready,and also fixed the bug related to linearBase.

@@ -0,0 +1,15 @@
import vllm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks that these 3 file can merged into one

@wangxiyuan wangxiyuan merged commit 5a71815 into vllm-project:main Sep 7, 2025
25 checks passed
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Sep 10, 2025
…vllm-project#2167)

### What this PR does / why we need it?
This PR introduces Oproj matrix tensor model parallel to achieve
decreasing of memory consumption. It only support graph mode in pure DP
scenario.

In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8
GB NPU memory per RANK. We got best performance when
oproj_tensor_parallel_size=4 without TPOT increasing.

performance data:
<img width="1442" height="442" alt="image"
src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d"
/>

### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| oproj_tensor_parallel_size | Split the o_proj matrix along the row
dimension (head num * head dim) into oproj_tensor_parallel_size pieces.
| No | int | default value is None, once this value is set, the feature
will be enabled, head num * head dim must be divisible by this value. |

example

`--additional_config={"oproj_tensor_parallel_size": 8}`

### How was this patch tested?


- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@eddaafc

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zzh <zzh_201018@outlook.com>
offline893 pushed a commit to offline893/vllm-ascend that referenced this pull request Sep 16, 2025
…vllm-project#2167)

### What this PR does / why we need it?
This PR introduces Oproj matrix tensor model parallel to achieve
decreasing of memory consumption. It only support graph mode in pure DP
scenario.

In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8
GB NPU memory per RANK. We got best performance when
oproj_tensor_parallel_size=4 without TPOT increasing.

performance data:
<img width="1442" height="442" alt="image"
src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d"
/>

### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| oproj_tensor_parallel_size | Split the o_proj matrix along the row
dimension (head num * head dim) into oproj_tensor_parallel_size pieces.
| No | int | default value is None, once this value is set, the feature
will be enabled, head num * head dim must be divisible by this value. |

example

`--additional_config={"oproj_tensor_parallel_size": 8}`

### How was this patch tested?

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@eddaafc

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zzh <zzh_201018@outlook.com>
Signed-off-by: offline0806 <z00858301@china.huawei.com>
wangxiaoteng888 pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Sep 25, 2025
…vllm-project#2167)

### What this PR does / why we need it?
This PR introduces Oproj matrix tensor model parallel to achieve
decreasing of memory consumption. It only support graph mode in pure DP
scenario.

In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8
GB NPU memory per RANK. We got best performance when
oproj_tensor_parallel_size=4 without TPOT increasing.

performance data:
<img width="1442" height="442" alt="image"
src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d"
/>

### Does this PR introduce _any_ user-facing change?
This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| oproj_tensor_parallel_size | Split the o_proj matrix along the row
dimension (head num * head dim) into oproj_tensor_parallel_size pieces.
| No | int | default value is None, once this value is set, the feature
will be enabled, head num * head dim must be divisible by this value. |

example

`--additional_config={"oproj_tensor_parallel_size": 8}`

### How was this patch tested?


- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@eddaafc

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zzh <zzh_201018@outlook.com>
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Sep 26, 2025
…vllm-project#2167)

This PR introduces Oproj matrix tensor model parallel to achieve
decreasing of memory consumption. It only support graph mode in pure DP
scenario.

In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
oproj_tensor_parallel_size = 8, we have 1 ms TPOT increasing, saved 5.8
GB NPU memory per RANK. We got best performance when
oproj_tensor_parallel_size=4 without TPOT increasing.

performance data:
<img width="1442" height="442" alt="image"
src="https://github.com/user-attachments/assets/83270fc5-868a-4387-b0a9-fac29b4a376d"
/>

This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| oproj_tensor_parallel_size | Split the o_proj matrix along the row
dimension (head num * head dim) into oproj_tensor_parallel_size pieces.
| No | int | default value is None, once this value is set, the feature
will be enabled, head num * head dim must be divisible by this value. |

example

`--additional_config={"oproj_tensor_parallel_size": 8}`

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@eddaafc

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zzh <zzh_201018@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants