13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
16
+ import itertools
16
17
import os
17
18
import sys
18
19
import unittest
24
25
import torch .optim as optim
25
26
from opacus import PrivacyEngine
26
27
from opacus .distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP
28
+ from opacus .grad_sample import GradSampleModuleFastGradientClipping
27
29
from opacus .optimizers .ddp_perlayeroptimizer import (
28
30
DistributedPerLayerOptimizer ,
29
31
SimpleDistributedPerLayerOptimizer ,
30
32
)
31
33
from opacus .optimizers .ddpoptimizer import DistributedDPOptimizer
34
+ from opacus .optimizers .ddpoptimizer_fast_gradient_clipping import (
35
+ DistributedDPOptimizerFastGradientClipping ,
36
+ )
32
37
from torch .nn .parallel import DistributedDataParallel as DDP
33
38
from torch .utils .data import DataLoader , TensorDataset
34
39
from torch .utils .data .distributed import DistributedSampler
@@ -69,6 +74,45 @@ def forward(self, x):
69
74
return self .net2 (self .relu (self .net1 (x )))
70
75
71
76
77
+ def run_ghost_clipping_test (
78
+ model , optimizer , data_loader , batch_size , max_grad_norm , weight , rank
79
+ ):
80
+
81
+ ddp_model = DPDDP (model )
82
+ ddp_model = GradSampleModuleFastGradientClipping (
83
+ ddp_model ,
84
+ max_grad_norm = max_grad_norm ,
85
+ use_ghost_clipping = True ,
86
+ )
87
+ optimizer = DistributedDPOptimizerFastGradientClipping (
88
+ optimizer ,
89
+ noise_multiplier = 0 ,
90
+ max_grad_norm = max_grad_norm ,
91
+ expected_batch_size = batch_size ,
92
+ )
93
+
94
+ assert isinstance (optimizer , DistributedDPOptimizerFastGradientClipping )
95
+
96
+ loss_fn = nn .CrossEntropyLoss (reduction = "none" )
97
+
98
+ for x , y in data_loader :
99
+ ddp_model .enable_hooks ()
100
+ outputs = ddp_model (x .to (rank ))
101
+ loss_per_sample = loss_fn (outputs , y )
102
+ torch .mean (loss_per_sample ).backward (retain_graph = True )
103
+ optimizer .zero_grad ()
104
+ rescaled_loss_per_sample = ddp_model .get_coeff () * loss_per_sample
105
+ rescaled_loss = torch .sum (rescaled_loss_per_sample )
106
+ ddp_model .disable_hooks ()
107
+ rescaled_loss .backward ()
108
+ ddp_model .enable_hooks ()
109
+ optimizer .step ()
110
+ break
111
+
112
+ weight .copy_ (model .net1 .weight .data .cpu ())
113
+ cleanup ()
114
+
115
+
72
116
def demo_basic (rank , weight , world_size , dp , clipping , grad_sample_mode ):
73
117
torch .manual_seed (world_size )
74
118
batch_size = 32
@@ -79,12 +123,15 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
79
123
model .net1 .weight .data .zero_ ()
80
124
optimizer = optim .SGD (model .parameters (), lr = 1 )
81
125
126
+ # create dataset
82
127
labels = torch .randn (2 * batch_size , 5 ).to (rank )
83
128
data = torch .randn (2 * batch_size , 10 )
84
-
85
129
dataset = TensorDataset (data , labels )
86
130
87
- loss_fn = nn .MSELoss ()
131
+ loss_fn = nn .CrossEntropyLoss ()
132
+
133
+ max_grad_norm = 1e8
134
+
88
135
if dp and clipping == "flat" :
89
136
ddp_model = DPDDP (model )
90
137
else :
@@ -96,8 +143,15 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
96
143
dataset , num_replicas = world_size , rank = rank , shuffle = False
97
144
)
98
145
data_loader = DataLoader (dataset , batch_size = batch_size , sampler = sampler )
146
+
147
+ # use a separate function for ghost clipping since the procedure has a different structure
148
+ if dp and clipping == "ghost" :
149
+ run_ghost_clipping_test (
150
+ model , optimizer , data_loader , batch_size , max_grad_norm , weight , rank
151
+ )
152
+ return
153
+
99
154
if dp :
100
- max_grad_norm = 1e8
101
155
if clipping == "per_layer" :
102
156
max_grad_norm = [max_grad_norm for _ in model .parameters ()]
103
157
ddp_model , optimizer , data_loader = privacy_engine .make_private (
@@ -141,33 +195,38 @@ def run_demo(demo_fn, weight, world_size, dp, clipping, grad_sample_mode):
141
195
142
196
class GradientComputationTest (unittest .TestCase ):
143
197
def test_gradient_correct (self ) -> None :
144
- # Tests that gradient is the same with DP or with DDP
198
+ # Tests that gradient is the same with DP or without DDP
145
199
n_gpus = torch .cuda .device_count ()
146
200
self .assertTrue (
147
201
n_gpus >= 2 , f"Need at least 2 gpus but was provided only { n_gpus } ."
148
202
)
149
203
150
- for clipping in ["flat" , "per_layer" ]:
151
- for grad_sample_mode in ["hooks" , "ew" ]:
152
- weight_dp , weight_nodp = torch .zeros (10 , 10 ), torch .zeros (10 , 10 )
153
-
154
- run_demo (
155
- demo_basic ,
156
- weight_dp ,
157
- 2 ,
158
- dp = True ,
159
- clipping = clipping ,
160
- grad_sample_mode = grad_sample_mode ,
161
- )
162
- run_demo (
163
- demo_basic ,
164
- weight_nodp ,
165
- 2 ,
166
- dp = False ,
167
- clipping = None ,
168
- grad_sample_mode = None ,
169
- )
170
-
171
- self .assertTrue (
172
- torch .allclose (weight_dp , weight_nodp , atol = 1e-5 , rtol = 1e-3 )
173
- )
204
+ clipping_grad_sample_pairs = list (
205
+ itertools .product (["flat" , "per_layer" ], ["hooks" , "ew" ])
206
+ )
207
+ clipping_grad_sample_pairs .append (("ghost" , "ghost" ))
208
+
209
+ for clipping , grad_sample_mode in clipping_grad_sample_pairs :
210
+
211
+ weight_dp , weight_nodp = torch .zeros (10 , 10 ), torch .zeros (10 , 10 )
212
+
213
+ run_demo (
214
+ demo_basic ,
215
+ weight_dp ,
216
+ 2 ,
217
+ dp = True ,
218
+ clipping = clipping ,
219
+ grad_sample_mode = grad_sample_mode ,
220
+ )
221
+ run_demo (
222
+ demo_basic ,
223
+ weight_nodp ,
224
+ 2 ,
225
+ dp = False ,
226
+ clipping = None ,
227
+ grad_sample_mode = None ,
228
+ )
229
+
230
+ self .assertTrue (
231
+ torch .allclose (weight_dp , weight_nodp , atol = 1e-5 , rtol = 1e-3 )
232
+ )
0 commit comments