SwiGLU¶
This page walks through examples/hopper/swiglu.py — a fused SwiGLU
activation kernel that reaches 2.8 TB/s at M=2048, F=8192 f32 on
H100 (94% of HBM3 peak), 1.6× faster than
F.silu(gate) * up eager.
It computes:
per element. Three tensors in, one out. Bandwidth-bound — the
interesting work is (a) saturating HBM3 with v4 loads, (b) computing
sigmoid without a div.f32 on the hot path.
Read RMS Norm first if you haven't. SwiGLU uses the
same v4 memory pattern; the new surface here is the math path for
silu.
What Makes SwiGLU Different From A Plain Map¶
SwiGLU is a three-input, one-output elementwise kernel. Compared to a plain unary map, that changes two things:
- Three DRAM streams to interleave. Two loads (gate, up), one store (out).
- No reductions, no barriers. Each thread is independent, so the kernel can afford to have many items per thread to hide latency.
The memory pattern is the same as RMS Norm pass 2 but doubled on the input side.
Step 1: Pick The Block Size For v4¶
for block in (1024, 512, 256, 128, 64, 32):
if f % (block * 4) == 0 and block >= 128:
ipt = f // block
if best is None or ipt > best[1]:
best = (block, ipt)
Same logic as RMS Norm: prefer the config with the most items per
thread such that every load can still be a ld.global.v4.f32. For
F=8192, block=1024 gives ipt=8 which is 2 v4 loads per thread per
row — good memory-level parallelism.
Step 2: Launch Config¶
@kernel(
in_specs=(Tile(M, F, f32), Tile(M, F, f32)),
out_specs=(Tile(M, F, f32),),
grid=(M // rows_per_cta, 1, 1),
block=(block, 1, 1),
arch=arch,
)
def fused_silu_mul(gate, up, out):
One CTA per row by default. rows_per_cta > 1 lets a single CTA walk
multiple rows, which is useful when M is small enough that launch
overhead dominates (e.g. M=4).
Step 3: Prologue — Three Pointers, One Row Offset¶
pg, pu, po = ptx.global_ptrs(gate, up, out)
row_base = reg.scalar(u32)
ptx.inst.mov.u32(row_base, ptx.special.ctaid.x())
if rows_per_cta > 1:
ptx.inst.mul.lo.u32(row_base, row_base, rows_per_cta)
row_byte_off = row_base * (F * 4)
pg += row_byte_off
pu += row_byte_off
po += row_byte_off
All three input tensors share the same row layout, so one
row_byte_off computed once advances all three pointers. The
rows_per_cta branch is resolved at trace time — it's Python if,
not ptx.if_ — so the compiled PTX contains no runtime check.
Step 4: The Fast silu Path¶
Here's the arithmetic core, one v4 chunk:
g_vals = [reg.scalar(f32) for _ in range(4)]
ptx.inst.ld.global_.v4.f32(g_vals, ptx.addr(pg + off))
u_vals = [reg.scalar(f32) for _ in range(4)]
ptx.inst.ld.global_.v4.f32(u_vals, ptx.addr(pu + off))
out_vals = []
for sub in range(4):
neg_g = reg.scalar(f32)
ptx.inst.mul.f32(neg_g, g_vals[sub], neg_log2e) # -g * log2(e)
exp_neg = reg.scalar(f32)
ptx.inst.ex2.approx.f32(exp_neg, neg_g) # exp(-g)
denom = reg.scalar(f32)
ptx.inst.add.f32(denom, one, exp_neg) # 1 + exp(-g)
sigm = reg.scalar(f32)
ptx.inst.rcp.approx.f32(sigm, denom) # 1 / (1 + exp(-g))
silu_g = reg.scalar(f32)
ptx.inst.mul.f32(silu_g, g_vals[sub], sigm) # g * sigmoid(g)
out_val = reg.scalar(f32)
ptx.inst.mul.f32(out_val, silu_g, u_vals[sub]) # silu(g) * u
out_vals.append(out_val)
ptx.inst.st.global_.v4.f32(ptx.addr(po + off), out_vals)
Four places where we're being specific about what we want:
ex2.approx.f32instead ofex.approx.f32. Hopper hasex2(base-2 exponential) as a hardware intrinsic;ex(natural) would need a multiply-by-log2(e)anyway. Fold that constant into the input — pre-computeneg_log2e = -1.4426950408889634as a compile-time register, multiply once.rcp.approx.f32instead ofdiv.f32.div.f32is slow and typically implemented asrcp+mul+ Newton iteration. We don't need the extra precision — approx is good for activations — so skip the Newton step.- The whole
siluis 5 instructions: mul, ex2, add, rcp, mul. Plus one more mul to combine withup. That's 6 fp32 instructions per element. The v4 load/store pair is 2 instructions per 4 elements = 0.5 memory instructions per element. Bandwidth bound: ~0.083 memory instructions per fp32 op. - Register list arguments.
g_valsis a Python list of 4Regobjects, passed directly told.global_.v4.f32. The DSL knows how to spell PTX's{r0, r1, r2, r3}register-vector syntax.
Step 5: Non-v4 Fallback¶
When items_per_thread isn't divisible by 4, the kernel has a scalar
fallback that does exactly the same arithmetic per element:
for i in range(items_per_thread):
idx = reg.scalar(u32)
ptx.inst.add.u32(idx, tid, i * block)
off = idx * 4
# ...single-element ld.global.f32, scalar silu, scalar st.global.f32
This path is a ~4× bandwidth regression because each transaction
moves 4 bytes instead of 16. It exists only for shapes that don't
admit v4 — F not divisible by block * 4. In practice the block
picker keeps this branch cold.
Step 6: Multi-Row Traversal¶
r * F * 4 is the byte offset added on top of each pointer for the
r-th row this CTA is handling. The Python for unrolls at trace
time — if rows_per_cta=4, the emitted kernel has 4 copies of the v4
load/silu/store sequence in a row, no runtime loop counter.
This is a deliberate knob for small M. For M=4 (or similar
debug-sized batches), you don't want to launch 4 CTAs — that's 4×
launch overhead. Setting rows_per_cta=4 collapses it to 1 CTA doing
all 4 rows, with the added benefit that all 4 row-prologue adds (row
offsets) are bulk-computed together.
Why This Kernel Matters For The DSL¶
SwiGLU is the second-simplest kernel in the repo for a reason:
- It's a no-reduction, no-sync elementwise path — the cleanest possible DSL stress test for memory + math.
- It uses three different HBM streams concurrently. If the DSL serialized the loads (e.g. by inserting hidden syncs), you'd see it in the TB/s number — you don't, so it doesn't.
ex2.approx+rcp.approx+fmais the standard toolkit for building transcendentals at approximate precision. Making all three first-class means you can spellsilu,gelu,tanh,softpluswithout ever reaching for inline asm.ptx.global_ptrs(gate, up, out)scales linearly to more tensors — try writing four or five pointer prologues by hand and the DSL's value shows up immediately.
At 2.8 TB/s on H100, this kernel is within 2 percentage points of memory-bandwidth peak. Everything the DSL does on top of raw PTX costs effectively zero runtime — that's the claim this kernel validates.
What To Read Next¶
- Layer Norm — same shape, same v4 pattern, but with a mean pass and a variance pass. Shows the two-moment version of RMS Norm.
- Hopper GEMM — tensor-core hot path instead of memory hot path.
- Grouped GEMM — batched multiply, warp-specialized K loop.