Skip to content

Commit b7025a7

Browse files
Improve bwd kernel for fused_add_rms_norm (#1187)
## Summary This has been generated using the `liger-kernel-perf` skill introduced in #1185 . Restructure the backward kernel of `fused_add_rms_norm` to reduce peak register liveness, dropping register count from **115 → 95 per thread** on H100. This eliminates a performance cliff at large hidden sizes and yields up to **70% backward speedup** at H=16384. ## Root Cause Analysis NCU profiling on H100 revealed the backward kernel was severely underutilizing the GPU: | Metric | Forward Kernel | Backward Kernel | |--------|---------------|-----------------| | Registers/Thread | 40 | **115** | | Theoretical Occupancy | 75% | **25%** | | Achieved Occupancy | 67.87% | **12.49%** | | Blocks/SM (register-limited) | 6 | **2** | | Waves/SM | 5.17 | **0.50** | | Memory Throughput (% peak) | 78.24% | 60.88% | | Compute Throughput (% peak) | 27.64% | 19.87% | The H100 has 65,536 registers per SM. With 256 threads/block (8 warps x 32 threads), each thread using 115 registers consumes 115 x 256 = 29,440 registers per block. This means only `floor(65536 / 29440) = 2` blocks can run concurrently per SM, yielding just 12.5% occupancy. With so few active warps, the GPU cannot overlap memory accesses with computation — stalling 77.5% of scheduler slots. ### Why so many registers? The backward kernel processes one row per loop iteration, maintaining several BLOCK_SIZE-wide vectors simultaneously in registers. With BLOCK_SIZE=4096 and 256 threads, each thread handles 16 elements per vector, requiring 16 registers per vector. The original code had up to **~8 vectors live simultaneously** at peak: | Vector | Lifetime | Purpose | |--------|----------|---------| | `W_row` | Persistent (loop-invariant) | Weight + offset | | `dW_row` | Persistent (accumulator) | Weight gradient | | `col_offsets` | Persistent | Column indices | | `mask` | Persistent | Bounds mask | | `dY_row` | Per-iteration | Upstream gradient | | `X_row` | Per-iteration | Saved residual (fp32) | | `m` | Per-iteration | `dY * W` intermediate | | `dX_row` | Per-iteration | Output gradient | | `dS_out_row` | Per-iteration (conditional) | Residual stream gradient | At 16 registers per vector x 8 vectors = 128 registers, plus scalars/pointers/loop vars, this matches the observed 115 registers (compiler optimizes some away). ## Changes ### 1. Reorder: compute dW before dX **Before:** The original code computed `m -> dX -> dW -> store dX`. This meant `dY_row`, `m`, `X_row`, `dX_row`, and `dW_row` were all live simultaneously during the dW accumulation. **After:** By computing `dW` first (`dW -> m -> dX -> store dX`), the compiler can reuse `dY_row`'s registers when computing `m = dY_row * W_row`, since `dY_row` is no longer needed independently after the dW step. This is safe because both `dW` and `m` read `dY_row` — they don't modify it (except in GEMMA mode where the cast happens after dW). ### 2. Factor the dX formula with a precomputed scalar **Before:** ```python dX_row = rstd_row * m dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row) ``` This expression keeps `m`, `X_row`, and the intermediate `rstd * (... * X_row)` vector all live while computing the compound expression. **After:** ```python dot = tl.sum(m * X_row, axis=0) # scalar reduction c = -(1.0 / n_cols) * rstd_row * rstd_row * rstd_row * dot # scalar dX_row = rstd_row * m + c * X_row # single fused expression ``` By precomputing `c` as a scalar, the final `dX_row` expression is a simple element-wise combination of two vectors (`m` and `X_row`) scaled by scalars. The compiler doesn't need to keep intermediate vector results alive — it can compute each element of `dX_row` independently, allowing better register reuse. ### 3. Defer dS_out load **Before:** `dS_out_row` was loaded and added to `dX_row` inside the compound correction expression, contributing to peak liveness. **After:** `dS_out_row` is loaded only after the `dX_row = rstd * m + c * X` computation is complete. By this point, `m` registers are freed, so `dS_out_row` reuses them rather than competing for register space. ### 4. Add `num_stages=2` for Hopper software pipelining **Rationale:** The current code never sets `num_stages`, so Triton defaults to `num_stages=1` (no software pipelining). On H100 (Hopper architecture), setting `num_stages >= 2` enables the compiler to overlap memory loads for the next iteration with computation for the current iteration using Hopper's async copy engine (TMA). A parameter sweep confirmed `num_stages=2` is safe across all hidden sizes with no regressions. Higher values (3-4) showed no additional benefit for this kernel because the backward kernel's low occupancy limits the pipeline depth that can be effectively utilized. `num_stages=2` was chosen as the conservative optimal: it enables basic pipelining without increasing register pressure (higher stages require more pipeline buffer registers). The `num_stages` value is passed through `ctx` from forward to backward, following the same pattern as `BLOCK_SIZE` and `num_warps`. ### Combined effect After these changes, the peak register liveness drops to ~5-6 vectors: | Vector | Still live? | |--------|------------| | `W_row` | Yes (persistent) | | `dW_row` | Yes (persistent) | | `col_offsets` | Yes (persistent) | | `mask` | Yes (persistent) | | `dY_row` / `m` | **Shared** (compiler reuses registers) | | `X_row` / `dX_row` | **Shared** (compiler reuses after dot product) | | `dS_out_row` | Loaded late, **reuses freed `m` registers** | NCU confirmed registers dropped from **115 → 95 per thread** (17% reduction). While this doesn't cross the 85-register threshold needed for 3 blocks/SM, the reduced pressure gives the compiler more freedom for instruction scheduling and ILP (instruction-level parallelism), which explains the dramatic speedup especially at H=16384 where the original had extreme register spilling. ## Benchmark Results **NVIDIA H100 80GB HBM3, M=2048, dtype=float32** ### Speed | Mode | H | Before (ms) | After (ms) | Improvement | |------|---|------------|-----------|-------------| | Backward | 2048 | 0.124 | 0.096 | **23% faster** | | Backward | 4096 | 0.124 | 0.101 | **18% faster** | | Backward | 8192 | 0.153 | 0.154 | ~same | | Backward | 16384 | 1.006 | 0.300 | **70% faster (3.35x)** | | Backward | 32768 | 1.390 | 1.119 | **20% faster** | | Full | 16384 | 1.186 | 0.482 | **59% faster** | | Full | 32768 | 1.819 | 1.550 | **15% faster** | | Forward | all sizes | - | - | unchanged | ### Memory No change at any size (~30% savings vs HuggingFace maintained). ### Plots **Backward Speed** (before / after / HuggingFace): <!-- Paste backward speed plot here --> **Full Pass Speed**: <!-- Paste full pass speed plot here --> **Forward Speed** (unchanged): <!-- Paste forward speed plot here --> **Memory** (unchanged): <!-- Paste memory plot here --> ## Test plan - [x] Full test suite passed: `python -m pytest test/transformers/test_fused_add_rms_norm.py -xvs` (40/40 tests, all casting modes + dtypes + shapes) - [x] Checkstyle passed (ruff) - [x] No memory regression at any hidden size - [x] No forward pass regression - [x] Benchmarked on NVIDIA H100 80GB HBM3 - [x] NCU profiling confirmed register reduction (115 → 95) --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2bf7117 commit b7025a7

File tree

2 files changed

+102
-93
lines changed

2 files changed

+102
-93
lines changed

0 commit comments

Comments
 (0)