Skip to content

Ampere / Gemm Pipelined

View on GitHub examples/ampere/gemm_pipelined.py

Overview

A100 (sm_80) bf16 GEMM with cp.async + SMEM ring buffer + mma.sync.

Production-leaning Ampere kernel. Uses every Ampere first-class instruction pyptx exposes:

ptx.cp.async_.cg(...)         — async global → SMEM prefetch (16-byte vec)
ptx.cp.async_.commit_group()  — close pending cp.async into a group
ptx.cp.async_.wait_group(N)   — wait until <= N groups remain pending
ptx.mma.sync(shape=(16, 8, 16), ...)        — Ampere tensor-core MMA

For the SMEM → register hand-off this file uses per-thread ld.shared.b32 loads rather than ldmatrix. The fragment-layout math is identical to the direct-from-global examples/ampere/gemm.py — each lane computes its own m16n8k16 fragment indices and loads a few packed-bf16 pairs. ldmatrix would be more efficient (single warp-collective instruction, hardware- optimized bank-conflict handling), but the per-thread path is simpler to verify and demonstrates the cp.async + SMEM-staging path on its own.

Block tile: BM x BN = 64 x 64 K-step: BK = 16 Warps/CTA: 4 (warp w handles M[w16 : (w+1)16]) Per warp: 1 (M) × 8 (N) = 8 mma.sync calls per K-iter SMEM stages: 2 (double-buffered cp.async prefetch)

Inputs: A: (M, K) bf16 row-major B_T: (N, K) bf16 row-major D: (M, N) f32 row-major

Source

Full source
"""A100 (sm_80) bf16 GEMM with cp.async + SMEM ring buffer + mma.sync.

Production-leaning Ampere kernel. Uses every Ampere first-class instruction
pyptx exposes:

    ptx.cp.async_.cg(...)         — async global → SMEM prefetch (16-byte vec)
    ptx.cp.async_.commit_group()  — close pending cp.async into a group
    ptx.cp.async_.wait_group(N)   — wait until <= N groups remain pending
    ptx.mma.sync(shape=(16, 8, 16), ...)        — Ampere tensor-core MMA

For the SMEM → register hand-off this file uses **per-thread `ld.shared.b32`
loads** rather than ``ldmatrix``. The fragment-layout math is identical to
the direct-from-global ``examples/ampere/gemm.py`` — each lane computes its
own m16n8k16 fragment indices and loads a few packed-bf16 pairs. ldmatrix
would be more efficient (single warp-collective instruction, hardware-
optimized bank-conflict handling), but the per-thread path is simpler to
verify and demonstrates the ``cp.async`` + SMEM-staging path on its own.

Block tile:  BM x BN = 64 x 64
K-step:      BK = 16
Warps/CTA:   4    (warp w handles M[w*16 : (w+1)*16])
Per warp:    1 (M) × 8 (N) = 8 ``mma.sync`` calls per K-iter
SMEM stages: 2    (double-buffered ``cp.async`` prefetch)

Inputs:
  A:   (M, K) bf16 row-major
  B_T: (N, K) bf16 row-major
  D:   (M, N) f32 row-major
"""
from __future__ import annotations

import os

os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")

import jax
import jax.numpy as jnp
import numpy as np

from pyptx import kernel, ptx, reg, smem, Tile
from pyptx.types import b32, bf16, f32, u32


BM, BN, BK = 64, 64, 16
NUM_WARPS = 4
THREADS = 32 * NUM_WARPS
WM = BM // NUM_WARPS         # 16
N_FRAG_N = BN // 8           # 8

A_STAGE_BYTES = BM * BK * 2  # 2048
B_STAGE_BYTES = BN * BK * 2  # 2048
STAGES = 2

A_SMEM_BASE = 0
B_SMEM_BASE = STAGES * A_STAGE_BYTES
SMEM_BYTES = STAGES * (A_STAGE_BYTES + B_STAGE_BYTES)


