Skip to content

L1 Norm Logging for Activations #1362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,16 @@ def __init__(
else 1
)

self.track_l1 = getattr(neox_args, "log_l1_norm", False)
if self.track_l1 and not hasattr(neox_args, "_l1_norm_cache"):
# one common container hangs off the config object
neox_args._l1_norm_cache = {
"token_embedding": [], # filled once by layer 0 below
"attn": [], # one entry per transformer layer
"mlp": [], # one entry per transformer layer
"lm_head": [], # filled in ParallelLinear
}

if self.num_experts > 1:
from megatron.model.moe import ParallelDroplessMoE

Expand Down Expand Up @@ -1043,6 +1053,12 @@ def _get_bias_dropout(self):
return fn

def forward(self, x, attention_mask, layer_past=None):
# clear l1 norms
if self.track_l1 and self.layer_number == 0:
for key in self.neox_args._l1_norm_cache:
self.neox_args._l1_norm_cache[key].clear()


layer_past = layer_past if layer_past is not None else self.layer_past
bias_dropout_fn = self._get_bias_dropout()

Expand Down Expand Up @@ -1086,6 +1102,12 @@ def forward(self, x, attention_mask, layer_past=None):
attention_output, attention_bias = self.attention(
x1, attention_mask, layer_past=layer_past
)

if self.track_l1:
self.neox_args._l1_norm_cache["attn"].append(
attention_output.detach().abs().sum().item()
)

if self.use_cache:
attention_output, presents = attention_output
self.layer_past = presents
Expand All @@ -1112,6 +1134,11 @@ def forward(self, x, attention_mask, layer_past=None):
else:
output = mlp_output

if self.track_l1:
self.neox_args._l1_norm_cache["mlp"].append(
mlp_output.detach().abs().sum().item()
)

# output = (x + attn(ln(x)) + mlp(ln(x))
output = residual + self.reduce(output)
else:
Expand All @@ -1126,6 +1153,11 @@ def forward(self, x, attention_mask, layer_past=None):
self.input_layernorm(x), attention_mask, layer_past=layer_past
)

if self.track_l1:
self.neox_args._l1_norm_cache["attn"].append(
attention_output.detach().abs().sum().item()
)

if self.use_cache:
attention_output, presents = attention_output
self.layer_past = presents
Expand Down Expand Up @@ -1161,6 +1193,11 @@ def forward(self, x, attention_mask, layer_past=None):
# call signatures of both dense and MoE are the same
mlp_output, mlp_bias = self.mlp(layernorm_output)

if self.track_l1:
self.neox_args._l1_norm_cache["mlp"].append(
mlp_output.detach().abs().sum().item()
)

with torch.enable_grad() if not self.eval else nullcontext():
if mlp_bias == None or (self.num_experts > 1):
# No dropout either
Expand Down
5 changes: 5 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,11 @@ class NeoXArgsLogging(NeoXArgsTemplate):
(N.B - this will only work with pp = 0 for now, as we don't have access to the gradients of the model because deepspeed.)
"""

log_l1_norm: bool = False
"""
If set, each forward pass records the L1-norm of token-embedding, every attention block, every MLP block, and the lm-head.
"""

log_optimizer_states: bool = False
"""
Log the frob norm of the optimizer states to wandb / tensorboard (useful for debugging).
Expand Down
Loading