RMS Norm¶
This page is a walkthrough of examples/hopper/rms_norm.py — the
simplest kernel in the repo and the right first read after
First Kernel.
It computes:
one row per CTA, B CTAs total. Reaches 2.6 TB/s at B=2048, N=8192
on H100 (88% of HBM3 peak), 3.9× faster than PyTorch eager.
This kernel is bandwidth-bound, so the interesting work is in the memory
access pattern — not in a tensor-core intrinsic. It's a good stress
test of whether the DSL stays honest when the hot path is ld.global.v4
instead of wgmma.mma_async.
What The Kernel Computes Per Row¶
- Load the
N-wide row ofX, strided acrossblockthreads. - Each thread accumulates its slice's sum-of-squares into a scalar.
- Warp-level butterfly reduction turns those into one partial per warp.
- A final warp reduces across partials and broadcasts via SMEM.
- Compute
rstd = 1/sqrt(mean + eps). - Reload the per-thread slice, multiply by
rstd * W[i], writeY.
Steps 1 and 6 are the two bandwidth passes. Everything else is arithmetic and synchronization.
Step 1: Pick The Block Size From N¶
_pick_block(n) prefers blocks that leave at least 4 f32 items per
thread and where that item count is divisible by 4:
for block in (512, 256, 128, 64, 32):
ipt = n // block
if n % block == 0 and ipt >= 4 and ipt % 4 == 0 and block >= 128:
# ...pick the biggest ipt
The "divisible by 4" constraint exists so every memory transaction can
be a ld.global.v4.f32 — one 16-byte load that feeds four accumulator
FMAs. That's the memory-level parallelism knob: more outstanding v4
loads per thread → more DRAM requests in flight → better HBM
utilization.
Step 2: Launch Config¶
@kernel(
in_specs=(Tile(B, N, f32), Tile(N, f32)),
out_specs=(Tile(B, N, f32),),
grid=(B, 1, 1),
block=(block, 1, 1),
arch=arch,
)
def rms_norm(X, W, Y):
One CTA per batch row. Inside the CTA, block threads cooperate on
that row. No cross-CTA communication.
Step 3: Prologue — Pointers, Row Offset, Per-Warp Bookkeeping¶
partials = smem.alloc(f32, (num_warps, 1))
stats = smem.alloc(f32, (1, 1))
px, pw, py = ptx.global_ptrs(X, W, Y)
tid = reg.scalar(u32); ptx.inst.mov.u32(tid, ptx.special.tid.x())
row = reg.scalar(u32); ptx.inst.mov.u32(row, ptx.special.ctaid.x())
row_byte_off = row * (N * 4)
px += row_byte_off
py += row_byte_off
lane = tid & (WARP_SIZE - 1)
warp_id = tid >> 5
Three things to notice:
ptx.global_ptrs(X, W, Y)is one call that unpacks three parameter-pointer prologues. Without it you'd write three near-identicalld.param.u64sequences by hand.row * (N * 4)uses Python operator overloading onReg. The DSL emitsmul.wide.u32under the covers — exactly what you'd write by hand.Whas no row offset — it's shared across all rows of the batch.
Step 4: Pass 1 — Load And Square¶
With use_v4 true (the common case), each thread executes
v4_iters = items_per_thread // 4 vectorized loads:
sum_sq = reg.scalar(f32, init=0.0)
x_vals = reg.array(f32, items_per_thread)
elem_base = tid << 2 # every thread's lane-0 element
for j in range(v4_iters):
idx = elem_base if j == 0 else elem_base + (j * block * 4)
ptr = px + idx * 4
ptx.inst.ld.global_.v4.f32(
[x_vals[j*4], x_vals[j*4+1], x_vals[j*4+2], x_vals[j*4+3]],
ptx.addr(ptr),
)
for sub in range(4):
ptx.inst.fma.rn.f32(sum_sq, x_vals[j*4+sub], x_vals[j*4+sub], sum_sq)
Two design choices worth reading twice:
x_vals = reg.array(f32, items_per_thread)is preserved across passes. The per-thread slice is held in registers through the whole kernel, so pass 2 doesn't need to re-load from global. That turns a 2-pass-over-HBM algorithm into a 1.5-pass — the weights load again, the inputs don't.fma.rn.f32(sum_sq, x, x, sum_sq)is one PTX instruction doingsum_sq = x*x + sum_sqrn-rounded. Writingsum_sq = sum_sq + x*xin the DSL would emit two instructions (mul + add) — not the same. When you care about the instruction count, reach forinst.fma.rn.
Step 5: Warp Reduce, Then Block Reduce¶
ptx.warp.reduce_sum(sum_sq)
with ptx.if_(lane == 0):
partials[warp_id, 0] = sum_sq
ptx.bar.sync(0)
with ptx.if_(tid == 0):
block_sum = reg.scalar(f32, init=0.0)
for i in range(num_warps):
ptx.inst.add.f32(block_sum, block_sum, partials[i, 0])
stats[0, 0] = block_sum
ptx.bar.sync(0)
ptx.inst.mov.f32(sum_sq, stats[0, 0])
Three hops:
warp.reduce_sumis a butterflyshfl.bfly.syncreduction — no hand-rolled helper, but the DSL still emits the same five PTX instructions you'd write manually.- Per-warp partials to SMEM. Lane 0 of each warp writes its
sum.
bar.sync(0)ensures all warps are past this point before the next stage reads. - Thread 0 collapses partials. Final result parked in
stats[0,0]for broadcast. The secondbar.sync(0)makes the value visible to the whole CTA. Every thread re-reads it.
This is intentionally the two-phase pattern. For small num_warps,
the second phase is a trivial linear scan — not worth another warp
butterfly.
Step 6: Compute rstd Once Per Thread¶
mean_sq = reg.scalar(f32)
inv_n = reg.scalar(f32, init=1.0 / N)
ptx.inst.mul.f32(mean_sq, sum_sq, inv_n)
eps_reg = reg.scalar(f32, init=eps)
ptx.inst.add.f32(mean_sq, mean_sq, eps_reg)
rstd = reg.scalar(f32)
ptx.inst.rsqrt.approx.f32(rstd, mean_sq)
rsqrt.approx.f32 is the single-instruction reciprocal square root.
N is baked in at trace time as 1.0 / N so the kernel doesn't do a
division on the hot path.
Step 7: Pass 2 — Load W, Multiply, Store¶
for j in range(v4_iters):
idx = elem_base if j == 0 else elem_base + (j * block * 4)
off = idx * 4
w_vals = [reg.scalar(f32) for _ in range(4)]
ptx.inst.ld.global_.v4.f32(w_vals, ptx.addr(pw + off))
y_vals = []
for sub in range(4):
y_val = reg.scalar(f32)
ptx.inst.mul.f32(y_val, x_vals[j*4+sub], rstd)
ptx.inst.mul.f32(y_val, y_val, w_vals[sub])
y_vals.append(y_val)
ptx.inst.st.global_.v4.f32(ptx.addr(py + off), y_vals)
x_valsis still in registers from pass 1 — no reload.Wis read fresh (it'sN-wide, notB*N, so it's tiny and cached).- The store is also v4. Every DRAM transaction in this kernel is 16 B.
Why This Kernel Matters For The DSL¶
RMS norm is a good DSL test because nothing here benefits from a tensor-core intrinsic:
ld.global.v4.f32andst.global.v4.f32with register lists must map cleanly to the PTX form.fma.rn.f32has to be spellable as one instruction, not lowered from Python*and+.rsqrt.approx.f32,reduce_sum, andbar.sync(0)need to be first-class, not escape hatches.smem.alloc(f32, (num_warps, 1))must work for tiny staging, not just WGMMA tiles.
If any of those fell back to a lower-level helper, this kernel would read like a hand-rolled assembler instead of a DSL. That it doesn't is the point.
What To Read Next¶
- SwiGLU — same memory pattern plus a fast-path
silu = x * sigmoid(x)built fromex2.approx+rcp.approx - Hopper GEMM — when the hot path is
wgmma.mma_asyncinstead ofld.global.v4 - PTX Namespace — every helper used above, as a reference page