Skip to content

Commit f1412fa

Browse files
HuanyuZhangfacebook-github-bot
authored andcommitted
Opacus release v1.5.2 (#663)
Summary: Pull Request resolved: #663 Release a new version of Opacus Furthermore, we replace "opt_einsum.contract" by torch.einsum to avoid errors when "opt_einsum" is not available. This will not hurt the performance since torch will automatically shift to "opt_einsum" for acceleratiton when the package is available (https://pytorch.org/docs/stable/generated/torch.einsum.html) Code pointer: https://pytorch.org/docs/stable/_modules/torch/backends/opt_einsum.html#is_available Reviewed By: EnayatUllah Differential Revision: D60672828 fbshipit-source-id: f8bbc0aa404e48f15ce129689a6e55af68daa5e4
1 parent eb94674 commit f1412fa

File tree

11 files changed

+35
-32
lines changed

11 files changed

+35
-32
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
# Changelog
22

3+
## v1.5.2
4+
5+
### New features
6+
* Add a function of "double_backward" simplifying the training loop (#661)
7+
8+
### Bug fixes
9+
* Fix issue with setting of param_group for the DPOptimizer wrapper (issue 649) (#660)
10+
* Fix issue of DDP optimizer for FGC. The step function incorrectly called "original_optimizer.original_optimizer" (#662)
11+
* Replace "opt_einsum.contract" by "torch.einsum"(#663)
12+
313
## v1.5.1
414

515
### Bug fixes

opacus/grad_sample/conv.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torch.nn as nn
2222
import torch.nn.functional as F
2323
from opacus.utils.tensor_utils import unfold2d, unfold3d
24-
from opt_einsum import contract
2524

2625
from .utils import register_grad_sampler
2726

@@ -90,7 +89,7 @@ def compute_conv_grad_sample(
9089
ret = {}
9190
if layer.weight.requires_grad:
9291
# n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz
93-
grad_sample = contract("noq,npq->nop", backprops, activations)
92+
grad_sample = torch.einsum("noq,npq->nop", backprops, activations)
9493
# rearrange the above tensor and extract diagonals.
9594
grad_sample = grad_sample.view(
9695
n,
@@ -100,7 +99,7 @@ def compute_conv_grad_sample(
10099
int(layer.in_channels / layer.groups),
101100
np.prod(layer.kernel_size),
102101
)
103-
grad_sample = contract("ngrg...->ngr...", grad_sample).contiguous()
102+
grad_sample = torch.einsum("ngrg...->ngr...", grad_sample).contiguous()
104103
shape = [n] + list(layer.weight.shape)
105104
ret[layer.weight] = grad_sample.view(shape)
106105

opacus/grad_sample/dp_rnn.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import torch
2020
import torch.nn as nn
2121
from opacus.layers.dp_rnn import RNNLinear
22-
from opt_einsum import contract
2322

2423
from .utils import register_grad_sampler
2524

@@ -42,8 +41,8 @@ def compute_rnn_linear_grad_sample(
4241
activations = activations[0]
4342
ret = {}
4443
if layer.weight.requires_grad:
45-
gs = contract("n...i,n...j->nij", backprops, activations)
44+
gs = torch.einsum("n...i,n...j->nij", backprops, activations)
4645
ret[layer.weight] = gs
4746
if layer.bias is not None and layer.bias.requires_grad:
48-
ret[layer.bias] = contract("n...k->nk", backprops)
47+
ret[layer.bias] = torch.einsum("n...k->nk", backprops)
4948
return ret

opacus/grad_sample/group_norm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import torch
2020
import torch.nn as nn
2121
import torch.nn.functional as F
22-
from opt_einsum import contract
2322

2423
from .utils import register_grad_sampler
2524

@@ -42,7 +41,7 @@ def compute_group_norm_grad_sample(
4241
ret = {}
4342
if layer.weight.requires_grad:
4443
gs = F.group_norm(activations, layer.num_groups, eps=layer.eps) * backprops
45-
ret[layer.weight] = contract("ni...->ni", gs)
44+
ret[layer.weight] = torch.einsum("ni...->ni", gs)
4645
if layer.bias is not None and layer.bias.requires_grad:
47-
ret[layer.bias] = contract("ni...->ni", backprops)
46+
ret[layer.bias] = torch.einsum("ni...->ni", backprops)
4847
return ret

opacus/grad_sample/instance_norm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21-
from opt_einsum import contract
2221

2322
from .utils import register_grad_sampler
2423

@@ -51,7 +50,7 @@ def compute_instance_norm_grad_sample(
5150
ret = {}
5251
if layer.weight.requires_grad:
5352
gs = F.instance_norm(activations, eps=layer.eps) * backprops
54-
ret[layer.weight] = contract("ni...->ni", gs)
53+
ret[layer.weight] = torch.einsum("ni...->ni", gs)
5554
if layer.bias is not None and layer.bias.requires_grad:
56-
ret[layer.bias] = contract("ni...->ni", backprops)
55+
ret[layer.bias] = torch.einsum("ni...->ni", backprops)
5756
return ret

opacus/grad_sample/linear.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import torch
2020
import torch.nn as nn
21-
from opt_einsum.contract import contract
2221

2322
from .utils import register_grad_sampler, register_norm_sampler
2423

@@ -42,10 +41,10 @@ def compute_linear_grad_sample(
4241
activations = activations[0]
4342
ret = {}
4443
if layer.weight.requires_grad:
45-
gs = contract("n...i,n...j->nij", backprops, activations)
44+
gs = torch.einsum("n...i,n...j->nij", backprops, activations)
4645
ret[layer.weight] = gs
4746
if layer.bias is not None and layer.bias.requires_grad:
48-
ret[layer.bias] = contract("n...k->nk", backprops)
47+
ret[layer.bias] = torch.einsum("n...k->nk", backprops)
4948
return ret
5049

5150

@@ -66,23 +65,25 @@ def compute_linear_norm_sample(
6665

6766
if backprops.dim() == 2:
6867
if layer.weight.requires_grad:
69-
g = contract("n...i,n...i->n", backprops, backprops)
70-
a = contract("n...j,n...j->n", activations, activations)
68+
g = torch.einsum("n...i,n...i->n", backprops, backprops)
69+
a = torch.einsum("n...j,n...j->n", activations, activations)
7170
ret[layer.weight] = torch.sqrt((g * a).flatten())
7271
if layer.bias is not None and layer.bias.requires_grad:
7372
ret[layer.bias] = torch.sqrt(
74-
contract("n...i,n...i->n", backprops, backprops).flatten()
73+
torch.einsum("n...i,n...i->n", backprops, backprops).flatten()
7574
)
7675
elif backprops.dim() == 3:
7776
if layer.weight.requires_grad:
7877

79-
ggT = contract("nik,njk->nij", backprops, backprops) # batchwise g g^T
80-
aaT = contract("nik,njk->nij", activations, activations) # batchwise a a^T
81-
ga = contract("n...i,n...i->n", ggT, aaT).clamp(min=0)
78+
ggT = torch.einsum("nik,njk->nij", backprops, backprops) # batchwise g g^T
79+
aaT = torch.einsum(
80+
"nik,njk->nij", activations, activations
81+
) # batchwise a a^T
82+
ga = torch.einsum("n...i,n...i->n", ggT, aaT).clamp(min=0)
8283

8384
ret[layer.weight] = torch.sqrt(ga)
8485
if layer.bias is not None and layer.bias.requires_grad:
85-
ggT = contract("nik,njk->nij", backprops, backprops)
86-
gg = contract("n...i,n...i->n", ggT, ggT).clamp(min=0)
86+
ggT = torch.einsum("nik,njk->nij", backprops, backprops)
87+
gg = torch.einsum("n...i,n...i->n", ggT, ggT).clamp(min=0)
8788
ret[layer.bias] = torch.sqrt(gg)
8889
return ret

opacus/optimizers/adaclipoptimizer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from typing import Callable, Optional
1919

2020
import torch
21-
from opt_einsum import contract
2221
from torch.optim import Optimizer
2322

2423
from .optimizer import (
@@ -107,7 +106,7 @@ def clip_and_accumulate(self):
107106
for p in self.params:
108107
_check_processed_flag(p.grad_sample)
109108
grad_sample = self._get_flat_grad_sample(p)
110-
grad = contract("i,i...", per_sample_clip_factor, grad_sample)
109+
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)
111110

112111
if p.summed_grad is not None:
113112
p.summed_grad += grad

opacus/optimizers/ddp_perlayeroptimizer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from typing import Callable, List, Optional
1919

2020
import torch
21-
from opt_einsum import contract
2221
from torch import nn
2322
from torch.optim import Optimizer
2423

@@ -31,7 +30,7 @@ def _clip_and_accumulate_parameter(p: nn.Parameter, max_grad_norm: float):
3130
per_sample_norms = p.grad_sample.view(len(p.grad_sample), -1).norm(2, dim=-1)
3231
per_sample_clip_factor = (max_grad_norm / (per_sample_norms + 1e-6)).clamp(max=1.0)
3332

34-
grad = contract("i,i...", per_sample_clip_factor, p.grad_sample)
33+
grad = torch.einsum("i,i...", per_sample_clip_factor, p.grad_sample)
3534
if p.summed_grad is not None:
3635
p.summed_grad += grad
3736
else:

opacus/optimizers/optimizer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import torch
2222
from opacus.optimizers.utils import params
23-
from opt_einsum.contract import contract
2423
from torch import nn
2524
from torch.optim import Optimizer
2625

@@ -450,7 +449,7 @@ def clip_and_accumulate(self):
450449
for p in self.params:
451450
_check_processed_flag(p.grad_sample)
452451
grad_sample = self._get_flat_grad_sample(p)
453-
grad = contract("i,i...", per_sample_clip_factor, grad_sample)
452+
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)
454453

455454
if p.summed_grad is not None:
456455
p.summed_grad += grad

opacus/optimizers/perlayeroptimizer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import torch
2020
from opacus.optimizers.utils import params
21-
from opt_einsum import contract
2221
from torch.optim import Optimizer
2322

2423
from .optimizer import DPOptimizer, _check_processed_flag, _mark_as_processed
@@ -65,7 +64,7 @@ def clip_and_accumulate(self):
6564
per_sample_clip_factor = (max_grad_norm / (per_sample_norms + 1e-6)).clamp(
6665
max=1.0
6766
)
68-
grad = contract("i,i...", per_sample_clip_factor, grad_sample)
67+
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)
6968

7069
if p.summed_grad is not None:
7170
p.summed_grad += grad

0 commit comments

Comments
 (0)