Skip to content

Ampere / Softmax

View on GitHub examples/ampere/softmax.py

Overview

Ampere row-wise softmax example using the maintained pyptx kernel path.

Run python examples/ampere/softmax.py to execute both a jax.jit path and a PyTorch eager path on sm_80 (A100).

The kernel itself is the maintained examples/hopper/softmax.py — v4 loads, two-pass max + sum reductions, ex2(fma(x, log2e, -m·log2e)) fold for the exp. The only arch-specific bit is the sm_80 PTX target.

Source

Full source
"""Ampere row-wise softmax example using the maintained pyptx kernel path.

Run ``python examples/ampere/softmax.py`` to execute both a ``jax.jit``
path and a PyTorch eager path on ``sm_80`` (A100).

The kernel itself is the maintained ``examples/hopper/softmax.py`` —
v4 loads, two-pass max + sum reductions, ``ex2(fma(x, log2e, -m·log2e))``
fold for the exp. The only arch-specific bit is the ``sm_80`` PTX
target.
"""
from __future__ import annotations

import os

os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")

import jax
import jax.numpy as jnp
import numpy as np

try:
    from pyptx.examples.hopper.softmax import build_softmax as _build_softmax
    from pyptx.examples.hopper.softmax import softmax_ref
except ImportError:
    from examples.hopper.softmax import build_softmax as _build_softmax
    from examples.hopper.softmax import softmax_ref


def build_softmax(B: int, N: int):
    return _build_softmax(B, N, arch="sm_80")


def _run_jax_case(B: int, N: int) -> None:
    k = build_softmax(B, N)
    np.random.seed(B * 7919 + N)
    x_np = np.random.randn(B, N).astype(np.float32)
    x = jnp.asarray(x_np)

    @jax.jit
    def fn(x):
        return k(x)

    out = np.asarray(fn(x))
    ref = np.asarray(softmax_ref(x))
    diff = float(np.abs(out - ref).max())
    ok = bool(np.allclose(out, ref, atol=1e-4, rtol=1e-3))
    status = "OK  " if ok else "FAIL"
    print(f"[JAX  {status}] B={B:4d} N={N:5d}  max_abs={diff:.3e}")


def _run_torch_case(B: int, N: int) -> None:
    import torch

    k = build_softmax(B, N)
    np.random.seed(B * 7919 + N)
    x_np = np.random.randn(B, N).astype(np.float32)
    x = torch.tensor(x_np, device="cuda")

    out = k(x)
    torch.cuda.synchronize()
    ref = torch.softmax(x, dim=-1)
    diff = float((out - ref).abs().max())
    ok = bool(torch.allclose(out, ref, atol=1e-4, rtol=1e-3))
    status = "OK  " if ok else "FAIL"
    print(f"[Torch{status}] B={B:4d} N={N:5d}  max_abs={diff:.3e}")


def main() -> None:
    _ = (jnp.ones((4,), dtype=jnp.float32) + 1).block_until_ready()

    for B, N in [(4, 64), (16, 512), (32, 1024), (128, 2048), (256, 4096)]:
        _run_jax_case(B, N)
        _run_torch_case(B, N)


if __name__ == "__main__":
    main()