How It Works¶
This page is a walkthrough of what happens when you write:
from pyptx import kernel, reg, ptx
from pyptx.types import f32, u32
@kernel(arch="sm_90a")
def tiny():
x = reg.scalar(u32)
ptx.inst.mov.u32(x, ptx.special.tid.x())
ptx.ret()
tiny() # triggers trace + emit + driver JIT on first call
Between "tiny()" on line 8 and "cuLaunchKernel" a few microseconds later, the function body traverses five compiler stages. There is no intermediate optimizer between your Python and the PTX — one call maps to one instruction by construction. The machinery in this page is what enforces that.
The Five Stages¶
┌───────┐ ┌──────────┐ ┌──────────┐ ┌──────┐ ┌────────┐
Python ─▶│ trace │─▶│ optimize │─▶│ assemble │─▶│ emit │─▶│ driver │─▶ SASS
└───────┘ └──────────┘ └──────────┘ └──────┘ └────────┘
| Stage | Input | Output | Source |
|---|---|---|---|
| trace | Python function body | list[Statement] |
pyptx/_trace.py |
| optimize | list[Statement] |
list[Statement] |
pyptx/ir/optimize.py |
| assemble | statements + decls | ir.Module |
pyptx/kernel.py |
| emit | ir.Module |
PTX source text | pyptx/emitter/ |
| driver JIT | PTX text | CUfunction |
cuModuleLoadData via cuda-python |
And one inverse direction, used by the transpiler and round-trip tests:
Emitter (stage 4) and parser are inverse functions over the IR. The
entire tests/corpus/ — 218+ real PTX files — round-trips
byte-identically through parse → emit.
The rest of this page walks each stage.
Stage 1: Tracing¶
The @kernel decorator doesn't compile anything at decoration time.
It stores the Python function and the arg specs, returns a callable
wrapper, and waits. On the first call, the wrapper:
- Binds symbolic shape dims against the actual tensor shapes.
- Opens a
TraceContext. - Calls the Python function body. Every
ptx.*andreg.*call records an IR node into the context. - Takes the accumulated IR and hands it to stages 2–5.
The whole tracing machinery lives in pyptx/_trace.py — 105 lines.
TraceContext¶
class TraceContext:
def __init__(self, *, ptx_version: tuple[int, int] | None = None) -> None:
self.reg_decls: list[RegDecl] = []
self.var_decls: list[VarDecl] = []
self.statements: list[Statement] = []
self.ptx_version: tuple[int, int] | None = ptx_version
self._label_counter = 0
self._reg_counter: dict[str, int] = {}
self._if_stack: list[tuple[str, str]] = []
self.dyn_smem_offset: int = 0
self.force_dynamic_smem: bool = False
self._scope_depth: int = 0
Three buffers:
reg_decls— register declarations (.reg .u32 %r<16>;). Hoisted to the top of the function body at emit time.var_decls— variable declarations (.shared .align 128 .b8 ...;). Also hoisted.statements— the instructions, labels, and inner scope blocks in emit order.
The split between reg_decls and statements exists because PTX
requires declarations before any instruction that uses them. Rather
than thread this requirement into every DSL call, the tracer just
parks decls in a separate bucket and concatenates at emit time.
Thread-local activation¶
The context is stored in Python's threading.local:
_local = threading.local()
@contextmanager
def trace_scope(*, ptx_version=None):
ctx = TraceContext(ptx_version=ptx_version)
old = getattr(_local, "ctx", None)
_local.ctx = ctx
try:
yield ctx
finally:
_local.ctx = old
get_ctx() retrieves the current context; if none exists, every
ptx.* call raises with a clear "call this inside a @kernel"
message.
This is why you can't call ptx.inst.mov.u32(...) at module import
time — there's no active trace context, and the call errors
immediately rather than silently producing unused IR.
How ptx.inst.mov.u32(...) becomes an Instruction node¶
Each call like:
resolves to a dispatch function in pyptx/ptx.py that:
- Pulls the active
TraceContextviaget_ctx(). - Builds an
Instruction(opcode="mov", modifiers=(".u32",), operands=(...)). - Calls
ctx.emit(stmt)to append it toctx.statements.
The operands tuple is constructed from the Python arguments —
Reg objects become RegisterOperand, Python ints become
ImmediateOperand, ptx.addr(...) calls become AddressOperand,
and so on. The mapping is 1:1 — one Python call records one
Instruction node.
reg.scalar(u32) is slightly different: it records a RegDecl via
ctx.emit_reg_decl(...) and returns a Reg wrapper that knows its
register name. Subsequent uses of that Reg reference the same name.
Scopes: ptx.scope() and ptx.expr()¶
Two special blocks modify tracing behavior:
with ptx.scope():increments_scope_depth. While depth > 0,emit_reg_declroutes into the statement list instead of the hoistedreg_decls— so declarations inside the scope stay local to the{ ... }block. This is how block-local register allocation works.with ptx.expr():collects all instructions emitted inside the block into a singleCompoundExprnode (more on this below). Used by the transpiler's--sugarpass to group temp chains that came from one Python expression.
Control flow primitives¶
Python-level control flow is traced:
ptx.if_(pred)/ptx.else_()/(close of thewith): emitsetp/bra/labeltriples around the body, usingfresh_label("If")/fresh_label("End")for the targets.ptx.loop("name", pred=...): emit a labeled backward branch.ptx.range_(n): emit an unrolled loop. Python-side — the body gets tracedntimes, each with its own register state.
Python for i in range(...): with a constant n is handled at the
Python level, not the PTX level — the loop body is traced n
times, and the IR has no loop construct. This is how for g in range(8):
in an epilogue unrolls to 8 copies of the store sequence.
Stage 2: Optimize¶
After the trace finishes, the accumulated statements run through
one semantics-preserving pass: copy propagation. Source is in
pyptx/ir/optimize.py, 150 lines.
Why it exists¶
The DSL lets you write:
(Note: subscript assignment goes through RegArray.__setitem__ — that's
the interceptable path. Plain scalar assignment like x = y between two
reg.scalar variables is just Python name binding and emits no PTX.)
Operator overloading (+, *, &, >>) creates a fresh temp
register for the result, and the __setitem__ on the array emits a
mov.b32 acc[5], %fresh_temp. That extra mov is wasted — you
could have just written the expression's result directly into
acc[5].
Copy propagation removes it:
- Scan for
mov.bN %dst, %srcwhere%srcis a fresh temp. - Verify
%srcis defined once and used only by this mov. - Rename
%src→%dstin the definition, delete the mov, delete the.regdeclaration for%src.
Result: the PTX is identical to what you'd get from writing
ptx.inst.*(acc[5], ...) directly — no extra register, no extra mov.
The only pass¶
Copy propagation is currently the only post-trace pass. There is no:
- Instruction scheduling (order is fixed by Python evaluation).
- Dead code elimination (you're expected to not emit dead code).
- Constant folding beyond what's visible to Python.
- Register allocation (the DSL allocates, your hand).
The bet is that the user is writing PTX and knows what they want; the
compiler shouldn't second-guess. Copy propagation is the narrow
exception because RegArray.__setitem__ genuinely emits an
instruction the user didn't ask for.
Stage 3: Assemble The Module¶
The traced body goes into an ir.Function, which goes into an
ir.Module. This happens in pyptx/kernel.py:
module = Module(
version=Version(8, 4),
target=Target(("sm_90a",)),
address_size=AddressSize(64),
directives=(
# ...any smem var decls...
Function(
is_entry=True,
name="tiny",
params=(...), # built from in_specs/out_specs
body=body_statements,
directives=(
FunctionDirective("maxntid", (128, 1, 1)),
# ...other hints...
),
),
),
)
The @kernel decorator's kwargs (block, grid, smem, arch)
become directives attached to the function. Dynamic SMEM > 48 KB
gets a .extern .shared .align 128 .b8 dyn_smem[]; variable
declaration and flips a bit that the launch shim reads to call
cuFuncSetAttribute(MAX_DYNAMIC_SHARED_SIZE_BYTES, N) before
launching.
The IR¶
Before stage 4, a word on what the IR actually looks like.
Frozen dataclasses, tuple-valued collections¶
Every IR node is a @dataclass(frozen=True). Collections inside
nodes are tuples, not lists. This means:
- Nodes are immutable. You can't mutate a parsed kernel — you
rebuild it with
dataclasses.replace(...). - Nodes are hashable by structure.
ir1 == ir2is a deep structural comparison;hash(ir1)works. - The IR is a value type. Equality is "same shape, same fields," not pointer identity.
The value-type design is load-bearing for round-trip testing: parse
a kernel, emit it, parse the emit, compare the IRs — if they're
structurally equal, the round-trip is lossless. The entire
tests/corpus/ (218+ real PTX files) is validated this way.
The node hierarchy¶
Core statement nodes (things that appear in a function body):
Instruction— one PTX instruction.opcode="mov",modifiers=(".b32",),operands=(dst, src), optionalpredicate.Label— a branch target.RegDecl/VarDecl— declarations.PragmaDirective—.pragma "...".Comment/BlankLine— preserved for formatting.RawLine— the escape hatch. When the parser can't structurally parse a line (very rare — new ISA features, odd formatting), it captures raw text. The emitter emits the text verbatim. The parser never crashes.Block— a nested{ ... }scope.IntrinsicScope— a DSL-only wrapper around instructions emitted by an@ptx.intrinsicdecorated function. Rendered asBEGIN/ENDcomments in the emitted PTX so inspection tools can see which high-level call produced which instructions.
And operand nodes (things that appear in Instruction.operands):
RegisterOperand—%r0.ImmediateOperand—42/0xFF/0d3FF0000000000000. Stored as raw text so float literal precision is preserved exactly.LabelOperand— a label used as a branch target.VectorOperand—{%r0, %r1, %r2, %r3}(v4 loads/stores).AddressOperand—[base]or[base+offset].ParenthesizedOperand—(op1, op2, ...)used in call returns.NegatedOperand—!%p0for logical negation.PipeOperand—%p0|%p1, the dual predicate output ofsetp.
FormattingInfo: the round-trip secret¶
Each statement can carry a FormattingInfo with indent,
trailing, blank_lines_before, preceding_comments, and raw_line.
The parser fills these in when it reads source. The emitter uses them exactly when present — same indent, same trailing whitespace, same blank lines before this statement. When absent (trace output), the emitter uses sensible defaults (4-space indent, one statement per line).
This is why parser → emitter is byte-identical on the corpus.
Without FormattingInfo, round-trip would be semantically correct
but not byte-identical (one-vs-four spaces of indent, reordered
whitespace, etc.). The corpus contains real CUTLASS, Triton,
DeepGEMM, TK output — irregular formatting is the rule, not the
exception.
CompoundExpr: ptx.expr() groups¶
Not in the Statement union; handled via duck-typed instructions
attribute in the emitter. Represents a group of instructions traced
from a single Python expression like:
The emitted PTX is identical — CompoundExpr is cosmetic grouping
only. The transpiler's --sugar pass produces these to re-group
long temp chains that came from one high-level expression in the
original source.
Stage 4: Emit¶
The IR → text path lives in pyptx/emitter/emitter.py, 405 lines.
Structurally it's a visitor over Module → Directive → Function
→ Statement → Operand, with one _emit_* function per node
type.
The top-level entry:
def emit(module: Module) -> str:
if module.raw_source is not None:
return module.raw_source # parsed-from-source shortcut
parts = []
# ...header + directives...
return "\n".join(parts) + "\n"
The raw_source shortcut is a round-trip optimization: if you parsed
a module and didn't modify it, emit just returns the original text.
Only when you construct or modify IR does the emitter actually walk
the tree.
Per-statement emission¶
For Instruction:
The emitter concatenates opcode + modifiers without spaces
(mov.b32), then comma-separates operands, then adds the trailing
;. If FormattingInfo specifies a leading indent or a trailing
comment, those are reproduced.
For RegDecl:
For a Label, it's the label name followed by : at the
appropriate indent.
For an IntrinsicScope, the emitter wraps the inner statements in
// BEGIN name(args_repr) and // END name(args_repr) comments.
The enclosed instructions are emitted normally — the comments are
for humans and tooling, not PTX semantics.
Why the emitter is simple¶
PTX is an assembly language — instruction, modifiers, operands,
terminating semicolon. There's no nesting, no expression grammar,
no type propagation. A visitor with one case per IR node type
covers everything. The ~400 lines of emitter.py handle real
CUTLASS and DeepGEMM output with no special cases.
Stage 5: Driver JIT¶
The emitted PTX text goes to NVIDIA's driver via cuModuleLoadData:
# pyptx/jax_support.py
module = cuda.cuModuleLoadData(ptx_bytes)
fn = cuda.cuModuleGetFunction(module, entry_name)
The driver JITs PTX → SASS (NVIDIA's real machine code) at load
time. The result is cached by (ptx_string, arch) so repeat calls
don't retrigger the JIT.
No ptxas required at install. No CUDA toolkit required beyond the
driver. The cuda-python package provides the binding; the driver
itself ships with the GPU.
Launch is then cuLaunchKernel(fn, grid_x, grid_y, grid_z, block_x,
block_y, block_z, smem_bytes, stream, args, ...). JAX and PyTorch
route through the tiny C++ shim at pyptx/_shim/pyptx_launch.cc so
the call is issued on the correct stream that the framework is
sequencing on.
The Reverse Direction: Parser¶
pyptx/parser/ turns PTX text back into IR. Three modules:
tokens.py— token types (77 lines).lexer.py— source text → stream of tokens (327 lines).parser.py— tokens → IR (1246 lines, recursive descent).
The parser is opcode-agnostic. It doesn't know mov vs
wgmma.mma_async — it parses the universal structure:
and produces an Instruction node with the right fields. This is
why new ISA features (Blackwell tcgen05.*, future Thor instructions)
parse correctly without any parser changes — they're just another
opcode with modifiers.
When the parser hits something it can't structurally parse (unusual
directive, inline asm with quirky escaping), it captures a RawLine
and moves on. The emitter emits RawLine verbatim. The kernel is
still valid IR; you just can't structurally modify that particular
line without reparsing it yourself.
Byte-identical round-trip¶
The test tests/test_roundtrip.py runs:
for path in corpus_files: # 218+ real PTX files
text = path.read_text()
ir = parse(text)
emitted = emit(ir)
assert emitted == text # byte-for-byte
This passes for CUTLASS kernels, Triton output, DeepGEMM, fast.cu,
ThunderKittens examples, and the Mamba-SSM kernels. The combination
of FormattingInfo preservation + raw_source fallback +
RawLine escape hatch is what makes it possible.
The transpiler (pyptx/codegen/) depends on this round-trip
property: it parses PTX into IR, runs rewriting passes (name
demangling, loop raising, expression grouping), and emits executable
Python. Every PTX kernel the transpiler accepts is one that survives
round-trip.
IntrinsicScope and @ptx.intrinsic¶
A small DSL surface detail worth knowing. @ptx.intrinsic wraps a
function that emits multiple instructions; the trace captures those
instructions into an IntrinsicScope:
@ptx.intrinsic
def reduce_sum(reg_in):
# ...a dozen shfl.bfly.sync + add.f32 instructions...
pass
# In a kernel:
ptx.warp.reduce_sum(sum_sq) # emits IntrinsicScope(name="reduce_sum", ...)
In the emitted PTX, this shows up as:
The comments are for humans reading the PTX. The parser sees them as
comments and discards the intrinsic grouping — a round-trip produces
the same instructions, just without the scope wrapper. That's fine
because IntrinsicScope is a construction-time concept, not a
semantic one.
Spec Validation¶
There's a small companion system that isn't part of the compiler
proper but prevents a class of user errors before trace even runs:
pyptx/spec/. It holds a declarative description of the PTX ISA —
which modifiers combine, what operand types each opcode takes, how
many destinations vs sources — and validates ptx.inst.* calls
against that spec.
When you write:
the spec validator catches it at trace time with a message like
"mov.u32 expects 2 operands, got 3," instead of producing broken PTX
that fails at cuModuleLoadData with a harder-to-debug error.
The spec is in pyptx/spec/ptx.py — 930 lines of data. The
validator (validate.py, 660 lines) is called from the ptx.inst.*
dispatch.
Why This Design¶
Five design decisions, in order of how much they matter:
- Frozen-dataclass IR with tuple-valued collections. Makes the
IR a value type: hashable, comparable by structure, immutable.
Round-trip testing is
a == b, not a custom walker. Rewrites usedataclasses.replace— no mutation accidents. - FormattingInfo on every statement. What makes byte-identical round-trip possible on real-world kernels with idiosyncratic formatting. Cheap — it's a pointer-sized field on each node that most code ignores.
- Opcode-agnostic parser. New ISA features parse for free. The
parser doesn't know about
tcgen05.mmaor any future instruction; it just parses "opcode.modifiers operands;" and the IR holds the strings. The validator (spec) knows the semantics, but the IR layer is ISA-blind. - One Python call = one Instruction node. No lowering, no scheduler, no optimizer between trace and emit (except the one copy-propagation pass that removes setitem movs). The user controls instruction order; the compiler respects it.
- Driver JIT, not ptxas. No CUDA toolkit at install time. PTX
strings go to
cuModuleLoadData; the driver produces SASS. Every supported CUDA driver can load every PTX the emitter produces.
The whole compiler — tracer + IR + optimizer + emitter + parser — is under 3000 lines of Python. The IR alone is 350 lines; most of the line count is in the parser (recursive descent over the full PTX grammar) and the ISA spec (data tables). There is no code generator in the conventional sense; emit is a visitor that stringifies already-complete instructions.
What To Read Next¶
- PTX Namespace — reference for every DSL call that appears in stage 1 (trace).
- Transpiler — the parser + emitter combined into a PTX → Python converter.
- Philosophy — the "why" of "one call = one instruction," restated at a higher level.
pyptx/_trace.py,pyptx/ir/nodes.py,pyptx/emitter/emitter.py— the source is ~900 lines total and readable end-to-end.