Tiles, Layouts, and TMA¶
Every kernel decorator opens with something like:
@kernel(
in_specs=(
Tile.wgmma_a(M, K, bf16, tile_m=BM, tile_k=tile_k),
Tile.wgmma_b(K, N, bf16, tile_k=tile_k, tile_n=BN),
),
out_specs=(Tile(M, N, f32, Layout.ROW),),
grid=(N // BN, M // BM, 1),
block=(128, 1, 1),
arch="sm_90a",
)
Every argument is load-bearing. This page is a reference for what each
piece controls, when to pick which, and how the Tile / Layout /
tma_box triple composes into a working TMA descriptor + matching
SMEM allocation. Read this before writing a new kernel with different
shapes from the examples.
The Three Roles A Tile Plays¶
A Tile in an @kernel spec does three things at once:
- Shape contract. Tells the framework what shape of JAX/PyTorch
tensor the kernel accepts (or outputs). Symbolic dims like
Tile("M", "K", bf16)are bound at call time from the actual tensor shape. - dtype. The element type — bf16, f32, b32, etc. Must match the runtime array dtype exactly.
- TMA descriptor specification. For input tiles, the
Layoutplustma_boxdetermine the TMA descriptor the framework builds and hands to the kernel at launch time (viatensor.tma_desc()inside the kernel body).
The output spec only uses (1) and (2) — outputs are written via plain
st.global.*, not TMA (the maintained examples use TMA for inputs
only).
Layout: What Each Value Means¶
class Layout(Enum):
ROW = "row" # row-major (C order), no swizzle
COL = "col" # column-major (Fortran order)
TMA_128B = "tma_128b" # TMA 128-byte swizzle
TMA_64B = "tma_64b" # TMA 64-byte swizzle
TMA_32B = "tma_32b" # TMA 32-byte swizzle
INTERLEAVED = "interleaved" # CUTLASS interleaved
The three values you'll use 95% of the time:
Layout.ROW— the default. Use for outputs. Use for inputs that don't feed a tensor-core instruction (e.g. element-wise kernels like RMS Norm, SwiGLU). No swizzle, plain row-major DRAM walk.Layout.TMA_128B— the canonical WGMMA/tcgen05 input swizzle. Use for inputs that feedwgmma.mma_asyncortcgen05.mma. The TMA engine writes data into SMEM in a permuted pattern that WGMMA reads back as logical row-major. Required for any bf16/f16 operand with row width ≥ 128 bytes (N ≥ 64 for bf16).Layout.TMA_64B/TMA_32B— smaller swizzle variants used when the row is narrower than 128 bytes. Usually picked automatically byTile.wgmma_a/Tile.wgmma_b— you only reach for these manually when you're building a kernel with an unusual operand shape.
Layout.COL and Layout.INTERLEAVED exist for completeness; the
maintained examples don't use them.
The Swizzle Matching Rule¶
The single most important fact about Layout.TMA_*B:
The TMA swizzle and the SMEM swizzle must be the same.
TMA writes into SMEM using one permutation. WGMMA reads from SMEM using another. These two permutations compose to identity — giving you logical row-major order back — only if they're the same swizzle family. Mismatched swizzles produce garbage output as soon as the result depends on the K-pairing across slices.
Concretely:
# In the @kernel spec:
Tile.wgmma_a(M, K, bf16, tile_m=BM, tile_k=tile_k) # → Layout.TMA_128B
# In the kernel body:
sA = smem.wgmma_tile(bf16, (BM, tile_k), major="K") # also 128B
Tile.wgmma_a picks Layout.TMA_128B when the row is ≥ 128 bytes
(which covers most real cases); smem.wgmma_tile picks its swizzle
based on the same tile shape. They match by construction. If you
build the TMA and SMEM sides by hand without the wgmma_* shortcuts,
you are responsible for matching them — they don't check each
other.
When To Use Tile.wgmma_a vs Plain Tile¶
Rule of thumb:
Tile.wgmma_a(M, K, bf16, tile_m=BM, tile_k=tile_k): for any input that feedswgmma.mma_asyncortcgen05.mma. This picks the rightLayout.TMA_*B, setstma_box=(tile_m, tile_k), and the SMEM swizzle you'll allocate viasmem.wgmma_tile(...)will compose correctly. Usewgmma_b(K, N, bf16, tile_k, tile_n)for the B operand.Tile(M, N, f32, Layout.ROW): for outputs and for non-MMA inputs (element-wise kernels). No TMA, plain row-major access viald.global.*/st.global.*.Tile(M, K, bf16, Layout.TMA_128B, tma_box=(BM, BK)): the explicit form, used by the Blackwell flagship kernel. Same result asTile.wgmma_a(M, K, bf16, tile_m=BM, tile_k=BK)but makes the layout choice visible. Use when you want to set the layout deliberately (e.g. Blackwell, where you may want different swizzle patterns than the auto-picked WGMMA one).
tma_box: The Box Shape Per TMA Load¶
The TMA descriptor built by the framework knows two shapes:
- The tensor shape — the full tensor in DRAM (M × K, say).
- The box shape — how much one TMA load brings in per issue.
tma_box=(BM, BK) means: each cp.async.bulk.tensor_2d call with
this descriptor transfers a box of BM × BK elements.
@kernel(
in_specs=(
Tile(M, K, bf16, Layout.TMA_128B, tma_box=(BM, BK)),
...
),
)
def kernel(A, B, D):
sA = smem.alloc(..., (BM, BK))
# Each TMA load transfers exactly BM * BK * sizeof(bf16) bytes.
ptx.cp.async_.bulk.tensor_2d(
dst=sA[0], src=A.tma_desc(),
coord=(k_off, m_row_base), mbar=...
)
The coordinates (k_off, m_row_base) are the top-left corner of
the box in the source tensor's coordinate system. The TMA engine then
transfers (tma_box[0], tma_box[1]) elements starting from that
corner.
Picking the box:
- Match your SMEM tile. If your SMEM allocation is
(BM, BK), yourtma_boxshould be(BM, BK). Mismatched sizes either waste SMEM or read past the end. - Align to 16-byte boundaries. TMA requires the innermost box
dim to span a multiple of 16 bytes (
BK * sizeof(dtype) % 16 == 0). For bf16 (2 bytes/elem), that meansBK % 8 == 0. - Smaller isn't always better. Bigger boxes amortize TMA issue
overhead and get better DRAM utilization, up to the point where
SMEM pressure forces smaller stages. The Blackwell GEMM uses
(BM=128, BK=64)which is 16 KB per A tile — comfortable.
When tma_box is None (the default), the TMA descriptor uses the
full tensor shape — one TMA load brings everything. That's what you
want for kernels where the whole tensor fits in SMEM, rare in
practice beyond toy shapes.
2D vs 3D TMA¶
The repo ships two TMA descriptor ranks:
tma_rank=2(default): rank-2 descriptor, called viaptx.cp.async_.bulk.tensor_2d(...). Matches the "2D box from a 2D tensor" mental model. Used by every maintained example.tma_rank=3: rank-3 descriptor with an explicit minor axis. Used by the high-perf Hopper GEMM example (gemm_highperf_hopper.py). Subtle win: the 3D form lets TMA stage wider blocks with less padding overhead in the epilogue.
For grouped GEMM (G problems), you don't need tma_rank=3 — the
per-group offset math lives in the kernel body (group * M + ...)
and the descriptor stays 2D. 3D TMA is a Hopper performance
optimization, not a batching mechanism.
Symbolic vs Concrete Dimensions¶
Shape dims can be ints or strings:
Tile("M", "K", bf16, Layout.ROW) # symbolic, bound at call time
Tile(64, 16, bf16, Layout.ROW) # concrete, fixed at decoration time
Symbolic dims mean one @kernel can handle different input sizes
(one traced program, one cubin cache entry per resolved shape). The
pyptx model is shape-specialized — each concrete combination of
bound symbolic dims produces its own specialized PTX. A call with
M=2048, K=8192 and a call with M=4096, K=8192 each get their own
trace and their own cubin.
The trade-off:
- Concrete dims (
Tile(64, 16, bf16)) generate one cubin, small cache, simple debug story. - Symbolic dims (
Tile("M", "K", bf16)) generate one cubin per observed shape, larger cache, more traces over the program's life.
For performance-critical kernels, it's common to make the tile
sizes concrete (BM, BN, BK are ints baked in) and the tensor
sizes symbolic (M, N, K are strings). That's what the build_gemm
pattern does: build_gemm(M=2048, N=8192, K=4096) concretizes every
shape inside the decorator.
Output Tiles¶
Outputs are simpler:
No TMA (outputs go via st.global.*), so no swizzle, no tma_box.
Layout.ROW is the only sensible choice unless you have a specific
reason to write column-major (you usually don't — downstream JAX /
PyTorch consumers expect row-major).
The output shape contract still matters: if the kernel returns an
f32 tensor and you declare bf16 here, the launch will fail the
shape check.
Full Example, Piece By Piece¶
The Hopper GEMM decorator:
@kernel(
in_specs=(
Tile.wgmma_a(M, K, bf16, tile_m=64, tile_k=16),
Tile.wgmma_b(K, N, bf16, tile_k=16, tile_n=64),
),
out_specs=(Tile(M, N, f32),),
grid=(N // 64, M // 64, 1),
block=(128, 1, 1),
arch="sm_90a",
)
Reading it:
- A is m×k bf16, WGMMA-ready. Each TMA load brings a
64 × 16box. SMEM allocation will be(64, 16)with 128B swizzle. - B is k×n bf16, WGMMA-ready. Each TMA load brings a
16 × 64box. Same swizzle matching. - D is m×n f32, plain row-major.
- Grid tiles the output: one CTA per
(64, 64)tile of D. - Block has 128 threads = one Hopper warpgroup (required for WGMMA).
- Arch selects the ISA —
sm_90afor Hopper,sm_100afor Blackwell.
Checklist For A New Kernel's Tile Specs¶
Before you write the kernel body, answer each question:
- What are the tensor shapes? Write them as
Tile(..., dtype, layout). Pick symbolic vs concrete per dim. - Do any inputs feed WGMMA or tcgen05? → use
Tile.wgmma_a/Tile.wgmma_b, and allocate SMEM viasmem.wgmma_tile(..., major=...). OtherwiseLayout.ROW+smem.alloc. - What's the per-load TMA box? Match your SMEM tile. Check that the innermost dim × dtype size is a multiple of 16 bytes.
- Output dtype match? Kernel writes
f32, output spec saysf32. No implicit casts. - Grid shape? How many CTAs cover your output. Usually
(N // BN, M // BM, ...). - Block shape? 128 threads for a single warpgroup (WGMMA). 256 for two warpgroups (rare). 128 for Blackwell 1-SM. Other counts for non-MMA kernels.
What To Read Next¶
- Mbarriers and async sync — the companion page that
explains what
mbar=...andmbarrier.waitactually do. - Fragment layouts — how WGMMA and
tcgen05.ldscatter the result across lanes in registers. - Shared Memory — the SMEM side of the TMA → SMEM → WGMMA pipeline.