pyptx.jax_support¶
This page is generated from source docstrings and public symbols.
JAX runtime integration for :func:pyptx.kernel.
This module owns the JAX/XLA execution path for pyptx kernels:
- resolve shapes and template parameters
- trace the kernel body to PTX
- compile PTX to a driver-loadable kernel handle
- register launch metadata with the C++ shim
- build a
jax.ffi.ffi_callthat launches on XLA's CUDA stream
In other words, this module is the bridge between a traced PTX kernel
and an actual @jax.jit call site.
Important design point:
The C++ shim is intentionally thin. Most of the interesting runtime logic lives here in Python:
- PTX compilation
- launch metadata registration
- TMA descriptor synthesis
- process-local kernel handle bookkeeping
On machines without the full CUDA/JAX runtime stack, the tracing and lowering parts still work. That lets codegen and inspection workflows operate without requiring a live GPU launch environment.
Public API¶
CubinRecordCubinRegistryget_cubin_registryshim_is_availableshim_load_errorcompile_ptx_to_cubinregister_launch_configadd_scalar_param_to_shimsynthesize_tma_descriptorsynthesize_tma_descriptor_3dadd_tma_spec_to_shimset_mock_ffi_callbackensure_ffi_registeredcall_kernel_via_ffi
CubinRecord¶
- Kind:
class
class CubinRecord(handle: 'int', ptx_source: 'str', kernel_name: 'str', smem_bytes: 'int' = 0, grid: 'tuple[int, int, int]' = (1, 1, 1), block: 'tuple[int, int, int]' = (1, 1, 1), cu_function: 'Optional[int]' = None, module: 'Any' = None, cubin_bytes: 'Optional[bytes]' = None) -> None
A compiled kernel + its launch config.
cu_function is the CUfunction pointer (as an int) returned by
cuModuleGetFunction. It's None on laptop builds where cuda-python
isn't installed or the driver isn't available. module is kept
alive so the function pointer stays valid for the lifetime of the
kernel.
Members¶
smem_bytes¶
-
Kind:
attribute -
Value:
0
int([x]) -> integer int(x, base=10) -> integer
Convert a number or string to an integer, or return 0 if no arguments are given. If x is a number, return x.int(). For floating-point numbers, this truncates towards zero.
If x is not a number or if base is given, then x must be a string, bytes, or bytearray instance representing an integer literal in the given base. The literal can be preceded by '+' or '-' and be surrounded by whitespace. The base defaults to 10. Valid bases are 0 and 2-36. Base 0 means to interpret the base from the string as an integer literal.
int('0b100', base=0) 4
grid¶
-
Kind:
attribute -
Value:
(1, 1, 1)
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
block¶
-
Kind:
attribute -
Value:
(1, 1, 1)
Built-in immutable sequence.
If no argument is given, the constructor returns an empty tuple. If iterable is specified the tuple is initialized from iterable's items.
If the argument is a tuple, the return value is the same object.
cu_function¶
- Kind:
attribute
No docstring yet.
module¶
- Kind:
attribute
No docstring yet.
cubin_bytes¶
- Kind:
attribute
No docstring yet.
CubinRegistry¶
- Kind:
class
Thread-safe process-local table mapping handle → CubinRecord.
Members¶
register(ptx_source: 'str', kernel_name: 'str', cubin_bytes: 'Optional[bytes]' = None, smem_bytes: 'int' = 0, grid: 'tuple[int, int, int]' = (1, 1, 1), block: 'tuple[int, int, int]' = (1, 1, 1), cu_function: 'Optional[int]' = None, module: 'Any' = None) -> 'int'¶
- Kind:
method
Insert a compiled kernel record and return its process-local handle.
get(handle: 'int') -> 'Optional[CubinRecord]'¶
- Kind:
method
Look up a previously registered kernel handle.
clear() -> 'None'¶
- Kind:
method
Drop all registered kernel records.
get_cubin_registry¶
- Kind:
function
Return the process-local cubin registry singleton.
shim_is_available¶
- Kind:
function
True if the C++ shim is loaded and ready.
shim_load_error¶
- Kind:
function
Return the last shim-load error, or None if the shim loaded fine.
compile_ptx_to_cubin¶
- Kind:
function
compile_ptx_to_cubin(ptx_source: 'str', arch: 'str', kernel_name: 'str' = '', dynamic_smem_bytes: 'int' = 0) -> 'Optional[tuple[int, Any]]'
Driver-JIT a PTX string into an executable CUfunction.
Returns (cu_function_ptr, cu_module) on success. The module is
returned so the caller can hold a reference and keep the function
pointer valid for the life of the kernel.
Returns None on laptops or CI machines without cuda-python / a CUDA driver — the caller may still register PTX metadata for tracing tests, but any attempt to launch will fail loudly.
The kernel_name parameter is the PTX entry symbol (e.g.
"vector_add"). If empty, we try to extract it from the
.visible .entry line in the PTX source.
register_launch_config¶
- Kind:
function
register_launch_config(handle: 'int', cu_function: 'int', grid: 'tuple[int, int, int]', block: 'tuple[int, int, int]', cluster: 'tuple[int, int, int]' = (1, 1, 1), smem_bytes: 'int' = 0) -> 'None'
Populate the shim's launch registry with a (handle, cu_fn, ...) entry.
Called once per handle, right after compilation. The shim's FFI handler will read this entry at kernel-launch time.
add_scalar_param_to_shim¶
- Kind:
function
Register a scalar raw .param value with the shim's launch config.
synthesize_tma_descriptor¶
- Kind:
function
synthesize_tma_descriptor(shape: 'tuple[int, ...]', dtype, layout, box_shape: 'tuple[int, ...] | None' = None, placeholder_ptr: 'int' = 0) -> 'tuple[Any, int, int]'
Build a 128-byte CUtensorMap for (shape, dtype, layout).
Returns (host_tmap, host_blob_ptr, device_blob_ptr):
- host_tmap is the cuda-python CUtensorMap Python object; keep it
alive for the lifetime of the kernel.
- host_blob_ptr is the raw 128-byte struct address inside the
host_tmap (what cuTensorMapReplaceAddress wants).
- device_blob_ptr is a freshly-allocated 128-byte device buffer,
which the shim uploads the patched host blob into at each launch.
box_shape defaults to a sensible tile for the given swizzle/dtype. placeholder_ptr is the globalAddress stored in the descriptor at creation time; the shim replaces it on each launch.
synthesize_tma_descriptor_3d¶
- Kind:
function
synthesize_tma_descriptor_3d(height: 'int', width: 'int', dtype, box_major: 'int', box_minor: 'int', *, swizzle_128b: 'bool' = True, padding: 'bool' = False, placeholder_ptr: 'int' = 0) -> 'tuple'
Build a 3D CUtensorMap matching fast.cu's create_tensor_map.
The 3D layout reshapes a (height, width) row-major matrix into
(64_elements, height, width/64) so TMA can handle tiles wider
than 64 bf16 elements (which exceeds the 128B swizzle line).
Args: height: number of rows (M for A, N for B, N for C). width: number of columns (K for A, K for B, M for C). dtype: element type (e.g. bf16). box_major: tile rows to load (BM for A, BN for B). box_minor: tile columns to load (BK for A/B, BM/consumers for C). swizzle_128b: use 128B swizzle (True for A/B, False for C). padding: pad the innermost box dim to 72 (True for C store). placeholder_ptr: global address (patched at launch time).
Returns:
(host_tmap, host_blob_ptr, device_blob_ptr) — same as
synthesize_tma_descriptor.
add_tma_spec_to_shim¶
- Kind:
function
add_tma_spec_to_shim(handle: 'int', xla_arg_index: 'int', host_blob_ptr: 'int', device_blob_ptr: 'int') -> 'None'
Register a TMA spec with the shim's per-handle launch config.
set_mock_ffi_callback¶
- Kind:
function
Install a mock callback (legacy test hook; pre-shim).
ensure_ffi_registered¶
- Kind:
function
Register the pyptx_launch FFI target with JAX, if not already.
Loads the C++ shim, wraps its PyptxLaunch symbol in a PyCapsule
via jax.ffi.pycapsule, and registers it for the CUDA platform
under the name "pyptx_launch" with typed FFI (api_version=1).
Returns True if registration succeeded. Returns False (rather than raising) on laptops without the shim or without JAX — so tracing tests can still run.
call_kernel_via_ffi¶
- Kind:
function
call_kernel_via_ffi(*inputs, cubin_handle: 'int', out_specs: 'Sequence[Tile]', out_shape_env: 'dict[str, int]', grid: 'tuple[int, int, int]', block: 'tuple[int, int, int]', cluster: 'tuple[int, int, int]' = (1, 1, 1), smem_bytes: 'int' = 0) -> 'Any'
Build a jax.ffi.ffi_call for this kernel invocation.
Uses typed FFI (api_version=1 / custom_call_api_version=4). The only
attribute passed to the handler is cubin_handle — grid, block,
and smem are already registered in the shim under that handle.
Returns a JAX array (or tuple of arrays) matching out_specs.