Skip to content

Hopper / Swiglu

View on GitHub examples/hopper/swiglu.py

Overview

Fused SwiGLU activation, written in pyptx, callable from JAX and PyTorch.

Reaches 2.8 TB/s at M=2048 F=8192 f32 on H100 (94% of HBM3 peak), 1.6x faster than F.silu(g) * u eager. Bandwidth-bound; saturates HBM3 on large shapes.

h[i, j] = silu(gate[i, j]) * up[i, j]

Vectorized with ld.global.v4.f32 / st.global.v4.f32 so every memory transaction moves 16 bytes. For F >= 4*block we also process multiple items per thread in a single v4 load, which gives us enough memory-level parallelism to saturate HBM3 on H100.

Source

Full source
"""Fused SwiGLU activation, written in pyptx, callable from JAX and PyTorch.

Reaches **2.8 TB/s** at M=2048 F=8192 f32 on H100 (94% of HBM3 peak),
**1.6x** faster than ``F.silu(g) * u`` eager. Bandwidth-bound; saturates
HBM3 on large shapes.

``h[i, j] = silu(gate[i, j]) * up[i, j]``

Vectorized with ``ld.global.v4.f32`` / ``st.global.v4.f32`` so every memory
transaction moves 16 bytes. For F >= 4*block we also process multiple
items per thread in a single v4 load, which gives us enough memory-level
parallelism to saturate HBM3 on H100.
"""
from __future__ import annotations

import math

import jax
import jax.numpy as jnp
import numpy as np

from pyptx import kernel, reg, ptx, Tile
from pyptx.types import f32, u32


BLOCK_CANDIDATES = (1024, 512, 256, 128, 64, 32)
LOG2E = 1.4426950408889634


def _pick_block(f: int) -> int:
    # Prefer blocks that admit v4 vectorization with many items/thread.
    best = None
    for block in BLOCK_CANDIDATES:
        if f % (block * 4) == 0 and block >= 128:
            ipt = f // block
            if best is None or ipt > best[1]:
                best = (block, ipt)
    if best is not None:
        return best[0]
    for block in BLOCK_CANDIDATES:
        if f % block == 0:
            return block
    raise AssertionError(
        f"F={f} must be divisible by one of {BLOCK_CANDIDATES}"
    )


