-
Notifications
You must be signed in to change notification settings - Fork 531
fix: pp grad accumulation is broken #1732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[problem] Using gradient accumulation is incompatible with PipleineSchedule(..., scale_grads=True) option, which defaults to True. When this option is set, at each step, all gradients are scaled by the micro-batch size. This works fine for a single gradient accumulation step, but when using multiple steps, this will rescale the total gradient by this factor, not just at the end of gradient accumulation. The result is that the accumulated gradient is an exponential moving average, rather than a sum. Overall, the resulting gradients are much smaller than they should be and using gradient accumulation with PP is not equivalent to using it without PP -- the loss curves diverge substantially, as well as the gradient-norms are way off. A secondary consequence of is that at every step, it divides the gradients by n_microbatches, which is computationally expensive when applied to a large model. [solution] Set "scale_grads=False" when creating the scheduler instance. Compute "n_microbatches" in the constructor and apply this factor, along with gradient_accumulation_steps, to the scale factor in "rescale_accumulated_loss()". This will cause the loss to be scaled, rather than the gradients, at each step by the correct factor. A secondary benifit of this approach is that it avoids having to modify all of the gradients. It's much cheaper, computationally than modifying all of the gradients -- and it's correct, which it is not, without the change. A side effect of the previous change is that the loss values returned by the pipeline have been scaled by this factor, which makes them too small by a factor of n_microbatches. We can correct this by rescaling the returned loss by the same factor. [testing] Witout these changes, a baseline run, with 10 gradient accumulation steps, on a single GPU is compared against a run (without the changes) to a 2 GPU pipeline, using 1F1B. The effective batch size is 320 in both cases, with all other variables controlled. The result is a substantial divergence between the loss curves and gradient-norm of the two runs. With this change applied, the results are nearly identical, ignoring minor differences from non-determinism. [references] scale_grads option: https://github.com/pytorch/pytorch/blob/281bb56cc50073159c8418c5c99c7459c914c4db/torch/distributed/pipelining/schedules.py#L286 scale_grads implementation: https://github.com/pytorch/pytorch/blob/281bb56cc50073159c8418c5c99c7459c914c4db/torch/distributed/pipelining/stage.py#L567
Hi @jdinalt! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
I'm not familiar with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the fix. I left some suggestions, please see if they make sense to you.
As sanity check, could you also share some experiment results as sanity check?
You could use your own setup, or follow https://github.com/pytorch/torchtitan/blob/main/docs/debugging.md#seed-checkpoint-based-reproducibility
schedule = schedule_class( | ||
stages if looped_schedule else stages[0], | ||
n_microbatches=n_microbatches, | ||
loss_fn=loss_fn, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we not change the current loss_fn
in train.py
, and call rescale_accumulated_loss(loss_fn, n_microbatches)
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems perfectly feasible. My concern is that it is double-wrapping the loss function, which has already been wrapped in train.py.
It will probably work just fine, while being slightly less efficient, having added one more layer.
This then begs the question of what do we do about addressing scaling the reported loss back up by the same factor in train.py?
With the present change, I cache the computed value of n_microbatches for rescalling the logged loss. In theory, we could wrap the scheduler too, but is it worth the trouble?
I'm happy to defer to your preferences on this. Let me know how you would like to see it addressed?
torchtitan/train.py
Outdated
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU | ||
loss = ( | ||
torch.mean(torch.stack(losses)).to(self.device) | ||
torch.mean(torch.stack(losses)).multiply(self.pp_n_microbatches).to(self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just do
torch.mean(torch.stack(losses)).multiply(self.pp_n_microbatches).to(self.device) | |
torch.sum(torch.stack(losses)).to(self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that would do the trick. I'll make the change and test it.
I have checked references to "this.loss_fn" and I don't see anything which will malfunction if the outer loss_fn is wrapped, while the one used by the scheduler is not the same, being wrapped twice. It's probably fine, but has a non-zero risk of someone assuming that this is the final loss function and using it in a way which breaks.
Overall, I think it's worth the risk and the changes should be kept out of train.py.
Thanks for the suggestions. I'm presently working on the requested changes. WRT "experiment results," I'm happy to do so. Are you requesting a set of native test configurations? Tensorboard data? Something else? |
I have applied the requested changes locally and I'm running the test again. So far, both grad-norm and train-loss are precisely following the trajectory of my previous commit. All of the test configurations already have full determinism enabled. |
I have attached the Tensorboard runs, with the latest "PP" run being with the latest changes. The results are identical to those from the prior commit. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Could you show a screenshot of the tensorboard check? Thanks!
while the one used by the scheduler is not the same, being wrapped twice. It's probably fine, but has a non-zero risk of someone assuming that this is the final loss function and using it in a way which breaks.
Overall, I think it's worth the risk and the changes should be kept out of train.py.
Totally agree. I think the problem is, as you said, the loss is summed twice in train.py (one for PP, one for grad accumulation), but scaled down once in train.py (grad accumulation) and once in PP.
The reason we are taking this risk is simply because n_microbatches
is a bit cumbersome to get. I think for now it's OK. O/w we can do it here. https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/validate.py#L150
Before we can merge, please
- add a comment here https://github.com/pytorch/torchtitan/pull/1732/files#diff-ea620cebba782ef8545fcfc700627348c15bb4cbb8ef5c5b4f417ddff955668bR455 stating "using sum instead of mean because we already rescale the loss_fn down by a factor of n_microbatches in torchtitan/distributed/pipeline_parallel.py"
- also fix the occurrence here https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/validate.py#L150
This changes makes the pipeline validation consistent with that of train, where we change the loss reduction form "mean" to "sum" in pipeline parallel case, as we are rescaling the loss by the number of microbatches internally via rescale_accumulated_loss(). This change also added comments around these changes, explaining why they were made.
I have pushed a new commit with the requested changes and tested with validation enabled. When validation reports loss, it is reporting the value, scaled by the number of gradient-accumulation steps. Thus, a loss of 4.0, with 10 steps, is reported as 0.4. I observed this issue before making any changes, and had left it as something to investigate later. Just to be sure, I switched back to a clean mainline branch and ran a pipeline test, without using any of my modified code. I have confirmed that the issue is present, without touching anything . As expected, the original issue is observed, with grad-norm readings below 0.1. I'll see if I can identify the source of the issue, but again, it does not appear to be related to my changes. If you wish to try to reproduce, this describes the delta between the standard debug llama3 config and the configuration I am using. It should not be that difficult to map it back to the original .toml config format. [metrics]
== super()
enable_tensorboard: true
save_tb_folder: "pp"
log_freq: 10
[training]
=> super()
steps: 500
local_batch_size: 32
global_batch_size: 320
seq_len: 512
dataset: "c4"
[validation]
== super()
freq: 100
local_batch_size: 32
steps: 14
enable: true
seq_len: 512
dataset: "c4_validation"
[parallelism]
== super()
data_parallel_shard_degree: 1
pipeline_parallel_degree: 2
pipeline_parallel_microbatch_size: 8
pipeline_parallel_layers_per_stage: 4
pipeline_parallel_schedule: "1F1B" |
I have kicked off a test run to log the full validation plot. It will be a little while. On a side note, I have noticed that what should be an equivalent pipeline configuration runs significantly slower in torchtitan than my own pipeline trainer. I'm using an equivalent Llama model, although with a different implementation. I'm not sure what the key difference is, but it's probably worth looking into, because you should be able to run much faster than it is. Some possible differences:
Anyhow, when I have a few spare cycles, I'll see if I can figure out what the difference is. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. H100 CI failure is unrelated.
I think I know the root cause. Sadly, I can't think of a super clean fix. As a workaround, I'm OK with passing in the number of grad accum steps to |
We might also need to apply the same fix here: https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/forge/example_train.py#L200 |
I have investigated the failed test. I have sufficient GPUs to run it and see the same error, with or without my changes. I just want to confirm if this is a known issue with this test? |
It was pointed out in the discussion for the overall fix that we had missed this file, which also needs to be changed.
Absolutely. I have made the same change to the requested file. I don't know how to immediately test this component, so please take a close look and let me know if this looks good to you? |
@jdinalt H100 test failure happens on other PRs as well. And it is an async TP issue, which is not related to this PR. |
I have found another copy of the pipeline code here.
I'll update this one as well. I believe that's it? |
@jdinalt
Are you OK with my workaround and would you incorporate it in this PR or another one?
|
I'm OK with your workaround, although I believe it would be cleaner to not entangle the two changes, as they are separate issues. I'm not strictly opposed to it either. If you are OK with keeping them separate, I'll create another PR and issue to track it? Otherwise, I'll go ahead and fix it in this PR. |
Let's do it in separate PR/issue. Thanks a lot!! |
[problem]
Using gradient accumulation is incompatible with PipleineSchedule(..., scale_grads=True) option, which defaults to True.
When this option is set, at each step, all gradients are scaled by the micro-batch size. This works fine for a single gradient accumulation step, but when using multiple steps, this will rescale the total gradient by this factor, not just at the end of gradient accumulation.
The result is that the accumulated gradient is an exponential moving average, rather than a sum. Overall, the resulting gradients are much smaller than they should be and using gradient accumulation with PP is not equivalent to using it without PP -- the loss curves diverge substantially, as well as the gradient-norms are way off.
A secondary consequence of is that at every step, it divides the gradients by n_microbatches, which is computationally expensive when applied to a large model.
[solution]
Set "scale_grads=False" when creating the scheduler instance.
Compute "n_microbatches" in the constructor and apply this factor, along with gradient_accumulation_steps, to the scale factor in "rescale_accumulated_loss()". This will cause the loss to be scaled, rather than the gradients, at each step by the correct factor.
A secondary benifit of this approach is that it avoids having to modify all of the gradients. It's much cheaper, computationally than modifying all of the gradients -- and it's correct, which it is not, without the change.
A side effect of the previous change is that the loss values returned by the pipeline have been scaled by this factor, which makes them too small by a factor of n_microbatches. We can correct this by rescaling the returned loss by the same factor.
[testing]
Witout these changes, a baseline run, with 10 gradient accumulation steps, on a single GPU is compared against a run (without the changes) to a 2 GPU pipeline, using 1F1B. The effective batch size is 320 in both cases, with all other variables controlled. The result is a substantial divergence between the loss curves and gradient-norm of the two runs.
With this change applied, the results are nearly identical, ignoring minor differences from non-determinism.
[references]
scale_grads option: https://github.com/pytorch/pytorch/blob/281bb56cc50073159c8418c5c99c7459c914c4db/torch/distributed/pipelining/schedules.py#L286
scale_grads implementation: https://github.com/pytorch/pytorch/blob/281bb56cc50073159c8418c5c99c7459c914c4db/torch/distributed/pipelining/stage.py#L567
Test code for reproduction of the issue and the testing the fix:
https://github.com/jdinalt/forgather/tree/main/examples/torchtitan/test_parallelisms