def build_gemm_pipelined(M: int, N: int, K: int, *, arch: str = "sm_80"):
    """Build the cp.async + SMEM-staged A100 bf16 GEMM kernel."""
    assert M % BM == 0
    assert N % BN == 0
    assert K % BK == 0
    n_iters = K // BK

    @kernel(
        in_specs=(
            Tile(M, K, bf16),
            Tile(N, K, bf16),
        ),
        out_specs=(Tile(M, N, f32),),
        grid=(N // BN, M // BM, 1),
        block=(THREADS, 1, 1),
        arch=arch,
        smem=SMEM_BYTES,
        extern_smem=True,
    )
    def gemm(A, B_T, D):
        pa, pb, pd = ptx.global_ptrs(A, B_T, D)
        smem_base = smem.base()

        m_base = reg.scalar(u32)
        ptx.inst.mov.u32(m_base, ptx.special.ctaid.y())
        ptx.inst.shl.b32(m_base, m_base, 6)  # * 64
        n_base = reg.scalar(u32)
        ptx.inst.mov.u32(n_base, ptx.special.ctaid.x())
        ptx.inst.shl.b32(n_base, n_base, 6)  # * 64

        tid = reg.scalar(u32)
        ptx.inst.mov.u32(tid, ptx.special.tid.x())
        warp_id = tid >> 5      # 0..3
        lane = tid & 31         # 0..31

        # ----- Per-thread cp.async load layout -----
        # Each thread loads exactly 16 bytes (= 8 bf16) per stage per matrix.
        # 128 threads × 16 bytes = 2048 bytes/stage = matches BM*BK*2 = BN*BK*2.
        # Mapping: thread t loads row t/2, cols (t%2)*8..(t%2)*8+7.
        load_row = tid >> 1                    # 0..63
        col_chunk = tid & 1
        col_start = col_chunk << 3             # 0 or 8
        load_smem_off = (load_row * BK + col_start) * 2  # bytes

        # ----- Accumulator: 8 m16n8 acc tiles × 4 f32 = 32 f32 regs/lane -----
        acc = reg.array(f32, N_FRAG_N * 4)
        zero = reg.scalar(f32, init=0.0)
        for i in range(N_FRAG_N * 4):
            ptx.inst.mov.f32(acc[i], zero)

        def issue_cp_async(s: int, k_idx_reg):
            """Per-thread cp.async for A and B at stage s, K base k_idx_reg."""
            a_smem_dst = smem_base + (A_SMEM_BASE + s * A_STAGE_BYTES) + load_smem_off
            b_smem_dst = smem_base + (B_SMEM_BASE + s * B_STAGE_BYTES) + load_smem_off
            a_global_off = ((m_base + load_row) * K + k_idx_reg + col_start) * 2
            b_global_off = ((n_base + load_row) * K + k_idx_reg + col_start) * 2
            ptx.cp.async_.cg(ptx.addr(a_smem_dst), ptx.addr(pa + a_global_off), 16)
            ptx.cp.async_.cg(ptx.addr(b_smem_dst), ptx.addr(pb + b_global_off), 16)

        # ----- Prologue: prime the pipeline -----
        k_zero = reg.scalar(u32, init=0)
        issue_cp_async(0, k_zero)
        ptx.cp.async_.commit_group()
        if n_iters > 1:
            k_one = reg.scalar(u32, init=BK)
            issue_cp_async(1, k_one)
            ptx.cp.async_.commit_group()

        # ----- m16n8k16 per-lane fragment indices (warp's M slice) -----
        gid = lane >> 2          # 0..7
        tig = lane & 3           # 0..3
        col_lo = tig << 1        # 0,2,4,6 (within each n-frag's 8 cols)
        # Warp's M slice base in SMEM A:
        warp_a_row_base = warp_id << 4   # warp_id * 16 (in SMEM rows)

        # ----- Hoisted fragment registers (reused every K-iter) -----
        # Allocating reg.array inside the K loop creates fresh PTX
        # registers per iter — for K=1024 (64 iters) that's 64*(4+16) =
        # 1280 b32 regs/lane, blowing past spill thresholds in ways that
        # compound subtly. Hoisting keeps the register set bounded.
        a_fr = reg.array(b32, 4)
        b_fr = reg.array(b32, N_FRAG_N * 2)

        # ----- Main K loop -----
        # Pipeline drain: in steady state we wait_group(STAGES-1), keeping
        # one prefetch in flight while the current stage is consumed. In
        # the last STAGES-1 iters no more prefetches are issued, so
        # `pending` decreases each iter and we must lower the threshold —
        # using the steady-state value for the tail would return
        # immediately without actually waiting for the data to land.
        for ki in range(n_iters):
            stage = ki & 1
            tail = max(0, ki - (n_iters - STAGES))
            wait_target = max(0, STAGES - 1 - tail)
            ptx.cp.async_.wait_group(wait_target)
            ptx.bar.sync(0)

            a_stage_base = smem_base + A_SMEM_BASE + stage * A_STAGE_BYTES
            b_stage_base = smem_base + B_SMEM_BASE + stage * B_STAGE_BYTES

            # ---- Load A fragment (4 b32 regs/lane) ----
            for i, (drow, dcol) in enumerate([(0, 0), (8, 0), (0, 8), (8, 8)]):
                row_in_smem = warp_a_row_base + gid + drow
                col_pair = col_lo + dcol
                a_off = (row_in_smem * BK + col_pair) * 2
                ptx.inst.ld.shared.b32(a_fr[i], ptx.addr(a_stage_base + a_off))

            # ---- Load B fragments (8 m16n8 frags, 2 b32 regs each) ----
            for nf in range(N_FRAG_N):
                row_in_smem = (nf << 3) + gid
                for i, dcol in enumerate([0, 8]):
                    col_pair = col_lo + dcol
                    b_off = (row_in_smem * BK + col_pair) * 2
                    ptx.inst.ld.shared.b32(b_fr[nf * 2 + i], ptx.addr(b_stage_base + b_off))

            # ---- Issue 8 mma.sync calls (1 m × 8 n) per warp ----
            for nf in range(N_FRAG_N):
                ptx.mma.sync(
                    shape=(16, 8, 16),
                    dtype_d=f32, dtype_a=bf16, dtype_b=bf16, dtype_c=f32,
                    d=[acc[nf*4], acc[nf*4+1], acc[nf*4+2], acc[nf*4+3]],
                    a=[a_fr[0], a_fr[1], a_fr[2], a_fr[3]],
                    b=[b_fr[nf*2], b_fr[nf*2+1]],
                    c=[acc[nf*4], acc[nf*4+1], acc[nf*4+2], acc[nf*4+3]],
                    a_layout="row", b_layout="col",
                )

            # ---- Sync before issuing the next prefetch ----
            # The prefetch writes to the SAME SMEM stage we just read
            # (STAGES=2 ring buffer). Without a bar.sync here, a fast
            # warp could issue its cp.async while a slow warp is still
            # finishing mma — the cp.async writes can land before the
            # slow warp has retired its ld.shared, corrupting in-flight
            # data delivered to mma. The CTA-wide bar.sync forces all
            # warps to complete mma (and the SMEM reads that fed it)
            # before any warp starts overwriting that SMEM stage.
            ptx.bar.sync(0)

            # ---- Prefetch ki+STAGES if there's K left ----
            if ki + STAGES < n_iters:
                k_next = reg.scalar(u32)
                ptx.inst.mov.u32(k_next, (ki + STAGES) * BK)
                next_stage = (ki + STAGES) & 1
                issue_cp_async(next_stage, k_next)
                ptx.cp.async_.commit_group()

        # ----- Epilogue: each lane stores its 8 m16n8 acc tiles to global D -----
        # m16n8 fragment per-lane: rows {gid, gid+8}, cols {2*tig, 2*tig+1}
        warp_m_global = m_base + (warp_id << 4)
        row_lo = warp_m_global + gid
        row_hi = row_lo + 8
        for nf in range(N_FRAG_N):
            n_global = n_base + (nf << 3)
            d_col_base = n_global + col_lo
            for i, (drow, dcol) in enumerate([(0, 0), (0, 1), (8, 0), (8, 1)]):
                row = row_lo if drow == 0 else row_hi
                col_elem = d_col_base + dcol
                elem_idx = row * N + col_elem
                byte_off = elem_idx * 4
                ptx.inst.st.global_.f32(ptx.addr(pd + byte_off), acc[nf * 4 + i])

        ptx.ret()

    return gemm


# ---------------------------------------------------------------------------
# Reference + test harness
# ---------------------------------------------------------------------------

def gemm_ref(A: jnp.ndarray, B_T: jnp.ndarray) -> jnp.ndarray:
    return jnp.einsum("mk,nk->mn", A.astype(jnp.float32), B_T.astype(jnp.float32))


def _run_jax_case(M: int, N: int, K: int) -> None:
    k = build_gemm_pipelined(M, N, K)
    rng = np.random.default_rng(M * 7919 + N * 31 + K)
    A = jnp.asarray(rng.standard_normal((M, K), dtype=np.float32) * 0.1, dtype=jnp.bfloat16)
    BT = jnp.asarray(rng.standard_normal((N, K), dtype=np.float32) * 0.1, dtype=jnp.bfloat16)

    @jax.jit
    def fn(A, BT):
        return k(A, BT)

    out = np.asarray(fn(A, BT))
    ref = np.asarray(gemm_ref(A, BT))
    diff = float(np.abs(out - ref).max())
    ok = bool(np.allclose(out, ref, atol=1e-2, rtol=1e-2))
    status = "OK  " if ok else "FAIL"
    print(f"[JAX  {status}] M={M:5d} N={N:5d} K={K:5d}  max_abs={diff:.3e}")


def _run_torch_case(M: int, N: int, K: int) -> None:
    import torch
    k = build_gemm_pipelined(M, N, K)
    rng = np.random.default_rng(M * 7919 + N * 31 + K)
    A = torch.tensor(rng.standard_normal((M, K), dtype=np.float32) * 0.1,
                     dtype=torch.bfloat16, device="cuda")
    BT = torch.tensor(rng.standard_normal((N, K), dtype=np.float32) * 0.1,
                      dtype=torch.bfloat16, device="cuda")
    out = k(A, BT)
    torch.cuda.synchronize()
    ref = (A.float() @ BT.float().T)
    diff = float((out - ref).abs().max())
    ok = bool(torch.allclose(out, ref, atol=1e-2, rtol=1e-2))
    status = "OK  " if ok else "FAIL"
    print(f"[Torch{status}] M={M:5d} N={N:5d} K={K:5d}  max_abs={diff:.3e}")


def main() -> None:
    _ = (jnp.ones((4,), dtype=jnp.float32) + 1).block_until_ready()
    for M, N, K in [
        (64, 64, 64),
        (64, 64, 256),
        (128, 128, 128),
        (256, 256, 256),
        (512, 512, 512),
        (1024, 1024, 1024),
        (2048, 2048, 2048),
        (4096, 4096, 4096),
    ]:
        _run_jax_case(M, N, K)
        _run_torch_case(M, N, K)


if __name__ == "__main__":
    main()