Skip to content

[llm] support tensorwise fp8/int8 training #10612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from

Conversation

lugimzzz
Copy link
Contributor

@lugimzzz lugimzzz commented May 19, 2025

PR types

New features

PR changes

APIs

Description

新增支持功能:
1.新增权重scale和激活scale all_reduce_max,以支持不同TP策略切分
2. 支持TP+PP训练FP8/INT8训练,使用Unified Checkpoint对权重存储
3. 哈达玛矩阵乘改用对角block 哈达玛矩阵
4. 统一FP8/INT8训练代码逻辑
5. 新增支持Triton版本FP8权重AdamW优化器(含bf16 moment和offload功能)
6. 支持主干模型FP8/INT8 LoRA

后续PR待支持功能:
1.支持Unified Checkpoint对optimizer参数保存
2.目前FP8权重使用paddle.int8表示np.int8存储,后续修改为float8表示
3.FP8/INT8训练支持数据并行训练
4.对FP8/INT8 quant-matmul-dequant 过程进行性能加速和对Moe结构进行加速适配

image

Copy link

paddle-bot bot commented May 19, 2025

Thanks for your contribution!

Copy link

codecov bot commented May 19, 2025

Codecov Report

Attention: Patch coverage is 18.47826% with 225 lines in your changes missing coverage. Please review.

Project coverage is 46.93%. Comparing base (ddcb722) to head (fb0224a).
Report is 11 commits behind head on develop.

Current head fb0224a differs from pull request most recent head 1bfb4d9

Please upload reports for the commit 1bfb4d9 to get more accurate results.

Files with missing lines Patch % Lines
paddlenlp/quantization/qat_utils.py 7.69% 72 Missing ⚠️
paddlenlp/utils/optimizer.py 9.09% 40 Missing ⚠️
paddlenlp/utils/adamw_triton.py 16.27% 36 Missing ⚠️
paddlenlp/transformers/model_utils.py 21.62% 29 Missing ⚠️
paddlenlp/quantization/hadamard_utils.py 16.66% 25 Missing ⚠️
paddlenlp/transformers/conversion_utils.py 13.63% 19 Missing ⚠️
paddlenlp/quantization/quantization_utils.py 25.00% 3 Missing ⚠️
paddlenlp/trainer/trainer.py 50.00% 1 Missing ⚠️

❌ Your patch check has failed because the patch coverage (18.47%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.
❌ Your project check has failed because the head coverage (46.93%) is below the target coverage (58.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop   #10612      +/-   ##
===========================================
- Coverage    46.94%   46.93%   -0.02%     
===========================================
  Files          799      800       +1     
  Lines       132348   132416      +68     
===========================================
+ Hits         62137    62147      +10     
- Misses       70211    70269      +58     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@lugimzzz lugimzzz changed the title add uc [llm] support tensorwise fp8 training May 21, 2025
@@ -478,8 +525,8 @@ def load_state_dict(
scale_dict.update(res_scale_dict)

if device == "cpu":
for k in list(state_dict.keys()):
with device_guard():
with device_guard():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

减少反复set_device增加耗时

"weight_only_int4",
"weight_only_int8",
]
elif isinstance(config.quantization_config.weight_quantize_algo, dict):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight_only_int8不支持不同TP分片共享同一个scale,暂不支持wint8权重灵活转化TP策略

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

post_quantize 代表先TP切分权重再量化(针对wint4/wint8)

@@ -2537,6 +2615,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
# load pt weights early so that we know which dtype to init the model under
if not is_sharded and state_dict is None:
# 4. loading non-sharded ckpt from the state dict
# Quantization: Loading non-sharded ckpt does not support saving with merge_tensor_parallel
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂时不考虑非safetensor权重的量化加载和保存

@lugimzzz lugimzzz changed the title [llm] support tensorwise fp8 training [llm] support tensorwise fp8/int8 training May 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant