Skip to content

Blackwell / Tcgen05 Roundtrip

View on GitHub examples/blackwell/tcgen05_roundtrip.py

Overview

Blackwell tcgen05 TMEM roundtrip diagnostic.

Write a known per-thread pattern to TMEM with tcgen05.st, then load it back with tcgen05.ld and store to GMEM. This isolates TMEM addressing and epilogue mapping from UMMA correctness.

Source

Full source
"""Blackwell tcgen05 TMEM roundtrip diagnostic.

Write a known per-thread pattern to TMEM with tcgen05.st, then load it back
with tcgen05.ld and store to GMEM. This isolates TMEM addressing and epilogue
mapping from UMMA correctness.
"""
from __future__ import annotations

import os

os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")

import torch

from pyptx import Tile, kernel, ptx, reg, smem
from pyptx.specs import Layout
from pyptx.types import b32, f32, pred, u32


ROWS = 32
COLS = 64
TMEM_SLOT_OFF = 0
SMEM_BYTES = 16


def build(*, arch: str = "sm_100a"):
    @kernel(
        in_specs=(Tile(1, 1, f32, Layout.ROW),),
        out_specs=(Tile(ROWS, COLS, f32, Layout.ROW),),
        grid=(1, 1, 1),
        block=(128, 1, 1),
        arch=arch,
        smem=SMEM_BYTES,
        extern_smem=True,
    )
    def k(_x, O):
        base = smem.base()
        tmem_slot = base + TMEM_SLOT_OFF

        tid = reg.scalar(u32)
        lane = reg.scalar(u32)
        alloc_warp = reg.scalar(pred)
        active_lane = reg.scalar(pred)
        ptx.inst.mov.u32(tid, ptx.special.tid.x())
        ptx.inst.and_.b32(lane, tid, 31)
        ptx.inst.setp.lt.u32(alloc_warp, tid, 32)
        ptx.inst.setp.lt.u32(active_lane, tid, ROWS)

        with ptx.if_(alloc_warp):
            ptx.tcgen05.alloc(tmem_slot, 512)
        ptx.bar.sync(0)

        tmem_base = smem.load(b32, ptx.addr(tmem_slot))
        tmem_addr = tmem_base + ((lane << 16) & 0x1F0000)

        src_bits = reg.array(b32, COLS)
        src_vals = reg.array(f32, COLS)
        dst_bits = reg.array(b32, COLS)
        dst_vals = reg.array(f32, COLS)
        for col in range(COLS):
            val = reg.scalar(f32, init=float(col + 1))
            ptx.inst.mov.b32(src_vals[col], val)
            ptx.inst.mov.b32(src_bits[col], src_vals[col])

        with ptx.if_(active_lane):
            ptx.tcgen05.st(
                tmem_addr,
                [src_bits[i] for i in range(COLS)],
                shape="32x32b",
                count=64,
                dtype="b32",
            )
            ptx.tcgen05.wait_st()
            ptx.tcgen05.ld(
                [dst_bits[i] for i in range(COLS)],
                tmem_addr,
                shape="32x32b",
                count=64,
                dtype="b32",
            )

        (po,) = ptx.global_ptrs(O)
        row_off = reg.scalar(u32)
        ptx.inst.mul.lo.u32(row_off, tid, COLS)
        base_ptr = po + ((row_off) << 2)
        with ptx.if_(active_lane):
            for col in range(COLS):
                ptx.inst.mov.b32(dst_vals[col], dst_bits[col])
                ptx.inst.st.global_.f32(ptx.addr(base_ptr, col * 4), dst_vals[col])

        with ptx.if_(alloc_warp):
            ptx.tcgen05.dealloc(tmem_base, 512)
            ptx.tcgen05.relinquish_alloc_permit()
        ptx.ret()

    return k


def run_torch():
    x = torch.zeros((1, 1), device="cuda", dtype=torch.float32)
    out = build()(x)
    torch.cuda.synchronize()
    print(out[:4, :16].cpu())


if __name__ == "__main__":
    run_torch()