Skip to content

Commit aea78b3

Browse files
HuanyuZhangfacebook-github-bot
authored andcommitted
Fix the initialization function of "GradSampleModuleFastGradientClipping" (#675)
Summary: Pull Request resolved: #675 ``GradSampleModuleFastGradientClipping`` does not correctly take ``strict`` and ``force_functorch`` in its initialization function. Made the fix to allow the change of the values of the two parameters. Reviewed By: iden-kalemaj Differential Revision: D62676700 fbshipit-source-id: 6df643fb5e9ea47fe91490eeb01c32bd4ed8d743
1 parent a246aa6 commit aea78b3

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

opacus/grad_sample/grad_sample_module.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,9 @@ def __init__(
105105
``[K, batch_size, ...]``
106106
loss_reduction: Indicates if the loss reduction (for aggregating the gradients)
107107
is a sum or a mean operation. Can take values "sum" or "mean"
108-
strict: If set to ``True``, the input module will be validated to check that
109-
``GradSampleModule`` has grad sampler functions for all submodules of
110-
the input module (i.e. if it knows how to calculate per sample gradients)
111-
for all model parameters. If set to ``False``, per sample gradients will
108+
strict: If set to ``True``, the input module will be validated to make sure that none of its submodules includes buffers,
109+
which is not currently supported by Opacus.
110+
If set to ``False``, per sample gradients will
112111
be computed on "best effort" basis - they will be available where
113112
possible and set to None otherwise. This is not recommended, because
114113
some unsupported modules (e.g. BatchNorm) affect other parameters and
@@ -120,7 +119,7 @@ def __init__(
120119
Raises:
121120
NotImplementedError
122121
If ``strict`` is set to ``True`` and module ``m`` (or any of its
123-
submodules) doesn't have a registered grad sampler function.
122+
submodules) includes a buffer.
124123
"""
125124
super().__init__(
126125
m,

opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,15 @@ def __init__(
107107
Raises:
108108
NotImplementedError
109109
If ``strict`` is set to ``True`` and module ``m`` (or any of its
110-
submodules) doesn't have a registered grad sampler function.
110+
submodules) includes a buffer.
111111
"""
112112

113113
super().__init__(
114114
m,
115115
batch_first=batch_first,
116116
loss_reduction=loss_reduction,
117+
strict=strict,
118+
force_functorch=force_functorch,
117119
)
118120
self.trainable_parameters = [p for _, p in trainable_parameters(self._module)]
119121
self.max_grad_norm = max_grad_norm

0 commit comments

Comments
 (0)