Ampere / Gemm Highperf Ampere¶
View on GitHub
examples/ampere/gemm_highperf_ampere.py
Overview¶
High-perf A100 (sm_80) bf16 GEMM — ldmatrix + cp.async + 3-stage SMEM.
Follows the CUTLASS SM80 production design:
- 128 × 128 × 32 CTA tile, 4 warps/CTA arranged 2×2 in (M, N), each warp owning a 64 × 64 output sub-tile.
- Per warp per K-iter: 64
mma.sync.m16n8k16(4 M-frags × 8 N-frags × 2 K-blocks). 256 mma per CTA per K-iter. - ldmatrix.sync.aligned.m8n8.x4.shared.b16 loads each 16×16 A fragment in one warp-collective instruction (4 b32/lane = the full A m16n8k16 fragment).
- ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 loads each 16×16 worth
of B (= 2 m16n8k16 B fragments) —
.transswaps SMEM-rows (N) with fragment-rows (K), matching mma's row.col B layout for SMEM stored asB_T(N, K) row-major. - 3-stage cp.async ring buffer (CUTLASS kStages=3 default). Each thread
issues 16-byte cp.async.cg per matrix per stage; with stages=3 the
wait_group(STAGES-2)pattern absorbs warp skew naturally — nobar.syncbetween mma and prefetch is required.
References:
- CUTLASS include/cutlass/gemm/device/default_gemm_configuration.h
(SM80 TensorOp: kStages=3, ThreadblockShape 128×128, WarpShape 64×64,
InstructionShape 16×8×16).
- CUTLASS examples/cute/tutorial/sgemm_sm80.cu
(SM75_U32x4_LDSM_N for A, .x4.trans for B, K_PIPE_MAX=3,
SM80_CP_ASYNC_CACHEALWAYS<uint128_t>).
Inputs:
A: (M, K) bf16 row-major, M % 128 == 0, K % 32 == 0
B_T: (N, K) bf16 row-major, N % 128 == 0 (transposed B; mma row.col
wants K-contiguous on B's leading axis)
D: (M, N) f32 row-major output
Grid: (N // 128, M // 128, 1) Block: (128, 1, 1)
Source¶
Full source
"""High-perf A100 (sm_80) bf16 GEMM — ldmatrix + cp.async + 3-stage SMEM.
Follows the CUTLASS SM80 production design:
- **128 × 128 × 32 CTA tile**, 4 warps/CTA arranged 2×2 in (M, N), each warp
owning a 64 × 64 output sub-tile.
- Per warp per K-iter: **64 ``mma.sync.m16n8k16``** (4 M-frags × 8 N-frags × 2
K-blocks). 256 mma per CTA per K-iter.
- **ldmatrix.sync.aligned.m8n8.x4.shared.b16** loads each 16×16 A fragment in
one warp-collective instruction (4 b32/lane = the full A m16n8k16 fragment).
- **ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16** loads each 16×16 worth
of B (= 2 m16n8k16 B fragments) — ``.trans`` swaps SMEM-rows (N) with
fragment-rows (K), matching mma's row.col B layout for SMEM stored as
``B_T`` (N, K) row-major.
- **3-stage cp.async ring buffer** (CUTLASS kStages=3 default). Each thread
issues 16-byte cp.async.cg per matrix per stage; with stages=3 the
``wait_group(STAGES-2)`` pattern absorbs warp skew naturally — no
``bar.sync`` between mma and prefetch is required.
References:
- CUTLASS ``include/cutlass/gemm/device/default_gemm_configuration.h``
(SM80 TensorOp: kStages=3, ThreadblockShape 128×128, WarpShape 64×64,
InstructionShape 16×8×16).
- CUTLASS ``examples/cute/tutorial/sgemm_sm80.cu``
(``SM75_U32x4_LDSM_N`` for A, ``.x4.trans`` for B, K_PIPE_MAX=3,
``SM80_CP_ASYNC_CACHEALWAYS<uint128_t>``).
Inputs:
A: (M, K) bf16 row-major, ``M % 128 == 0``, ``K % 32 == 0``
B_T: (N, K) bf16 row-major, ``N % 128 == 0`` (transposed B; mma row.col
wants K-contiguous on B's leading axis)
D: (M, N) f32 row-major output
Grid: (N // 128, M // 128, 1)
Block: (128, 1, 1)
"""
from __future__ import annotations
import os
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
import jax
import jax.numpy as jnp
import numpy as np
from pyptx import kernel, ptx, reg, smem, Tile
from pyptx.types import b32, bf16, f32, u32
BM, BN, BK = 128, 128, 32
NUM_WARPS = 4
THREADS = 32 * NUM_WARPS # 128
WARPS_M, WARPS_N = 2, 2
WM, WN = BM // WARPS_M, BN // WARPS_N # 64, 64
N_FRAG_M = WM // 16 # 4
N_FRAG_N = WN // 8 # 8
N_FRAG_K = BK // 16 # 2
STAGES = 4 # 4 SMEM buffers, 3 in flight (CUTLASS pattern)
A_STAGE_BYTES = BM * BK * 2 # 8192
B_STAGE_BYTES = BN * BK * 2 # 8192
A_SMEM_BASE = 0
B_SMEM_BASE = STAGES * A_STAGE_BYTES
SMEM_BYTES = STAGES * (A_STAGE_BYTES + B_STAGE_BYTES) # 64 KB (extern_smem opt-in)
def build_gemm_highperf(M: int, N: int, K: int, *, arch: str = "sm_80"):
"""Build the production A100 bf16 GEMM kernel."""
assert M % BM == 0, f"M={M} must be divisible by {BM}"
assert N % BN == 0, f"N={N} must be divisible by {BN}"
assert K % BK == 0, f"K={K} must be divisible by {BK}"
n_iters = K // BK
@kernel(
in_specs=(
Tile(M, K, bf16),
Tile(N, K, bf16),
),
out_specs=(Tile(M, N, f32),),
grid=(N // BN, M // BM, 1),
block=(THREADS, 1, 1),
arch=arch,
smem=SMEM_BYTES,
extern_smem=True,
)
def gemm(A, B_T, D):
pa, pb, pd = ptx.global_ptrs(A, B_T, D)
smem_base = smem.base()
# CTA bases.
m_base = reg.scalar(u32)
ptx.inst.mov.u32(m_base, ptx.special.ctaid.y())
ptx.inst.shl.b32(m_base, m_base, 7) # * 128
n_base = reg.scalar(u32)
ptx.inst.mov.u32(n_base, ptx.special.ctaid.x())
ptx.inst.shl.b32(n_base, n_base, 7) # * 128
# Thread + warp identity.
tid = reg.scalar(u32)
ptx.inst.mov.u32(tid, ptx.special.tid.x())
warp_id = tid >> 5 # 0..3
lane = tid & 31 # 0..31
warp_m = warp_id >> 1 # 0..1 (top 2 warps M=0, bottom M=1)
warp_n = warp_id & 1 # 0..1
warp_m_smem = warp_m << 6 # 0 or 64 (warp's M base in SMEM rows)
warp_n_smem = warp_n << 6 # 0 or 64
# ----- cp.async layout -----
# Each matrix (A or B) is 128 rows × 32 cols of bf16 = 8 KB per stage.
# 128 threads × 16 B/thread = 2 KB per pass; 4 passes per stage per matrix.
# Map 4 threads per row (each thread 8 bf16 = 16 B), 32 rows per pass.
load_row_in_pass = tid >> 2 # 0..31 (1 row per 4 threads)
load_col_chunk = tid & 3 # 0..3 (which 8-bf16 chunk in K=32)
load_col_off = load_col_chunk << 3 # 0,8,16,24 bf16
load_smem_col_bytes = load_col_off * 2 # bytes within a row
# SMEM XOR swizzle (CUTLASS / MatmulTutorial v9 / v15 pattern).
# For 64-byte SMEM rows (BK=32 bf16 = 4 atoms × 16 B), swizzle the
# 4 atom slots per row using 2 bits of the row index. The exact
# formula matches MatmulTutorial v15 (line 38, 76, etc.):
# col_eff = col ^ ((row & 3) << 3) (col in bf16, 4-row period)
# which is equivalent to:
# atom_eff = atom ^ (row & 3) (atom in 16 B, 4-row period)
# Applied identically to cp.async writes, ldmatrix reads, and (if
# used) any epilogue SMEM round-trip — the formula MUST match on
# all paths or data lands in the wrong bank.
# MatmulTutorial v9 reports +13% from this on their fp16 kernel.
USE_SWIZZLE = True
def smem_swizzle_byte_off(row, atom_idx, sub_atom_bytes=0):
if USE_SWIZZLE:
swizzled_atom = atom_idx ^ (row & 3)
return row * (BK * 2) + swizzled_atom * 16 + sub_atom_bytes
return row * (BK * 2) + atom_idx * 16 + sub_atom_bytes
# ----- Precompute per-thread cp.async offsets (loop-invariant) -----
# cp.async write SMEM offset and global row-base address are
# per-thread constants. Hoisting these out of the K-loop turns
# each cp.async into 1-2 ops (add + store) instead of ~8.
cp_smem_off = [] # per-pass: SMEM byte offset (with swizzle)
cp_a_row_base_x_K = [] # per-pass: (m_base + row_off) * K, in elements
cp_b_row_base_x_K = [] # per-pass: (n_base + row_off) * K, in elements
for pass_id in range(4):
row_off = load_row_in_pass + (pass_id << 5)
cp_smem_off.append(smem_swizzle_byte_off(row_off, load_col_chunk))
cp_a_row_base_x_K.append((m_base + row_off) * K)
cp_b_row_base_x_K.append((n_base + row_off) * K)
def issue_cp_async(stage: int, k_idx_reg):
"""Issue cp.async for both A and B at stage, K = k_idx_reg.
32 rows per pass × 4 passes = 128 rows of each matrix.
Each thread writes one 16-byte atom; swizzled SMEM dst for ldmatrix.
"""
a_stage = smem_base + (A_SMEM_BASE + stage * A_STAGE_BYTES)
b_stage = smem_base + (B_SMEM_BASE + stage * B_STAGE_BYTES)
for pass_id in range(4):
a_smem_dst = a_stage + cp_smem_off[pass_id]
b_smem_dst = b_stage + cp_smem_off[pass_id]
a_glob_off = (cp_a_row_base_x_K[pass_id] + k_idx_reg + load_col_off) * 2
b_glob_off = (cp_b_row_base_x_K[pass_id] + k_idx_reg + load_col_off) * 2
ptx.cp.async_.cg(ptx.addr(a_smem_dst), ptx.addr(pa + a_glob_off), 16)
ptx.cp.async_.cg(ptx.addr(b_smem_dst), ptx.addr(pb + b_glob_off), 16)
# ----- Accumulator: 4 M-frag × 8 N-frag × 4 f32 = 128 f32/lane -----
ACC_TILES = N_FRAG_M * N_FRAG_N
acc = reg.array(f32, ACC_TILES * 4)
zero = reg.scalar(f32, init=0.0)
for i in range(ACC_TILES * 4):
ptx.inst.mov.f32(acc[i], zero)
# ----- Prologue: prime STAGES-1 stages -----
for s in range(STAGES - 1):
k_s = reg.scalar(u32, init=s * BK)
issue_cp_async(s, k_s)
ptx.cp.async_.commit_group()
# ----- ldmatrix lane-to-address math -----
# ldmatrix.m8n8.x4 (no trans) — for A.
# Loads 4 8x8 b16 matrices arranged 2x2 in (rows, cols):
# mat 0: rows 0..7, cols 0..7 (lanes 0..7 provide row pointers)
# mat 1: rows 8..15, cols 0..7 (lanes 8..15 provide row pointers)
# mat 2: rows 0..7, cols 8..15 (lanes 16..23 provide row pointers)
# mat 3: rows 8..15, cols 8..15 (lanes 24..31 provide row pointers)
# After load, lane t holds {mat0[t/4, 2*(t%4)..2*(t%4)+1], ...} which
# exactly matches the m16n8k16 A fragment per-thread layout.
#
# Per-lane address: row = t & 15, col_off_in_frag = (t >> 4) << 3.
a_ldsm_row = lane & 15 # 0..15
a_ldsm_col_in_frag = (lane >> 4) << 3 # 0 or 8 (bf16)
# ldmatrix.m8n8.x4 (no .trans) — for B.
# SMEM stores B_T = (N, K) row-major, so SMEM[N, K] gives the bf16
# value at B[K, N]. For mma's row.col B fragment, lane t holds
# B[K=2*(t%4)+{0,1,8,9}, N=t/4] — i.e. K-stride at fixed N. That's
# exactly what ldmatrix (no trans) produces from B_T SMEM:
# reg[i] = mat[i][row=t/4, cols=2*(t%4)..2*(t%4)+1]
# = B_T[N0 + t/4, K0 + 2*(t%4)..2*(t%4)+1]
# = B[K0 + 2*(t%4)..2*(t%4)+1, N0 + t/4] (matches mma).
# The 4 mats cover 16K × 16N (= 2 m16n8k16 N-frags) per call:
# mat 0: SMEM rows N0..N0+7, K=K0..K0+7 (K-low, N-frag 0)
# mat 1: SMEM rows N0..N0+7, K=K0+8..K0+15 (K-high, N-frag 0)
# mat 2: SMEM rows N0+8..N0+15, K=K0..K0+7 (K-low, N-frag 1)
# mat 3: SMEM rows N0+8..N0+15, K=K0+8..K0+15 (K-high, N-frag 1)
#
# Per-lane address: N-row = ((t >> 4) << 3) + (t & 7),
# K-col = ((t >> 3) & 1) << 3.
b_ldsm_row_in_pair = ((lane >> 4) << 3) + (lane & 7) # 0..15
b_ldsm_col_in_block = ((lane >> 3) & 1) << 3 # 0 or 8 (bf16)
# ----- Double-buffered fragment registers -----
# Two register banks alternate per K-block within an iter, AND across
# iter boundaries (CUTLASS-style ping-pong). Bank 0 holds (ki, kb=0),
# bank 1 holds (ki, kb=1). At end of iter ki we pre-load (ki+1, kb=0)
# into bank 0 — overlapping with bank-1 mma — so the first mma of
# ki+1 has zero ldmatrix latency.
a_fr_buf = [reg.array(b32, N_FRAG_M * 4) for _ in range(2)]
b_fr_buf = [reg.array(b32, (N_FRAG_N // 2) * 4) for _ in range(2)]
a_atom_lane_part = lane >> 4 # 0 or 1
b_atom_lane_part = (lane >> 3) & 1 # 0 or 1
# ----- Precomputed per-lane SMEM byte offsets (loop-invariant) -----
# Each ldmatrix's address = current_stage_base + precomputed_offset.
# Hoisting the row*64 + atom*16 (+ swizzle XOR) computation out of
# the K-loop turns each inner-loop ldmatrix from 5-7 PTX address
# ops into a single add. With 16 ldmatrix per K-iter × 128 K-iters
# at K=4096, this saves ~10K PTX ops per warp.
a_ldsm_off = [] # [mf][kb] → Reg (byte offset within stage)
for mf in range(N_FRAG_M):
row_for_mf = warp_m_smem + (mf << 4) + a_ldsm_row
row_swizz = (row_for_mf & 3) if USE_SWIZZLE else 0
row_x_64 = row_for_mf * (BK * 2)
row_off = []
for kb_ in range(N_FRAG_K):
atom = (kb_ << 1) + a_atom_lane_part
atom_eff = (atom ^ row_swizz) if USE_SWIZZLE else atom
row_off.append(row_x_64 + atom_eff * 16)
a_ldsm_off.append(row_off)
b_ldsm_off = [] # [np_][kb] → Reg
for np_ in range(N_FRAG_N // 2):
row_for_np = warp_n_smem + (np_ << 4) + b_ldsm_row_in_pair
row_swizz = (row_for_np & 3) if USE_SWIZZLE else 0
row_x_64 = row_for_np * (BK * 2)
row_off = []
for kb_ in range(N_FRAG_K):
atom = (kb_ << 1) + b_atom_lane_part
atom_eff = (atom ^ row_swizz) if USE_SWIZZLE else atom
row_off.append(row_x_64 + atom_eff * 16)
b_ldsm_off.append(row_off)
def load_a_kb(stage_base, kb_, dst_buf):
"""Issue 4 ldmatrix.x4 calls into dst_buf for K-block kb_."""
for mf in range(N_FRAG_M):
addr = stage_base + a_ldsm_off[mf][kb_]
base_idx = mf * 4
ptx.ldmatrix(
dst=[dst_buf[base_idx + 0], dst_buf[base_idx + 1],
dst_buf[base_idx + 2], dst_buf[base_idx + 3]],
src=ptx.addr(addr),
layout="m8n8.x4",
)
def load_b_kb(stage_base, kb_, dst_buf):
"""Issue 4 ldmatrix.x4 calls into dst_buf for K-block kb_."""
for np_ in range(N_FRAG_N // 2):
addr = stage_base + b_ldsm_off[np_][kb_]
base_idx = np_ * 4
ptx.ldmatrix(
dst=[dst_buf[base_idx + 0], dst_buf[base_idx + 1],
dst_buf[base_idx + 2], dst_buf[base_idx + 3]],
src=ptx.addr(addr),
layout="m8n8.x4",
)
def mma_buf(a_buf, b_buf):
"""Issue 32 mma.sync (4 M × 8 N) using one register bank pair.
Serpentine N-order on odd M-rows: ``n = (mf & 1) ? N-1-nf : nf``.
Adjacent mmas share one operand (better register reuse, fewer
ldmatrix-induced stalls). MatmulTutorial v13 reports +7% from this.
"""
for mf in range(N_FRAG_M):
a_base = mf * 4
a_regs = [a_buf[a_base + i] for i in range(4)]
for nf_iter in range(N_FRAG_N):
nf = (N_FRAG_N - 1 - nf_iter) if (mf & 1) else nf_iter
np_ = nf >> 1
nf_in_pair = nf & 1
b_base = np_ * 4
b_lo = b_base + (nf_in_pair << 1) + 0
b_hi = b_base + (nf_in_pair << 1) + 1
acc_base = (mf * N_FRAG_N + nf) * 4
ptx.mma.sync(
shape=(16, 8, 16),
dtype_d=f32, dtype_a=bf16, dtype_b=bf16, dtype_c=f32,
d=[acc[acc_base + i] for i in range(4)],
a=a_regs,
b=[b_buf[b_lo], b_buf[b_hi]],
c=[acc[acc_base + i] for i in range(4)],
a_layout="row", b_layout="col",
)
# ----- Pre-load (iter 0, kb=0) into bank 0 -----
# Wait for stage 0 to land first.
wait0_target = max(0, STAGES - 2)
ptx.cp.async_.wait_group(wait0_target)
ptx.bar.sync(0)
a_stage0 = smem_base + A_SMEM_BASE + 0
b_stage0 = smem_base + B_SMEM_BASE + 0
load_a_kb(a_stage0, 0, a_fr_buf[0])
load_b_kb(b_stage0, 0, b_fr_buf[0])
# ----- Main K loop with cross-iter register double-buffering -----
for ki in range(n_iters):
stage = ki % STAGES
a_stage_base = smem_base + A_SMEM_BASE + stage * A_STAGE_BYTES
b_stage_base = smem_base + B_SMEM_BASE + stage * B_STAGE_BYTES
# bank 0 already holds (ki, kb=0). Pre-load (ki, kb=1) into bank 1.
load_a_kb(a_stage_base, 1, a_fr_buf[1])
load_b_kb(b_stage_base, 1, b_fr_buf[1])
# mma using bank 0 (overlaps with bank-1 ldmatrix above).
mma_buf(a_fr_buf[0], b_fr_buf[0])
# Issue prefetch for ki + STAGES - 1 into stage (ki+3) % 4.
if ki + (STAGES - 1) < n_iters:
k_next = reg.scalar(u32)
ptx.inst.mov.u32(k_next, (ki + STAGES - 1) * BK)
next_prefetch_stage = (ki + STAGES - 1) % STAGES
issue_cp_async(next_prefetch_stage, k_next)
ptx.cp.async_.commit_group()
# Wait for next iter's stage + pre-load (ki+1, kb=0) into bank 0.
if ki + 1 < n_iters:
tail = max(0, (ki + 1) - (n_iters - STAGES))
wait_target = max(0, (STAGES - 2) - tail)
ptx.cp.async_.wait_group(wait_target)
ptx.bar.sync(0)
next_stage = (ki + 1) % STAGES
a_next = smem_base + A_SMEM_BASE + next_stage * A_STAGE_BYTES
b_next = smem_base + B_SMEM_BASE + next_stage * B_STAGE_BYTES
load_a_kb(a_next, 0, a_fr_buf[0])
load_b_kb(b_next, 0, b_fr_buf[0])
# mma using bank 1 (overlaps with the bank-0 cross-iter pre-load).
mma_buf(a_fr_buf[1], b_fr_buf[1])
# ----- Epilogue: write 4 M × 8 N × 4 f32 acc tiles to global D -----
# m16n8 fragment per-lane: rows {gid, gid+8} × cols {2*tig, 2*tig+1}.
gid = lane >> 2 # 0..7
tig = lane & 3 # 0..3
col_lo = tig << 1 # 0,2,4,6
warp_m_global = m_base + warp_m_smem # CTA M + warp M
warp_n_global = n_base + warp_n_smem # CTA N + warp N
for mf in range(N_FRAG_M):
m_frag_global = warp_m_global + (mf << 4)
row_lo = m_frag_global + gid
row_hi = row_lo + 8
for nf in range(N_FRAG_N):
n_frag_global = warp_n_global + (nf << 3)
d_col_base = n_frag_global + col_lo
acc_base = (mf * N_FRAG_N + nf) * 4
for i, (drow, dcol) in enumerate([(0, 0), (0, 1), (8, 0), (8, 1)]):
row = row_lo if drow == 0 else row_hi
col_elem = d_col_base + dcol
elem_idx = row * N + col_elem
byte_off = elem_idx * 4
ptx.inst.st.global_.f32(ptx.addr(pd + byte_off), acc[acc_base + i])
ptx.ret()
return gemm
# ---------------------------------------------------------------------------
# Reference + test harness
# ---------------------------------------------------------------------------
def gemm_ref(A: jnp.ndarray, B_T: jnp.ndarray) -> jnp.ndarray:
return jnp.einsum("mk,nk->mn", A.astype(jnp.float32), B_T.astype(jnp.float32))
def _run_jax_case(M: int, N: int, K: int) -> None:
k = build_gemm_highperf(M, N, K)
rng = np.random.default_rng(M * 7919 + N * 31 + K)
A = jnp.asarray(rng.standard_normal((M, K), dtype=np.float32) * 0.1, dtype=jnp.bfloat16)
BT = jnp.asarray(rng.standard_normal((N, K), dtype=np.float32) * 0.1, dtype=jnp.bfloat16)
@jax.jit
def fn(A, BT):
return k(A, BT)
out = np.asarray(fn(A, BT))
ref = np.asarray(gemm_ref(A, BT))
diff = float(np.abs(out - ref).max())
ok = bool(np.allclose(out, ref, atol=1e-2, rtol=1e-2))
status = "OK " if ok else "FAIL"
print(f"[JAX {status}] M={M:5d} N={N:5d} K={K:5d} max_abs={diff:.3e}")
def _run_torch_case(M: int, N: int, K: int) -> None:
import torch
k = build_gemm_highperf(M, N, K)
rng = np.random.default_rng(M * 7919 + N * 31 + K)
A = torch.tensor(rng.standard_normal((M, K), dtype=np.float32) * 0.1,
dtype=torch.bfloat16, device="cuda")
BT = torch.tensor(rng.standard_normal((N, K), dtype=np.float32) * 0.1,
dtype=torch.bfloat16, device="cuda")
out = k(A, BT)
torch.cuda.synchronize()
ref = (A.float() @ BT.float().T)
diff = float((out - ref).abs().max())
ok = bool(torch.allclose(out, ref, atol=1e-2, rtol=1e-2))
status = "OK " if ok else "FAIL"
print(f"[Torch{status}] M={M:5d} N={N:5d} K={K:5d} max_abs={diff:.3e}")
def main() -> None:
_ = (jnp.ones((4,), dtype=jnp.float32) + 1).block_until_ready()
for M, N, K in [
(128, 128, 32), # single CTA, single K-iter
(128, 128, 64), # single CTA, 2 K-iters
(128, 128, 256), # single CTA, 8 K-iters
(256, 256, 256), # 2x2 CTAs
(512, 512, 512),
(1024, 1024, 1024),
(2048, 2048, 2048),
(4096, 4096, 4096),
]:
_run_jax_case(M, N, K)
_run_torch_case(M, N, K)
if __name__ == "__main__":
main()