def build_fused_silu_mul(
    M: int,
    F: int,
    *,
    rows_per_cta: int = 1,
    arch: str = "sm_90a",
):
    """Build a fused silu+mul kernel for inputs shaped ``(M, F)``.

    ``rows_per_cta`` lets one CTA process multiple rows to amortize
    launch overhead for small M.
    """
    block = _pick_block(F)
    items_per_thread = F // block
    use_v4 = items_per_thread % 4 == 0
    v4_iters = items_per_thread // 4 if use_v4 else 0
    assert items_per_thread >= 1
    if M % rows_per_cta != 0:
        rows_per_cta = 1
    version = (8, 7) if arch.startswith("sm_100") else None

    @kernel(
        in_specs=(
            Tile(M, F, f32),   # gate
            Tile(M, F, f32),   # up
        ),
        out_specs=(Tile(M, F, f32),),  # out
        grid=(M // rows_per_cta, 1, 1),
        block=(block, 1, 1),
        arch=arch,
        version=version,
    )
    def fused_silu_mul(gate, up, out):
        pg, pu, po = ptx.global_ptrs(gate, up, out)

        row_base = reg.scalar(u32)
        ptx.inst.mov.u32(row_base, ptx.special.ctaid.x())
        if rows_per_cta > 1:
            ptx.inst.mul.lo.u32(row_base, row_base, rows_per_cta)
        row_byte_off = row_base * (F * 4)
        pg += row_byte_off
        pu += row_byte_off
        po += row_byte_off

        tid = reg.scalar(u32)
        ptx.inst.mov.u32(tid, ptx.special.tid.x())

        neg_log2e = reg.scalar(f32, init=-LOG2E)
        one = reg.scalar(f32, init=1.0)

        def process_row(row_bytes):
            if use_v4:
                elem_base = tid << 2
                for j in range(v4_iters):
                    idx = elem_base if j == 0 else elem_base + (j * block * 4)
                    off = idx * 4
                    if row_bytes:
                        off = off + row_bytes
                    g_vals = [reg.scalar(f32) for _ in range(4)]
                    ptx.inst.ld.global_.v4.f32(g_vals, ptx.addr(pg + off))
                    u_vals = [reg.scalar(f32) for _ in range(4)]
                    ptx.inst.ld.global_.v4.f32(u_vals, ptx.addr(pu + off))
                    out_vals = []
                    for sub in range(4):
                        neg_g = reg.scalar(f32)
                        ptx.inst.mul.f32(neg_g, g_vals[sub], neg_log2e)
                        exp_neg = reg.scalar(f32)
                        ptx.inst.ex2.approx.f32(exp_neg, neg_g)
                        denom = reg.scalar(f32)
                        ptx.inst.add.f32(denom, one, exp_neg)
                        sigm = reg.scalar(f32)
                        ptx.inst.rcp.approx.f32(sigm, denom)
                        silu_g = reg.scalar(f32)
                        ptx.inst.mul.f32(silu_g, g_vals[sub], sigm)
                        out_val = reg.scalar(f32)
                        ptx.inst.mul.f32(out_val, silu_g, u_vals[sub])
                        out_vals.append(out_val)
                    ptx.inst.st.global_.v4.f32(ptx.addr(po + off), out_vals)
            else:
                for i in range(items_per_thread):
                    idx = reg.scalar(u32)
                    ptx.inst.add.u32(idx, tid, i * block)
                    off = idx * 4
                    if row_bytes:
                        off = off + row_bytes
                    g_val = reg.scalar(f32)
                    ptx.inst.ld.global_.f32(g_val, ptx.addr(pg + off))
                    u_val = reg.scalar(f32)
                    ptx.inst.ld.global_.f32(u_val, ptx.addr(pu + off))
                    neg_g_log2 = reg.scalar(f32)
                    ptx.inst.mul.f32(neg_g_log2, g_val, neg_log2e)
                    exp_neg = reg.scalar(f32)
                    ptx.inst.ex2.approx.f32(exp_neg, neg_g_log2)
                    denom = reg.scalar(f32)
                    ptx.inst.add.f32(denom, one, exp_neg)
                    sigm = reg.scalar(f32)
                    ptx.inst.rcp.approx.f32(sigm, denom)
                    silu_g = reg.scalar(f32)
                    ptx.inst.mul.f32(silu_g, g_val, sigm)
                    out_val = reg.scalar(f32)
                    ptx.inst.mul.f32(out_val, silu_g, u_val)
                    ptx.inst.st.global_.f32(ptx.addr(po + off), out_val)

        for r in range(rows_per_cta):
            process_row(r * F * 4 if rows_per_cta > 1 else 0)

        ptx.ret()

    return fused_silu_mul


# ---------------------------------------------------------------------------
# JAX reference + test harness
# ---------------------------------------------------------------------------

def fused_silu_mul_ref(gate, up):
    return jax.nn.silu(gate) * up


def _run_jax_case(M: int, F: int) -> None:
    k = build_fused_silu_mul(M, F)
    np.random.seed(M * 65537 + F)
    g = jnp.asarray(np.random.randn(M, F).astype(np.float32) * 0.5)
    u = jnp.asarray(np.random.randn(M, F).astype(np.float32) * 0.5)

    @jax.jit
    def fn(g, u):
        return k(g, u)

    out = np.asarray(fn(g, u))
    ref = np.asarray(fused_silu_mul_ref(g, u))
    diff = float(np.abs(out - ref).max())
    ok = bool(np.allclose(out, ref, atol=1e-4, rtol=1e-3))
    status = "OK  " if ok else "FAIL"
    print(f"[JAX  {status}] M={M:4d} F={F:5d}  max_abs={diff:.3e}")


def _run_torch_case(M: int, F: int) -> None:
    import torch

    k = build_fused_silu_mul(M, F)
    np.random.seed(M * 65537 + F)
    g = torch.tensor(np.random.randn(M, F).astype(np.float32) * 0.5, device="cuda")
    u = torch.tensor(np.random.randn(M, F).astype(np.float32) * 0.5, device="cuda")

    out = k(g, u)
    torch.cuda.synchronize()
    ref = torch.nn.functional.silu(g) * u
    diff = float((out - ref).abs().max())
    ok = bool(torch.allclose(out, ref, atol=1e-4, rtol=1e-3))
    status = "OK  " if ok else "FAIL"
    print(f"[Torch{status}] M={M:4d} F={F:5d}  max_abs={diff:.3e}")


def main() -> None:
    _ = (jnp.ones((4,), dtype=jnp.float32) + 1).block_until_ready()

    for M, F in [
        (4, 512),
        (16, 512),
        (32, 1024),
        (128, 2048),
        (256, 4096),
        (512, 8192),
    ]:
        _run_jax_case(M, F)
        _run_torch_case(M, F)


if __name__ == "__main__":
    main()