From 8343bca286f2903ca408462e241b5d4b20ecd312 Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Mon, 9 Jun 2025 15:43:25 -0700 Subject: [PATCH] add some l1 norm statements --- megatron/model/transformer.py | 37 ++++++++++++++++++++++++++++ megatron/neox_arguments/neox_args.py | 5 ++++ 2 files changed, 42 insertions(+) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 514ccba26..945686aaa 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -945,6 +945,16 @@ def __init__( else 1 ) + self.track_l1 = getattr(neox_args, "log_l1_norm", False) + if self.track_l1 and not hasattr(neox_args, "_l1_norm_cache"): + # one common container hangs off the config object + neox_args._l1_norm_cache = { + "token_embedding": [], # filled once by layer 0 below + "attn": [], # one entry per transformer layer + "mlp": [], # one entry per transformer layer + "lm_head": [], # filled in ParallelLinear + } + if self.num_experts > 1: from megatron.model.moe import ParallelDroplessMoE @@ -1043,6 +1053,12 @@ def _get_bias_dropout(self): return fn def forward(self, x, attention_mask, layer_past=None): + # clear l1 norms + if self.track_l1 and self.layer_number == 0: + for key in self.neox_args._l1_norm_cache: + self.neox_args._l1_norm_cache[key].clear() + + layer_past = layer_past if layer_past is not None else self.layer_past bias_dropout_fn = self._get_bias_dropout() @@ -1086,6 +1102,12 @@ def forward(self, x, attention_mask, layer_past=None): attention_output, attention_bias = self.attention( x1, attention_mask, layer_past=layer_past ) + + if self.track_l1: + self.neox_args._l1_norm_cache["attn"].append( + attention_output.detach().abs().sum().item() + ) + if self.use_cache: attention_output, presents = attention_output self.layer_past = presents @@ -1112,6 +1134,11 @@ def forward(self, x, attention_mask, layer_past=None): else: output = mlp_output + if self.track_l1: + self.neox_args._l1_norm_cache["mlp"].append( + mlp_output.detach().abs().sum().item() + ) + # output = (x + attn(ln(x)) + mlp(ln(x)) output = residual + self.reduce(output) else: @@ -1126,6 +1153,11 @@ def forward(self, x, attention_mask, layer_past=None): self.input_layernorm(x), attention_mask, layer_past=layer_past ) + if self.track_l1: + self.neox_args._l1_norm_cache["attn"].append( + attention_output.detach().abs().sum().item() + ) + if self.use_cache: attention_output, presents = attention_output self.layer_past = presents @@ -1161,6 +1193,11 @@ def forward(self, x, attention_mask, layer_past=None): # call signatures of both dense and MoE are the same mlp_output, mlp_bias = self.mlp(layernorm_output) + if self.track_l1: + self.neox_args._l1_norm_cache["mlp"].append( + mlp_output.detach().abs().sum().item() + ) + with torch.enable_grad() if not self.eval else nullcontext(): if mlp_bias == None or (self.num_experts > 1): # No dropout either diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 2929a3036..c9f216de1 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -804,6 +804,11 @@ class NeoXArgsLogging(NeoXArgsTemplate): (N.B - this will only work with pp = 0 for now, as we don't have access to the gradients of the model because deepspeed.) """ + log_l1_norm: bool = False + """ + If set, each forward pass records the L1-norm of token-embedding, every attention block, every MLP block, and the lm-head. + """ + log_optimizer_states: bool = False """ Log the frob norm of the optimizer states to wandb / tensorboard (useful for debugging).