Skip to content

Conversation

nuclearwu
Copy link
Contributor

@nuclearwu nuclearwu commented Jul 23, 2025

Signed-off-by: wuzhongjian wuzhongjian_yewu@cmss.chinamobile.com

What this PR does / why we need it?

self.hidden_size_per_attention_head = dist_utils.divide(
            projection_size, num_heads)
self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
            self.hidden_size_per_attention_head = MAX_PAD_SIZE

The intention of this code of __init__ method is: when the hidden size of each attention head is between 64 and 128, it will be filled to 128 to optimize the computing performance on the Ascend platform.
However, in the forward method, when calling torch_npu._npu_flash_attention_unpad, scale_value uses the original origin_hidden_size_per_attention_head. Rather than the hidden_size_per_attention_head that might be filled in:

scale_value=self.origin_hidden_size_per_attention_head**-0.5,

If hidden_size_per_attention_head is filled to 128, but scale_value still uses the origin_hidden_size_per_attention_head(for example, 84), it will lead to an incorrect scaling ratio, thereby affecting the calculation accuracy of the attention weight.

Does this PR introduce any user-facing change?

How was this patch tested?

@nuclearwu
Copy link
Contributor Author

nuclearwu commented Jul 23, 2025

@wangxiyuan @Yikun @ApsarasX Please review. Thank you!

Copy link

codecov bot commented Jul 23, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 65.78%. Comparing base (ac0bf13) to head (458a0b5).
⚠️ Report is 39 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1958   +/-   ##
=======================================
  Coverage   65.78%   65.78%           
=======================================
  Files          78       78           
  Lines        8406     8406           
=======================================
  Hits         5530     5530           
  Misses       2876     2876           
Flag Coverage Δ
unittests 65.78% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: wuzhongjian <wuzhongjian_yewu@cmss.chinamobile.com>
@wangxiyuan
Copy link
Collaborator

Let me check and run it locally, thanks for the fix

@nuclearwu
Copy link
Contributor Author

Let me check and run it locally, thanks for the fix

@wangxiyuan Could you share the results after running it on your end?

@wangxiyuan wangxiyuan added accuracy-test enable all accuracy test for PR ready-for-test start test by label for PR labels Jul 30, 2025
@zouyida2052
Copy link
Contributor

If convenience, please provide specific results on the dataset under both precisions. The scale value here is used for normalization — we want to avoid having the zero-padded regions affect the original data. If the size is increased from 80 to 128, the padded areas will influence the normalization process, which could ultimately impact the model's accuracy.

@wangxiyuan
Copy link
Collaborator

@wangxiyuan
Copy link
Collaborator

@nuclearwu any feeback about the accuracy problme?

@moguizhizi
Copy link

Signed-off-by: wuzhongjian wuzhongjian_yewu@cmss.chinamobile.com

What this PR does / why we need it?

self.hidden_size_per_attention_head = dist_utils.divide(
            projection_size, num_heads)
self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
            self.hidden_size_per_attention_head = MAX_PAD_SIZE

The intention of this code of __init__ method is: when the hidden size of each attention head is between 64 and 128, it will be filled to 128 to optimize the computing performance on the Ascend platform. However, in the forward method, when calling torch_npu._npu_flash_attention_unpad, scale_value uses the original origin_hidden_size_per_attention_head. Rather than the hidden_size_per_attention_head that might be filled in:

scale_value=self.origin_hidden_size_per_attention_head**-0.5,

If hidden_size_per_attention_head is filled to 128, but scale_value still uses the origin_hidden_size_per_attention_head(for example, 84), it will lead to an incorrect scaling ratio, thereby affecting the calculation accuracy of the attention weight.

Does this PR introduce any user-facing change?

How was this patch tested?

Since we are optimizing performance, why isn't padding applied in the non-VL pipeline?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

accuracy-test enable all accuracy test for PR ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants