pyptx.torch_support¶
This page is generated from source docstrings and public symbols.
PyTorch runtime integration for :func:pyptx.kernel.
This module is the Torch-side counterpart to :mod:pyptx.jax_support.
The PTX compilation and launch-record machinery is shared; this module
focuses on the Torch-specific boundary:
- detect
torch.Tensorinputs - collect device pointers and the active CUDA stream
- allocate output tensors
- launch through the raw shim entry point
- expose a
torch.compile-compatible custom-op wrapper
The same C++ shim backs both frameworks:
PyptxLaunchis used by the JAX/XLA FFI pathpyptx_shim_launch_rawis used by the Torch ctypes path
Current scope:
- eager mode works
torch.compileworks throughtorch.library.custom_opplus a fake/meta implementation- inputs are expected to be contiguous CUDA tensors
- backward/autograd via
differentiable_kernel
Public API¶
is_torch_tensorany_torch_tensorscall_kernel_via_torchextract_input_shapesget_or_register_torch_opcall_kernel_via_torch_compiledifferentiable_kernel
is_torch_tensor¶
- Kind:
function
Return True iff obj is a PyTorch tensor. False on
non-tensor inputs and on machines where torch isn't installed.
any_torch_tensors¶
- Kind:
function
True if ANY of inputs is a torch.Tensor. Used by the
dispatch logic in Kernel.__call__ to decide between the
JAX path and the PyTorch path.
call_kernel_via_torch¶
- Kind:
function
call_kernel_via_torch(*inputs, cubin_handle: 'int', out_specs: 'Sequence[Tile]', out_shape_env: 'dict[str, int]', grid: 'tuple[int, int, int]', block: 'tuple[int, int, int]', cluster: 'tuple[int, int, int]' = (1, 1, 1), smem_bytes: 'int' = 0) -> 'Any'
Launch a pyptx kernel with PyTorch tensor inputs.
The shim's launch config is already registered under cubin_handle
(via register_launch_config during tracing). This function only:
- Allocates output tensors with the right shape / dtype on the same CUDA device as the first input.
- Builds a
void**array of device pointers in inputs-then-outputs order — matching what the shim's FFI path builds for JAX. - Calls
pyptx_shim_launch_raw(handle, stream, ptrs, n). - Returns the output tensor(s).
extract_input_shapes¶
- Kind:
function
Return the concrete shape of each input tensor.
get_or_register_torch_op¶
- Kind:
function
get_or_register_torch_op(cubin_handle: 'int', out_specs: 'Sequence[Tile]', out_shape_env: 'dict[str, int]', grid: 'tuple[int, int, int]', block: 'tuple[int, int, int]', cluster: 'tuple[int, int, int]', smem_bytes: 'int')
Return a callable that takes (*input_tensors) and returns
output tensor(s). The callable is a torch.library.custom_op
that survives torch.compile / Dynamo tracing.
First call with a given cubin_handle registers the op;
subsequent calls reuse it.
call_kernel_via_torch_compile¶
- Kind:
function
call_kernel_via_torch_compile(*inputs, cubin_handle: 'int', out_specs: 'Sequence[Tile]', out_shape_env: 'dict[str, int]', grid: 'tuple[int, int, int]', block: 'tuple[int, int, int]', cluster: 'tuple[int, int, int]' = (1, 1, 1), smem_bytes: 'int' = 0) -> 'Any'
torch.compile-compatible launch path.
Wraps call_kernel_via_torch inside a registered
torch.library.custom_op so Dynamo can trace through it.
Returns a single tensor if there's one output, else a tuple.
differentiable_kernel¶
- Kind:
function
differentiable_kernel(forward_kernel, backward_kernel, *, save_for_backward: 'Sequence[int] | None' = None, num_grad_inputs: 'int | None' = None)
Wrap a forward + backward pyptx kernel pair for torch.autograd.
Usage::
from pyptx.torch_support import differentiable_kernel
fwd = build_my_forward(M, N)
bwd = build_my_backward(M, N)
my_op = differentiable_kernel(
fwd, bwd,
save_for_backward=[0, 1], # save inputs 0 and 1
)
# Now supports autograd:
x = torch.randn(M, N, device="cuda", requires_grad=True)
w = torch.randn(N, device="cuda", requires_grad=True)
out = my_op(x, w)
out.sum().backward()
print(x.grad, w.grad)
Args:
forward_kernel: A pyptx Kernel for the forward pass.
backward_kernel: A pyptx Kernel for the backward pass.
Called with (*saved_tensors, *grad_outputs) and must
return one gradient per input.
save_for_backward: Indices of forward inputs to save for the
backward pass. Defaults to saving all inputs.
num_grad_inputs: Number of inputs that need gradients.
Defaults to the number of forward inputs.