Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Sep 12, 2025

Context

Changes

  • Remove MoE conversion logic from mx model converter. Using low precision training on MoE layers should not be dependent on converting dense layers (plus, putting all the logic in one converted was getting messy).
  • Add new model converter mx_moe, with a user facing API to convert MoE layers to use mxfp8 scaled grouped gemms, instead of high precision grouped gemms.
  • This clean separation allows users more flexibility/control, to use mxfp8 for dense layers, MoE layers, or both.

Test plan

  • NGPU=2 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=1 --model.converters="mx_moe" --mx_moe.fqns="experts" --model.print-after-conversion --metrics.log_freq=10 --training.steps=30 --compile.enable

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 12, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft September 12, 2025 00:21
@danielvegamyhre danielvegamyhre marked this pull request as ready for review September 12, 2025 00:27
@danielvegamyhre danielvegamyhre force-pushed the mx-moe-update branch 2 times, most recently from 960840d to 90582f1 Compare September 12, 2025 00:43
@danielvegamyhre
Copy link
Contributor Author

cc @tianyu-l @drisspg for review

# TODO: add warning in torchao when this happens, or find a better way to avoid this.
if self.moe_fqns:
self._convert_moe_layers(model)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it feels like the converter registered for self.config should handle this case specifically

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as in having separate converters registered for mxfp8 moe, fp8 moe etc? If so, I kind of agree actually, the converters are becoming a bit of a mess, and having separate converters would also allow users to convert just dense or just MoE (or both), rather than the current state of having to convert dense in order to convert MoE.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that is one way, another thing I was thinking is that https://github.com/pytorch/ao/blob/93030e750186ace1c1c2ee7a849e2818a9f0ffde/torchao/prototype/moe_training/conversion_utils.py#L50 should be able to gracefully handle the case where the module has already be converted

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separated the dense and MoE converters

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the log

Swapped w1.weight to ScaledGroupedMMTensor

it's hard to tell which w1 is converted. Can we include full fqn?

(eager only currently - compile support in progress)

Does it mean we shouldn't care too much about performance with this PR?
How about numerics?

I'll try to find some time to look into #1651

@danielvegamyhre danielvegamyhre marked this pull request as draft September 14, 2025 00:23
@danielvegamyhre danielvegamyhre changed the title [mxfp8 moe training] add torchao MXFP8 MoE training integration; bump version guard [WIP] [mxfp8 moe training] add torchao MXFP8 MoE training integration; bump version guard Sep 15, 2025
@danielvegamyhre danielvegamyhre changed the title [WIP] [mxfp8 moe training] add torchao MXFP8 MoE training integration; bump version guard [mxfp8 moe training] add MX MoE model converter using torchao mxfp8 moe training Sep 20, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review September 20, 2025 18:15
@danielvegamyhre danielvegamyhre force-pushed the mx-moe-update branch 6 times, most recently from d674008 to 5c45908 Compare September 20, 2025 19:19
@danielvegamyhre
Copy link
Contributor Author

@tianyu-l @drisspg this is ready for another look

@danielvegamyhre
Copy link
Contributor Author

it's hard to tell which w1 is converted. Can we include full fqn?

Yes, i've been meaning to make this update in torchao... will do.

filter_fqns = ["output", "router.gate"]
moe_fqns_prototype = ["experts"]

[mx]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you prefer to not update the mx_moe configs in the toml configs?

moe_fqns_prototype: list[str] | str = field(default_factory=list)

@dataclass
class MXMoE:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separating Linear and MoE treatment is reasonable. I have two comments:

  1. To make naming consistent, should we rename MX to MXLinear?
  2. In general, as we are creating more and more quantization dataclasses in job_config.py, I somehow prefer we group them under the same root level dataclass, say Quantization, matching the folder organization. I think this can be achieved by using hierarchical dataclasses and nested toml https://toml.io/en/v1.0.0#array-of-tables. We may have to update config/manager.py so doesn't need to be addressed in this PR.

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make naming consistent, should we rename MX to MXLinear?

That's fine with me, can do in follow up

I somehow prefer we group them under the same root level dataclass, say Quantization, matching the folder organization. I think this can be achieved by using hierarchical dataclasses and nested toml

I like this idea!

# The quantization modules are intended to be ran under `torch.compile`` for competitive performance

# Module level global constants
MXFP8_GROUP_ALIGNMENT_SIZE = 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about float8, should we also put the constant here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I plan to refactor float8 rowwise MoE to use the same pattern in a self-contained follow up PR.

return True
return False

config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the torchtitan feature name here should match the torchao one. If you want mxtraining, make the torchao workflow adhere to that too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants