Skip to content

Commit 36f7a34

Browse files
EnayatUllahfacebook-github-bot
authored andcommitted
Towards making the interface of ghost clipping same as that of PyTorch (#668)
Summary: Pull Request resolved: #668 We define two classes DPLossFastGradientClipping and DPTensorFastGradientClipping in the utils fine, which allows us to repurpose loss.backward() to perform the two backward passes and loss scaling required to implement ghost clipping. Reviewed By: HuanyuZhang Differential Revision: D61162530 fbshipit-source-id: 9b832469e1645513a13e1c962a13500169a3806b
1 parent 27e6a1d commit 36f7a34

File tree

6 files changed

+141
-42
lines changed

6 files changed

+141
-42
lines changed

opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,21 @@ def create_norm_sample(
4747
"""
4848

4949
if param.requires_grad:
50-
param._norm_sample = torch.zeros(
51-
torch.Size([max_batch_len, 1]),
52-
device=grad_sample.device,
53-
dtype=grad_sample.dtype,
54-
)
55-
param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm(2, dim=-1)
50+
if (
51+
max_batch_len == 0
52+
): # To handle the case of empty batch that may arise from Poisson sampling
53+
param._norm_sample = torch.tensor(
54+
[], device=grad_sample.device, dtype=grad_sample.dtype
55+
)
56+
else:
57+
param._norm_sample = torch.zeros(
58+
torch.Size([max_batch_len, 1]),
59+
device=grad_sample.device,
60+
dtype=grad_sample.dtype,
61+
)
62+
param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm(
63+
2, dim=-1
64+
)
5665

5766

5867
class GradSampleModuleFastGradientClipping(GradSampleModule):
@@ -110,7 +119,7 @@ def __init__(
110119
self.max_grad_norm = max_grad_norm
111120
self.use_ghost_clipping = use_ghost_clipping
112121

113-
def get_coeff(self) -> torch.Tensor:
122+
def get_clipping_coef(self) -> torch.Tensor:
114123
"""Get per-example gradient scaling factor for clipping."""
115124
norm_sample = self.get_norm_sample()
116125
return (self.max_grad_norm / (norm_sample + 1e-6)).clamp(max=1.0)
@@ -175,6 +184,7 @@ def capture_backprops_hook(
175184
return
176185

177186
backprops = forward_output[0].detach()
187+
178188
activations, backprops = self.rearrange_grad_samples(
179189
module=module,
180190
backprops=backprops,
@@ -216,7 +226,6 @@ def capture_backprops_hook(
216226
max_batch_len=module.max_batch_len,
217227
)
218228
del p.grad_sample
219-
220229
if len(module.activations) == 0:
221230
if hasattr(module, "max_batch_len"):
222231
del module.max_batch_len

opacus/optimizers/__init__.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,17 @@
3939

4040

4141
def get_optimizer_class(clipping: str, distributed: bool, grad_sample_mode: str = None):
42-
if clipping == "flat" and distributed is False:
42+
if grad_sample_mode == "ghost":
43+
if clipping == "flat" and distributed is False:
44+
return DPOptimizerFastGradientClipping
45+
elif clipping == "flat" and distributed is True:
46+
return DistributedDPOptimizerFastGradientClipping
47+
else:
48+
raise ValueError(
49+
f"Unsupported combination of parameters. Clipping: {clipping} and grad_sample_mode: {grad_sample_mode}"
50+
)
51+
elif clipping == "flat" and distributed is False:
4352
return DPOptimizer
44-
elif clipping == "ghost" and distributed is False:
45-
return DPOptimizerFastGradientClipping
46-
elif clipping == "ghost" and distributed is True:
47-
return DistributedDPOptimizerFastGradientClipping
4853
elif clipping == "flat" and distributed is True:
4954
return DistributedDPOptimizer
5055
elif clipping == "per_layer" and distributed is False:

opacus/privacy_engine.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
from opacus.optimizers import DPOptimizer, get_optimizer_class
3232
from opacus.schedulers import _GradClipScheduler, _NoiseScheduler
33+
from opacus.utils.fast_gradient_clipping_utils import DPLossFastGradientClipping
3334
from opacus.validators.module_validator import ModuleValidator
3435
from torch import nn, optim
3536
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -277,6 +278,7 @@ def make_private(
277278
*,
278279
module: nn.Module,
279280
optimizer: optim.Optimizer,
281+
criterion=nn.CrossEntropyLoss(), # Added deafult for backward compatibility
280282
data_loader: DataLoader,
281283
noise_multiplier: float,
282284
max_grad_norm: Union[float, List[float]],
@@ -400,6 +402,11 @@ def make_private(
400402
optimizer.attach_step_hook(
401403
self.accountant.get_optimizer_hook_fn(sample_rate=sample_rate)
402404
)
405+
if grad_sample_mode == "ghost":
406+
criterion = DPLossFastGradientClipping(
407+
module, optimizer, criterion, loss_reduction
408+
)
409+
return module, optimizer, criterion, data_loader
403410

404411
return module, optimizer, data_loader
405412

@@ -408,6 +415,7 @@ def make_private_with_epsilon(
408415
*,
409416
module: nn.Module,
410417
optimizer: optim.Optimizer,
418+
criterion=nn.CrossEntropyLoss(), # Added deafult for backward compatibility
411419
data_loader: DataLoader,
412420
target_epsilon: float,
413421
target_delta: float,
@@ -487,6 +495,7 @@ def make_private_with_epsilon(
487495
module=module,
488496
optimizer=optimizer,
489497
data_loader=data_loader,
498+
criterion=criterion,
490499
noise_multiplier=get_noise_multiplier(
491500
target_epsilon=target_epsilon,
492501
target_delta=target_delta,

opacus/tests/grad_sample_module_fast_gradient_clipping_test.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from hypothesis import given, settings
2323
from opacus.grad_sample import GradSampleModule, GradSampleModuleFastGradientClipping
2424
from opacus.optimizers import DPOptimizer, DPOptimizerFastGradientClipping
25-
from opacus.utils.fast_gradient_clipping_utils import double_backward
25+
from opacus.utils.fast_gradient_clipping_utils import DPLossFastGradientClipping
2626
from opacus.utils.per_sample_gradients_utils import clone_module
2727
from torch.utils.data import DataLoader, Dataset
2828

@@ -146,7 +146,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
146146
(input_data, target_data) = list(self.dl)[0]
147147
optimizer_normal.zero_grad()
148148
output_normal = self.model_normal(input_data)
149-
loss_normal = torch.mean(self.criterion(output_normal, target_data))
149+
loss_normal = torch.mean(self.criterion(output_normal, target_data), dim=0)
150150
loss_normal.backward()
151151
all_norms_normal = torch.stack(
152152
[
@@ -165,7 +165,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
165165
first_loss.backward(retain_graph=True)
166166

167167
optimizer_gc.zero_grad()
168-
coeff = self.grad_sample_module.get_coeff()
168+
coeff = self.grad_sample_module.get_clipping_coef()
169169
second_loss_per_sample = coeff * first_loss_per_sample
170170
second_loss = torch.sum(second_loss_per_sample)
171171
self.grad_sample_module.disable_hooks()
@@ -190,7 +190,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
190190
@settings(deadline=1000000)
191191
def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
192192
"""
193-
Tests if gradients are the same between standard (opacus) and fast gradient clipping, using double_backward function"
193+
Tests if gradients are the same between standard (opacus) and fast gradient clipping"
194194
"""
195195

196196
noise_multiplier = 0.0
@@ -200,7 +200,7 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
200200
self.dim = dim
201201
self.setUp_data_sequantial(self.size, self.length, self.dim)
202202
max_grad_norm = 1.0
203-
self.criterion = torch.nn.CrossEntropyLoss(reduction="none")
203+
self.criterion = torch.nn.CrossEntropyLoss()
204204

205205
sample_module = SampleModule()
206206
self.model_normal = GradSampleModule(clone_module(sample_module))
@@ -226,10 +226,14 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
226226
expected_batch_size=batch_size,
227227
)
228228

229+
criterion_gc = DPLossFastGradientClipping(
230+
self.grad_sample_module, optimizer_gc, self.criterion
231+
)
232+
229233
(input_data, target_data) = list(self.dl)[0]
230234
optimizer_normal.zero_grad()
231235
output_normal = self.model_normal(input_data)
232-
loss_normal = torch.mean(self.criterion(output_normal, target_data))
236+
loss_normal = torch.mean(self.criterion(output_normal, target_data), dim=0)
233237
loss_normal.backward()
234238
optimizer_normal.step()
235239

@@ -240,8 +244,9 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
240244

241245
output_gc = self.grad_sample_module(input_data)
242246

243-
first_loss_per_sample = self.criterion(output_gc, target_data)
244-
double_backward(self.grad_sample_module, optimizer_gc, first_loss_per_sample)
247+
loss_gc = criterion_gc(output_gc, target_data)
248+
loss_gc.backward()
249+
# double_backward(self.grad_sample_module, optimizer_gc, first_loss_per_sample)
245250

246251
all_grads_gc = [param.grad for param in self.grad_sample_module.parameters()]
247252
flat_grads_gc = torch.cat([p.flatten() for p in all_grads_gc])

opacus/tests/multigpu_gradcheck.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def run_ghost_clipping_test(
101101
loss_per_sample = loss_fn(outputs, y)
102102
torch.mean(loss_per_sample).backward(retain_graph=True)
103103
optimizer.zero_grad()
104-
rescaled_loss_per_sample = ddp_model.get_coeff() * loss_per_sample
104+
rescaled_loss_per_sample = ddp_model.get_clipping_coef() * loss_per_sample
105105
rescaled_loss = torch.sum(rescaled_loss_per_sample)
106106
ddp_model.disable_hooks()
107107
rescaled_loss.backward()

opacus/utils/fast_gradient_clipping_utils.py

Lines changed: 91 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,99 @@
2020
from opacus.optimizers import DPOptimizerFastGradientClipping
2121

2222

23-
def double_backward(
24-
module: GradSampleModuleFastGradientClipping,
25-
optimizer: DPOptimizerFastGradientClipping,
26-
loss_per_sample: torch.Tensor,
27-
) -> None:
23+
class DPTensorFastGradientClipping:
2824
"""
29-
Packages the training loop for Fast Gradient and Ghost Clipping. It does the two backward passes, as well as the loss rescaling and hook operations in between.
30-
This function also works with DistributedDPOptimizer.
25+
Packages the training loop for Fast Gradient and Ghost Clipping into loss.backward().
26+
"""
27+
28+
def __init__(
29+
self,
30+
module: GradSampleModuleFastGradientClipping,
31+
optimizer: DPOptimizerFastGradientClipping,
32+
loss_per_sample: torch.Tensor,
33+
loss_reduction: str = "mean",
34+
):
35+
"""
36+
37+
Args:
38+
module: the module to train
39+
optimizer: the optimizer used to train the module
40+
loss_per_sample: loss on each sample in the mini-batch of size [batch_size, 1]
41+
42+
"""
43+
44+
self.module = module
45+
self.optimizer = optimizer
46+
self.loss_per_sample = loss_per_sample
47+
self.loss_reduction = loss_reduction
48+
49+
def item(self):
50+
if self.loss_reduction == "mean":
51+
return torch.mean(self.loss_per_sample).detach().item()
52+
elif self.loss_reduction == "sum":
53+
return torch.sum(self.loss_per_sample).detach().item()
3154

32-
Args:
33-
module: The DP gradient sample module to train
34-
optimizer: The DP optimizer used to train the module
35-
loss_per_sample: loss on each sample in the mini-batch of size [batch_size, 1]
55+
def backward(self):
56+
"""
57+
Repurposes loss.backward() to perform two backward passes, as well as the loss rescaling and hook operations in between
58+
"""
3659

37-
Returns:
38-
None
60+
if self.loss_reduction == "mean":
61+
reduced_loss = torch.mean(self.loss_per_sample, dim=0)
62+
elif self.loss_reduction == "sum":
63+
reduced_loss = torch.sum(self.loss_per_sample, dim=0)
64+
else:
65+
raise ValueError(
66+
f"loss_reduction = {self.loss_reduction}. Only 'sum' and 'mean' losses are supported"
67+
)
68+
reduced_loss.backward(retain_graph=True)
69+
self.optimizer.zero_grad()
70+
coeff = self.module.get_clipping_coef()
71+
second_loss_per_sample = coeff * self.loss_per_sample
72+
second_loss = torch.sum(second_loss_per_sample)
73+
self.module.disable_hooks()
74+
second_loss.backward()
75+
self.module.enable_hooks()
76+
77+
78+
class DPLossFastGradientClipping:
79+
"""
80+
Wrapper on the loss function to be used with Fast Gradient and Ghost Clipping. It computes the per-sample loss, and wraps it in DPTensorFastGradientClipping.
3981
"""
4082

41-
torch.mean(loss_per_sample).backward(retain_graph=True)
42-
optimizer.zero_grad()
43-
rescaled_loss_per_sample = module.get_coeff() * loss_per_sample
44-
rescaled_loss = torch.sum(rescaled_loss_per_sample)
45-
module.disable_hooks()
46-
rescaled_loss.backward()
47-
module.enable_hooks()
83+
def __init__(
84+
self,
85+
module: GradSampleModuleFastGradientClipping,
86+
optimizer: DPOptimizerFastGradientClipping,
87+
criterion,
88+
loss_reduction: str = "mean",
89+
):
90+
assert loss_reduction in [
91+
"mean",
92+
"sum",
93+
], "loss_reduction should be either 'mean' or 'sum'"
94+
assert (
95+
loss_reduction
96+
== criterion.reduction
97+
== module.loss_reduction
98+
== optimizer.loss_reduction
99+
), "loss_reduction should be the same across GradSampleModule, Optimizer, Criterion, and loss_reduction"
100+
101+
self.optimizer = optimizer
102+
self.module = module
103+
self.criterion = criterion
104+
self.loss_reduction = loss_reduction
105+
self.criterion.reduction = "none"
106+
107+
def __call__(self, input, target) -> DPTensorFastGradientClipping:
108+
"""
109+
Redefining the forward function to compute per-sample loss and wrap it in DPTensorFastGradientClipping
110+
"""
111+
112+
loss_per_sample = self.criterion(
113+
input,
114+
target,
115+
)
116+
return DPTensorFastGradientClipping(
117+
self.module, self.optimizer, loss_per_sample, self.loss_reduction
118+
)

0 commit comments

Comments
 (0)