Skip to content

Commit 4804a51

Browse files
EnayatUllahfacebook-github-bot
authored andcommitted
One backward function for Ghost Clipping (#661)
Summary: Pull Request resolved: #661 Simplfied training loop for ghost clipping using only one "double backward" function. Reviewed By: HuanyuZhang Differential Revision: D60427371 fbshipit-source-id: 73c016a31f0692adcfa3f6838e74315fbed26bb1
1 parent a059670 commit 4804a51

File tree

3 files changed

+53
-14
lines changed

3 files changed

+53
-14
lines changed

opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
class DistributedDPOptimizerFastGradientClipping(DPOptimizerFastGradientClipping):
2626
"""
27-
:class:`~opacus.optimizers.optimizer.DPOptimizer` compatible with
27+
:class:`opacus.optimizers.optimizer.DPOptimizer` compatible with
2828
distributed data processing
2929
"""
3030

opacus/tests/grad_sample_module_fast_gradient_clipping_test.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from hypothesis import given, settings
2323
from opacus.grad_sample import GradSampleModule, GradSampleModuleFastGradientClipping
2424
from opacus.optimizers import DPOptimizer, DPOptimizerFastGradientClipping
25+
from opacus.utils.fast_gradient_clipping_utils import double_backward
2526
from opacus.utils.per_sample_gradients_utils import clone_module
2627
from torch.utils.data import DataLoader, Dataset
2728

@@ -108,7 +109,7 @@ def setUp_data_sequantial(self, size, length, dim):
108109
@settings(deadline=1000000)
109110
def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
110111
"""
111-
Tests if norm calculation is same between standard (opacus) and fast gradient clipping"
112+
Tests if norm calculation is the same between standard (opacus) and fast gradient clipping"
112113
"""
113114
self.length = length
114115
self.size = size
@@ -189,7 +190,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
189190
@settings(deadline=1000000)
190191
def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
191192
"""
192-
Tests if gradients are same between standard (opacus) and fast gradient clipping"
193+
Tests if gradients are the same between standard (opacus) and fast gradient clipping, using double_backward function"
193194
"""
194195

195196
noise_multiplier = 0.0
@@ -237,19 +238,10 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
237238
]
238239
flat_grads_normal = torch.cat([p.flatten() for p in all_grads_normal])
239240

240-
self.grad_sample_module.enable_hooks()
241241
output_gc = self.grad_sample_module(input_data)
242242

243243
first_loss_per_sample = self.criterion(output_gc, target_data)
244-
first_loss = torch.mean(first_loss_per_sample)
245-
first_loss.backward(retain_graph=True)
246-
247-
optimizer_gc.zero_grad()
248-
coeff = self.grad_sample_module.get_coeff()
249-
second_loss_per_sample = coeff * first_loss_per_sample
250-
second_loss = torch.sum(second_loss_per_sample)
251-
self.grad_sample_module.disable_hooks()
252-
second_loss.backward()
244+
double_backward(self.grad_sample_module, optimizer_gc, first_loss_per_sample)
253245

254246
all_grads_gc = [param.grad for param in self.grad_sample_module.parameters()]
255247
flat_grads_gc = torch.cat([p.flatten() for p in all_grads_gc])
@@ -261,5 +253,5 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
261253
]
262254
)
263255
logging.info(f"Diff = {diff}")
264-
msg = "FAIL: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
256+
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
265257
assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import torch
17+
from opacus.grad_sample.grad_sample_module_fast_gradient_clipping import (
18+
GradSampleModuleFastGradientClipping,
19+
)
20+
from opacus.optimizers import DPOptimizerFastGradientClipping
21+
22+
23+
def double_backward(
24+
module: GradSampleModuleFastGradientClipping,
25+
optimizer: DPOptimizerFastGradientClipping,
26+
loss_per_sample: torch.Tensor,
27+
) -> None:
28+
"""
29+
Packages the training loop for Fast Gradient and Ghost Clipping. It does the two backward passes, as well as the loss rescaling and hook operations in between.
30+
This function also works with DistributedDPOptimizer.
31+
32+
Args:
33+
module: The DP gradient sample module to train
34+
optimizer: The DP optimizer used to train the module
35+
loss_per_sample: loss on each sample in the mini-batch of size [batch_size, 1]
36+
37+
Returns:
38+
None
39+
"""
40+
41+
torch.mean(loss_per_sample).backward(retain_graph=True)
42+
optimizer.zero_grad()
43+
rescaled_loss_per_sample = module.get_coeff() * loss_per_sample
44+
rescaled_loss = torch.sum(rescaled_loss_per_sample)
45+
module.disable_hooks()
46+
rescaled_loss.backward()
47+
module.enable_hooks()

0 commit comments

Comments
 (0)