-
Notifications
You must be signed in to change notification settings - Fork 528
[mxfp8 moe training] add MX MoE model converter using torchao mxfp8 moe training #1701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
960840d
to
90582f1
Compare
# TODO: add warning in torchao when this happens, or find a better way to avoid this. | ||
if self.moe_fqns: | ||
self._convert_moe_layers(model) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it feels like the converter registered for self.config should handle this case specifically
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as in having separate converters registered for mxfp8 moe, fp8 moe etc? If so, I kind of agree actually, the converters are becoming a bit of a mess, and having separate converters would also allow users to convert just dense or just MoE (or both), rather than the current state of having to convert dense in order to convert MoE.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that is one way, another thing I was thinking is that https://github.com/pytorch/ao/blob/93030e750186ace1c1c2ee7a849e2818a9f0ffde/torchao/prototype/moe_training/conversion_utils.py#L50
should be able to gracefully handle the case where the module has already be converted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separated the dense and MoE converters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the log
Swapped w1.weight to ScaledGroupedMMTensor
it's hard to tell which w1
is converted. Can we include full fqn?
(eager only currently - compile support in progress)
Does it mean we shouldn't care too much about performance with this PR?
How about numerics?
I'll try to find some time to look into #1651
902b922
to
4643915
Compare
d674008
to
5c45908
Compare
5c45908
to
c02b5f2
Compare
c02b5f2
to
80c2d2e
Compare
Yes, i've been meaning to make this update in torchao... will do. |
filter_fqns = ["output", "router.gate"] | ||
moe_fqns_prototype = ["experts"] | ||
|
||
[mx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so you prefer to not update the mx_moe
configs in the toml configs?
moe_fqns_prototype: list[str] | str = field(default_factory=list) | ||
|
||
@dataclass | ||
class MXMoE: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separating Linear and MoE treatment is reasonable. I have two comments:
- To make naming consistent, should we rename
MX
toMXLinear
? - In general, as we are creating more and more quantization dataclasses in
job_config.py
, I somehow prefer we group them under the same root level dataclass, sayQuantization
, matching the folder organization. I think this can be achieved by using hierarchical dataclasses and nested toml https://toml.io/en/v1.0.0#array-of-tables. We may have to updateconfig/manager.py
so doesn't need to be addressed in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make naming consistent, should we rename MX to MXLinear?
That's fine with me, can do in follow up
I somehow prefer we group them under the same root level dataclass, say Quantization, matching the folder organization. I think this can be achieved by using hierarchical dataclasses and nested toml
I like this idea!
# The quantization modules are intended to be ran under `torch.compile`` for competitive performance | ||
|
||
# Module level global constants | ||
MXFP8_GROUP_ALIGNMENT_SIZE = 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about float8, should we also put the constant here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I plan to refactor float8 rowwise MoE to use the same pattern in a self-contained follow up PR.
return True | ||
return False | ||
|
||
config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the torchtitan feature name here should match the torchao one. If you want mxtraining, make the torchao workflow adhere to that too.
Context
Changes
mx
model converter. Using low precision training on MoE layers should not be dependent on converting dense layers (plus, putting all the logic in one converted was getting messy).mx_moe
, with a user facing API to convert MoE layers to use mxfp8 scaled grouped gemms, instead of high precision grouped gemms.Test plan
NGPU=2 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=1 --model.converters="mx_moe" --mx_moe.fqns="experts" --model.print-after-conversion --metrics.log_freq=10 --training.steps=30 --compile.enable