Commit b7025a7
Improve bwd kernel for fused_add_rms_norm (#1187)
## Summary
This has been generated using the `liger-kernel-perf` skill introduced
in #1185 .
Restructure the backward kernel of `fused_add_rms_norm` to reduce peak
register liveness, dropping register count from **115 → 95 per thread**
on H100. This eliminates a performance cliff at large hidden sizes and
yields up to **70% backward speedup** at H=16384.
## Root Cause Analysis
NCU profiling on H100 revealed the backward kernel was severely
underutilizing the GPU:
| Metric | Forward Kernel | Backward Kernel |
|--------|---------------|-----------------|
| Registers/Thread | 40 | **115** |
| Theoretical Occupancy | 75% | **25%** |
| Achieved Occupancy | 67.87% | **12.49%** |
| Blocks/SM (register-limited) | 6 | **2** |
| Waves/SM | 5.17 | **0.50** |
| Memory Throughput (% peak) | 78.24% | 60.88% |
| Compute Throughput (% peak) | 27.64% | 19.87% |
The H100 has 65,536 registers per SM. With 256 threads/block (8 warps x
32 threads), each thread using 115 registers consumes 115 x 256 = 29,440
registers per block. This means only `floor(65536 / 29440) = 2` blocks
can run concurrently per SM, yielding just 12.5% occupancy. With so few
active warps, the GPU cannot overlap memory accesses with computation —
stalling 77.5% of scheduler slots.
### Why so many registers?
The backward kernel processes one row per loop iteration, maintaining
several BLOCK_SIZE-wide vectors simultaneously in registers. With
BLOCK_SIZE=4096 and 256 threads, each thread handles 16 elements per
vector, requiring 16 registers per vector. The original code had up to
**~8 vectors live simultaneously** at peak:
| Vector | Lifetime | Purpose |
|--------|----------|---------|
| `W_row` | Persistent (loop-invariant) | Weight + offset |
| `dW_row` | Persistent (accumulator) | Weight gradient |
| `col_offsets` | Persistent | Column indices |
| `mask` | Persistent | Bounds mask |
| `dY_row` | Per-iteration | Upstream gradient |
| `X_row` | Per-iteration | Saved residual (fp32) |
| `m` | Per-iteration | `dY * W` intermediate |
| `dX_row` | Per-iteration | Output gradient |
| `dS_out_row` | Per-iteration (conditional) | Residual stream gradient
|
At 16 registers per vector x 8 vectors = 128 registers, plus
scalars/pointers/loop vars, this matches the observed 115 registers
(compiler optimizes some away).
## Changes
### 1. Reorder: compute dW before dX
**Before:** The original code computed `m -> dX -> dW -> store dX`. This
meant `dY_row`, `m`, `X_row`, `dX_row`, and `dW_row` were all live
simultaneously during the dW accumulation.
**After:** By computing `dW` first (`dW -> m -> dX -> store dX`), the
compiler can reuse `dY_row`'s registers when computing `m = dY_row *
W_row`, since `dY_row` is no longer needed independently after the dW
step. This is safe because both `dW` and `m` read `dY_row` — they don't
modify it (except in GEMMA mode where the cast happens after dW).
### 2. Factor the dX formula with a precomputed scalar
**Before:**
```python
dX_row = rstd_row * m
dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)
```
This expression keeps `m`, `X_row`, and the intermediate `rstd * (... *
X_row)` vector all live while computing the compound expression.
**After:**
```python
dot = tl.sum(m * X_row, axis=0) # scalar reduction
c = -(1.0 / n_cols) * rstd_row * rstd_row * rstd_row * dot # scalar
dX_row = rstd_row * m + c * X_row # single fused expression
```
By precomputing `c` as a scalar, the final `dX_row` expression is a
simple element-wise combination of two vectors (`m` and `X_row`) scaled
by scalars. The compiler doesn't need to keep intermediate vector
results alive — it can compute each element of `dX_row` independently,
allowing better register reuse.
### 3. Defer dS_out load
**Before:** `dS_out_row` was loaded and added to `dX_row` inside the
compound correction expression, contributing to peak liveness.
**After:** `dS_out_row` is loaded only after the `dX_row = rstd * m + c
* X` computation is complete. By this point, `m` registers are freed, so
`dS_out_row` reuses them rather than competing for register space.
### 4. Add `num_stages=2` for Hopper software pipelining
**Rationale:** The current code never sets `num_stages`, so Triton
defaults to `num_stages=1` (no software pipelining). On H100 (Hopper
architecture), setting `num_stages >= 2` enables the compiler to overlap
memory loads for the next iteration with computation for the current
iteration using Hopper's async copy engine (TMA).
A parameter sweep confirmed `num_stages=2` is safe across all hidden
sizes with no regressions. Higher values (3-4) showed no additional
benefit for this kernel because the backward kernel's low occupancy
limits the pipeline depth that can be effectively utilized.
`num_stages=2` was chosen as the conservative optimal: it enables basic
pipelining without increasing register pressure (higher stages require
more pipeline buffer registers).
The `num_stages` value is passed through `ctx` from forward to backward,
following the same pattern as `BLOCK_SIZE` and `num_warps`.
### Combined effect
After these changes, the peak register liveness drops to ~5-6 vectors:
| Vector | Still live? |
|--------|------------|
| `W_row` | Yes (persistent) |
| `dW_row` | Yes (persistent) |
| `col_offsets` | Yes (persistent) |
| `mask` | Yes (persistent) |
| `dY_row` / `m` | **Shared** (compiler reuses registers) |
| `X_row` / `dX_row` | **Shared** (compiler reuses after dot product) |
| `dS_out_row` | Loaded late, **reuses freed `m` registers** |
NCU confirmed registers dropped from **115 → 95 per thread** (17%
reduction). While this doesn't cross the 85-register threshold needed
for 3 blocks/SM, the reduced pressure gives the compiler more freedom
for instruction scheduling and ILP (instruction-level parallelism),
which explains the dramatic speedup especially at H=16384 where the
original had extreme register spilling.
## Benchmark Results
**NVIDIA H100 80GB HBM3, M=2048, dtype=float32**
### Speed
| Mode | H | Before (ms) | After (ms) | Improvement |
|------|---|------------|-----------|-------------|
| Backward | 2048 | 0.124 | 0.096 | **23% faster** |
| Backward | 4096 | 0.124 | 0.101 | **18% faster** |
| Backward | 8192 | 0.153 | 0.154 | ~same |
| Backward | 16384 | 1.006 | 0.300 | **70% faster (3.35x)** |
| Backward | 32768 | 1.390 | 1.119 | **20% faster** |
| Full | 16384 | 1.186 | 0.482 | **59% faster** |
| Full | 32768 | 1.819 | 1.550 | **15% faster** |
| Forward | all sizes | - | - | unchanged |
### Memory
No change at any size (~30% savings vs HuggingFace maintained).
### Plots
**Backward Speed** (before / after / HuggingFace):
<!-- Paste backward speed plot here -->
**Full Pass Speed**:
<!-- Paste full pass speed plot here -->
**Forward Speed** (unchanged):
<!-- Paste forward speed plot here -->
**Memory** (unchanged):
<!-- Paste memory plot here -->
## Test plan
- [x] Full test suite passed: `python -m pytest
test/transformers/test_fused_add_rms_norm.py -xvs` (40/40 tests, all
casting modes + dtypes + shapes)
- [x] Checkstyle passed (ruff)
- [x] No memory regression at any hidden size
- [x] No forward pass regression
- [x] Benchmarked on NVIDIA H100 80GB HBM3
- [x] NCU profiling confirmed register reduction (115 → 95)
---------
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>1 parent 2bf7117 commit b7025a7
File tree
2 files changed
+102
-93
lines changed- benchmark/data
- src/liger_kernel/ops
2 files changed
+102
-93
lines changed
0 commit comments