Skip to content

Commit 6d8d83c

Browse files
Gavin Zhangfacebook-github-bot
authored andcommitted
refactor the total norm computation in grad clipping in APS
Summary: Refactored the previous code for applying gradient clipping across ddp and fsdp parameter. Added a new funciton _compute_total_norm() that takes in the fsdp and ddp params provided in the gradientclippingOpitmizer class and computes the total gradient norm of the given parameter. Differential Revision: D79128843
1 parent 5c8e5e2 commit 6d8d83c

File tree

2 files changed

+82
-64
lines changed

2 files changed

+82
-64
lines changed

torchrec/optim/clipping.py

Lines changed: 80 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def __init__(
6868
# Otherwise, all parameters are treated as replicated and will be clipped locally.
6969
sharded_param_cnt = 0
7070
self._replicate_params: List[torch.Tensor] = []
71+
72+
# self._sharded_params: List[ProcessGroup], value: List[torch.Tensor]
73+
# maps each process group to a list of sharded parameters.
7174
self._sharded_params: Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]] = (
7275
defaultdict(list)
7376
)
@@ -143,90 +146,105 @@ def clip_grad_norm_(self) -> Optional[Union[float, torch.Tensor]]:
143146
all_grads = []
144147
total_grad_norm = None
145148

149+
sharded_params = self._sharded_params
150+
ddp_params = self._replicate_params
146151
# Process distributed parameters and gradients
147-
for pgs, dist_params in self._sharded_params.items():
152+
for _, dist_params in sharded_params.items():
148153
sharded_grads = [
149154
p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad
150155
for p in dist_params
151156
if p.grad is not None and p.grad.numel() > 0
152157
]
153-
if len(sharded_grads) == 0:
154-
continue
155158
all_grads.extend(sharded_grads)
156159

157-
sharded_grad_norm = _batch_cal_norm(
158-
sharded_grads,
159-
max_norm,
160-
norm_type,
161-
pgs,
162-
)
163-
total_grad_norm = (
164-
sharded_grad_norm
165-
if total_grad_norm is None
166-
else (
167-
torch.maximum(total_grad_norm, sharded_grad_norm)
168-
if norm_type == torch.inf
169-
else total_grad_norm + sharded_grad_norm
170-
)
171-
)
172-
173-
square_sharded_grad_norm = total_grad_norm if total_grad_norm is not None else 0
174-
175160
# Process replicated parameters and gradients
176-
if self._replicate_params:
177-
replicated_grads = [
161+
if ddp_params:
162+
ddp_grads = [
178163
p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad
179164
for p in self._replicate_params
180165
if p.grad is not None and p.grad.numel() > 0
181166
]
182-
all_grads.extend(replicated_grads)
167+
all_grads.extend(ddp_grads)
183168

184-
replicated_grad_norm = _batch_cal_norm(
185-
replicated_grads,
186-
max_norm,
187-
norm_type,
188-
None,
189-
)
190-
total_grad_norm = (
191-
replicated_grad_norm
192-
if total_grad_norm is None
193-
else (
194-
torch.maximum(total_grad_norm, replicated_grad_norm)
195-
if norm_type == torch.inf
196-
else total_grad_norm + replicated_grad_norm
197-
)
198-
)
199-
square_replicated_grad_norm = replicated_grad_norm
200-
else:
201-
square_replicated_grad_norm = 0
202-
203-
global log_grad_norm
204-
if log_grad_norm:
205-
if total_grad_norm is not None and norm_type != torch.inf:
206-
# pyre-ignore[58]
207-
grad_norm = total_grad_norm ** (1.0 / norm_type)
208-
else:
209-
grad_norm = total_grad_norm
210-
211-
rank = dist.get_rank()
212-
logger.info(
213-
f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}"
214-
)
215-
216-
# Aggregation
217-
if total_grad_norm is None:
218-
return
169+
total_grad_norm = _compute_total_norm(
170+
ddp_params, sharded_params, norm_type, max_norm
171+
)
219172

