Skip to content

Commit 4823344

Browse files
iden-kalemajfacebook-github-bot
authored andcommitted
Add multi_gpu test for ghost clipping (#665)
Summary: Pull Request resolved: #665 Modify the existing `multigpu_gradcheck.py` test to check gradient correctness for ghost clipping in a distributed setting. Reviewed By: HuanyuZhang Differential Revision: D60840755 fbshipit-source-id: 5162fde94588eec0f6e546afe1a23a370ca4a48c
1 parent f2a591a commit 4823344

File tree

1 file changed

+87
-28
lines changed

1 file changed

+87
-28
lines changed

opacus/tests/multigpu_gradcheck.py

Lines changed: 87 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import itertools
1617
import os
1718
import sys
1819
import unittest
@@ -24,11 +25,15 @@
2425
import torch.optim as optim
2526
from opacus import PrivacyEngine
2627
from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP
28+
from opacus.grad_sample import GradSampleModuleFastGradientClipping
2729
from opacus.optimizers.ddp_perlayeroptimizer import (
2830
DistributedPerLayerOptimizer,
2931
SimpleDistributedPerLayerOptimizer,
3032
)
3133
from opacus.optimizers.ddpoptimizer import DistributedDPOptimizer
34+
from opacus.optimizers.ddpoptimizer_fast_gradient_clipping import (
35+
DistributedDPOptimizerFastGradientClipping,
36+
)
3237
from torch.nn.parallel import DistributedDataParallel as DDP
3338
from torch.utils.data import DataLoader, TensorDataset
3439
from torch.utils.data.distributed import DistributedSampler
@@ -69,6 +74,45 @@ def forward(self, x):
6974
return self.net2(self.relu(self.net1(x)))
7075

7176

77+
def run_ghost_clipping_test(
78+
model, optimizer, data_loader, batch_size, max_grad_norm, weight, rank
79+
):
80+
81+
ddp_model = DPDDP(model)
82+
ddp_model = GradSampleModuleFastGradientClipping(
83+
ddp_model,
84+
max_grad_norm=max_grad_norm,
85+
use_ghost_clipping=True,
86+
)
87+
optimizer = DistributedDPOptimizerFastGradientClipping(
88+
optimizer,
89+
noise_multiplier=0,
90+
max_grad_norm=max_grad_norm,
91+
expected_batch_size=batch_size,
92+
)
93+
94+
assert isinstance(optimizer, DistributedDPOptimizerFastGradientClipping)
95+
96+
loss_fn = nn.CrossEntropyLoss(reduction="none")
97+
98+
for x, y in data_loader:
99+
ddp_model.enable_hooks()
100+
outputs = ddp_model(x.to(rank))
101+
loss_per_sample = loss_fn(outputs, y)
102+
torch.mean(loss_per_sample).backward(retain_graph=True)
103+
optimizer.zero_grad()
104+
rescaled_loss_per_sample = ddp_model.get_coeff() * loss_per_sample
105+
rescaled_loss = torch.sum(rescaled_loss_per_sample)
106+
ddp_model.disable_hooks()
107+
rescaled_loss.backward()
108+
ddp_model.enable_hooks()
109+
optimizer.step()
110+
break
111+
112+
weight.copy_(model.net1.weight.data.cpu())
113+
cleanup()
114+
115+
72116
def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
73117
torch.manual_seed(world_size)
74118
batch_size = 32
@@ -79,12 +123,15 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
79123
model.net1.weight.data.zero_()
80124
optimizer = optim.SGD(model.parameters(), lr=1)
81125

126+
# create dataset
82127
labels = torch.randn(2 * batch_size, 5).to(rank)
83128
data = torch.randn(2 * batch_size, 10)
84-
85129
dataset = TensorDataset(data, labels)
86130

87-
loss_fn = nn.MSELoss()
131+
loss_fn = nn.CrossEntropyLoss()
132+
133+
max_grad_norm = 1e8
134+
88135
if dp and clipping == "flat":
89136
ddp_model = DPDDP(model)
90137
else:
@@ -96,8 +143,15 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
96143
dataset, num_replicas=world_size, rank=rank, shuffle=False
97144
)
98145
data_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
146+
147+
# use a separate function for ghost clipping since the procedure has a different structure
148+
if dp and clipping == "ghost":
149+
run_ghost_clipping_test(
150+
model, optimizer, data_loader, batch_size, max_grad_norm, weight, rank
151+
)
152+
return
153+
99154
if dp:
100-
max_grad_norm = 1e8
101155
if clipping == "per_layer":
102156
max_grad_norm = [max_grad_norm for _ in model.parameters()]
103157
ddp_model, optimizer, data_loader = privacy_engine.make_private(
@@ -141,33 +195,38 @@ def run_demo(demo_fn, weight, world_size, dp, clipping, grad_sample_mode):
141195

142196
class GradientComputationTest(unittest.TestCase):
143197
def test_gradient_correct(self) -> None:
144-
# Tests that gradient is the same with DP or with DDP
198+
# Tests that gradient is the same with DP or without DDP
145199
n_gpus = torch.cuda.device_count()
146200
self.assertTrue(
147201
n_gpus >= 2, f"Need at least 2 gpus but was provided only {n_gpus}."
148202
)
149203

150-
for clipping in ["flat", "per_layer"]:
151-
for grad_sample_mode in ["hooks", "ew"]:
152-
weight_dp, weight_nodp = torch.zeros(10, 10), torch.zeros(10, 10)
153-
154-
run_demo(
155-
demo_basic,
156-
weight_dp,
157-
2,
158-
dp=True,
159-
clipping=clipping,
160-
grad_sample_mode=grad_sample_mode,
161-
)
162-
run_demo(
163-
demo_basic,
164-
weight_nodp,
165-
2,
166-
dp=False,
167-
clipping=None,
168-
grad_sample_mode=None,
169-
)
170-
171-
self.assertTrue(
172-
torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3)
173-
)
204+
clipping_grad_sample_pairs = list(
205+
itertools.product(["flat", "per_layer"], ["hooks", "ew"])
206+
)
207+
clipping_grad_sample_pairs.append(("ghost", "ghost"))
208+
209+
for clipping, grad_sample_mode in clipping_grad_sample_pairs:
210+
211+
weight_dp, weight_nodp = torch.zeros(10, 10), torch.zeros(10, 10)
212+
213+
run_demo(
214+
demo_basic,
215+
weight_dp,
216+
2,
217+
dp=True,
218+
clipping=clipping,
219+
grad_sample_mode=grad_sample_mode,
220+
)
221+
run_demo(
222+
demo_basic,
223+
weight_nodp,
224+
2,
225+
dp=False,
226+
clipping=None,
227+
grad_sample_mode=None,
228+
)
229+
230+
self.assertTrue(
231+
torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3)
232+
)

0 commit comments

Comments
 (0)