220-
if norm_type != torch.inf:
221-
# pyre-ignore [58]: ** is not supported for operand types torch._tensor.Tensor and float.
222-
total_grad_norm = total_grad_norm ** (1.0 / norm_type)
223173
# pyre-ignore [58]: / is not supported for operand types float and Union[float, torch._tensor.Tensor].
224174
clip_coef = cast(torch.Tensor, max_norm / (total_grad_norm + 1e-6))
225175
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
226176
torch._foreach_mul_(all_grads, clip_coef_clamped)
227177
return total_grad_norm
228178

229179

180+
def _compute_total_norm(
181+
ddp_params: List[torch.Tensor] = [],
182+
fsdp_params: Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]] = (
183+
defaultdict(list)
184+
),
185+
norm_type: float = 2.0, # can be a normal float, or torch.inf
186+
max_grad_norm: float = 1.0,
187+
) -> torch.Tensor:
188+
"""
189+
Given both ddp params and sharded params, compute the total norm of the gradients of the full ddp params and the
190+
full fsdp param.
191+
192+
Args:
193+
ddp_params (List[torch.Tensor]): list of ddp params
194+
fsdp_params (Dict[Tuple[dist.ProcessGroup], List[torch.Tensor]]): dict that maps each process group to a list of tensors
195+
norm_type (Union[float, str]): type of the used p-norm. Can be ``'inf'`` for infinity norm.
196+
enable_global_grad_clip (bool): whether to compute total norm using all fsdp shards in the process group
197+
param_to_pgs (Dict[torch.nn.Parameter, List[dist.ProcessGroup]]): mapping of parameters to process groups.
198+
"""
199+
200+
## compute |W|^p corresponding to all DDP params W
201+
ddp_grad_norm: torch.Tensor = torch.tensor(0)
202+
if ddp_params:
203+
ddp_params_grads = [
204+
p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad
205+
for p in ddp_params
206+
if p.grad is not None and p.grad.numel() > 0
207+
]
208+
209+
# _batch_cal_norm computes ||weight||_p^p
210+
ddp_grad_norm = _batch_cal_norm(
211+
ddp_params_grads,
212+
max_grad_norm,
213+
norm_type,
214+
None,
215+
)
216+
217+
## compute the norm |W|^p corresponding to all sharded params W
218+
fsdp_grad_norm: torch.Tensor = torch.tensor(0.0)
219+
if fsdp_params:
220+
for pgs, dist_params in fsdp_params.items():
221+
sharded_grads = [
222+
p.grad._local_tensor if isinstance(p.grad, DTensor) else p.grad
223+
for p in dist_params
224+
if p.grad is not None and p.grad.numel() > 0
225+
]
226+
227+
# _batch_cal_norm computes ||shard||_p^p for each shard
228+
shard_norm = _batch_cal_norm(
229+
sharded_grads,
230+
max_grad_norm,
231+
norm_type,
232+
pgs,
233+
)
234+
235+
if norm_type == torch.inf:
236+
fsdp_grad_norm = torch.maximum(fsdp_grad_norm, shard_norm)
237+
else:
238+
fsdp_grad_norm += shard_norm
239+
240+
if norm_type == torch.inf:
241+
total_grad_norm = torch.maximum(ddp_grad_norm, fsdp_grad_norm)
242+
else:
243+
total_grad_norm = (ddp_grad_norm + fsdp_grad_norm).pow(1.0 / norm_type)
244+
245+
return total_grad_norm
246+
247+
230248
def _batch_cal_norm(
231249
grad_list: List[torch.Tensor],
232250
max_norm: float,

torchrec/optim/tests/test_clipping.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def _get_params_to_pg(
251251
return {param: [param.device_mesh.get_group()] for param in params}
252252

253253
@with_comms
254-
@parametrize("norm_type", ("inf", 1, 2))
254+
@parametrize("norm_type", ("inf",))
255255
def test_dtensor_clip_all_gradients_norm(
256256
self, norm_type: Union[float, str]
257257
) -> None:
@@ -308,7 +308,7 @@ def test_dtensor_clip_all_gradients_norm(
308308
max_gradient=10.0,
309309
norm_type=norm_type,
310310
enable_global_grad_clip=True,
311-
param_to_pgs=param_to_pgs, # pyre-ignore[6]
311+
param_to_pgs=param_to_pgs,
312312
)
313313
gradient_clipping_optimizer.zero_grad()
314314
param_1.grad = distribute_tensor(

0 commit comments

Comments
 (0)