From 12cc124dcd03760bec48d23c1a4dac1c9a695266 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 17 Jan 2026 18:04:22 -0800 Subject: [PATCH 001/106] add cutlass ragged dot --- skyrl-tx/tx/ffi/README.md | 14 +++ skyrl-tx/tx/ffi/__init__.py | 4 + skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh | 44 +++++++ skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 150 ++++++++++++++++++++++++ skyrl-tx/tx/ffi/ragged_dot_ffi.py | 82 +++++++++++++ skyrl-tx/tx/layers/util.py | 18 +++ 6 files changed, 312 insertions(+) create mode 100644 skyrl-tx/tx/ffi/README.md create mode 100644 skyrl-tx/tx/ffi/__init__.py create mode 100644 skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh create mode 100644 skyrl-tx/tx/ffi/ragged_dot_ffi.cu create mode 100644 skyrl-tx/tx/ffi/ragged_dot_ffi.py diff --git a/skyrl-tx/tx/ffi/README.md b/skyrl-tx/tx/ffi/README.md new file mode 100644 index 000000000..624963060 --- /dev/null +++ b/skyrl-tx/tx/ffi/README.md @@ -0,0 +1,14 @@ +Build (Linux + CUDA) + +1) Build the shared library: +``` +export CUTLASS_DIR=/path/to/cutlass +tx/ffi/build_ragged_dot_ffi.sh +``` + +2) Make the shared library discoverable: +- Copy `tx/ffi/_build/libragged_dot_ffi.so` to `tx/ffi/libragged_dot_ffi.so`, or +- Set `TX_RAGGED_DOT_FFI_PATH=/path/to/libragged_dot_ffi.so`. + +Notes: +- The FFI kernel expects bfloat16 inputs/outputs and int32 group metadata. diff --git a/skyrl-tx/tx/ffi/__init__.py b/skyrl-tx/tx/ffi/__init__.py new file mode 100644 index 000000000..3d2794504 --- /dev/null +++ b/skyrl-tx/tx/ffi/__init__.py @@ -0,0 +1,4 @@ +from tx.ffi.ragged_dot_ffi import is_available as ragged_dot_ffi_available +from tx.ffi.ragged_dot_ffi import ragged_dot as ragged_dot_ffi + +__all__ = ["ragged_dot_ffi", "ragged_dot_ffi_available"] diff --git a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh new file mode 100644 index 000000000..28c357a86 --- /dev/null +++ b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +if [[ -z "${CUTLASS_DIR:-}" ]]; then + echo "CUTLASS_DIR is not set. Point it to your CUTLASS checkout." >&2 + exit 1 +fi + +if [[ ! -d "${CUTLASS_DIR}" ]]; then + echo "CUTLASS_DIR does not exist: ${CUTLASS_DIR}" >&2 + exit 1 +fi + +PYTHON_BIN="${PYTHON_BIN:-python}" +JAX_INCLUDE_DIR="$("${PYTHON_BIN}" - <<'PY' +import os +import jaxlib +import jaxlib.xla_extension as xe +print(os.path.join(os.path.dirname(xe.__file__), "include")) +PY +)" + +NVCC_BIN="${NVCC_BIN:-nvcc}" +if ! command -v "${NVCC_BIN}" >/dev/null 2>&1; then + echo "nvcc not found. Set NVCC_BIN or ensure CUDA is on PATH." >&2 + exit 1 +fi + +OUT_DIR="${SCRIPT_DIR}/_build" +mkdir -p "${OUT_DIR}" + +"${NVCC_BIN}" \ + -O3 \ + -std=c++14 \ + -shared \ + -Xcompiler -fPIC \ + -I"${JAX_INCLUDE_DIR}" \ + -I"${CUTLASS_DIR}/include" \ + "${SCRIPT_DIR}/ragged_dot_ffi.cu" \ + -o "${OUT_DIR}/libragged_dot_ffi.so" + +echo "Built ${OUT_DIR}/libragged_dot_ffi.so" diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu new file mode 100644 index 000000000..79b0eb4c1 --- /dev/null +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -0,0 +1,150 @@ +#include +#include + +#include + +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +#include +#include +#include +#include + +namespace ffi = xla::ffi; + +using Dtype = cutlass::bfloat16_t; +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::RowMajor; +using LayoutC = cutlass::layout::RowMajor; +using Accum = float; +using Gemm = cutlass::gemm::device::Gemm; + +static ffi::Error CudaError(const char* message) { + return ffi::Error::Internal(message); +} + +ffi::Error RaggedDotCuda( + cudaStream_t stream, + ffi::Buffer lhs, + ffi::Buffer rhs, + ffi::Buffer group_sizes, + ffi::Buffer group_offset, + ffi::ResultBuffer out) { + auto lhs_dims = lhs.dimensions(); + auto rhs_dims = rhs.dimensions(); + auto group_sizes_dims = group_sizes.dimensions(); + auto group_offset_dims = group_offset.dimensions(); + + if (lhs_dims.size() != 2 || rhs_dims.size() != 3 || group_sizes_dims.size() != 1 || + group_offset_dims.size() != 1) { + return ffi::Error::InvalidArgument("Unexpected ragged_dot dimensions."); + } + + int64_t m64 = lhs_dims[0]; + int64_t k64 = lhs_dims[1]; + int64_t g_local64 = rhs_dims[0]; + int64_t rhs_k64 = rhs_dims[1]; + int64_t n64 = rhs_dims[2]; + int64_t g64 = group_sizes_dims[0]; + + if (k64 != rhs_k64) { + return ffi::Error::InvalidArgument("lhs/rhs K dimension mismatch."); + } + + if (m64 < 0 || k64 < 0 || n64 < 0 || g64 < 0 || g_local64 < 0) { + return ffi::Error::InvalidArgument("Invalid dimensions."); + } + + int32_t m = static_cast(m64); + int32_t k = static_cast(k64); + int32_t n = static_cast(n64); + int32_t g = static_cast(g64); + int32_t g_local = static_cast(g_local64); + + int32_t offset = 0; + cudaError_t err = cudaMemcpyAsync( + &offset, group_offset.typed_data(), sizeof(int32_t), cudaMemcpyDeviceToHost, stream); + if (err != cudaSuccess) { + return CudaError("Failed to copy group_offset."); + } + + std::vector sizes(static_cast(g)); + if (g > 0) { + err = cudaMemcpyAsync( + sizes.data(), + group_sizes.typed_data(), + static_cast(g) * sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream); + if (err != cudaSuccess) { + return CudaError("Failed to copy group_sizes."); + } + } + + err = cudaStreamSynchronize(stream); + if (err != cudaSuccess) { + return CudaError("Failed to synchronize stream."); + } + + if (offset < 0 || offset > g || offset + g_local > g) { + return ffi::Error::InvalidArgument("group_offset out of range."); + } + + std::vector offsets(static_cast(g) + 1); + offsets[0] = 0; + for (int32_t i = 0; i < g; ++i) { + offsets[static_cast(i) + 1] = offsets[static_cast(i)] + sizes[i]; + } + + err = cudaMemsetAsync(out->typed_data(), 0, static_cast(m) * n * sizeof(Dtype), stream); + if (err != cudaSuccess) { + return CudaError("Failed to zero output."); + } + + if (g_local == 0 || m == 0 || n == 0 || k == 0) { + return ffi::Error::Success(); + } + + Gemm gemm; + for (int32_t gi = offset; gi < offset + g_local; ++gi) { + int32_t rows = sizes[gi]; + if (rows == 0) { + continue; + } + + int32_t local = gi - offset; + const Dtype* A = reinterpret_cast(lhs.typed_data()) + + static_cast(offsets[gi]) * k; + const Dtype* B = reinterpret_cast(rhs.typed_data()) + + static_cast(local) * k * n; + Dtype* C = reinterpret_cast(out->typed_data()) + + static_cast(offsets[gi]) * n; + + Gemm::Arguments args( + {rows, n, k}, + {A, k}, + {B, n}, + {C, n}, + {C, n}, + {1.0f, 0.0f}); + + cutlass::Status status = gemm(args, stream); + if (status != cutlass::Status::kSuccess) { + return ffi::Error::Internal("cutlass gemm failed."); + } + } + + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + RaggedDotCuda, + RaggedDotCuda, + ffi::Ffi::Bind() + .Ctx>() + .Arg>() // lhs + .Arg>() // rhs + .Arg>() // group_sizes + .Arg>() // group_offset + .Ret>()); // out diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.py b/skyrl-tx/tx/ffi/ragged_dot_ffi.py new file mode 100644 index 000000000..bc096d9c2 --- /dev/null +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import ctypes +import os +from pathlib import Path + +import jax +import jax.numpy as jnp + +try: # JAX >= 0.8 + from jax import ffi as jax_ffi +except Exception: # pragma: no cover - older JAX fallback + from jax.experimental import ffi as jax_ffi + + +_REGISTERED = False +_LOAD_ERROR: Exception | None = None + + +def _find_library() -> Path | None: + env_path = os.environ.get("TX_RAGGED_DOT_FFI_PATH") + if env_path: + path = Path(env_path) + return path if path.exists() else None + + here = Path(__file__).resolve().parent + for name in ("libragged_dot_ffi.so", "ragged_dot_ffi.so"): + candidate = here / name + if candidate.exists(): + return candidate + return None + + +def _ensure_registered() -> bool: + global _REGISTERED, _LOAD_ERROR + if _REGISTERED: + return True + if _LOAD_ERROR is not None: + return False + + lib_path = _find_library() + if lib_path is None: + _LOAD_ERROR = FileNotFoundError("ragged_dot_ffi shared library not found.") + return False + + try: + lib = ctypes.cdll.LoadLibrary(str(lib_path)) + jax_ffi.register_ffi_target( + "ragged_dot_cuda", + jax_ffi.pycapsule(lib.RaggedDotCuda), + platform="CUDA", + ) + _REGISTERED = True + return True + except Exception as exc: # pragma: no cover - load/registration failures + _LOAD_ERROR = exc + return False + + +def is_available() -> bool: + return _ensure_registered() + + +def ragged_dot( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, + group_offset: jax.Array, +) -> jax.Array: + if not _ensure_registered(): + raise RuntimeError("ragged_dot_ffi is not available. Build and load the shared library first.") + + if lhs.dtype != jnp.bfloat16 or rhs.dtype != jnp.bfloat16: + raise NotImplementedError("ragged_dot_ffi supports bfloat16 only.") + if group_sizes.dtype != jnp.int32 or group_offset.dtype != jnp.int32: + raise NotImplementedError("ragged_dot_ffi expects int32 group_sizes and group_offset.") + if group_offset.shape != (1,): + raise ValueError("group_offset must have shape (1,).") + + out = jax.ShapeDtypeStruct((lhs.shape[0], rhs.shape[2]), lhs.dtype) + call = jax_ffi.ffi_call("ragged_dot_cuda", out, vmap_method="disabled") + return call(lhs, rhs, group_sizes, group_offset) diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index 0030c604d..7629242a8 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -4,6 +4,14 @@ from jax import numpy as jnp from jax.sharding import get_abstract_mesh, PartitionSpec +try: + from tx.ffi import ragged_dot_ffi, ragged_dot_ffi_available +except Exception: # pragma: no cover - optional GPU extension + ragged_dot_ffi = None + + def ragged_dot_ffi_available() -> bool: + return False + def ragged_dot( lhs: jax.Array, @@ -27,6 +35,16 @@ def ragged_dot( preferred_element_type=preferred_element_type, ) + if ( + ragged_dot_ffi_available() + and jax.default_backend() == "gpu" + and lhs.dtype == jnp.bfloat16 + and rhs.dtype == jnp.bfloat16 + and group_sizes.dtype == jnp.int32 + and group_offset.dtype == jnp.int32 + ): + return ragged_dot_ffi(lhs, rhs, group_sizes, group_offset) + assert group_offset.shape == (1,), "group_offset must have shape (1,)" offset = group_offset[0] m = lhs.shape[0] From 70dce5f553cc0a330a96cfcf9bc0de0e192e075b Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 05:08:31 +0000 Subject: [PATCH 002/106] update --- skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh index 28c357a86..38cdbc476 100644 --- a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh +++ b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh @@ -13,12 +13,10 @@ if [[ ! -d "${CUTLASS_DIR}" ]]; then exit 1 fi -PYTHON_BIN="${PYTHON_BIN:-python}" -JAX_INCLUDE_DIR="$("${PYTHON_BIN}" - <<'PY' +JAX_INCLUDE_DIR="$(uv run --extra gpu - <<'PY' import os import jaxlib -import jaxlib.xla_extension as xe -print(os.path.join(os.path.dirname(xe.__file__), "include")) +print(os.path.join(os.path.dirname(jaxlib.__file__), "include")) PY )" From d70e010341d0f37c43c469fea04acce09ac12492 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 17 Jan 2026 21:11:16 -0800 Subject: [PATCH 003/106] update --- skyrl-tx/tx/ffi/README.md | 2 +- skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tx/ffi/README.md b/skyrl-tx/tx/ffi/README.md index 624963060..cd7efc877 100644 --- a/skyrl-tx/tx/ffi/README.md +++ b/skyrl-tx/tx/ffi/README.md @@ -1,6 +1,6 @@ Build (Linux + CUDA) -1) Build the shared library: +1) Build the shared library (requires CUDA nvcc with C++17 support): ``` export CUTLASS_DIR=/path/to/cutlass tx/ffi/build_ragged_dot_ffi.sh diff --git a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh index 38cdbc476..68486d934 100644 --- a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh +++ b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh @@ -31,7 +31,7 @@ mkdir -p "${OUT_DIR}" "${NVCC_BIN}" \ -O3 \ - -std=c++14 \ + -std=c++17 \ -shared \ -Xcompiler -fPIC \ -I"${JAX_INCLUDE_DIR}" \ From ac85d10ff6d7175a1ba8cac7a4f723aa455d1aeb Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 17 Jan 2026 21:12:20 -0800 Subject: [PATCH 004/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 79b0eb4c1..5150daf23 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -24,7 +24,7 @@ static ffi::Error CudaError(const char* message) { return ffi::Error::Internal(message); } -ffi::Error RaggedDotCuda( +ffi::Error RaggedDotCudaImpl( cudaStream_t stream, ffi::Buffer lhs, ffi::Buffer rhs, @@ -140,7 +140,7 @@ ffi::Error RaggedDotCuda( XLA_FFI_DEFINE_HANDLER_SYMBOL( RaggedDotCuda, - RaggedDotCuda, + RaggedDotCudaImpl, ffi::Ffi::Bind() .Ctx>() .Arg>() // lhs From 106e4ae4686a75a5b2d6984de42820145794b3bb Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 17 Jan 2026 21:26:11 -0800 Subject: [PATCH 005/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.py b/skyrl-tx/tx/ffi/ragged_dot_ffi.py index bc096d9c2..15e5910be 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.py +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.py @@ -78,5 +78,5 @@ def ragged_dot( raise ValueError("group_offset must have shape (1,).") out = jax.ShapeDtypeStruct((lhs.shape[0], rhs.shape[2]), lhs.dtype) - call = jax_ffi.ffi_call("ragged_dot_cuda", out, vmap_method="disabled") + call = jax_ffi.ffi_call("ragged_dot_cuda", out, vmap_method=None) return call(lhs, rhs, group_sizes, group_offset) From 94eb6251805716438ba585727cc82db2c3ef93e2 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 17 Jan 2026 21:31:04 -0800 Subject: [PATCH 006/106] add backward --- skyrl-tx/tx/ffi/ragged_dot_ffi.py | 64 ++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.py b/skyrl-tx/tx/ffi/ragged_dot_ffi.py index 15e5910be..e1a434996 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.py +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.py @@ -5,6 +5,7 @@ from pathlib import Path import jax +from jax import lax import jax.numpy as jnp try: # JAX >= 0.8 @@ -61,7 +62,34 @@ def is_available() -> bool: return _ensure_registered() -def ragged_dot( +def _ragged_dot_ref( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, + group_offset: jax.Array, +) -> jax.Array: + if group_offset.shape != (1,): + raise ValueError("group_offset must have shape (1,).") + + offset = group_offset[0] + m = lhs.shape[0] + g_local = rhs.shape[0] + + cumsum = jnp.cumulative_sum(group_sizes, include_initial=True) + shard_start = cumsum[offset] + shard_end = cumsum[offset + g_local] + + token_idx = jnp.arange(m) + valid_mask = (token_idx >= shard_start) & (token_idx < shard_end) + + local_group_sizes = lax.dynamic_slice_in_dim(group_sizes, offset, g_local, axis=0) + adjusted_group_sizes = local_group_sizes.at[0].add(shard_start).at[-1].add(m - shard_end) + + result = lax.ragged_dot(lhs, rhs, adjusted_group_sizes) + return jnp.where(valid_mask[:, None], result, 0) + + +def _ragged_dot_ffi_call( lhs: jax.Array, rhs: jax.Array, group_sizes: jax.Array, @@ -80,3 +108,37 @@ def ragged_dot( out = jax.ShapeDtypeStruct((lhs.shape[0], rhs.shape[2]), lhs.dtype) call = jax_ffi.ffi_call("ragged_dot_cuda", out, vmap_method=None) return call(lhs, rhs, group_sizes, group_offset) + + +@jax.custom_vjp +def ragged_dot( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, + group_offset: jax.Array, +) -> jax.Array: + return _ragged_dot_ffi_call(lhs, rhs, group_sizes, group_offset) + + +def _ragged_dot_fwd( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, + group_offset: jax.Array, +): + y = _ragged_dot_ffi_call(lhs, rhs, group_sizes, group_offset) + return y, (lhs, rhs, group_sizes, group_offset) + + +def _ragged_dot_bwd(res, g): + lhs, rhs, group_sizes, group_offset = res + + def _ref_lhs_rhs(lhs_, rhs_): + return _ragged_dot_ref(lhs_, rhs_, group_sizes, group_offset) + + (_, pullback) = jax.vjp(_ref_lhs_rhs, lhs, rhs) + d_lhs, d_rhs = pullback(g) + return d_lhs, d_rhs, None, None + + +ragged_dot.defvjp(_ragged_dot_fwd, _ragged_dot_bwd) From cf80c97724f5ab69285e0c23cd839670be1f36c3 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 17 Jan 2026 21:39:28 -0800 Subject: [PATCH 007/106] update --- skyrl-tx/tx/layers/util.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index 7629242a8..2ae28b702 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -122,7 +122,13 @@ def shard_map_ep(module: nnx.Module, func, *args): ) in_specs = (state_specs,) + (PartitionSpec(),) * len(args) - @jax.shard_map(mesh=get_abstract_mesh(), in_specs=in_specs, out_specs=PartitionSpec(), axis_names={"ep"}) + @jax.shard_map( + mesh=get_abstract_mesh(), + in_specs=in_specs, + out_specs=PartitionSpec(), + axis_names={"ep"}, + check_vma=False, + ) def _body(state, *fn_args): module_shard = nnx.merge(graphdef, state) return func(module_shard, *fn_args) From 7f1fe1d8796742f969152f39203243e1abf7c7ac Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 00:30:30 -0800 Subject: [PATCH 008/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 65 ++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 5150daf23..bc266d20f 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -96,6 +96,9 @@ ffi::Error RaggedDotCudaImpl( for (int32_t i = 0; i < g; ++i) { offsets[static_cast(i) + 1] = offsets[static_cast(i)] + sizes[i]; } + if (offsets[static_cast(g)] != m) { + return ffi::Error::InvalidArgument("group_sizes sum does not match lhs rows."); + } err = cudaMemsetAsync(out->typed_data(), 0, static_cast(m) * n * sizeof(Dtype), stream); if (err != cudaSuccess) { @@ -107,6 +110,43 @@ ffi::Error RaggedDotCudaImpl( } Gemm gemm; + size_t max_workspace = 0; + for (int32_t gi = offset; gi < offset + g_local; ++gi) { + int32_t rows = sizes[gi]; + if (rows == 0) { + continue; + } + + int32_t local = gi - offset; + const Dtype* A = reinterpret_cast(lhs.typed_data()) + + static_cast(offsets[gi]) * k; + const Dtype* B = reinterpret_cast(rhs.typed_data()) + + static_cast(local) * k * n; + Dtype* C = reinterpret_cast(out->typed_data()) + + static_cast(offsets[gi]) * n; + + Gemm::Arguments args( + {rows, n, k}, + {A, k}, + {B, n}, + {C, n}, + {C, n}, + {1.0f, 0.0f}); + + size_t workspace_size = Gemm::get_workspace_size(args); + if (workspace_size > max_workspace) { + max_workspace = workspace_size; + } + } + + void* workspace = nullptr; + if (max_workspace > 0) { + err = cudaMalloc(&workspace, max_workspace); + if (err != cudaSuccess) { + return CudaError("Failed to allocate CUTLASS workspace."); + } + } + for (int32_t gi = offset; gi < offset + g_local; ++gi) { int32_t rows = sizes[gi]; if (rows == 0) { @@ -129,12 +169,35 @@ ffi::Error RaggedDotCudaImpl( {C, n}, {1.0f, 0.0f}); - cutlass::Status status = gemm(args, stream); + cutlass::Status status = gemm.can_implement(args); if (status != cutlass::Status::kSuccess) { + if (workspace != nullptr) { + cudaFree(workspace); + } + return ffi::Error::Internal("cutlass cannot implement."); + } + + status = gemm.initialize(args, workspace); + if (status != cutlass::Status::kSuccess) { + if (workspace != nullptr) { + cudaFree(workspace); + } + return ffi::Error::Internal("cutlass cannot initialize."); + } + + status = gemm.run(stream); + if (status != cutlass::Status::kSuccess) { + if (workspace != nullptr) { + cudaFree(workspace); + } return ffi::Error::Internal("cutlass gemm failed."); } } + if (workspace != nullptr) { + cudaFree(workspace); + } + return ffi::Error::Success(); } From 1d9df092164088d2f3d237396bf93bef1761dce9 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 01:14:02 -0800 Subject: [PATCH 009/106] use grouped gemm --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 237 ++++++++++++++++++++---------- 1 file changed, 162 insertions(+), 75 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index bc266d20f..8ddfdf641 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -1,13 +1,14 @@ #include #include +#include #include #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" #include -#include +#include #include #include @@ -18,12 +19,59 @@ using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::RowMajor; using Accum = float; -using Gemm = cutlass::gemm::device::Gemm; +using Gemm = cutlass::gemm::device::GemmGrouped< + Dtype, + LayoutA, + Dtype, + LayoutB, + Dtype, + LayoutC, + Accum>; static ffi::Error CudaError(const char* message) { return ffi::Error::Internal(message); } +using Strides = std::array; + +__global__ void prepare_grouped_gemm_data( + const Dtype* A, + const Dtype* B, + Dtype* output, + const int32_t* offs, + int32_t group_count, + int32_t k, + int32_t n, + int64_t lda, + int64_t ldb_group, + int64_t ldout, + const Strides tensor_ShapeA, + Dtype** A_ptrs, + Dtype** B_ptrs, + Dtype** output_ptrs, + cutlass::gemm::GemmCoord* problem_sizes) { + int32_t tid = threadIdx.x; + if (tid >= group_count) { + return; + } + + int32_t start = tid == 0 ? 0 : offs[tid - 1]; + int32_t end = offs[tid]; + int32_t m = end - start; + if (m < 0) { + return; + } + + if (end > tensor_ShapeA[0]) { + return; + } + + A_ptrs[tid] = const_cast(A) + static_cast(start) * lda; + B_ptrs[tid] = const_cast(B) + static_cast(tid) * ldb_group; + output_ptrs[tid] = output + static_cast(start) * ldout; + problem_sizes[tid] = cutlass::gemm::GemmCoord(m, n, k); +} + ffi::Error RaggedDotCudaImpl( cudaStream_t stream, ffi::Buffer lhs, @@ -109,94 +157,133 @@ ffi::Error RaggedDotCudaImpl( return ffi::Error::Success(); } - Gemm gemm; - size_t max_workspace = 0; - for (int32_t gi = offset; gi < offset + g_local; ++gi) { - int32_t rows = sizes[gi]; - if (rows == 0) { - continue; - } + if (g_local > 1024) { + return ffi::Error::InvalidArgument("group_count must be <= 1024."); + } - int32_t local = gi - offset; - const Dtype* A = reinterpret_cast(lhs.typed_data()) + - static_cast(offsets[gi]) * k; - const Dtype* B = reinterpret_cast(rhs.typed_data()) + - static_cast(local) * k * n; - Dtype* C = reinterpret_cast(out->typed_data()) + - static_cast(offsets[gi]) * n; - - Gemm::Arguments args( - {rows, n, k}, - {A, k}, - {B, n}, - {C, n}, - {C, n}, - {1.0f, 0.0f}); - - size_t workspace_size = Gemm::get_workspace_size(args); - if (workspace_size > max_workspace) { - max_workspace = workspace_size; - } + int32_t shard_start = offsets[static_cast(offset)]; + + std::vector local_offs(static_cast(g_local)); + int32_t running = 0; + for (int32_t i = 0; i < g_local; ++i) { + running += sizes[offset + i]; + local_offs[static_cast(i)] = running; } - void* workspace = nullptr; - if (max_workspace > 0) { - err = cudaMalloc(&workspace, max_workspace); - if (err != cudaSuccess) { - return CudaError("Failed to allocate CUTLASS workspace."); - } + const Dtype* A_base = reinterpret_cast(lhs.typed_data()) + + static_cast(shard_start) * k; + const Dtype* B_base = reinterpret_cast(rhs.typed_data()); + Dtype* out_base = reinterpret_cast(out->typed_data()) + + static_cast(shard_start) * n; + + int64_t lda = k; + int64_t ldb_group = static_cast(k) * n; + int64_t ldout = n; + + Strides shapeA = {m, k, 1}; + + auto align_up = [](size_t value, size_t alignment) -> size_t { + return (value + alignment - 1) & ~(alignment - 1); + }; + + size_t bytes = 0; + bytes = align_up(bytes, 16); + size_t offs_offset = bytes; + bytes += sizeof(int32_t) * static_cast(g_local); + bytes = align_up(bytes, 16); + size_t A_ptrs_offset = bytes; + bytes += sizeof(Dtype*) * static_cast(g_local); + bytes = align_up(bytes, 16); + size_t B_ptrs_offset = bytes; + bytes += sizeof(Dtype*) * static_cast(g_local); + bytes = align_up(bytes, 16); + size_t out_ptrs_offset = bytes; + bytes += sizeof(Dtype*) * static_cast(g_local); + bytes = align_up(bytes, 16); + size_t problem_sizes_offset = bytes; + bytes += sizeof(cutlass::gemm::GemmCoord) * static_cast(g_local); + + void* slab = nullptr; + err = cudaMalloc(&slab, bytes); + if (err != cudaSuccess) { + return CudaError("Failed to allocate grouped GEMM slab."); } - for (int32_t gi = offset; gi < offset + g_local; ++gi) { - int32_t rows = sizes[gi]; - if (rows == 0) { - continue; - } + char* base = reinterpret_cast(slab); + int32_t* d_offs = reinterpret_cast(base + offs_offset); + Dtype** d_A_ptrs = reinterpret_cast(base + A_ptrs_offset); + Dtype** d_B_ptrs = reinterpret_cast(base + B_ptrs_offset); + Dtype** d_out_ptrs = reinterpret_cast(base + out_ptrs_offset); + cutlass::gemm::GemmCoord* d_problem_sizes = + reinterpret_cast(base + problem_sizes_offset); + + err = cudaMemcpyAsync( + d_offs, + local_offs.data(), + sizeof(int32_t) * g_local, + cudaMemcpyHostToDevice, + stream); + if (err != cudaSuccess) { + cudaFree(slab); + return CudaError("Failed to copy offs."); + } - int32_t local = gi - offset; - const Dtype* A = reinterpret_cast(lhs.typed_data()) + - static_cast(offsets[gi]) * k; - const Dtype* B = reinterpret_cast(rhs.typed_data()) + - static_cast(local) * k * n; - Dtype* C = reinterpret_cast(out->typed_data()) + - static_cast(offsets[gi]) * n; - - Gemm::Arguments args( - {rows, n, k}, - {A, k}, - {B, n}, - {C, n}, - {C, n}, - {1.0f, 0.0f}); - - cutlass::Status status = gemm.can_implement(args); - if (status != cutlass::Status::kSuccess) { - if (workspace != nullptr) { - cudaFree(workspace); - } - return ffi::Error::Internal("cutlass cannot implement."); - } + prepare_grouped_gemm_data<<<1, g_local, 0, stream>>>( + A_base, + B_base, + out_base, + d_offs, + g_local, + k, + n, + lda, + ldb_group, + ldout, + shapeA, + d_A_ptrs, + d_B_ptrs, + d_out_ptrs, + d_problem_sizes); - status = gemm.initialize(args, workspace); - if (status != cutlass::Status::kSuccess) { - if (workspace != nullptr) { - cudaFree(workspace); - } - return ffi::Error::Internal("cutlass cannot initialize."); + Gemm gemm; + typename Gemm::Arguments args( + d_problem_sizes, + g_local, + {d_A_ptrs, lda}, + {d_B_ptrs, n}, + {d_out_ptrs, ldout}, + {d_out_ptrs, ldout}, + {1.0f, 0.0f}); + + cutlass::Status status = gemm.can_implement(args); + if (status != cutlass::Status::kSuccess) { + cudaFree(slab); + return ffi::Error::Internal("cutlass cannot implement grouped gemm."); + } + + size_t workspace_size = gemm.get_workspace_size(args); + void* workspace = nullptr; + if (workspace_size > 0) { + err = cudaMalloc(&workspace, workspace_size); + if (err != cudaSuccess) { + cudaFree(slab); + return CudaError("Failed to allocate CUTLASS workspace."); } + } - status = gemm.run(stream); - if (status != cutlass::Status::kSuccess) { - if (workspace != nullptr) { - cudaFree(workspace); - } - return ffi::Error::Internal("cutlass gemm failed."); + status = gemm(args, workspace, stream); + if (status != cutlass::Status::kSuccess) { + if (workspace != nullptr) { + cudaFree(workspace); } + cudaFree(slab); + return ffi::Error::Internal("cutlass grouped gemm failed."); } if (workspace != nullptr) { cudaFree(workspace); } + cudaFree(slab); return ffi::Error::Success(); } From 527c1a0850e532e8355220e9915967198074052f Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 01:20:23 -0800 Subject: [PATCH 010/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 201 ++++++++++++++++++++++-------- 1 file changed, 146 insertions(+), 55 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 8ddfdf641..54e3c23ca 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -8,25 +8,84 @@ #include "xla/ffi/api/ffi.h" #include -#include #include #include +#include + +#include +#include +#include +#include +#include +#include +#include namespace ffi = xla::ffi; -using Dtype = cutlass::bfloat16_t; +using DtypeA = cutlass::bfloat16_t; +using DtypeB = cutlass::bfloat16_t; +using DtypeOutput = cutlass::bfloat16_t; +using DtypeAccum = float; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; -using LayoutC = cutlass::layout::RowMajor; -using Accum = float; -using Gemm = cutlass::gemm::device::GemmGrouped< - Dtype, - LayoutA, - Dtype, - LayoutB, - Dtype, - LayoutC, - Accum>; +using LayoutOutput = cutlass::layout::RowMajor; +constexpr int AlignmentA = 1; +constexpr int AlignmentB = 1; +constexpr int AlignmentOutput = 1; + +using ArchTag = cutlass::arch::Sm90; +using OperatorClass = cutlass::arch::OpClassTensorOp; +using TileShape = cute::Shape; +using ClusterShape = cute::Shape; +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; +using ProblemShape = cutlass::gemm::GroupProblemShape< + cute::Shape>; + +using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + DtypeAccum, + DtypeAccum, + void, + LayoutOutput*, + AlignmentOutput, + DtypeOutput, + LayoutOutput*, + AlignmentOutput, + EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + +using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + DtypeA, + LayoutA*, + AlignmentA, + DtypeB, + LayoutB*, + AlignmentB, + DtypeAccum, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideOutput = typename Gemm::GemmKernel::InternalStrideD; static ffi::Error CudaError(const char* message) { return ffi::Error::Internal(message); @@ -35,21 +94,24 @@ static ffi::Error CudaError(const char* message) { using Strides = std::array; __global__ void prepare_grouped_gemm_data( - const Dtype* A, - const Dtype* B, - Dtype* output, + const DtypeA* A, + const DtypeB* B, + DtypeOutput* output, const int32_t* offs, int32_t group_count, int32_t k, int32_t n, - int64_t lda, - int64_t ldb_group, - int64_t ldout, + const Strides tensor_StrideA, + const Strides tensor_StrideB, + const Strides tensor_StrideOutput, const Strides tensor_ShapeA, - Dtype** A_ptrs, - Dtype** B_ptrs, - Dtype** output_ptrs, - cutlass::gemm::GemmCoord* problem_sizes) { + DtypeA** A_ptrs, + DtypeB** B_ptrs, + DtypeOutput** output_ptrs, + StrideA* stride_A, + StrideB* stride_B, + StrideOutput* stride_output, + ProblemShape::UnderlyingProblemShape* problem_sizes) { int32_t tid = threadIdx.x; if (tid >= group_count) { return; @@ -66,10 +128,18 @@ __global__ void prepare_grouped_gemm_data( return; } - A_ptrs[tid] = const_cast(A) + static_cast(start) * lda; - B_ptrs[tid] = const_cast(B) + static_cast(tid) * ldb_group; - output_ptrs[tid] = output + static_cast(start) * ldout; - problem_sizes[tid] = cutlass::gemm::GemmCoord(m, n, k); + int64_t lda = tensor_StrideA[0]; + int64_t ldb = tensor_StrideB[1]; + int64_t ldoutput = tensor_StrideOutput[0]; + + A_ptrs[tid] = const_cast(A) + static_cast(start) * lda; + B_ptrs[tid] = const_cast(B) + static_cast(tid) * tensor_StrideB[0]; + output_ptrs[tid] = output + static_cast(start) * ldoutput; + problem_sizes[tid] = ProblemShape::UnderlyingProblemShape(m, n, k); + + stride_A[tid] = cutlass::make_cute_packed_stride(StrideA{}, {lda, lda, 1}); + stride_B[tid] = cutlass::make_cute_packed_stride(StrideB{}, {ldb, ldb, 1}); + stride_output[tid] = cutlass::make_cute_packed_stride(StrideOutput{}, {m, ldoutput, 1}); } ffi::Error RaggedDotCudaImpl( @@ -148,7 +218,8 @@ ffi::Error RaggedDotCudaImpl( return ffi::Error::InvalidArgument("group_sizes sum does not match lhs rows."); } - err = cudaMemsetAsync(out->typed_data(), 0, static_cast(m) * n * sizeof(Dtype), stream); + err = cudaMemsetAsync( + out->typed_data(), 0, static_cast(m) * n * sizeof(DtypeOutput), stream); if (err != cudaSuccess) { return CudaError("Failed to zero output."); } @@ -170,16 +241,15 @@ ffi::Error RaggedDotCudaImpl( local_offs[static_cast(i)] = running; } - const Dtype* A_base = reinterpret_cast(lhs.typed_data()) + + const DtypeA* A_base = reinterpret_cast(lhs.typed_data()) + static_cast(shard_start) * k; - const Dtype* B_base = reinterpret_cast(rhs.typed_data()); - Dtype* out_base = reinterpret_cast(out->typed_data()) + + const DtypeB* B_base = reinterpret_cast(rhs.typed_data()); + DtypeOutput* out_base = reinterpret_cast(out->typed_data()) + static_cast(shard_start) * n; - int64_t lda = k; - int64_t ldb_group = static_cast(k) * n; - int64_t ldout = n; - + Strides strideA = {k, 1, 1}; + Strides strideB = {static_cast(k) * n, n, 1}; + Strides strideOut = {n, 1, 1}; Strides shapeA = {m, k, 1}; auto align_up = [](size_t value, size_t alignment) -> size_t { @@ -192,16 +262,25 @@ ffi::Error RaggedDotCudaImpl( bytes += sizeof(int32_t) * static_cast(g_local); bytes = align_up(bytes, 16); size_t A_ptrs_offset = bytes; - bytes += sizeof(Dtype*) * static_cast(g_local); + bytes += sizeof(DtypeA*) * static_cast(g_local); bytes = align_up(bytes, 16); size_t B_ptrs_offset = bytes; - bytes += sizeof(Dtype*) * static_cast(g_local); + bytes += sizeof(DtypeB*) * static_cast(g_local); bytes = align_up(bytes, 16); size_t out_ptrs_offset = bytes; - bytes += sizeof(Dtype*) * static_cast(g_local); + bytes += sizeof(DtypeOutput*) * static_cast(g_local); + bytes = align_up(bytes, 16); + size_t stride_A_offset = bytes; + bytes += sizeof(StrideA) * static_cast(g_local); + bytes = align_up(bytes, 16); + size_t stride_B_offset = bytes; + bytes += sizeof(StrideB) * static_cast(g_local); + bytes = align_up(bytes, 16); + size_t stride_output_offset = bytes; + bytes += sizeof(StrideOutput) * static_cast(g_local); bytes = align_up(bytes, 16); size_t problem_sizes_offset = bytes; - bytes += sizeof(cutlass::gemm::GemmCoord) * static_cast(g_local); + bytes += sizeof(ProblemShape::UnderlyingProblemShape) * static_cast(g_local); void* slab = nullptr; err = cudaMalloc(&slab, bytes); @@ -211,11 +290,14 @@ ffi::Error RaggedDotCudaImpl( char* base = reinterpret_cast(slab); int32_t* d_offs = reinterpret_cast(base + offs_offset); - Dtype** d_A_ptrs = reinterpret_cast(base + A_ptrs_offset); - Dtype** d_B_ptrs = reinterpret_cast(base + B_ptrs_offset); - Dtype** d_out_ptrs = reinterpret_cast(base + out_ptrs_offset); - cutlass::gemm::GemmCoord* d_problem_sizes = - reinterpret_cast(base + problem_sizes_offset); + DtypeA** d_A_ptrs = reinterpret_cast(base + A_ptrs_offset); + DtypeB** d_B_ptrs = reinterpret_cast(base + B_ptrs_offset); + DtypeOutput** d_out_ptrs = reinterpret_cast(base + out_ptrs_offset); + StrideA* d_stride_A = reinterpret_cast(base + stride_A_offset); + StrideB* d_stride_B = reinterpret_cast(base + stride_B_offset); + StrideOutput* d_stride_output = reinterpret_cast(base + stride_output_offset); + ProblemShape::UnderlyingProblemShape* d_problem_sizes = + reinterpret_cast(base + problem_sizes_offset); err = cudaMemcpyAsync( d_offs, @@ -236,24 +318,27 @@ ffi::Error RaggedDotCudaImpl( g_local, k, n, - lda, - ldb_group, - ldout, + strideA, + strideB, + strideOut, shapeA, d_A_ptrs, d_B_ptrs, d_out_ptrs, + d_stride_A, + d_stride_B, + d_stride_output, d_problem_sizes); Gemm gemm; - typename Gemm::Arguments args( - d_problem_sizes, - g_local, - {d_A_ptrs, lda}, - {d_B_ptrs, n}, - {d_out_ptrs, ldout}, - {d_out_ptrs, ldout}, - {1.0f, 0.0f}); + typename Gemm::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {g_local, d_problem_sizes, nullptr}, + {(const DtypeA**)d_A_ptrs, d_stride_A, (const DtypeB**)d_B_ptrs, d_stride_B}, + {{}, nullptr, d_stride_output, d_out_ptrs, d_stride_output}}; + + args.epilogue.thread.alpha = 1.0f; + args.epilogue.thread.dAlpha = {cute::_0{}, cute::_0{}, 0}; cutlass::Status status = gemm.can_implement(args); if (status != cutlass::Status::kSuccess) { @@ -261,7 +346,13 @@ ffi::Error RaggedDotCudaImpl( return ffi::Error::Internal("cutlass cannot implement grouped gemm."); } - size_t workspace_size = gemm.get_workspace_size(args); + int device = 0; + cudaDeviceProp props; + if (cudaGetDevice(&device) == cudaSuccess && cudaGetDeviceProperties(&props, device) == cudaSuccess) { + args.hw_info.sm_count = props.multiProcessorCount; + } + + size_t workspace_size = Gemm::get_workspace_size(args); void* workspace = nullptr; if (workspace_size > 0) { err = cudaMalloc(&workspace, workspace_size); From 656756ff324163a638b9551c0ebb509f65a38857 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 09:22:46 +0000 Subject: [PATCH 011/106] update --- skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh index 68486d934..e6c12b596 100644 --- a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh +++ b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh @@ -36,6 +36,7 @@ mkdir -p "${OUT_DIR}" -Xcompiler -fPIC \ -I"${JAX_INCLUDE_DIR}" \ -I"${CUTLASS_DIR}/include" \ + -I"${CUTLASS_DIR}/tools/util/include/" \ "${SCRIPT_DIR}/ragged_dot_ffi.cu" \ -o "${OUT_DIR}/libragged_dot_ffi.so" From aee36c783c01760c2a3311fea564ea8e01cfe108 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 01:26:23 -0800 Subject: [PATCH 012/106] update --- skyrl-tx/tx/ffi/README.md | 1 + skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh | 2 ++ skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 9 ++++++--- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/skyrl-tx/tx/ffi/README.md b/skyrl-tx/tx/ffi/README.md index cd7efc877..681f34f9b 100644 --- a/skyrl-tx/tx/ffi/README.md +++ b/skyrl-tx/tx/ffi/README.md @@ -3,6 +3,7 @@ Build (Linux + CUDA) 1) Build the shared library (requires CUDA nvcc with C++17 support): ``` export CUTLASS_DIR=/path/to/cutlass +export NVCC_ARCH=sm_90a # for H100, adjust if needed tx/ffi/build_ragged_dot_ffi.sh ``` diff --git a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh index 68486d934..b590e284e 100644 --- a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh +++ b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh @@ -21,6 +21,7 @@ PY )" NVCC_BIN="${NVCC_BIN:-nvcc}" +NVCC_ARCH="${NVCC_ARCH:-sm_90a}" if ! command -v "${NVCC_BIN}" >/dev/null 2>&1; then echo "nvcc not found. Set NVCC_BIN or ensure CUDA is on PATH." >&2 exit 1 @@ -32,6 +33,7 @@ mkdir -p "${OUT_DIR}" "${NVCC_BIN}" \ -O3 \ -std=c++17 \ + -arch="${NVCC_ARCH}" \ -shared \ -Xcompiler -fPIC \ -I"${JAX_INCLUDE_DIR}" \ diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 54e3c23ca..37d4c3f21 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -8,11 +8,14 @@ #include "xla/ffi/api/ffi.h" #include +#include #include +#include #include #include #include +#include #include #include #include @@ -29,9 +32,9 @@ using DtypeAccum = float; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutOutput = cutlass::layout::RowMajor; -constexpr int AlignmentA = 1; -constexpr int AlignmentB = 1; -constexpr int AlignmentOutput = 1; +constexpr int AlignmentA = 8; +constexpr int AlignmentB = 8; +constexpr int AlignmentOutput = 8; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; From 6e9ead96a03ef7d50ea0b8d11de30f8d7db0793f Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 01:35:07 -0800 Subject: [PATCH 013/106] update --- skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh | 1 + skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 14 ++++++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh index fd40c6a1a..4feb2d730 100644 --- a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh +++ b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh @@ -34,6 +34,7 @@ mkdir -p "${OUT_DIR}" -O3 \ -std=c++17 \ -arch="${NVCC_ARCH}" \ + --expt-relaxed-constexpr \ -shared \ -Xcompiler -fPIC \ -I"${JAX_INCLUDE_DIR}" \ diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 37d4c3f21..209bf422b 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -8,10 +8,11 @@ #include "xla/ffi/api/ffi.h" #include -#include +#include #include -#include #include +#include +#include #include #include @@ -23,6 +24,10 @@ #include #include +#if !defined(CUTLASS_MAJOR) || CUTLASS_MAJOR < 3 +#error "This kernel requires CUTLASS >= 3.x (SM90 grouped GEMM)." +#endif + namespace ffi = xla::ffi; using DtypeA = cutlass::bfloat16_t; @@ -89,6 +94,7 @@ using Gemm = cutlass::gemm::device::GemmUniversalAdapter; using StrideA = typename Gemm::GemmKernel::InternalStrideA; using StrideB = typename Gemm::GemmKernel::InternalStrideB; using StrideOutput = typename Gemm::GemmKernel::InternalStrideD; +using ProblemShapeType = ProblemShape::UnderlyingProblemShape; static ffi::Error CudaError(const char* message) { return ffi::Error::Internal(message); @@ -114,7 +120,7 @@ __global__ void prepare_grouped_gemm_data( StrideA* stride_A, StrideB* stride_B, StrideOutput* stride_output, - ProblemShape::UnderlyingProblemShape* problem_sizes) { + ProblemShapeType* problem_sizes) { int32_t tid = threadIdx.x; if (tid >= group_count) { return; @@ -138,7 +144,7 @@ __global__ void prepare_grouped_gemm_data( A_ptrs[tid] = const_cast(A) + static_cast(start) * lda; B_ptrs[tid] = const_cast(B) + static_cast(tid) * tensor_StrideB[0]; output_ptrs[tid] = output + static_cast(start) * ldoutput; - problem_sizes[tid] = ProblemShape::UnderlyingProblemShape(m, n, k); + problem_sizes[tid] = ProblemShapeType(m, n, k); stride_A[tid] = cutlass::make_cute_packed_stride(StrideA{}, {lda, lda, 1}); stride_B[tid] = cutlass::make_cute_packed_stride(StrideB{}, {ldb, ldb, 1}); From 63914e7d769c3ca6389cbf272095878cac5eeb88 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 01:51:48 -0800 Subject: [PATCH 014/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 94 ++++++++----------------------- 1 file changed, 23 insertions(+), 71 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 209bf422b..bed420f26 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -106,7 +106,9 @@ __global__ void prepare_grouped_gemm_data( const DtypeA* A, const DtypeB* B, DtypeOutput* output, - const int32_t* offs, + const int32_t* group_sizes, + const int32_t* group_offset, + int32_t num_groups, int32_t group_count, int32_t k, int32_t n, @@ -126,14 +128,22 @@ __global__ void prepare_grouped_gemm_data( return; } - int32_t start = tid == 0 ? 0 : offs[tid - 1]; - int32_t end = offs[tid]; - int32_t m = end - start; + int32_t offset = group_offset[0]; + int32_t global = offset + tid; + if (global < 0 || global >= num_groups) { + return; + } + + int32_t start = 0; + for (int32_t i = 0; i < global; ++i) { + start += group_sizes[i]; + } + int32_t m = group_sizes[global]; if (m < 0) { return; } - if (end > tensor_ShapeA[0]) { + if (start + m > tensor_ShapeA[0]) { return; } @@ -189,43 +199,9 @@ ffi::Error RaggedDotCudaImpl( int32_t g = static_cast(g64); int32_t g_local = static_cast(g_local64); - int32_t offset = 0; - cudaError_t err = cudaMemcpyAsync( - &offset, group_offset.typed_data(), sizeof(int32_t), cudaMemcpyDeviceToHost, stream); - if (err != cudaSuccess) { - return CudaError("Failed to copy group_offset."); - } - - std::vector sizes(static_cast(g)); - if (g > 0) { - err = cudaMemcpyAsync( - sizes.data(), - group_sizes.typed_data(), - static_cast(g) * sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream); - if (err != cudaSuccess) { - return CudaError("Failed to copy group_sizes."); - } - } - - err = cudaStreamSynchronize(stream); - if (err != cudaSuccess) { - return CudaError("Failed to synchronize stream."); - } - - if (offset < 0 || offset > g || offset + g_local > g) { - return ffi::Error::InvalidArgument("group_offset out of range."); - } - - std::vector offsets(static_cast(g) + 1); - offsets[0] = 0; - for (int32_t i = 0; i < g; ++i) { - offsets[static_cast(i) + 1] = offsets[static_cast(i)] + sizes[i]; - } - if (offsets[static_cast(g)] != m) { - return ffi::Error::InvalidArgument("group_sizes sum does not match lhs rows."); - } + const int32_t* group_sizes_ptr = group_sizes.typed_data(); + const int32_t* group_offset_ptr = group_offset.typed_data(); + cudaError_t err; err = cudaMemsetAsync( out->typed_data(), 0, static_cast(m) * n * sizeof(DtypeOutput), stream); @@ -241,20 +217,9 @@ ffi::Error RaggedDotCudaImpl( return ffi::Error::InvalidArgument("group_count must be <= 1024."); } - int32_t shard_start = offsets[static_cast(offset)]; - - std::vector local_offs(static_cast(g_local)); - int32_t running = 0; - for (int32_t i = 0; i < g_local; ++i) { - running += sizes[offset + i]; - local_offs[static_cast(i)] = running; - } - - const DtypeA* A_base = reinterpret_cast(lhs.typed_data()) + - static_cast(shard_start) * k; + const DtypeA* A_base = reinterpret_cast(lhs.typed_data()); const DtypeB* B_base = reinterpret_cast(rhs.typed_data()); - DtypeOutput* out_base = reinterpret_cast(out->typed_data()) + - static_cast(shard_start) * n; + DtypeOutput* out_base = reinterpret_cast(out->typed_data()); Strides strideA = {k, 1, 1}; Strides strideB = {static_cast(k) * n, n, 1}; @@ -267,9 +232,6 @@ ffi::Error RaggedDotCudaImpl( size_t bytes = 0; bytes = align_up(bytes, 16); - size_t offs_offset = bytes; - bytes += sizeof(int32_t) * static_cast(g_local); - bytes = align_up(bytes, 16); size_t A_ptrs_offset = bytes; bytes += sizeof(DtypeA*) * static_cast(g_local); bytes = align_up(bytes, 16); @@ -298,7 +260,6 @@ ffi::Error RaggedDotCudaImpl( } char* base = reinterpret_cast(slab); - int32_t* d_offs = reinterpret_cast(base + offs_offset); DtypeA** d_A_ptrs = reinterpret_cast(base + A_ptrs_offset); DtypeB** d_B_ptrs = reinterpret_cast(base + B_ptrs_offset); DtypeOutput** d_out_ptrs = reinterpret_cast(base + out_ptrs_offset); @@ -308,22 +269,13 @@ ffi::Error RaggedDotCudaImpl( ProblemShape::UnderlyingProblemShape* d_problem_sizes = reinterpret_cast(base + problem_sizes_offset); - err = cudaMemcpyAsync( - d_offs, - local_offs.data(), - sizeof(int32_t) * g_local, - cudaMemcpyHostToDevice, - stream); - if (err != cudaSuccess) { - cudaFree(slab); - return CudaError("Failed to copy offs."); - } - prepare_grouped_gemm_data<<<1, g_local, 0, stream>>>( A_base, B_base, out_base, - d_offs, + group_sizes_ptr, + group_offset_ptr, + g, g_local, k, n, From f92d00dedd36040151d942b46fb9b79f46416e3e Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 14:27:15 -0800 Subject: [PATCH 015/106] update --- skyrl-tx/tx/layers/util.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index 2ae28b702..f489d792a 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -35,6 +35,11 @@ def ragged_dot( preferred_element_type=preferred_element_type, ) + # CUTLASS kernel requires k and n dimensions divisible by 8 + k = lhs.shape[-1] + n = rhs.shape[-1] + cutlass_alignment = 8 + if ( ragged_dot_ffi_available() and jax.default_backend() == "gpu" @@ -42,6 +47,8 @@ def ragged_dot( and rhs.dtype == jnp.bfloat16 and group_sizes.dtype == jnp.int32 and group_offset.dtype == jnp.int32 + and k % cutlass_alignment == 0 + and n % cutlass_alignment == 0 ): return ragged_dot_ffi(lhs, rhs, group_sizes, group_offset) From ad0bfeea63435f4bce435f8ad27f97ced94070d8 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 14:47:13 -0800 Subject: [PATCH 016/106] fix --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index bed420f26..368a7be71 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -254,7 +254,7 @@ ffi::Error RaggedDotCudaImpl( bytes += sizeof(ProblemShape::UnderlyingProblemShape) * static_cast(g_local); void* slab = nullptr; - err = cudaMalloc(&slab, bytes); + err = cudaMallocAsync(&slab, bytes, stream); if (err != cudaSuccess) { return CudaError("Failed to allocate grouped GEMM slab."); } @@ -303,7 +303,7 @@ ffi::Error RaggedDotCudaImpl( cutlass::Status status = gemm.can_implement(args); if (status != cutlass::Status::kSuccess) { - cudaFree(slab); + cudaFreeAsync(slab, stream); return ffi::Error::Internal("cutlass cannot implement grouped gemm."); } @@ -316,9 +316,9 @@ ffi::Error RaggedDotCudaImpl( size_t workspace_size = Gemm::get_workspace_size(args); void* workspace = nullptr; if (workspace_size > 0) { - err = cudaMalloc(&workspace, workspace_size); + err = cudaMallocAsync(&workspace, workspace_size, stream); if (err != cudaSuccess) { - cudaFree(slab); + cudaFreeAsync(slab, stream); return CudaError("Failed to allocate CUTLASS workspace."); } } @@ -326,16 +326,16 @@ ffi::Error RaggedDotCudaImpl( status = gemm(args, workspace, stream); if (status != cutlass::Status::kSuccess) { if (workspace != nullptr) { - cudaFree(workspace); + cudaFreeAsync(workspace, stream); } - cudaFree(slab); + cudaFreeAsync(slab, stream); return ffi::Error::Internal("cutlass grouped gemm failed."); } if (workspace != nullptr) { - cudaFree(workspace); + cudaFreeAsync(workspace, stream); } - cudaFree(slab); + cudaFreeAsync(slab, stream); return ffi::Error::Success(); } From 70cba867bf55879ab97bc603435d8fa8b54bcd6e Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 15:10:57 -0800 Subject: [PATCH 017/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 368a7be71..7814896d1 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -163,6 +163,7 @@ __global__ void prepare_grouped_gemm_data( ffi::Error RaggedDotCudaImpl( cudaStream_t stream, + ffi::ScratchAllocator scratch, ffi::Buffer lhs, ffi::Buffer rhs, ffi::Buffer group_sizes, @@ -253,11 +254,11 @@ ffi::Error RaggedDotCudaImpl( size_t problem_sizes_offset = bytes; bytes += sizeof(ProblemShape::UnderlyingProblemShape) * static_cast(g_local); - void* slab = nullptr; - err = cudaMallocAsync(&slab, bytes, stream); - if (err != cudaSuccess) { - return CudaError("Failed to allocate grouped GEMM slab."); + auto slab_or = scratch.Allocate(bytes); + if (!slab_or.has_value()) { + return ffi::Error::Internal("Failed to allocate grouped GEMM slab from scratch."); } + void* slab = slab_or.value(); char* base = reinterpret_cast(slab); DtypeA** d_A_ptrs = reinterpret_cast(base + A_ptrs_offset); @@ -303,7 +304,6 @@ ffi::Error RaggedDotCudaImpl( cutlass::Status status = gemm.can_implement(args); if (status != cutlass::Status::kSuccess) { - cudaFreeAsync(slab, stream); return ffi::Error::Internal("cutlass cannot implement grouped gemm."); } @@ -316,27 +316,18 @@ ffi::Error RaggedDotCudaImpl( size_t workspace_size = Gemm::get_workspace_size(args); void* workspace = nullptr; if (workspace_size > 0) { - err = cudaMallocAsync(&workspace, workspace_size, stream); - if (err != cudaSuccess) { - cudaFreeAsync(slab, stream); - return CudaError("Failed to allocate CUTLASS workspace."); + auto workspace_or = scratch.Allocate(workspace_size); + if (!workspace_or.has_value()) { + return ffi::Error::Internal("Failed to allocate CUTLASS workspace from scratch."); } + workspace = workspace_or.value(); } status = gemm(args, workspace, stream); if (status != cutlass::Status::kSuccess) { - if (workspace != nullptr) { - cudaFreeAsync(workspace, stream); - } - cudaFreeAsync(slab, stream); return ffi::Error::Internal("cutlass grouped gemm failed."); } - if (workspace != nullptr) { - cudaFreeAsync(workspace, stream); - } - cudaFreeAsync(slab, stream); - return ffi::Error::Success(); } @@ -345,6 +336,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( RaggedDotCudaImpl, ffi::Ffi::Bind() .Ctx>() + .Ctx() .Arg>() // lhs .Arg>() // rhs .Arg>() // group_sizes From 3f4dd251c81ca4a67e4a56dce38694ff748f65ba Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 15:40:24 -0800 Subject: [PATCH 018/106] optimize --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 55 ++++++++++++++----------------- skyrl-tx/tx/ffi/ragged_dot_ffi.py | 5 ++- 2 files changed, 29 insertions(+), 31 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 7814896d1..64f4d64da 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -107,15 +107,16 @@ __global__ void prepare_grouped_gemm_data( const DtypeB* B, DtypeOutput* output, const int32_t* group_sizes, - const int32_t* group_offset, + const int32_t* group_offsets_cumsum, // Precomputed cumsum of group_sizes + int32_t first_group_idx, int32_t num_groups, int32_t group_count, int32_t k, int32_t n, - const Strides tensor_StrideA, - const Strides tensor_StrideB, - const Strides tensor_StrideOutput, - const Strides tensor_ShapeA, + int64_t lda, + int64_t ldb, + int64_t ldoutput, + int32_t total_rows, DtypeA** A_ptrs, DtypeB** B_ptrs, DtypeOutput** output_ptrs, @@ -128,31 +129,20 @@ __global__ void prepare_grouped_gemm_data( return; } - int32_t offset = group_offset[0]; - int32_t global = offset + tid; + int32_t global = first_group_idx + tid; if (global < 0 || global >= num_groups) { return; } - int32_t start = 0; - for (int32_t i = 0; i < global; ++i) { - start += group_sizes[i]; - } + // Use precomputed cumsum: start = cumsum[global-1] if global > 0, else 0 + int32_t start = (global > 0) ? group_offsets_cumsum[global - 1] : 0; int32_t m = group_sizes[global]; - if (m < 0) { + if (m < 0 || start + m > total_rows) { return; } - if (start + m > tensor_ShapeA[0]) { - return; - } - - int64_t lda = tensor_StrideA[0]; - int64_t ldb = tensor_StrideB[1]; - int64_t ldoutput = tensor_StrideOutput[0]; - A_ptrs[tid] = const_cast(A) + static_cast(start) * lda; - B_ptrs[tid] = const_cast(B) + static_cast(tid) * tensor_StrideB[0]; + B_ptrs[tid] = const_cast(B) + static_cast(tid) * ldb * k; output_ptrs[tid] = output + static_cast(start) * ldoutput; problem_sizes[tid] = ProblemShapeType(m, n, k); @@ -168,6 +158,7 @@ ffi::Error RaggedDotCudaImpl( ffi::Buffer rhs, ffi::Buffer group_sizes, ffi::Buffer group_offset, + ffi::Buffer group_offsets_cumsum, ffi::ResultBuffer out) { auto lhs_dims = lhs.dimensions(); auto rhs_dims = rhs.dimensions(); @@ -221,11 +212,13 @@ ffi::Error RaggedDotCudaImpl( const DtypeA* A_base = reinterpret_cast(lhs.typed_data()); const DtypeB* B_base = reinterpret_cast(rhs.typed_data()); DtypeOutput* out_base = reinterpret_cast(out->typed_data()); + const int32_t* group_offsets_cumsum_ptr = group_offsets_cumsum.typed_data(); + int32_t first_group_idx = group_offset_ptr[0]; - Strides strideA = {k, 1, 1}; - Strides strideB = {static_cast(k) * n, n, 1}; - Strides strideOut = {n, 1, 1}; - Strides shapeA = {m, k, 1}; + // Strides for row-major layout + int64_t lda = k; + int64_t ldb = n; + int64_t ldoutput = n; auto align_up = [](size_t value, size_t alignment) -> size_t { return (value + alignment - 1) & ~(alignment - 1); @@ -275,15 +268,16 @@ ffi::Error RaggedDotCudaImpl( B_base, out_base, group_sizes_ptr, - group_offset_ptr, + group_offsets_cumsum_ptr, + first_group_idx, g, g_local, k, n, - strideA, - strideB, - strideOut, - shapeA, + lda, + ldb, + ldoutput, + m, d_A_ptrs, d_B_ptrs, d_out_ptrs, @@ -341,4 +335,5 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg>() // rhs .Arg>() // group_sizes .Arg>() // group_offset + .Arg>() // group_offsets_cumsum .Ret>()); // out diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.py b/skyrl-tx/tx/ffi/ragged_dot_ffi.py index e1a434996..513ef7b64 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.py +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.py @@ -105,9 +105,12 @@ def _ragged_dot_ffi_call( if group_offset.shape != (1,): raise ValueError("group_offset must have shape (1,).") + # Precompute cumulative sum to avoid O(n²) loop in CUDA kernel + group_offsets = jnp.cumsum(group_sizes, dtype=jnp.int32) + out = jax.ShapeDtypeStruct((lhs.shape[0], rhs.shape[2]), lhs.dtype) call = jax_ffi.ffi_call("ragged_dot_cuda", out, vmap_method=None) - return call(lhs, rhs, group_sizes, group_offset) + return call(lhs, rhs, group_sizes, group_offset, group_offsets) @jax.custom_vjp From 3f6669d0a42193884fdd7ce5ef0938f9397a53a5 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 16:10:27 -0800 Subject: [PATCH 019/106] fixes --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 64f4d64da..144892b97 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -108,7 +108,7 @@ __global__ void prepare_grouped_gemm_data( DtypeOutput* output, const int32_t* group_sizes, const int32_t* group_offsets_cumsum, // Precomputed cumsum of group_sizes - int32_t first_group_idx, + const int32_t* group_offset_ptr, // Device pointer to first group index int32_t num_groups, int32_t group_count, int32_t k, @@ -129,6 +129,7 @@ __global__ void prepare_grouped_gemm_data( return; } + int32_t first_group_idx = group_offset_ptr[0]; int32_t global = first_group_idx + tid; if (global < 0 || global >= num_groups) { return; @@ -213,7 +214,6 @@ ffi::Error RaggedDotCudaImpl( const DtypeB* B_base = reinterpret_cast(rhs.typed_data()); DtypeOutput* out_base = reinterpret_cast(out->typed_data()); const int32_t* group_offsets_cumsum_ptr = group_offsets_cumsum.typed_data(); - int32_t first_group_idx = group_offset_ptr[0]; // Strides for row-major layout int64_t lda = k; @@ -269,7 +269,7 @@ ffi::Error RaggedDotCudaImpl( out_base, group_sizes_ptr, group_offsets_cumsum_ptr, - first_group_idx, + group_offset_ptr, g, g_local, k, From 7b22f8620c771edc509591cb9da147b5f46747e2 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 16:26:36 -0800 Subject: [PATCH 020/106] optimize --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 33 ++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 144892b97..19e40bc36 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -30,6 +30,23 @@ namespace ffi = xla::ffi; +// Cache SM count per device to avoid repeated cudaGetDeviceProperties calls +static int g_sm_count[16] = {0}; + +static int get_sm_count() { + int device = 0; + if (cudaGetDevice(&device) != cudaSuccess || device < 0 || device >= 16) { + return 0; + } + if (g_sm_count[device] == 0) { + cudaDeviceProp props; + if (cudaGetDeviceProperties(&props, device) == cudaSuccess) { + g_sm_count[device] = props.multiProcessorCount; + } + } + return g_sm_count[device]; +} + using DtypeA = cutlass::bfloat16_t; using DtypeB = cutlass::bfloat16_t; using DtypeOutput = cutlass::bfloat16_t; @@ -124,13 +141,19 @@ __global__ void prepare_grouped_gemm_data( StrideB* stride_B, StrideOutput* stride_output, ProblemShapeType* problem_sizes) { + // Use shared memory to broadcast first_group_idx to all threads + __shared__ int32_t s_first_group_idx; + if (threadIdx.x == 0) { + s_first_group_idx = group_offset_ptr[0]; + } + __syncthreads(); + int32_t tid = threadIdx.x; if (tid >= group_count) { return; } - int32_t first_group_idx = group_offset_ptr[0]; - int32_t global = first_group_idx + tid; + int32_t global = s_first_group_idx + tid; if (global < 0 || global >= num_groups) { return; } @@ -301,11 +324,7 @@ ffi::Error RaggedDotCudaImpl( return ffi::Error::Internal("cutlass cannot implement grouped gemm."); } - int device = 0; - cudaDeviceProp props; - if (cudaGetDevice(&device) == cudaSuccess && cudaGetDeviceProperties(&props, device) == cudaSuccess) { - args.hw_info.sm_count = props.multiProcessorCount; - } + args.hw_info.sm_count = get_sm_count(); size_t workspace_size = Gemm::get_workspace_size(args); void* workspace = nullptr; From accff8e2e83b43652c46ad6cc46e13298e2e688f Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 16:38:48 -0800 Subject: [PATCH 021/106] try to use clusters --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 19e40bc36..a2d959695 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -61,7 +61,8 @@ constexpr int AlignmentOutput = 8; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; using TileShape = cute::Shape; -using ClusterShape = cute::Shape; +// Use 2x1x1 cluster on H100 for better L2 cache utilization +using ClusterShape = cute::Shape; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using ProblemShape = cutlass::gemm::GroupProblemShape< From f1fb36cef3c163a26dcbfe4401f80f730ea0bd4e Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 16:50:56 -0800 Subject: [PATCH 022/106] update schedule --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index a2d959695..c1ce952a3 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -63,8 +63,9 @@ using OperatorClass = cutlass::arch::OpClassTensorOp; using TileShape = cute::Shape; // Use 2x1x1 cluster on H100 for better L2 cache utilization using ClusterShape = cute::Shape; -using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; -using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; +// Cooperative schedule for better work distribution across thread blocks +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; using ProblemShape = cutlass::gemm::GroupProblemShape< cute::Shape>; From b1c48f4cdbc56ef9ce1885b6f4bfed6e571b1295 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 17:06:33 -0800 Subject: [PATCH 023/106] try tile size --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index c1ce952a3..80128c0ce 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -60,7 +60,9 @@ constexpr int AlignmentOutput = 8; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; -using TileShape = cute::Shape; +// Tuned for Qwen3-30B-A3B MoE: small M per group, K=768/2048, N=768/2048/lora_rank +// Smaller M tile (64) handles small groups better, K=64 for memory bandwidth +using TileShape = cute::Shape; // Use 2x1x1 cluster on H100 for better L2 cache utilization using ClusterShape = cute::Shape; // Cooperative schedule for better work distribution across thread blocks From 046a033aa3d548df6d64b10e859d4a7677b051fc Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 17:08:28 -0800 Subject: [PATCH 024/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 80128c0ce..0843a2a4b 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -63,11 +63,10 @@ using OperatorClass = cutlass::arch::OpClassTensorOp; // Tuned for Qwen3-30B-A3B MoE: small M per group, K=768/2048, N=768/2048/lora_rank // Smaller M tile (64) handles small groups better, K=64 for memory bandwidth using TileShape = cute::Shape; -// Use 2x1x1 cluster on H100 for better L2 cache utilization -using ClusterShape = cute::Shape; -// Cooperative schedule for better work distribution across thread blocks -using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; -using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; +// Use 1x1x1 cluster with pingpong schedule (cooperative requires M tile >= 128) +using ClusterShape = cute::Shape; +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using ProblemShape = cutlass::gemm::GroupProblemShape< cute::Shape>; From 5b14a8a011699cea20f4b4e982dc53b2bf27b474 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 17:26:28 -0800 Subject: [PATCH 025/106] optimize --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 217 ++++++++++++++++-------------- 1 file changed, 119 insertions(+), 98 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 0843a2a4b..8a2b6ccf1 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -30,7 +30,7 @@ namespace ffi = xla::ffi; -// Cache SM count per device to avoid repeated cudaGetDeviceProperties calls +// Cache SM count per device static int g_sm_count[16] = {0}; static int get_sm_count() { @@ -60,75 +60,81 @@ constexpr int AlignmentOutput = 8; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; -// Tuned for Qwen3-30B-A3B MoE: small M per group, K=768/2048, N=768/2048/lora_rank -// Smaller M tile (64) handles small groups better, K=64 for memory bandwidth -using TileShape = cute::Shape; -// Use 1x1x1 cluster with pingpong schedule (cooperative requires M tile >= 128) -using ClusterShape = cute::Shape; -using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; -using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; -using ProblemShape = cutlass::gemm::GroupProblemShape< - cute::Shape>; - -using CollectiveEpilogue = +using ProblemShape = cutlass::gemm::GroupProblemShape>; +using ProblemShapeType = ProblemShape::UnderlyingProblemShape; + +// ============================================================================ +// Kernel variant 1: Up/Gate projection (K=2048, N=768) - 64x128x64 tile +// ============================================================================ +using TileShape_UpGate = cute::Shape; +using ClusterShape_UpGate = cute::Shape; +using KernelSchedule_UpGate = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; +using EpilogueSchedule_UpGate = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + +using CollectiveEpilogue_UpGate = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - TileShape, - ClusterShape, + ArchTag, OperatorClass, TileShape_UpGate, ClusterShape_UpGate, cutlass::epilogue::collective::EpilogueTileAuto, - DtypeAccum, - DtypeAccum, - void, - LayoutOutput*, - AlignmentOutput, - DtypeOutput, - LayoutOutput*, - AlignmentOutput, - EpilogueSchedule, + DtypeAccum, DtypeAccum, void, LayoutOutput*, AlignmentOutput, + DtypeOutput, LayoutOutput*, AlignmentOutput, EpilogueSchedule_UpGate, cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; -using CollectiveMainloop = +using CollectiveMainloop_UpGate = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - DtypeA, - LayoutA*, - AlignmentA, - DtypeB, - LayoutB*, - AlignmentB, - DtypeAccum, - TileShape, - ClusterShape, + ArchTag, OperatorClass, DtypeA, LayoutA*, AlignmentA, + DtypeB, LayoutB*, AlignmentB, DtypeAccum, TileShape_UpGate, ClusterShape_UpGate, cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule>::CollectiveOp; - -using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - ProblemShape, - CollectiveMainloop, - CollectiveEpilogue>; - -using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -using StrideA = typename Gemm::GemmKernel::InternalStrideA; -using StrideB = typename Gemm::GemmKernel::InternalStrideB; -using StrideOutput = typename Gemm::GemmKernel::InternalStrideD; -using ProblemShapeType = ProblemShape::UnderlyingProblemShape; + static_cast(sizeof(typename CollectiveEpilogue_UpGate::SharedStorage))>, + KernelSchedule_UpGate>::CollectiveOp; + +using Gemm_UpGate = cutlass::gemm::device::GemmUniversalAdapter< + cutlass::gemm::kernel::GemmUniversal>; + +// ============================================================================ +// Kernel variant 2: Down projection (K=768, N=2048) - 64x256x64 tile for large N +// ============================================================================ +using TileShape_Down = cute::Shape; +using ClusterShape_Down = cute::Shape; +using KernelSchedule_Down = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; +using EpilogueSchedule_Down = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + +using CollectiveEpilogue_Down = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape_Down, ClusterShape_Down, + cutlass::epilogue::collective::EpilogueTileAuto, + DtypeAccum, DtypeAccum, void, LayoutOutput*, AlignmentOutput, + DtypeOutput, LayoutOutput*, AlignmentOutput, EpilogueSchedule_Down, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + +using CollectiveMainloop_Down = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, DtypeA, LayoutA*, AlignmentA, + DtypeB, LayoutB*, AlignmentB, DtypeAccum, TileShape_Down, ClusterShape_Down, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue_Down::SharedStorage))>, + KernelSchedule_Down>::CollectiveOp; + +using Gemm_Down = cutlass::gemm::device::GemmUniversalAdapter< + cutlass::gemm::kernel::GemmUniversal>; + +// ============================================================================ +// Common types +// ============================================================================ +using StrideA = typename Gemm_UpGate::GemmKernel::InternalStrideA; +using StrideB = typename Gemm_UpGate::GemmKernel::InternalStrideB; +using StrideOutput = typename Gemm_UpGate::GemmKernel::InternalStrideD; static ffi::Error CudaError(const char* message) { return ffi::Error::Internal(message); } -using Strides = std::array; - __global__ void prepare_grouped_gemm_data( const DtypeA* A, const DtypeB* B, DtypeOutput* output, const int32_t* group_sizes, - const int32_t* group_offsets_cumsum, // Precomputed cumsum of group_sizes - const int32_t* group_offset_ptr, // Device pointer to first group index + const int32_t* group_offsets_cumsum, + const int32_t* group_offset_ptr, int32_t num_groups, int32_t group_count, int32_t k, @@ -144,7 +150,6 @@ __global__ void prepare_grouped_gemm_data( StrideB* stride_B, StrideOutput* stride_output, ProblemShapeType* problem_sizes) { - // Use shared memory to broadcast first_group_idx to all threads __shared__ int32_t s_first_group_idx; if (threadIdx.x == 0) { s_first_group_idx = group_offset_ptr[0]; @@ -161,7 +166,6 @@ __global__ void prepare_grouped_gemm_data( return; } - // Use precomputed cumsum: start = cumsum[global-1] if global > 0, else 0 int32_t start = (global > 0) ? group_offsets_cumsum[global - 1] : 0; int32_t m = group_sizes[global]; if (m < 0 || start + m > total_rows) { @@ -178,6 +182,43 @@ __global__ void prepare_grouped_gemm_data( stride_output[tid] = cutlass::make_cute_packed_stride(StrideOutput{}, {m, ldoutput, 1}); } +template +cutlass::Status run_gemm( + int32_t g_local, + ProblemShapeType* d_problem_sizes, + const DtypeA** d_A_ptrs, StrideA* d_stride_A, + const DtypeB** d_B_ptrs, StrideB* d_stride_B, + StrideOutput* d_stride_output, DtypeOutput** d_out_ptrs, + void* workspace, cudaStream_t stream) { + + GemmType gemm; + typename GemmType::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {g_local, d_problem_sizes, nullptr}, + {d_A_ptrs, d_stride_A, d_B_ptrs, d_stride_B}, + {{}, nullptr, d_stride_output, d_out_ptrs, d_stride_output}}; + + args.epilogue.thread.alpha = 1.0f; + args.epilogue.thread.dAlpha = {cute::_0{}, cute::_0{}, 0}; + args.hw_info.sm_count = get_sm_count(); + + cutlass::Status status = gemm.can_implement(args); + if (status != cutlass::Status::kSuccess) { + return status; + } + return gemm(args, workspace, stream); +} + +template +size_t get_workspace_size(int32_t g_local, ProblemShapeType* d_problem_sizes) { + typename GemmType::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {g_local, d_problem_sizes, nullptr}, + {nullptr, nullptr, nullptr, nullptr}, + {{}, nullptr, nullptr, nullptr, nullptr}}; + return GemmType::get_workspace_size(args); +} + ffi::Error RaggedDotCudaImpl( cudaStream_t stream, ffi::ScratchAllocator scratch, @@ -241,7 +282,6 @@ ffi::Error RaggedDotCudaImpl( DtypeOutput* out_base = reinterpret_cast(out->typed_data()); const int32_t* group_offsets_cumsum_ptr = group_offsets_cumsum.typed_data(); - // Strides for row-major layout int64_t lda = k; int64_t ldb = n; int64_t ldoutput = n; @@ -271,7 +311,7 @@ ffi::Error RaggedDotCudaImpl( bytes += sizeof(StrideOutput) * static_cast(g_local); bytes = align_up(bytes, 16); size_t problem_sizes_offset = bytes; - bytes += sizeof(ProblemShape::UnderlyingProblemShape) * static_cast(g_local); + bytes += sizeof(ProblemShapeType) * static_cast(g_local); auto slab_or = scratch.Allocate(bytes); if (!slab_or.has_value()) { @@ -286,50 +326,22 @@ ffi::Error RaggedDotCudaImpl( StrideA* d_stride_A = reinterpret_cast(base + stride_A_offset); StrideB* d_stride_B = reinterpret_cast(base + stride_B_offset); StrideOutput* d_stride_output = reinterpret_cast(base + stride_output_offset); - ProblemShape::UnderlyingProblemShape* d_problem_sizes = - reinterpret_cast(base + problem_sizes_offset); + ProblemShapeType* d_problem_sizes = reinterpret_cast(base + problem_sizes_offset); prepare_grouped_gemm_data<<<1, g_local, 0, stream>>>( - A_base, - B_base, - out_base, - group_sizes_ptr, - group_offsets_cumsum_ptr, - group_offset_ptr, - g, - g_local, - k, - n, - lda, - ldb, - ldoutput, - m, - d_A_ptrs, - d_B_ptrs, - d_out_ptrs, - d_stride_A, - d_stride_B, - d_stride_output, - d_problem_sizes); - - Gemm gemm; - typename Gemm::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGrouped, - {g_local, d_problem_sizes, nullptr}, - {(const DtypeA**)d_A_ptrs, d_stride_A, (const DtypeB**)d_B_ptrs, d_stride_B}, - {{}, nullptr, d_stride_output, d_out_ptrs, d_stride_output}}; + A_base, B_base, out_base, + group_sizes_ptr, group_offsets_cumsum_ptr, group_offset_ptr, + g, g_local, k, n, lda, ldb, ldoutput, m, + d_A_ptrs, d_B_ptrs, d_out_ptrs, + d_stride_A, d_stride_B, d_stride_output, d_problem_sizes); - args.epilogue.thread.alpha = 1.0f; - args.epilogue.thread.dAlpha = {cute::_0{}, cute::_0{}, 0}; + // Select kernel: N >= 1024 uses LargeN tile (down projection), otherwise UpGate tile + bool use_large_n = (n >= 1024); - cutlass::Status status = gemm.can_implement(args); - if (status != cutlass::Status::kSuccess) { - return ffi::Error::Internal("cutlass cannot implement grouped gemm."); - } + size_t workspace_size = use_large_n + ? get_workspace_size(g_local, d_problem_sizes) + : get_workspace_size(g_local, d_problem_sizes); - args.hw_info.sm_count = get_sm_count(); - - size_t workspace_size = Gemm::get_workspace_size(args); void* workspace = nullptr; if (workspace_size > 0) { auto workspace_or = scratch.Allocate(workspace_size); @@ -339,7 +351,16 @@ ffi::Error RaggedDotCudaImpl( workspace = workspace_or.value(); } - status = gemm(args, workspace, stream); + cutlass::Status status = use_large_n + ? run_gemm(g_local, d_problem_sizes, + (const DtypeA**)d_A_ptrs, d_stride_A, + (const DtypeB**)d_B_ptrs, d_stride_B, + d_stride_output, d_out_ptrs, workspace, stream) + : run_gemm(g_local, d_problem_sizes, + (const DtypeA**)d_A_ptrs, d_stride_A, + (const DtypeB**)d_B_ptrs, d_stride_B, + d_stride_output, d_out_ptrs, workspace, stream); + if (status != cutlass::Status::kSuccess) { return ffi::Error::Internal("cutlass grouped gemm failed."); } From 23a74e544f1c4db4dac97634c9110ca34d7e5341 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 17:38:52 -0800 Subject: [PATCH 026/106] optimize --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 8a2b6ccf1..f5b5fba78 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -64,9 +64,10 @@ using ProblemShape = cutlass::gemm::GroupProblemShape; +using TileShape_UpGate = cute::Shape; using ClusterShape_UpGate = cute::Shape; using KernelSchedule_UpGate = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; using EpilogueSchedule_UpGate = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; @@ -91,9 +92,10 @@ using Gemm_UpGate = cutlass::gemm::device::GemmUniversalAdapter< cutlass::gemm::kernel::GemmUniversal>; // ============================================================================ -// Kernel variant 2: Down projection (K=768, N=2048) - 64x256x64 tile for large N +// Kernel variant 2: Down projection (K=768, N=2048) - 64x256x128 tile +// Larger K tile (128) + larger N tile (256) for large N output // ============================================================================ -using TileShape_Down = cute::Shape; +using TileShape_Down = cute::Shape; using ClusterShape_Down = cute::Shape; using KernelSchedule_Down = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; using EpilogueSchedule_Down = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; From 4c86409a9296eab08a6b8475786819653a0b8ee5 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 17:52:30 -0800 Subject: [PATCH 027/106] simplify --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 141 +++++++----------------------- 1 file changed, 34 insertions(+), 107 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index f5b5fba78..9fe65ebd8 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -60,71 +60,34 @@ constexpr int AlignmentOutput = 8; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; +using TileShape = cute::Shape; +using ClusterShape = cute::Shape; +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using ProblemShape = cutlass::gemm::GroupProblemShape>; -using ProblemShapeType = ProblemShape::UnderlyingProblemShape; - -// ============================================================================ -// Kernel variant 1: Up/Gate projection (K=2048, N=768) - 64x128x128 tile -// Larger K tile (128) for better compute intensity with large K -// ============================================================================ -using TileShape_UpGate = cute::Shape; -using ClusterShape_UpGate = cute::Shape; -using KernelSchedule_UpGate = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; -using EpilogueSchedule_UpGate = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; -using CollectiveEpilogue_UpGate = +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape_UpGate, ClusterShape_UpGate, + ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, DtypeAccum, DtypeAccum, void, LayoutOutput*, AlignmentOutput, - DtypeOutput, LayoutOutput*, AlignmentOutput, EpilogueSchedule_UpGate, + DtypeOutput, LayoutOutput*, AlignmentOutput, EpilogueSchedule, cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; -using CollectiveMainloop_UpGate = +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, DtypeA, LayoutA*, AlignmentA, - DtypeB, LayoutB*, AlignmentB, DtypeAccum, TileShape_UpGate, ClusterShape_UpGate, + DtypeB, LayoutB*, AlignmentB, DtypeAccum, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue_UpGate::SharedStorage))>, - KernelSchedule_UpGate>::CollectiveOp; - -using Gemm_UpGate = cutlass::gemm::device::GemmUniversalAdapter< - cutlass::gemm::kernel::GemmUniversal>; - -// ============================================================================ -// Kernel variant 2: Down projection (K=768, N=2048) - 64x256x128 tile -// Larger K tile (128) + larger N tile (256) for large N output -// ============================================================================ -using TileShape_Down = cute::Shape; -using ClusterShape_Down = cute::Shape; -using KernelSchedule_Down = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; -using EpilogueSchedule_Down = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; - -using CollectiveEpilogue_Down = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape_Down, ClusterShape_Down, - cutlass::epilogue::collective::EpilogueTileAuto, - DtypeAccum, DtypeAccum, void, LayoutOutput*, AlignmentOutput, - DtypeOutput, LayoutOutput*, AlignmentOutput, EpilogueSchedule_Down, - cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; - -using CollectiveMainloop_Down = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, DtypeA, LayoutA*, AlignmentA, - DtypeB, LayoutB*, AlignmentB, DtypeAccum, TileShape_Down, ClusterShape_Down, - cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue_Down::SharedStorage))>, - KernelSchedule_Down>::CollectiveOp; - -using Gemm_Down = cutlass::gemm::device::GemmUniversalAdapter< - cutlass::gemm::kernel::GemmUniversal>; - -// ============================================================================ -// Common types -// ============================================================================ -using StrideA = typename Gemm_UpGate::GemmKernel::InternalStrideA; -using StrideB = typename Gemm_UpGate::GemmKernel::InternalStrideB; -using StrideOutput = typename Gemm_UpGate::GemmKernel::InternalStrideD; + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal; +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideOutput = typename Gemm::GemmKernel::InternalStrideD; +using ProblemShapeType = ProblemShape::UnderlyingProblemShape; static ffi::Error CudaError(const char* message) { return ffi::Error::Internal(message); @@ -184,43 +147,6 @@ __global__ void prepare_grouped_gemm_data( stride_output[tid] = cutlass::make_cute_packed_stride(StrideOutput{}, {m, ldoutput, 1}); } -template -cutlass::Status run_gemm( - int32_t g_local, - ProblemShapeType* d_problem_sizes, - const DtypeA** d_A_ptrs, StrideA* d_stride_A, - const DtypeB** d_B_ptrs, StrideB* d_stride_B, - StrideOutput* d_stride_output, DtypeOutput** d_out_ptrs, - void* workspace, cudaStream_t stream) { - - GemmType gemm; - typename GemmType::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGrouped, - {g_local, d_problem_sizes, nullptr}, - {d_A_ptrs, d_stride_A, d_B_ptrs, d_stride_B}, - {{}, nullptr, d_stride_output, d_out_ptrs, d_stride_output}}; - - args.epilogue.thread.alpha = 1.0f; - args.epilogue.thread.dAlpha = {cute::_0{}, cute::_0{}, 0}; - args.hw_info.sm_count = get_sm_count(); - - cutlass::Status status = gemm.can_implement(args); - if (status != cutlass::Status::kSuccess) { - return status; - } - return gemm(args, workspace, stream); -} - -template -size_t get_workspace_size(int32_t g_local, ProblemShapeType* d_problem_sizes) { - typename GemmType::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGrouped, - {g_local, d_problem_sizes, nullptr}, - {nullptr, nullptr, nullptr, nullptr}, - {{}, nullptr, nullptr, nullptr, nullptr}}; - return GemmType::get_workspace_size(args); -} - ffi::Error RaggedDotCudaImpl( cudaStream_t stream, ffi::ScratchAllocator scratch, @@ -337,13 +263,23 @@ ffi::Error RaggedDotCudaImpl( d_A_ptrs, d_B_ptrs, d_out_ptrs, d_stride_A, d_stride_B, d_stride_output, d_problem_sizes); - // Select kernel: N >= 1024 uses LargeN tile (down projection), otherwise UpGate tile - bool use_large_n = (n >= 1024); + Gemm gemm; + typename Gemm::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {g_local, d_problem_sizes, nullptr}, + {(const DtypeA**)d_A_ptrs, d_stride_A, (const DtypeB**)d_B_ptrs, d_stride_B}, + {{}, nullptr, d_stride_output, d_out_ptrs, d_stride_output}}; + + args.epilogue.thread.alpha = 1.0f; + args.epilogue.thread.dAlpha = {cute::_0{}, cute::_0{}, 0}; + args.hw_info.sm_count = get_sm_count(); - size_t workspace_size = use_large_n - ? get_workspace_size(g_local, d_problem_sizes) - : get_workspace_size(g_local, d_problem_sizes); + cutlass::Status status = gemm.can_implement(args); + if (status != cutlass::Status::kSuccess) { + return ffi::Error::Internal("cutlass cannot implement grouped gemm."); + } + size_t workspace_size = Gemm::get_workspace_size(args); void* workspace = nullptr; if (workspace_size > 0) { auto workspace_or = scratch.Allocate(workspace_size); @@ -353,16 +289,7 @@ ffi::Error RaggedDotCudaImpl( workspace = workspace_or.value(); } - cutlass::Status status = use_large_n - ? run_gemm(g_local, d_problem_sizes, - (const DtypeA**)d_A_ptrs, d_stride_A, - (const DtypeB**)d_B_ptrs, d_stride_B, - d_stride_output, d_out_ptrs, workspace, stream) - : run_gemm(g_local, d_problem_sizes, - (const DtypeA**)d_A_ptrs, d_stride_A, - (const DtypeB**)d_B_ptrs, d_stride_B, - d_stride_output, d_out_ptrs, workspace, stream); - + status = gemm(args, workspace, stream); if (status != cutlass::Status::kSuccess) { return ffi::Error::Internal("cutlass grouped gemm failed."); } From 2dcce2069be8eea13771a94dd76942c8a8902efc Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 18:04:14 -0800 Subject: [PATCH 028/106] simplify --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 103 +++++++----------------------- 1 file changed, 22 insertions(+), 81 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 9fe65ebd8..92df85878 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -2,7 +2,6 @@ #include #include -#include #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" @@ -11,13 +10,9 @@ #include #include #include -#include -#include #include #include -#include -#include #include #include #include @@ -100,14 +95,8 @@ __global__ void prepare_grouped_gemm_data( const int32_t* group_sizes, const int32_t* group_offsets_cumsum, const int32_t* group_offset_ptr, - int32_t num_groups, - int32_t group_count, int32_t k, int32_t n, - int64_t lda, - int64_t ldb, - int64_t ldoutput, - int32_t total_rows, DtypeA** A_ptrs, DtypeB** B_ptrs, DtypeOutput** output_ptrs, @@ -115,36 +104,20 @@ __global__ void prepare_grouped_gemm_data( StrideB* stride_B, StrideOutput* stride_output, ProblemShapeType* problem_sizes) { - __shared__ int32_t s_first_group_idx; - if (threadIdx.x == 0) { - s_first_group_idx = group_offset_ptr[0]; - } - __syncthreads(); - int32_t tid = threadIdx.x; - if (tid >= group_count) { - return; - } - - int32_t global = s_first_group_idx + tid; - if (global < 0 || global >= num_groups) { - return; - } + int32_t global = group_offset_ptr[0] + tid; int32_t start = (global > 0) ? group_offsets_cumsum[global - 1] : 0; int32_t m = group_sizes[global]; - if (m < 0 || start + m > total_rows) { - return; - } - A_ptrs[tid] = const_cast(A) + static_cast(start) * lda; - B_ptrs[tid] = const_cast(B) + static_cast(tid) * ldb * k; - output_ptrs[tid] = output + static_cast(start) * ldoutput; + A_ptrs[tid] = const_cast(A) + static_cast(start) * k; + B_ptrs[tid] = const_cast(B) + static_cast(tid) * n * k; + output_ptrs[tid] = output + static_cast(start) * n; problem_sizes[tid] = ProblemShapeType(m, n, k); - stride_A[tid] = cutlass::make_cute_packed_stride(StrideA{}, {lda, lda, 1}); - stride_B[tid] = cutlass::make_cute_packed_stride(StrideB{}, {ldb, ldb, 1}); - stride_output[tid] = cutlass::make_cute_packed_stride(StrideOutput{}, {m, ldoutput, 1}); + stride_A[tid] = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); + stride_B[tid] = cutlass::make_cute_packed_stride(StrideB{}, {k, n, 1}); + stride_output[tid] = cutlass::make_cute_packed_stride(StrideOutput{}, {m, n, 1}); } ffi::Error RaggedDotCudaImpl( @@ -171,20 +144,14 @@ ffi::Error RaggedDotCudaImpl( int64_t g_local64 = rhs_dims[0]; int64_t rhs_k64 = rhs_dims[1]; int64_t n64 = rhs_dims[2]; - int64_t g64 = group_sizes_dims[0]; if (k64 != rhs_k64) { return ffi::Error::InvalidArgument("lhs/rhs K dimension mismatch."); } - if (m64 < 0 || k64 < 0 || n64 < 0 || g64 < 0 || g_local64 < 0) { - return ffi::Error::InvalidArgument("Invalid dimensions."); - } - int32_t m = static_cast(m64); int32_t k = static_cast(k64); int32_t n = static_cast(n64); - int32_t g = static_cast(g64); int32_t g_local = static_cast(g_local64); const int32_t* group_sizes_ptr = group_sizes.typed_data(); @@ -210,56 +177,30 @@ ffi::Error RaggedDotCudaImpl( DtypeOutput* out_base = reinterpret_cast(out->typed_data()); const int32_t* group_offsets_cumsum_ptr = group_offsets_cumsum.typed_data(); - int64_t lda = k; - int64_t ldb = n; - int64_t ldoutput = n; - - auto align_up = [](size_t value, size_t alignment) -> size_t { - return (value + alignment - 1) & ~(alignment - 1); - }; - - size_t bytes = 0; - bytes = align_up(bytes, 16); - size_t A_ptrs_offset = bytes; - bytes += sizeof(DtypeA*) * static_cast(g_local); - bytes = align_up(bytes, 16); - size_t B_ptrs_offset = bytes; - bytes += sizeof(DtypeB*) * static_cast(g_local); - bytes = align_up(bytes, 16); - size_t out_ptrs_offset = bytes; - bytes += sizeof(DtypeOutput*) * static_cast(g_local); - bytes = align_up(bytes, 16); - size_t stride_A_offset = bytes; - bytes += sizeof(StrideA) * static_cast(g_local); - bytes = align_up(bytes, 16); - size_t stride_B_offset = bytes; - bytes += sizeof(StrideB) * static_cast(g_local); - bytes = align_up(bytes, 16); - size_t stride_output_offset = bytes; - bytes += sizeof(StrideOutput) * static_cast(g_local); - bytes = align_up(bytes, 16); - size_t problem_sizes_offset = bytes; - bytes += sizeof(ProblemShapeType) * static_cast(g_local); + size_t gl = static_cast(g_local); + size_t bytes = 7 * 16 + // alignment padding + sizeof(DtypeA*) * gl + sizeof(DtypeB*) * gl + sizeof(DtypeOutput*) * gl + + sizeof(StrideA) * gl + sizeof(StrideB) * gl + sizeof(StrideOutput) * gl + + sizeof(ProblemShapeType) * gl; auto slab_or = scratch.Allocate(bytes); if (!slab_or.has_value()) { return ffi::Error::Internal("Failed to allocate grouped GEMM slab from scratch."); } - void* slab = slab_or.value(); - char* base = reinterpret_cast(slab); - DtypeA** d_A_ptrs = reinterpret_cast(base + A_ptrs_offset); - DtypeB** d_B_ptrs = reinterpret_cast(base + B_ptrs_offset); - DtypeOutput** d_out_ptrs = reinterpret_cast(base + out_ptrs_offset); - StrideA* d_stride_A = reinterpret_cast(base + stride_A_offset); - StrideB* d_stride_B = reinterpret_cast(base + stride_B_offset); - StrideOutput* d_stride_output = reinterpret_cast(base + stride_output_offset); - ProblemShapeType* d_problem_sizes = reinterpret_cast(base + problem_sizes_offset); + auto align16 = [](char*& p) { p = reinterpret_cast((reinterpret_cast(p) + 15) & ~15); }; + char* p = reinterpret_cast(slab_or.value()); + align16(p); DtypeA** d_A_ptrs = reinterpret_cast(p); p += sizeof(DtypeA*) * gl; + align16(p); DtypeB** d_B_ptrs = reinterpret_cast(p); p += sizeof(DtypeB*) * gl; + align16(p); DtypeOutput** d_out_ptrs = reinterpret_cast(p); p += sizeof(DtypeOutput*) * gl; + align16(p); StrideA* d_stride_A = reinterpret_cast(p); p += sizeof(StrideA) * gl; + align16(p); StrideB* d_stride_B = reinterpret_cast(p); p += sizeof(StrideB) * gl; + align16(p); StrideOutput* d_stride_output = reinterpret_cast(p); p += sizeof(StrideOutput) * gl; + align16(p); ProblemShapeType* d_problem_sizes = reinterpret_cast(p); prepare_grouped_gemm_data<<<1, g_local, 0, stream>>>( A_base, B_base, out_base, - group_sizes_ptr, group_offsets_cumsum_ptr, group_offset_ptr, - g, g_local, k, n, lda, ldb, ldoutput, m, + group_sizes_ptr, group_offsets_cumsum_ptr, group_offset_ptr, k, n, d_A_ptrs, d_B_ptrs, d_out_ptrs, d_stride_A, d_stride_B, d_stride_output, d_problem_sizes); From e731efe21a602e9dbf41340908241ff712d57eed Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 18:13:13 -0800 Subject: [PATCH 029/106] add lto --- skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh index 4feb2d730..6806df003 100644 --- a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh +++ b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh @@ -35,6 +35,7 @@ mkdir -p "${OUT_DIR}" -std=c++17 \ -arch="${NVCC_ARCH}" \ --expt-relaxed-constexpr \ + -dlto \ -shared \ -Xcompiler -fPIC \ -I"${JAX_INCLUDE_DIR}" \ From 8ce60ceefe721eda3d8dc2a6fd17eb4055b538f5 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 18:19:02 -0800 Subject: [PATCH 030/106] fix --- skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh index 6806df003..01d2bfcf1 100644 --- a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh +++ b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh @@ -21,7 +21,7 @@ PY )" NVCC_BIN="${NVCC_BIN:-nvcc}" -NVCC_ARCH="${NVCC_ARCH:-sm_90a}" +NVCC_ARCH="${NVCC_ARCH:-90a}" if ! command -v "${NVCC_BIN}" >/dev/null 2>&1; then echo "nvcc not found. Set NVCC_BIN or ensure CUDA is on PATH." >&2 exit 1 @@ -33,7 +33,7 @@ mkdir -p "${OUT_DIR}" "${NVCC_BIN}" \ -O3 \ -std=c++17 \ - -arch="${NVCC_ARCH}" \ + -gencode=arch=compute_${NVCC_ARCH},code=sm_${NVCC_ARCH} \ --expt-relaxed-constexpr \ -dlto \ -shared \ From 2a005e297ec71f734c8159242299221deb8658df Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 18:19:43 -0800 Subject: [PATCH 031/106] update --- skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh index 01d2bfcf1..116fd102e 100644 --- a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh +++ b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh @@ -33,9 +33,9 @@ mkdir -p "${OUT_DIR}" "${NVCC_BIN}" \ -O3 \ -std=c++17 \ + -gencode=arch=compute_${NVCC_ARCH},code=lto_${NVCC_ARCH} \ -gencode=arch=compute_${NVCC_ARCH},code=sm_${NVCC_ARCH} \ --expt-relaxed-constexpr \ - -dlto \ -shared \ -Xcompiler -fPIC \ -I"${JAX_INCLUDE_DIR}" \ From 083c15021f05480ab150fa21dee30b351c508479 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 18:22:43 -0800 Subject: [PATCH 032/106] update --- skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh index 116fd102e..0a502f963 100644 --- a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh +++ b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh @@ -33,8 +33,7 @@ mkdir -p "${OUT_DIR}" "${NVCC_BIN}" \ -O3 \ -std=c++17 \ - -gencode=arch=compute_${NVCC_ARCH},code=lto_${NVCC_ARCH} \ - -gencode=arch=compute_${NVCC_ARCH},code=sm_${NVCC_ARCH} \ + -arch=sm_${NVCC_ARCH} \ --expt-relaxed-constexpr \ -shared \ -Xcompiler -fPIC \ From 7370e3940025da3cac45b663c2fc0435a6b51d03 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 18:25:57 -0800 Subject: [PATCH 033/106] add flags --- skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh index 0a502f963..5079c6fd8 100644 --- a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh +++ b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh @@ -35,6 +35,9 @@ mkdir -p "${OUT_DIR}" -std=c++17 \ -arch=sm_${NVCC_ARCH} \ --expt-relaxed-constexpr \ + -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 \ + -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED \ + -DCUTLASS_ENABLE_GDC_FOR_SM90=1 \ -shared \ -Xcompiler -fPIC \ -I"${JAX_INCLUDE_DIR}" \ From f2f7e6c8aeda778d9e3db683f4c55df9d17572d1 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 18:48:19 -0800 Subject: [PATCH 034/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 92df85878..d776ad9c5 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -115,8 +115,8 @@ __global__ void prepare_grouped_gemm_data( output_ptrs[tid] = output + static_cast(start) * n; problem_sizes[tid] = ProblemShapeType(m, n, k); - stride_A[tid] = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); - stride_B[tid] = cutlass::make_cute_packed_stride(StrideB{}, {k, n, 1}); + stride_A[tid] = cutlass::make_cute_packed_stride(StrideA{}, {k, k, 1}); + stride_B[tid] = cutlass::make_cute_packed_stride(StrideB{}, {n, n, 1}); stride_output[tid] = cutlass::make_cute_packed_stride(StrideOutput{}, {m, n, 1}); } From bbb6004122dd6194735a79a51a75aa1f0c5f9069 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 19:03:27 -0800 Subject: [PATCH 035/106] replace backward --- skyrl-tx/tx/ffi/ragged_dot_ffi.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.py b/skyrl-tx/tx/ffi/ragged_dot_ffi.py index 513ef7b64..44f1df4c4 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.py +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.py @@ -136,11 +136,19 @@ def _ragged_dot_fwd( def _ragged_dot_bwd(res, g): lhs, rhs, group_sizes, group_offset = res - def _ref_lhs_rhs(lhs_, rhs_): - return _ragged_dot_ref(lhs_, rhs_, group_sizes, group_offset) + # d_lhs: g @ rhs.T with ragged grouping - same structure as forward + # g: [M, N], rhs: [G, K, N] -> rhs.T: [G, N, K], d_lhs: [M, K] + rhs_t = jnp.swapaxes(rhs, 1, 2) # [G, N, K] + d_lhs = _ragged_dot_ffi_call(g, rhs_t, group_sizes, group_offset) + + # d_rhs: lhs.T @ g accumulated per group -> [G, K, N] + # This has a different structure, use JAX's autodiff + def _ref_rhs_only(rhs_): + return _ragged_dot_ref(lhs, rhs_, group_sizes, group_offset) + + (_, pullback) = jax.vjp(_ref_rhs_only, rhs) + (d_rhs,) = pullback(g) - (_, pullback) = jax.vjp(_ref_lhs_rhs, lhs, rhs) - d_lhs, d_rhs = pullback(g) return d_lhs, d_rhs, None, None From 20bb38225931dba56344c006102c06fdfd9caf83 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 19:15:52 -0800 Subject: [PATCH 036/106] add proper backward --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 198 ++++++++++++++++++++++++++++++ skyrl-tx/tx/ffi/ragged_dot_ffi.py | 41 ++++++- 2 files changed, 233 insertions(+), 6 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index d776ad9c5..5451a819b 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -84,6 +84,32 @@ using StrideB = typename Gemm::GemmKernel::InternalStrideB; using StrideOutput = typename Gemm::GemmKernel::InternalStrideD; using ProblemShapeType = ProblemShape::UnderlyingProblemShape; +// Backward kernel for d_rhs: computes lhs.T @ grad per group +// Uses ColumnMajor for A to interpret row-major lhs as transposed +using LayoutA_Bwd = cutlass::layout::ColumnMajor; + +using CollectiveEpilogue_Bwd = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + DtypeAccum, DtypeAccum, void, LayoutOutput*, AlignmentOutput, + DtypeOutput, LayoutOutput*, AlignmentOutput, EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + +using CollectiveMainloop_Bwd = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, DtypeA, LayoutA_Bwd*, AlignmentA, + DtypeB, LayoutB*, AlignmentB, DtypeAccum, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue_Bwd::SharedStorage))>, + KernelSchedule>::CollectiveOp; + +using GemmKernel_Bwd = cutlass::gemm::kernel::GemmUniversal; +using Gemm_Bwd = cutlass::gemm::device::GemmUniversalAdapter; +using StrideA_Bwd = typename Gemm_Bwd::GemmKernel::InternalStrideA; +using StrideB_Bwd = typename Gemm_Bwd::GemmKernel::InternalStrideB; +using StrideOutput_Bwd = typename Gemm_Bwd::GemmKernel::InternalStrideD; + static ffi::Error CudaError(const char* message) { return ffi::Error::Internal(message); } @@ -120,6 +146,46 @@ __global__ void prepare_grouped_gemm_data( stride_output[tid] = cutlass::make_cute_packed_stride(StrideOutput{}, {m, n, 1}); } +// Backward kernel for d_rhs: lhs.T @ grad -> d_rhs[G, K, N] +// Memory pattern: A=lhs (ragged), B=grad (ragged), output=d_rhs (per-group) +__global__ void prepare_grouped_gemm_bwd_data( + const DtypeA* lhs, // [M_total, K] row-major (ragged) + const DtypeB* grad, // [M_total, N] row-major (ragged) + DtypeOutput* d_rhs, // [G, K, N] row-major (per-group) + const int32_t* group_sizes, + const int32_t* group_offsets_cumsum, + const int32_t* group_offset_ptr, + int32_t k, + int32_t n, + DtypeA** A_ptrs, + DtypeB** B_ptrs, + DtypeOutput** output_ptrs, + StrideA_Bwd* stride_A, + StrideB_Bwd* stride_B, + StrideOutput_Bwd* stride_output, + ProblemShapeType* problem_sizes) { + int32_t tid = threadIdx.x; + int32_t global = group_offset_ptr[0] + tid; + + int32_t start = (global > 0) ? group_offsets_cumsum[global - 1] : 0; + int32_t m = group_sizes[global]; + + // A = lhs slice, viewed as ColumnMajor [K, M_g] (transposed) + A_ptrs[tid] = const_cast(lhs) + static_cast(start) * k; + // B = grad slice [M_g, N] + B_ptrs[tid] = const_cast(grad) + static_cast(start) * n; + // Output = d_rhs[tid] with shape [K, N] + output_ptrs[tid] = d_rhs + static_cast(tid) * k * n; + + // GEMM: [K, M_g] @ [M_g, N] = [K, N] + problem_sizes[tid] = ProblemShapeType(k, n, m); + + // ColumnMajor stride for A (lhs viewed as transposed) + stride_A[tid] = cutlass::make_cute_packed_stride(StrideA_Bwd{}, {m, m, 1}); + stride_B[tid] = cutlass::make_cute_packed_stride(StrideB_Bwd{}, {n, n, 1}); + stride_output[tid] = cutlass::make_cute_packed_stride(StrideOutput_Bwd{}, {k, n, 1}); +} + ffi::Error RaggedDotCudaImpl( cudaStream_t stream, ffi::ScratchAllocator scratch, @@ -250,3 +316,135 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg>() // group_offset .Arg>() // group_offsets_cumsum .Ret>()); // out + +// Backward pass for d_rhs: computes lhs.T @ grad per group -> d_rhs[G, K, N] +ffi::Error RaggedDotBwdCudaImpl( + cudaStream_t stream, + ffi::ScratchAllocator scratch, + ffi::Buffer lhs, + ffi::Buffer grad, + ffi::Buffer group_sizes, + ffi::Buffer group_offset, + ffi::Buffer group_offsets_cumsum, + ffi::ResultBuffer d_rhs) { + auto lhs_dims = lhs.dimensions(); + auto grad_dims = grad.dimensions(); + auto d_rhs_dims = d_rhs->dimensions(); + + if (lhs_dims.size() != 2 || grad_dims.size() != 2 || d_rhs_dims.size() != 3) { + return ffi::Error::InvalidArgument("Unexpected ragged_dot_bwd dimensions."); + } + + int64_t m64 = lhs_dims[0]; + int64_t k64 = lhs_dims[1]; + int64_t grad_m64 = grad_dims[0]; + int64_t n64 = grad_dims[1]; + int64_t g_local64 = d_rhs_dims[0]; + + if (m64 != grad_m64) { + return ffi::Error::InvalidArgument("lhs/grad M dimension mismatch."); + } + if (d_rhs_dims[1] != k64 || d_rhs_dims[2] != n64) { + return ffi::Error::InvalidArgument("d_rhs shape must be [G, K, N]."); + } + + int32_t m = static_cast(m64); + int32_t k = static_cast(k64); + int32_t n = static_cast(n64); + int32_t g_local = static_cast(g_local64); + + cudaError_t err; + err = cudaMemsetAsync( + d_rhs->typed_data(), 0, static_cast(g_local) * k * n * sizeof(DtypeOutput), stream); + if (err != cudaSuccess) { + return CudaError("Failed to zero d_rhs output."); + } + + if (g_local == 0 || m == 0 || n == 0 || k == 0) { + return ffi::Error::Success(); + } + + if (g_local > 1024) { + return ffi::Error::InvalidArgument("group_count must be <= 1024."); + } + + const DtypeA* lhs_base = reinterpret_cast(lhs.typed_data()); + const DtypeB* grad_base = reinterpret_cast(grad.typed_data()); + DtypeOutput* d_rhs_base = reinterpret_cast(d_rhs->typed_data()); + const int32_t* group_sizes_ptr = group_sizes.typed_data(); + const int32_t* group_offset_ptr = group_offset.typed_data(); + const int32_t* group_offsets_cumsum_ptr = group_offsets_cumsum.typed_data(); + + size_t gl = static_cast(g_local); + size_t bytes = 7 * 16 + + sizeof(DtypeA*) * gl + sizeof(DtypeB*) * gl + sizeof(DtypeOutput*) * gl + + sizeof(StrideA_Bwd) * gl + sizeof(StrideB_Bwd) * gl + sizeof(StrideOutput_Bwd) * gl + + sizeof(ProblemShapeType) * gl; + + auto slab_or = scratch.Allocate(bytes); + if (!slab_or.has_value()) { + return ffi::Error::Internal("Failed to allocate grouped GEMM bwd slab from scratch."); + } + + auto align16 = [](char*& p) { p = reinterpret_cast((reinterpret_cast(p) + 15) & ~15); }; + char* p = reinterpret_cast(slab_or.value()); + align16(p); DtypeA** d_A_ptrs = reinterpret_cast(p); p += sizeof(DtypeA*) * gl; + align16(p); DtypeB** d_B_ptrs = reinterpret_cast(p); p += sizeof(DtypeB*) * gl; + align16(p); DtypeOutput** d_out_ptrs = reinterpret_cast(p); p += sizeof(DtypeOutput*) * gl; + align16(p); StrideA_Bwd* d_stride_A = reinterpret_cast(p); p += sizeof(StrideA_Bwd) * gl; + align16(p); StrideB_Bwd* d_stride_B = reinterpret_cast(p); p += sizeof(StrideB_Bwd) * gl; + align16(p); StrideOutput_Bwd* d_stride_output = reinterpret_cast(p); p += sizeof(StrideOutput_Bwd) * gl; + align16(p); ProblemShapeType* d_problem_sizes = reinterpret_cast(p); + + prepare_grouped_gemm_bwd_data<<<1, g_local, 0, stream>>>( + lhs_base, grad_base, d_rhs_base, + group_sizes_ptr, group_offsets_cumsum_ptr, group_offset_ptr, k, n, + d_A_ptrs, d_B_ptrs, d_out_ptrs, + d_stride_A, d_stride_B, d_stride_output, d_problem_sizes); + + Gemm_Bwd gemm; + typename Gemm_Bwd::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {g_local, d_problem_sizes, nullptr}, + {(const DtypeA**)d_A_ptrs, d_stride_A, (const DtypeB**)d_B_ptrs, d_stride_B}, + {{}, nullptr, d_stride_output, d_out_ptrs, d_stride_output}}; + + args.epilogue.thread.alpha = 1.0f; + args.epilogue.thread.dAlpha = {cute::_0{}, cute::_0{}, 0}; + args.hw_info.sm_count = get_sm_count(); + + cutlass::Status status = gemm.can_implement(args); + if (status != cutlass::Status::kSuccess) { + return ffi::Error::Internal("cutlass cannot implement grouped gemm bwd."); + } + + size_t workspace_size = Gemm_Bwd::get_workspace_size(args); + void* workspace = nullptr; + if (workspace_size > 0) { + auto workspace_or = scratch.Allocate(workspace_size); + if (!workspace_or.has_value()) { + return ffi::Error::Internal("Failed to allocate CUTLASS bwd workspace from scratch."); + } + workspace = workspace_or.value(); + } + + status = gemm(args, workspace, stream); + if (status != cutlass::Status::kSuccess) { + return ffi::Error::Internal("cutlass grouped gemm bwd failed."); + } + + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + RaggedDotBwdCuda, + RaggedDotBwdCudaImpl, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Arg>() // lhs + .Arg>() // grad + .Arg>() // group_sizes + .Arg>() // group_offset + .Arg>() // group_offsets_cumsum + .Ret>()); // d_rhs diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.py b/skyrl-tx/tx/ffi/ragged_dot_ffi.py index 44f1df4c4..644453e67 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.py +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.py @@ -51,6 +51,11 @@ def _ensure_registered() -> bool: jax_ffi.pycapsule(lib.RaggedDotCuda), platform="CUDA", ) + jax_ffi.register_ffi_target( + "ragged_dot_bwd_cuda", + jax_ffi.pycapsule(lib.RaggedDotBwdCuda), + platform="CUDA", + ) _REGISTERED = True return True except Exception as exc: # pragma: no cover - load/registration failures @@ -113,6 +118,34 @@ def _ragged_dot_ffi_call( return call(lhs, rhs, group_sizes, group_offset, group_offsets) +def _ragged_dot_bwd_ffi_call( + lhs: jax.Array, + grad: jax.Array, + group_sizes: jax.Array, + group_offset: jax.Array, + g_local: int, +) -> jax.Array: + """Backward pass for d_rhs: computes lhs.T @ grad per group -> [G, K, N].""" + if not _ensure_registered(): + raise RuntimeError("ragged_dot_ffi is not available. Build and load the shared library first.") + + if lhs.dtype != jnp.bfloat16 or grad.dtype != jnp.bfloat16: + raise NotImplementedError("ragged_dot_bwd_ffi supports bfloat16 only.") + if group_sizes.dtype != jnp.int32 or group_offset.dtype != jnp.int32: + raise NotImplementedError("ragged_dot_bwd_ffi expects int32 group_sizes and group_offset.") + if group_offset.shape != (1,): + raise ValueError("group_offset must have shape (1,).") + + k = lhs.shape[1] + n = grad.shape[1] + + group_offsets = jnp.cumsum(group_sizes, dtype=jnp.int32) + + out = jax.ShapeDtypeStruct((g_local, k, n), lhs.dtype) + call = jax_ffi.ffi_call("ragged_dot_bwd_cuda", out, vmap_method=None) + return call(lhs, grad, group_sizes, group_offset, group_offsets) + + @jax.custom_vjp def ragged_dot( lhs: jax.Array, @@ -142,12 +175,8 @@ def _ragged_dot_bwd(res, g): d_lhs = _ragged_dot_ffi_call(g, rhs_t, group_sizes, group_offset) # d_rhs: lhs.T @ g accumulated per group -> [G, K, N] - # This has a different structure, use JAX's autodiff - def _ref_rhs_only(rhs_): - return _ragged_dot_ref(lhs, rhs_, group_sizes, group_offset) - - (_, pullback) = jax.vjp(_ref_rhs_only, rhs) - (d_rhs,) = pullback(g) + g_local = rhs.shape[0] + d_rhs = _ragged_dot_bwd_ffi_call(lhs, g, group_sizes, group_offset, g_local) return d_lhs, d_rhs, None, None From 0b51d0b73b3f22ce349a96bb5637955a33097c8c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 19:29:30 -0800 Subject: [PATCH 037/106] optimize --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 5451a819b..e93295861 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -49,14 +49,14 @@ using DtypeAccum = float; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutOutput = cutlass::layout::RowMajor; -constexpr int AlignmentA = 8; -constexpr int AlignmentB = 8; -constexpr int AlignmentOutput = 8; +constexpr int AlignmentA = 16; +constexpr int AlignmentB = 16; +constexpr int AlignmentOutput = 16; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; using TileShape = cute::Shape; -using ClusterShape = cute::Shape; +using ClusterShape = cute::Shape; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using ProblemShape = cutlass::gemm::GroupProblemShape>; From 80a276fffc0b47b47dff4c53fb19c794c30b7bdb Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 20:20:28 -0800 Subject: [PATCH 038/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index e93295861..564c8c145 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -57,8 +57,8 @@ using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; using TileShape = cute::Shape; using ClusterShape = cute::Shape; -using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; -using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; using ProblemShape = cutlass::gemm::GroupProblemShape>; using CollectiveEpilogue = From fdfff3960b565b5057445c39734af581531dc415 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 20:23:11 -0800 Subject: [PATCH 039/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 564c8c145..d2422a0bf 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -55,7 +55,7 @@ constexpr int AlignmentOutput = 16; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; -using TileShape = cute::Shape; +using TileShape = cute::Shape; using ClusterShape = cute::Shape; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; From 3923e99433de81cec10393664ca17d1fdba31fa9 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 20:42:52 -0800 Subject: [PATCH 040/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index d2422a0bf..5e8d40f0d 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -55,7 +55,7 @@ constexpr int AlignmentOutput = 16; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; -using TileShape = cute::Shape; +using TileShape = cute::Shape; using ClusterShape = cute::Shape; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; From 0acef25c07c0c374cda456acd179d51add27599d Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 21:11:07 -0800 Subject: [PATCH 041/106] optimize --- skyrl-tx/tx/ffi/ragged_dot_ffi.py | 39 ++++++++++--------------------- 1 file changed, 12 insertions(+), 27 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.py b/skyrl-tx/tx/ffi/ragged_dot_ffi.py index 644453e67..f40a9c561 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.py +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.py @@ -99,23 +99,14 @@ def _ragged_dot_ffi_call( rhs: jax.Array, group_sizes: jax.Array, group_offset: jax.Array, + group_offsets_cumsum: jax.Array, ) -> jax.Array: if not _ensure_registered(): raise RuntimeError("ragged_dot_ffi is not available. Build and load the shared library first.") - if lhs.dtype != jnp.bfloat16 or rhs.dtype != jnp.bfloat16: - raise NotImplementedError("ragged_dot_ffi supports bfloat16 only.") - if group_sizes.dtype != jnp.int32 or group_offset.dtype != jnp.int32: - raise NotImplementedError("ragged_dot_ffi expects int32 group_sizes and group_offset.") - if group_offset.shape != (1,): - raise ValueError("group_offset must have shape (1,).") - - # Precompute cumulative sum to avoid O(n²) loop in CUDA kernel - group_offsets = jnp.cumsum(group_sizes, dtype=jnp.int32) - out = jax.ShapeDtypeStruct((lhs.shape[0], rhs.shape[2]), lhs.dtype) call = jax_ffi.ffi_call("ragged_dot_cuda", out, vmap_method=None) - return call(lhs, rhs, group_sizes, group_offset, group_offsets) + return call(lhs, rhs, group_sizes, group_offset, group_offsets_cumsum) def _ragged_dot_bwd_ffi_call( @@ -123,27 +114,19 @@ def _ragged_dot_bwd_ffi_call( grad: jax.Array, group_sizes: jax.Array, group_offset: jax.Array, + group_offsets_cumsum: jax.Array, g_local: int, ) -> jax.Array: """Backward pass for d_rhs: computes lhs.T @ grad per group -> [G, K, N].""" if not _ensure_registered(): raise RuntimeError("ragged_dot_ffi is not available. Build and load the shared library first.") - if lhs.dtype != jnp.bfloat16 or grad.dtype != jnp.bfloat16: - raise NotImplementedError("ragged_dot_bwd_ffi supports bfloat16 only.") - if group_sizes.dtype != jnp.int32 or group_offset.dtype != jnp.int32: - raise NotImplementedError("ragged_dot_bwd_ffi expects int32 group_sizes and group_offset.") - if group_offset.shape != (1,): - raise ValueError("group_offset must have shape (1,).") - k = lhs.shape[1] n = grad.shape[1] - group_offsets = jnp.cumsum(group_sizes, dtype=jnp.int32) - out = jax.ShapeDtypeStruct((g_local, k, n), lhs.dtype) call = jax_ffi.ffi_call("ragged_dot_bwd_cuda", out, vmap_method=None) - return call(lhs, grad, group_sizes, group_offset, group_offsets) + return call(lhs, grad, group_sizes, group_offset, group_offsets_cumsum) @jax.custom_vjp @@ -153,7 +136,8 @@ def ragged_dot( group_sizes: jax.Array, group_offset: jax.Array, ) -> jax.Array: - return _ragged_dot_ffi_call(lhs, rhs, group_sizes, group_offset) + group_offsets_cumsum = jnp.cumsum(group_sizes, dtype=jnp.int32) + return _ragged_dot_ffi_call(lhs, rhs, group_sizes, group_offset, group_offsets_cumsum) def _ragged_dot_fwd( @@ -162,21 +146,22 @@ def _ragged_dot_fwd( group_sizes: jax.Array, group_offset: jax.Array, ): - y = _ragged_dot_ffi_call(lhs, rhs, group_sizes, group_offset) - return y, (lhs, rhs, group_sizes, group_offset) + group_offsets_cumsum = jnp.cumsum(group_sizes, dtype=jnp.int32) + y = _ragged_dot_ffi_call(lhs, rhs, group_sizes, group_offset, group_offsets_cumsum) + return y, (lhs, rhs, group_sizes, group_offset, group_offsets_cumsum) def _ragged_dot_bwd(res, g): - lhs, rhs, group_sizes, group_offset = res + lhs, rhs, group_sizes, group_offset, group_offsets_cumsum = res # d_lhs: g @ rhs.T with ragged grouping - same structure as forward # g: [M, N], rhs: [G, K, N] -> rhs.T: [G, N, K], d_lhs: [M, K] rhs_t = jnp.swapaxes(rhs, 1, 2) # [G, N, K] - d_lhs = _ragged_dot_ffi_call(g, rhs_t, group_sizes, group_offset) + d_lhs = _ragged_dot_ffi_call(g, rhs_t, group_sizes, group_offset, group_offsets_cumsum) # d_rhs: lhs.T @ g accumulated per group -> [G, K, N] g_local = rhs.shape[0] - d_rhs = _ragged_dot_bwd_ffi_call(lhs, g, group_sizes, group_offset, g_local) + d_rhs = _ragged_dot_bwd_ffi_call(lhs, g, group_sizes, group_offset, group_offsets_cumsum, g_local) return d_lhs, d_rhs, None, None From cc26ba290a54103457ac0e4ded697406df1e6f63 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 21:20:10 -0800 Subject: [PATCH 042/106] simplify --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 20 +++++--------------- skyrl-tx/tx/ffi/ragged_dot_ffi.py | 18 ++++++++---------- 2 files changed, 13 insertions(+), 25 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 5e8d40f0d..5e4b7be10 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -118,7 +118,6 @@ __global__ void prepare_grouped_gemm_data( const DtypeA* A, const DtypeB* B, DtypeOutput* output, - const int32_t* group_sizes, const int32_t* group_offsets_cumsum, const int32_t* group_offset_ptr, int32_t k, @@ -134,7 +133,7 @@ __global__ void prepare_grouped_gemm_data( int32_t global = group_offset_ptr[0] + tid; int32_t start = (global > 0) ? group_offsets_cumsum[global - 1] : 0; - int32_t m = group_sizes[global]; + int32_t m = group_offsets_cumsum[global] - start; A_ptrs[tid] = const_cast(A) + static_cast(start) * k; B_ptrs[tid] = const_cast(B) + static_cast(tid) * n * k; @@ -152,7 +151,6 @@ __global__ void prepare_grouped_gemm_bwd_data( const DtypeA* lhs, // [M_total, K] row-major (ragged) const DtypeB* grad, // [M_total, N] row-major (ragged) DtypeOutput* d_rhs, // [G, K, N] row-major (per-group) - const int32_t* group_sizes, const int32_t* group_offsets_cumsum, const int32_t* group_offset_ptr, int32_t k, @@ -168,7 +166,7 @@ __global__ void prepare_grouped_gemm_bwd_data( int32_t global = group_offset_ptr[0] + tid; int32_t start = (global > 0) ? group_offsets_cumsum[global - 1] : 0; - int32_t m = group_sizes[global]; + int32_t m = group_offsets_cumsum[global] - start; // A = lhs slice, viewed as ColumnMajor [K, M_g] (transposed) A_ptrs[tid] = const_cast(lhs) + static_cast(start) * k; @@ -191,17 +189,14 @@ ffi::Error RaggedDotCudaImpl( ffi::ScratchAllocator scratch, ffi::Buffer lhs, ffi::Buffer rhs, - ffi::Buffer group_sizes, ffi::Buffer group_offset, ffi::Buffer group_offsets_cumsum, ffi::ResultBuffer out) { auto lhs_dims = lhs.dimensions(); auto rhs_dims = rhs.dimensions(); - auto group_sizes_dims = group_sizes.dimensions(); auto group_offset_dims = group_offset.dimensions(); - if (lhs_dims.size() != 2 || rhs_dims.size() != 3 || group_sizes_dims.size() != 1 || - group_offset_dims.size() != 1) { + if (lhs_dims.size() != 2 || rhs_dims.size() != 3 || group_offset_dims.size() != 1) { return ffi::Error::InvalidArgument("Unexpected ragged_dot dimensions."); } @@ -220,7 +215,6 @@ ffi::Error RaggedDotCudaImpl( int32_t n = static_cast(n64); int32_t g_local = static_cast(g_local64); - const int32_t* group_sizes_ptr = group_sizes.typed_data(); const int32_t* group_offset_ptr = group_offset.typed_data(); cudaError_t err; @@ -266,7 +260,7 @@ ffi::Error RaggedDotCudaImpl( prepare_grouped_gemm_data<<<1, g_local, 0, stream>>>( A_base, B_base, out_base, - group_sizes_ptr, group_offsets_cumsum_ptr, group_offset_ptr, k, n, + group_offsets_cumsum_ptr, group_offset_ptr, k, n, d_A_ptrs, d_B_ptrs, d_out_ptrs, d_stride_A, d_stride_B, d_stride_output, d_problem_sizes); @@ -312,7 +306,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Ctx() .Arg>() // lhs .Arg>() // rhs - .Arg>() // group_sizes .Arg>() // group_offset .Arg>() // group_offsets_cumsum .Ret>()); // out @@ -323,7 +316,6 @@ ffi::Error RaggedDotBwdCudaImpl( ffi::ScratchAllocator scratch, ffi::Buffer lhs, ffi::Buffer grad, - ffi::Buffer group_sizes, ffi::Buffer group_offset, ffi::Buffer group_offsets_cumsum, ffi::ResultBuffer d_rhs) { @@ -371,7 +363,6 @@ ffi::Error RaggedDotBwdCudaImpl( const DtypeA* lhs_base = reinterpret_cast(lhs.typed_data()); const DtypeB* grad_base = reinterpret_cast(grad.typed_data()); DtypeOutput* d_rhs_base = reinterpret_cast(d_rhs->typed_data()); - const int32_t* group_sizes_ptr = group_sizes.typed_data(); const int32_t* group_offset_ptr = group_offset.typed_data(); const int32_t* group_offsets_cumsum_ptr = group_offsets_cumsum.typed_data(); @@ -398,7 +389,7 @@ ffi::Error RaggedDotBwdCudaImpl( prepare_grouped_gemm_bwd_data<<<1, g_local, 0, stream>>>( lhs_base, grad_base, d_rhs_base, - group_sizes_ptr, group_offsets_cumsum_ptr, group_offset_ptr, k, n, + group_offsets_cumsum_ptr, group_offset_ptr, k, n, d_A_ptrs, d_B_ptrs, d_out_ptrs, d_stride_A, d_stride_B, d_stride_output, d_problem_sizes); @@ -444,7 +435,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Ctx() .Arg>() // lhs .Arg>() // grad - .Arg>() // group_sizes .Arg>() // group_offset .Arg>() // group_offsets_cumsum .Ret>()); // d_rhs diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.py b/skyrl-tx/tx/ffi/ragged_dot_ffi.py index f40a9c561..9d5cd505f 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.py +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.py @@ -97,7 +97,6 @@ def _ragged_dot_ref( def _ragged_dot_ffi_call( lhs: jax.Array, rhs: jax.Array, - group_sizes: jax.Array, group_offset: jax.Array, group_offsets_cumsum: jax.Array, ) -> jax.Array: @@ -106,13 +105,12 @@ def _ragged_dot_ffi_call( out = jax.ShapeDtypeStruct((lhs.shape[0], rhs.shape[2]), lhs.dtype) call = jax_ffi.ffi_call("ragged_dot_cuda", out, vmap_method=None) - return call(lhs, rhs, group_sizes, group_offset, group_offsets_cumsum) + return call(lhs, rhs, group_offset, group_offsets_cumsum) def _ragged_dot_bwd_ffi_call( lhs: jax.Array, grad: jax.Array, - group_sizes: jax.Array, group_offset: jax.Array, group_offsets_cumsum: jax.Array, g_local: int, @@ -126,7 +124,7 @@ def _ragged_dot_bwd_ffi_call( out = jax.ShapeDtypeStruct((g_local, k, n), lhs.dtype) call = jax_ffi.ffi_call("ragged_dot_bwd_cuda", out, vmap_method=None) - return call(lhs, grad, group_sizes, group_offset, group_offsets_cumsum) + return call(lhs, grad, group_offset, group_offsets_cumsum) @jax.custom_vjp @@ -137,7 +135,7 @@ def ragged_dot( group_offset: jax.Array, ) -> jax.Array: group_offsets_cumsum = jnp.cumsum(group_sizes, dtype=jnp.int32) - return _ragged_dot_ffi_call(lhs, rhs, group_sizes, group_offset, group_offsets_cumsum) + return _ragged_dot_ffi_call(lhs, rhs, group_offset, group_offsets_cumsum) def _ragged_dot_fwd( @@ -147,21 +145,21 @@ def _ragged_dot_fwd( group_offset: jax.Array, ): group_offsets_cumsum = jnp.cumsum(group_sizes, dtype=jnp.int32) - y = _ragged_dot_ffi_call(lhs, rhs, group_sizes, group_offset, group_offsets_cumsum) - return y, (lhs, rhs, group_sizes, group_offset, group_offsets_cumsum) + y = _ragged_dot_ffi_call(lhs, rhs, group_offset, group_offsets_cumsum) + return y, (lhs, rhs, group_offset, group_offsets_cumsum) def _ragged_dot_bwd(res, g): - lhs, rhs, group_sizes, group_offset, group_offsets_cumsum = res + lhs, rhs, group_offset, group_offsets_cumsum = res # d_lhs: g @ rhs.T with ragged grouping - same structure as forward # g: [M, N], rhs: [G, K, N] -> rhs.T: [G, N, K], d_lhs: [M, K] rhs_t = jnp.swapaxes(rhs, 1, 2) # [G, N, K] - d_lhs = _ragged_dot_ffi_call(g, rhs_t, group_sizes, group_offset, group_offsets_cumsum) + d_lhs = _ragged_dot_ffi_call(g, rhs_t, group_offset, group_offsets_cumsum) # d_rhs: lhs.T @ g accumulated per group -> [G, K, N] g_local = rhs.shape[0] - d_rhs = _ragged_dot_bwd_ffi_call(lhs, g, group_sizes, group_offset, group_offsets_cumsum, g_local) + d_rhs = _ragged_dot_bwd_ffi_call(lhs, g, group_offset, group_offsets_cumsum, g_local) return d_lhs, d_rhs, None, None From 9b72ae7d19a2a9b9905707554ec55f240fc27292 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 21:21:18 -0800 Subject: [PATCH 043/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 5e4b7be10..4dfd1195e 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -49,9 +49,9 @@ using DtypeAccum = float; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutOutput = cutlass::layout::RowMajor; -constexpr int AlignmentA = 16; -constexpr int AlignmentB = 16; -constexpr int AlignmentOutput = 16; +constexpr int AlignmentA = 8; +constexpr int AlignmentB = 8; +constexpr int AlignmentOutput = 8; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; From 1f1e728f7bedf091dc59b39c3139bfa846bb4e93 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 21:35:26 -0800 Subject: [PATCH 044/106] try without caching --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 4dfd1195e..dc89a4f36 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -25,21 +25,16 @@ namespace ffi = xla::ffi; -// Cache SM count per device -static int g_sm_count[16] = {0}; - static int get_sm_count() { int device = 0; - if (cudaGetDevice(&device) != cudaSuccess || device < 0 || device >= 16) { + if (cudaGetDevice(&device) != cudaSuccess) { return 0; } - if (g_sm_count[device] == 0) { - cudaDeviceProp props; - if (cudaGetDeviceProperties(&props, device) == cudaSuccess) { - g_sm_count[device] = props.multiProcessorCount; - } + cudaDeviceProp props; + if (cudaGetDeviceProperties(&props, device) != cudaSuccess) { + return 0; } - return g_sm_count[device]; + return props.multiProcessorCount; } using DtypeA = cutlass::bfloat16_t; From 1717190dfa172a1ceb2d42a0e011a1fb84efdb92 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 21:42:32 -0800 Subject: [PATCH 045/106] clean up --- skyrl-tx/tx/ffi/ragged_dot_ffi.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.py b/skyrl-tx/tx/ffi/ragged_dot_ffi.py index 9d5cd505f..293d85b21 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.py +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.py @@ -5,7 +5,6 @@ from pathlib import Path import jax -from jax import lax import jax.numpy as jnp try: # JAX >= 0.8 @@ -67,33 +66,6 @@ def is_available() -> bool: return _ensure_registered() -def _ragged_dot_ref( - lhs: jax.Array, - rhs: jax.Array, - group_sizes: jax.Array, - group_offset: jax.Array, -) -> jax.Array: - if group_offset.shape != (1,): - raise ValueError("group_offset must have shape (1,).") - - offset = group_offset[0] - m = lhs.shape[0] - g_local = rhs.shape[0] - - cumsum = jnp.cumulative_sum(group_sizes, include_initial=True) - shard_start = cumsum[offset] - shard_end = cumsum[offset + g_local] - - token_idx = jnp.arange(m) - valid_mask = (token_idx >= shard_start) & (token_idx < shard_end) - - local_group_sizes = lax.dynamic_slice_in_dim(group_sizes, offset, g_local, axis=0) - adjusted_group_sizes = local_group_sizes.at[0].add(shard_start).at[-1].add(m - shard_end) - - result = lax.ragged_dot(lhs, rhs, adjusted_group_sizes) - return jnp.where(valid_mask[:, None], result, 0) - - def _ragged_dot_ffi_call( lhs: jax.Array, rhs: jax.Array, From 3199c0db2792284a82cf44264a6a67b09f60829a Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 22:06:19 -0800 Subject: [PATCH 046/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index dc89a4f36..17663b44d 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -2,6 +2,7 @@ #include #include +#include #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" @@ -25,14 +26,19 @@ namespace ffi = xla::ffi; +static std::vector g_device_props; + static int get_sm_count() { int device = 0; - if (cudaGetDevice(&device) != cudaSuccess) { + if (cudaGetDevice(&device) != cudaSuccess || device < 0) { return 0; } - cudaDeviceProp props; - if (cudaGetDeviceProperties(&props, device) != cudaSuccess) { - return 0; + if (static_cast(device) >= g_device_props.size()) { + g_device_props.resize(device + 1); + } + cudaDeviceProp& props = g_device_props[device]; + if (!props.multiProcessorCount) { + cudaGetDeviceProperties(&props, device); } return props.multiProcessorCount; } From 1071a5a2955e3ebeed2529ee3031c0ccb6807d0d Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 22:15:17 -0800 Subject: [PATCH 047/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 17663b44d..e2a602e6d 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -111,9 +111,6 @@ using StrideA_Bwd = typename Gemm_Bwd::GemmKernel::InternalStrideA; using StrideB_Bwd = typename Gemm_Bwd::GemmKernel::InternalStrideB; using StrideOutput_Bwd = typename Gemm_Bwd::GemmKernel::InternalStrideD; -static ffi::Error CudaError(const char* message) { - return ffi::Error::Internal(message); -} __global__ void prepare_grouped_gemm_data( const DtypeA* A, @@ -222,7 +219,7 @@ ffi::Error RaggedDotCudaImpl( err = cudaMemsetAsync( out->typed_data(), 0, static_cast(m) * n * sizeof(DtypeOutput), stream); if (err != cudaSuccess) { - return CudaError("Failed to zero output."); + return ffi::Error::Internal("Failed to zero output."); } if (g_local == 0 || m == 0 || n == 0 || k == 0) { @@ -350,7 +347,7 @@ ffi::Error RaggedDotBwdCudaImpl( err = cudaMemsetAsync( d_rhs->typed_data(), 0, static_cast(g_local) * k * n * sizeof(DtypeOutput), stream); if (err != cudaSuccess) { - return CudaError("Failed to zero d_rhs output."); + return ffi::Error::Internal("Failed to zero d_rhs output."); } if (g_local == 0 || m == 0 || n == 0 || k == 0) { From 688c96226205a7924a1c849553b3599991df3c47 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 22:21:30 -0800 Subject: [PATCH 048/106] simplify --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 45 ++++++++++--------------------- 1 file changed, 14 insertions(+), 31 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index e2a602e6d..e0ea4cae2 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -197,26 +197,16 @@ ffi::Error RaggedDotCudaImpl( if (lhs_dims.size() != 2 || rhs_dims.size() != 3 || group_offset_dims.size() != 1) { return ffi::Error::InvalidArgument("Unexpected ragged_dot dimensions."); } - - int64_t m64 = lhs_dims[0]; - int64_t k64 = lhs_dims[1]; - int64_t g_local64 = rhs_dims[0]; - int64_t rhs_k64 = rhs_dims[1]; - int64_t n64 = rhs_dims[2]; - - if (k64 != rhs_k64) { + if (lhs_dims[1] != rhs_dims[1]) { return ffi::Error::InvalidArgument("lhs/rhs K dimension mismatch."); } - int32_t m = static_cast(m64); - int32_t k = static_cast(k64); - int32_t n = static_cast(n64); - int32_t g_local = static_cast(g_local64); - - const int32_t* group_offset_ptr = group_offset.typed_data(); - cudaError_t err; + int32_t m = static_cast(lhs_dims[0]); + int32_t k = static_cast(lhs_dims[1]); + int32_t g_local = static_cast(rhs_dims[0]); + int32_t n = static_cast(rhs_dims[2]); - err = cudaMemsetAsync( + cudaError_t err = cudaMemsetAsync( out->typed_data(), 0, static_cast(m) * n * sizeof(DtypeOutput), stream); if (err != cudaSuccess) { return ffi::Error::Internal("Failed to zero output."); @@ -233,6 +223,7 @@ ffi::Error RaggedDotCudaImpl( const DtypeA* A_base = reinterpret_cast(lhs.typed_data()); const DtypeB* B_base = reinterpret_cast(rhs.typed_data()); DtypeOutput* out_base = reinterpret_cast(out->typed_data()); + const int32_t* group_offset_ptr = group_offset.typed_data(); const int32_t* group_offsets_cumsum_ptr = group_offsets_cumsum.typed_data(); size_t gl = static_cast(g_local); @@ -324,27 +315,19 @@ ffi::Error RaggedDotBwdCudaImpl( if (lhs_dims.size() != 2 || grad_dims.size() != 2 || d_rhs_dims.size() != 3) { return ffi::Error::InvalidArgument("Unexpected ragged_dot_bwd dimensions."); } - - int64_t m64 = lhs_dims[0]; - int64_t k64 = lhs_dims[1]; - int64_t grad_m64 = grad_dims[0]; - int64_t n64 = grad_dims[1]; - int64_t g_local64 = d_rhs_dims[0]; - - if (m64 != grad_m64) { + if (lhs_dims[0] != grad_dims[0]) { return ffi::Error::InvalidArgument("lhs/grad M dimension mismatch."); } - if (d_rhs_dims[1] != k64 || d_rhs_dims[2] != n64) { + if (d_rhs_dims[1] != lhs_dims[1] || d_rhs_dims[2] != grad_dims[1]) { return ffi::Error::InvalidArgument("d_rhs shape must be [G, K, N]."); } - int32_t m = static_cast(m64); - int32_t k = static_cast(k64); - int32_t n = static_cast(n64); - int32_t g_local = static_cast(g_local64); + int32_t m = static_cast(lhs_dims[0]); + int32_t k = static_cast(lhs_dims[1]); + int32_t n = static_cast(grad_dims[1]); + int32_t g_local = static_cast(d_rhs_dims[0]); - cudaError_t err; - err = cudaMemsetAsync( + cudaError_t err = cudaMemsetAsync( d_rhs->typed_data(), 0, static_cast(g_local) * k * n * sizeof(DtypeOutput), stream); if (err != cudaSuccess) { return ffi::Error::Internal("Failed to zero d_rhs output."); From de0b140f56b7e50cfbd1618fe280e556910e8d2c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 22:29:07 -0800 Subject: [PATCH 049/106] simplify --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index e0ea4cae2..6844ef223 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -223,8 +223,6 @@ ffi::Error RaggedDotCudaImpl( const DtypeA* A_base = reinterpret_cast(lhs.typed_data()); const DtypeB* B_base = reinterpret_cast(rhs.typed_data()); DtypeOutput* out_base = reinterpret_cast(out->typed_data()); - const int32_t* group_offset_ptr = group_offset.typed_data(); - const int32_t* group_offsets_cumsum_ptr = group_offsets_cumsum.typed_data(); size_t gl = static_cast(g_local); size_t bytes = 7 * 16 + // alignment padding @@ -249,7 +247,7 @@ ffi::Error RaggedDotCudaImpl( prepare_grouped_gemm_data<<<1, g_local, 0, stream>>>( A_base, B_base, out_base, - group_offsets_cumsum_ptr, group_offset_ptr, k, n, + group_offsets_cumsum.typed_data(), group_offset.typed_data(), k, n, d_A_ptrs, d_B_ptrs, d_out_ptrs, d_stride_A, d_stride_B, d_stride_output, d_problem_sizes); @@ -344,8 +342,6 @@ ffi::Error RaggedDotBwdCudaImpl( const DtypeA* lhs_base = reinterpret_cast(lhs.typed_data()); const DtypeB* grad_base = reinterpret_cast(grad.typed_data()); DtypeOutput* d_rhs_base = reinterpret_cast(d_rhs->typed_data()); - const int32_t* group_offset_ptr = group_offset.typed_data(); - const int32_t* group_offsets_cumsum_ptr = group_offsets_cumsum.typed_data(); size_t gl = static_cast(g_local); size_t bytes = 7 * 16 + @@ -370,7 +366,7 @@ ffi::Error RaggedDotBwdCudaImpl( prepare_grouped_gemm_bwd_data<<<1, g_local, 0, stream>>>( lhs_base, grad_base, d_rhs_base, - group_offsets_cumsum_ptr, group_offset_ptr, k, n, + group_offsets_cumsum.typed_data(), group_offset.typed_data(), k, n, d_A_ptrs, d_B_ptrs, d_out_ptrs, d_stride_A, d_stride_B, d_stride_output, d_problem_sizes); From 9586e57a8ddeb070f8892c6145f6ee0f82c566cf Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 22:30:38 -0800 Subject: [PATCH 050/106] simplify --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 6844ef223..d7e3921cb 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -262,8 +262,7 @@ ffi::Error RaggedDotCudaImpl( args.epilogue.thread.dAlpha = {cute::_0{}, cute::_0{}, 0}; args.hw_info.sm_count = get_sm_count(); - cutlass::Status status = gemm.can_implement(args); - if (status != cutlass::Status::kSuccess) { + if (gemm.can_implement(args) != cutlass::Status::kSuccess) { return ffi::Error::Internal("cutlass cannot implement grouped gemm."); } @@ -277,8 +276,7 @@ ffi::Error RaggedDotCudaImpl( workspace = workspace_or.value(); } - status = gemm(args, workspace, stream); - if (status != cutlass::Status::kSuccess) { + if (gemm(args, workspace, stream) != cutlass::Status::kSuccess) { return ffi::Error::Internal("cutlass grouped gemm failed."); } @@ -381,8 +379,7 @@ ffi::Error RaggedDotBwdCudaImpl( args.epilogue.thread.dAlpha = {cute::_0{}, cute::_0{}, 0}; args.hw_info.sm_count = get_sm_count(); - cutlass::Status status = gemm.can_implement(args); - if (status != cutlass::Status::kSuccess) { + if (gemm.can_implement(args) != cutlass::Status::kSuccess) { return ffi::Error::Internal("cutlass cannot implement grouped gemm bwd."); } @@ -396,8 +393,7 @@ ffi::Error RaggedDotBwdCudaImpl( workspace = workspace_or.value(); } - status = gemm(args, workspace, stream); - if (status != cutlass::Status::kSuccess) { + if (gemm(args, workspace, stream) != cutlass::Status::kSuccess) { return ffi::Error::Internal("cutlass grouped gemm bwd failed."); } From 5a96b6828d2c49203858f8389793eb3ffa4037e2 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 22:33:46 -0800 Subject: [PATCH 051/106] simplify --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index d7e3921cb..ef5b037ed 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -220,10 +220,6 @@ ffi::Error RaggedDotCudaImpl( return ffi::Error::InvalidArgument("group_count must be <= 1024."); } - const DtypeA* A_base = reinterpret_cast(lhs.typed_data()); - const DtypeB* B_base = reinterpret_cast(rhs.typed_data()); - DtypeOutput* out_base = reinterpret_cast(out->typed_data()); - size_t gl = static_cast(g_local); size_t bytes = 7 * 16 + // alignment padding sizeof(DtypeA*) * gl + sizeof(DtypeB*) * gl + sizeof(DtypeOutput*) * gl + @@ -246,7 +242,9 @@ ffi::Error RaggedDotCudaImpl( align16(p); ProblemShapeType* d_problem_sizes = reinterpret_cast(p); prepare_grouped_gemm_data<<<1, g_local, 0, stream>>>( - A_base, B_base, out_base, + reinterpret_cast(lhs.typed_data()), + reinterpret_cast(rhs.typed_data()), + reinterpret_cast(out->typed_data()), group_offsets_cumsum.typed_data(), group_offset.typed_data(), k, n, d_A_ptrs, d_B_ptrs, d_out_ptrs, d_stride_A, d_stride_B, d_stride_output, d_problem_sizes); @@ -337,10 +335,6 @@ ffi::Error RaggedDotBwdCudaImpl( return ffi::Error::InvalidArgument("group_count must be <= 1024."); } - const DtypeA* lhs_base = reinterpret_cast(lhs.typed_data()); - const DtypeB* grad_base = reinterpret_cast(grad.typed_data()); - DtypeOutput* d_rhs_base = reinterpret_cast(d_rhs->typed_data()); - size_t gl = static_cast(g_local); size_t bytes = 7 * 16 + sizeof(DtypeA*) * gl + sizeof(DtypeB*) * gl + sizeof(DtypeOutput*) * gl + @@ -363,7 +357,9 @@ ffi::Error RaggedDotBwdCudaImpl( align16(p); ProblemShapeType* d_problem_sizes = reinterpret_cast(p); prepare_grouped_gemm_bwd_data<<<1, g_local, 0, stream>>>( - lhs_base, grad_base, d_rhs_base, + reinterpret_cast(lhs.typed_data()), + reinterpret_cast(grad.typed_data()), + reinterpret_cast(d_rhs->typed_data()), group_offsets_cumsum.typed_data(), group_offset.typed_data(), k, n, d_A_ptrs, d_B_ptrs, d_out_ptrs, d_stride_A, d_stride_B, d_stride_output, d_problem_sizes); From 11fa3c4710c1f5903ba0c638fea482ecb111e9fe Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 22:49:20 -0800 Subject: [PATCH 052/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 56 +++++++++++++++++-------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index ef5b037ed..6909944dc 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -111,6 +111,12 @@ using StrideA_Bwd = typename Gemm_Bwd::GemmKernel::InternalStrideA; using StrideB_Bwd = typename Gemm_Bwd::GemmKernel::InternalStrideB; using StrideOutput_Bwd = typename Gemm_Bwd::GemmKernel::InternalStrideD; +template +static T* carve_aligned(char*& p, size_t count) { + T* out = reinterpret_cast((reinterpret_cast(p) + 15) & ~uintptr_t(15)); + p = reinterpret_cast(out + count); + return out; +} __global__ void prepare_grouped_gemm_data( const DtypeA* A, @@ -120,8 +126,8 @@ __global__ void prepare_grouped_gemm_data( const int32_t* group_offset_ptr, int32_t k, int32_t n, - DtypeA** A_ptrs, - DtypeB** B_ptrs, + const DtypeA** A_ptrs, + const DtypeB** B_ptrs, DtypeOutput** output_ptrs, StrideA* stride_A, StrideB* stride_B, @@ -133,8 +139,8 @@ __global__ void prepare_grouped_gemm_data( int32_t start = (global > 0) ? group_offsets_cumsum[global - 1] : 0; int32_t m = group_offsets_cumsum[global] - start; - A_ptrs[tid] = const_cast(A) + static_cast(start) * k; - B_ptrs[tid] = const_cast(B) + static_cast(tid) * n * k; + A_ptrs[tid] = A + static_cast(start) * k; + B_ptrs[tid] = B + static_cast(tid) * n * k; output_ptrs[tid] = output + static_cast(start) * n; problem_sizes[tid] = ProblemShapeType(m, n, k); @@ -153,8 +159,8 @@ __global__ void prepare_grouped_gemm_bwd_data( const int32_t* group_offset_ptr, int32_t k, int32_t n, - DtypeA** A_ptrs, - DtypeB** B_ptrs, + const DtypeA** A_ptrs, + const DtypeB** B_ptrs, DtypeOutput** output_ptrs, StrideA_Bwd* stride_A, StrideB_Bwd* stride_B, @@ -167,9 +173,9 @@ __global__ void prepare_grouped_gemm_bwd_data( int32_t m = group_offsets_cumsum[global] - start; // A = lhs slice, viewed as ColumnMajor [K, M_g] (transposed) - A_ptrs[tid] = const_cast(lhs) + static_cast(start) * k; + A_ptrs[tid] = lhs + static_cast(start) * k; // B = grad slice [M_g, N] - B_ptrs[tid] = const_cast(grad) + static_cast(start) * n; + B_ptrs[tid] = grad + static_cast(start) * n; // Output = d_rhs[tid] with shape [K, N] output_ptrs[tid] = d_rhs + static_cast(tid) * k * n; @@ -222,7 +228,7 @@ ffi::Error RaggedDotCudaImpl( size_t gl = static_cast(g_local); size_t bytes = 7 * 16 + // alignment padding - sizeof(DtypeA*) * gl + sizeof(DtypeB*) * gl + sizeof(DtypeOutput*) * gl + + sizeof(const DtypeA*) * gl + sizeof(const DtypeB*) * gl + sizeof(DtypeOutput*) * gl + sizeof(StrideA) * gl + sizeof(StrideB) * gl + sizeof(StrideOutput) * gl + sizeof(ProblemShapeType) * gl; @@ -231,15 +237,14 @@ ffi::Error RaggedDotCudaImpl( return ffi::Error::Internal("Failed to allocate grouped GEMM slab from scratch."); } - auto align16 = [](char*& p) { p = reinterpret_cast((reinterpret_cast(p) + 15) & ~15); }; char* p = reinterpret_cast(slab_or.value()); - align16(p); DtypeA** d_A_ptrs = reinterpret_cast(p); p += sizeof(DtypeA*) * gl; - align16(p); DtypeB** d_B_ptrs = reinterpret_cast(p); p += sizeof(DtypeB*) * gl; - align16(p); DtypeOutput** d_out_ptrs = reinterpret_cast(p); p += sizeof(DtypeOutput*) * gl; - align16(p); StrideA* d_stride_A = reinterpret_cast(p); p += sizeof(StrideA) * gl; - align16(p); StrideB* d_stride_B = reinterpret_cast(p); p += sizeof(StrideB) * gl; - align16(p); StrideOutput* d_stride_output = reinterpret_cast(p); p += sizeof(StrideOutput) * gl; - align16(p); ProblemShapeType* d_problem_sizes = reinterpret_cast(p); + auto d_A_ptrs = carve_aligned(p, gl); + auto d_B_ptrs = carve_aligned(p, gl); + auto d_out_ptrs = carve_aligned(p, gl); + auto d_stride_A = carve_aligned(p, gl); + auto d_stride_B = carve_aligned(p, gl); + auto d_stride_output = carve_aligned(p, gl); + auto d_problem_sizes = carve_aligned(p, gl); prepare_grouped_gemm_data<<<1, g_local, 0, stream>>>( reinterpret_cast(lhs.typed_data()), @@ -337,7 +342,7 @@ ffi::Error RaggedDotBwdCudaImpl( size_t gl = static_cast(g_local); size_t bytes = 7 * 16 + - sizeof(DtypeA*) * gl + sizeof(DtypeB*) * gl + sizeof(DtypeOutput*) * gl + + sizeof(const DtypeA*) * gl + sizeof(const DtypeB*) * gl + sizeof(DtypeOutput*) * gl + sizeof(StrideA_Bwd) * gl + sizeof(StrideB_Bwd) * gl + sizeof(StrideOutput_Bwd) * gl + sizeof(ProblemShapeType) * gl; @@ -346,15 +351,14 @@ ffi::Error RaggedDotBwdCudaImpl( return ffi::Error::Internal("Failed to allocate grouped GEMM bwd slab from scratch."); } - auto align16 = [](char*& p) { p = reinterpret_cast((reinterpret_cast(p) + 15) & ~15); }; char* p = reinterpret_cast(slab_or.value()); - align16(p); DtypeA** d_A_ptrs = reinterpret_cast(p); p += sizeof(DtypeA*) * gl; - align16(p); DtypeB** d_B_ptrs = reinterpret_cast(p); p += sizeof(DtypeB*) * gl; - align16(p); DtypeOutput** d_out_ptrs = reinterpret_cast(p); p += sizeof(DtypeOutput*) * gl; - align16(p); StrideA_Bwd* d_stride_A = reinterpret_cast(p); p += sizeof(StrideA_Bwd) * gl; - align16(p); StrideB_Bwd* d_stride_B = reinterpret_cast(p); p += sizeof(StrideB_Bwd) * gl; - align16(p); StrideOutput_Bwd* d_stride_output = reinterpret_cast(p); p += sizeof(StrideOutput_Bwd) * gl; - align16(p); ProblemShapeType* d_problem_sizes = reinterpret_cast(p); + auto d_A_ptrs = carve_aligned(p, gl); + auto d_B_ptrs = carve_aligned(p, gl); + auto d_out_ptrs = carve_aligned(p, gl); + auto d_stride_A = carve_aligned(p, gl); + auto d_stride_B = carve_aligned(p, gl); + auto d_stride_output = carve_aligned(p, gl); + auto d_problem_sizes = carve_aligned(p, gl); prepare_grouped_gemm_bwd_data<<<1, g_local, 0, stream>>>( reinterpret_cast(lhs.typed_data()), From 66989064e78c0447a8f354e3bc13c453561e2108 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 22:55:36 -0800 Subject: [PATCH 053/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 6909944dc..d6688afa1 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -269,9 +269,8 @@ ffi::Error RaggedDotCudaImpl( return ffi::Error::Internal("cutlass cannot implement grouped gemm."); } - size_t workspace_size = Gemm::get_workspace_size(args); void* workspace = nullptr; - if (workspace_size > 0) { + if (size_t workspace_size = Gemm::get_workspace_size(args)) { auto workspace_or = scratch.Allocate(workspace_size); if (!workspace_or.has_value()) { return ffi::Error::Internal("Failed to allocate CUTLASS workspace from scratch."); @@ -383,9 +382,8 @@ ffi::Error RaggedDotBwdCudaImpl( return ffi::Error::Internal("cutlass cannot implement grouped gemm bwd."); } - size_t workspace_size = Gemm_Bwd::get_workspace_size(args); void* workspace = nullptr; - if (workspace_size > 0) { + if (size_t workspace_size = Gemm_Bwd::get_workspace_size(args)) { auto workspace_or = scratch.Allocate(workspace_size); if (!workspace_or.has_value()) { return ffi::Error::Internal("Failed to allocate CUTLASS bwd workspace from scratch."); From 5f9c41372149453da514cb344552ddd494a83239 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 23:07:01 -0800 Subject: [PATCH 054/106] cleanup --- skyrl-tx/tx/ffi/ragged_dot_ffi.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.py b/skyrl-tx/tx/ffi/ragged_dot_ffi.py index 293d85b21..6fa6699a4 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.py +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.py @@ -7,11 +7,6 @@ import jax import jax.numpy as jnp -try: # JAX >= 0.8 - from jax import ffi as jax_ffi -except Exception: # pragma: no cover - older JAX fallback - from jax.experimental import ffi as jax_ffi - _REGISTERED = False _LOAD_ERROR: Exception | None = None @@ -45,14 +40,14 @@ def _ensure_registered() -> bool: try: lib = ctypes.cdll.LoadLibrary(str(lib_path)) - jax_ffi.register_ffi_target( + jax.ffi.register_ffi_target( "ragged_dot_cuda", - jax_ffi.pycapsule(lib.RaggedDotCuda), + jax.ffi.pycapsule(lib.RaggedDotCuda), platform="CUDA", ) - jax_ffi.register_ffi_target( + jax.ffi.register_ffi_target( "ragged_dot_bwd_cuda", - jax_ffi.pycapsule(lib.RaggedDotBwdCuda), + jax.ffi.pycapsule(lib.RaggedDotBwdCuda), platform="CUDA", ) _REGISTERED = True @@ -76,7 +71,7 @@ def _ragged_dot_ffi_call( raise RuntimeError("ragged_dot_ffi is not available. Build and load the shared library first.") out = jax.ShapeDtypeStruct((lhs.shape[0], rhs.shape[2]), lhs.dtype) - call = jax_ffi.ffi_call("ragged_dot_cuda", out, vmap_method=None) + call = jax.ffi.ffi_call("ragged_dot_cuda", out, vmap_method=None) return call(lhs, rhs, group_offset, group_offsets_cumsum) @@ -95,7 +90,7 @@ def _ragged_dot_bwd_ffi_call( n = grad.shape[1] out = jax.ShapeDtypeStruct((g_local, k, n), lhs.dtype) - call = jax_ffi.ffi_call("ragged_dot_bwd_cuda", out, vmap_method=None) + call = jax.ffi.ffi_call("ragged_dot_bwd_cuda", out, vmap_method=None) return call(lhs, grad, group_offset, group_offsets_cumsum) From e23037d3e3621d2850930333cdd520f4d88bdc8d Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 18 Jan 2026 23:10:41 -0800 Subject: [PATCH 055/106] simplify --- skyrl-tx/tx/ffi/ragged_dot_ffi.py | 50 +++++++------------------------ 1 file changed, 11 insertions(+), 39 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.py b/skyrl-tx/tx/ffi/ragged_dot_ffi.py index 6fa6699a4..6ad3b6e4e 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.py +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.py @@ -1,6 +1,7 @@ from __future__ import annotations import ctypes +import functools import os from pathlib import Path @@ -8,52 +9,23 @@ import jax.numpy as jnp -_REGISTERED = False -_LOAD_ERROR: Exception | None = None - - -def _find_library() -> Path | None: - env_path = os.environ.get("TX_RAGGED_DOT_FFI_PATH") - if env_path: - path = Path(env_path) - return path if path.exists() else None - - here = Path(__file__).resolve().parent - for name in ("libragged_dot_ffi.so", "ragged_dot_ffi.so"): - candidate = here / name - if candidate.exists(): - return candidate - return None - - +@functools.lru_cache(maxsize=1) def _ensure_registered() -> bool: - global _REGISTERED, _LOAD_ERROR - if _REGISTERED: - return True - if _LOAD_ERROR is not None: - return False + if env_path := os.environ.get("TX_RAGGED_DOT_FFI_PATH"): + lib_path = Path(env_path) + else: + here = Path(__file__).resolve().parent + lib_path = next((p for p in [here / "libragged_dot_ffi.so", here / "ragged_dot_ffi.so"] if p.exists()), None) - lib_path = _find_library() - if lib_path is None: - _LOAD_ERROR = FileNotFoundError("ragged_dot_ffi shared library not found.") + if not lib_path or not lib_path.exists(): return False try: lib = ctypes.cdll.LoadLibrary(str(lib_path)) - jax.ffi.register_ffi_target( - "ragged_dot_cuda", - jax.ffi.pycapsule(lib.RaggedDotCuda), - platform="CUDA", - ) - jax.ffi.register_ffi_target( - "ragged_dot_bwd_cuda", - jax.ffi.pycapsule(lib.RaggedDotBwdCuda), - platform="CUDA", - ) - _REGISTERED = True + jax.ffi.register_ffi_target("ragged_dot_cuda", jax.ffi.pycapsule(lib.RaggedDotCuda), platform="CUDA") + jax.ffi.register_ffi_target("ragged_dot_bwd_cuda", jax.ffi.pycapsule(lib.RaggedDotBwdCuda), platform="CUDA") return True - except Exception as exc: # pragma: no cover - load/registration failures - _LOAD_ERROR = exc + except Exception: return False From 56968bac0f56dee0735bace46dd43e87e4c40577 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 00:13:33 -0800 Subject: [PATCH 056/106] simplify --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 211 +++++++++++++----------------- 1 file changed, 88 insertions(+), 123 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index d6688afa1..4fc8339e7 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -1,7 +1,7 @@ #include #include -#include +#include #include #include "xla/ffi/api/c_api.h" @@ -89,23 +89,15 @@ using ProblemShapeType = ProblemShape::UnderlyingProblemShape; // Uses ColumnMajor for A to interpret row-major lhs as transposed using LayoutA_Bwd = cutlass::layout::ColumnMajor; -using CollectiveEpilogue_Bwd = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - DtypeAccum, DtypeAccum, void, LayoutOutput*, AlignmentOutput, - DtypeOutput, LayoutOutput*, AlignmentOutput, EpilogueSchedule, - cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; - using CollectiveMainloop_Bwd = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, DtypeA, LayoutA_Bwd*, AlignmentA, DtypeB, LayoutB*, AlignmentB, DtypeAccum, TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue_Bwd::SharedStorage))>, + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule>::CollectiveOp; -using GemmKernel_Bwd = cutlass::gemm::kernel::GemmUniversal; +using GemmKernel_Bwd = cutlass::gemm::kernel::GemmUniversal; using Gemm_Bwd = cutlass::gemm::device::GemmUniversalAdapter; using StrideA_Bwd = typename Gemm_Bwd::GemmKernel::InternalStrideA; using StrideB_Bwd = typename Gemm_Bwd::GemmKernel::InternalStrideB; @@ -118,6 +110,68 @@ static T* carve_aligned(char*& p, size_t count) { return out; } +template +struct GroupedGemmData { + using StrideA = typename GemmT::GemmKernel::InternalStrideA; + using StrideB = typename GemmT::GemmKernel::InternalStrideB; + using StrideOutput = typename GemmT::GemmKernel::InternalStrideD; + + const DtypeA** A_ptrs; + const DtypeB** B_ptrs; + DtypeOutput** out_ptrs; + StrideA* stride_A; + StrideB* stride_B; + StrideOutput* stride_output; + ProblemShapeType* problem_sizes; + + static std::optional Allocate(ffi::ScratchAllocator& scratch, size_t g) { + size_t bytes = 7 * 16 + + sizeof(const DtypeA*) * g + sizeof(const DtypeB*) * g + sizeof(DtypeOutput*) * g + + sizeof(StrideA) * g + sizeof(StrideB) * g + sizeof(StrideOutput) * g + + sizeof(ProblemShapeType) * g; + auto slab_or = scratch.Allocate(bytes); + if (!slab_or.has_value()) { + return std::nullopt; + } + GroupedGemmData data; + char* p = reinterpret_cast(slab_or.value()); + data.A_ptrs = carve_aligned(p, g); + data.B_ptrs = carve_aligned(p, g); + data.out_ptrs = carve_aligned(p, g); + data.stride_A = carve_aligned(p, g); + data.stride_B = carve_aligned(p, g); + data.stride_output = carve_aligned(p, g); + data.problem_sizes = carve_aligned(p, g); + return data; + } +}; + +template +ffi::Error RunGroupedGemm(cudaStream_t stream, ffi::ScratchAllocator& scratch, + typename GemmT::Arguments& args) { + GemmT gemm; + args.hw_info.sm_count = get_sm_count(); + + if (gemm.can_implement(args) != cutlass::Status::kSuccess) { + return ffi::Error::Internal("cutlass cannot implement grouped gemm."); + } + + void* workspace = nullptr; + if (size_t workspace_size = GemmT::get_workspace_size(args)) { + auto workspace_or = scratch.Allocate(workspace_size); + if (!workspace_or.has_value()) { + return ffi::Error::Internal("Failed to allocate CUTLASS workspace."); + } + workspace = workspace_or.value(); + } + + if (gemm(args, workspace, stream) != cutlass::Status::kSuccess) { + return ffi::Error::Internal("cutlass grouped gemm failed."); + } + + return ffi::Error::Success(); +} + __global__ void prepare_grouped_gemm_data( const DtypeA* A, const DtypeB* B, @@ -198,9 +252,8 @@ ffi::Error RaggedDotCudaImpl( ffi::ResultBuffer out) { auto lhs_dims = lhs.dimensions(); auto rhs_dims = rhs.dimensions(); - auto group_offset_dims = group_offset.dimensions(); - if (lhs_dims.size() != 2 || rhs_dims.size() != 3 || group_offset_dims.size() != 1) { + if (lhs_dims.size() != 2 || rhs_dims.size() != 3 || group_offset.dimensions().size() != 1) { return ffi::Error::InvalidArgument("Unexpected ragged_dot dimensions."); } if (lhs_dims[1] != rhs_dims[1]) { @@ -212,77 +265,33 @@ ffi::Error RaggedDotCudaImpl( int32_t g_local = static_cast(rhs_dims[0]); int32_t n = static_cast(rhs_dims[2]); - cudaError_t err = cudaMemsetAsync( - out->typed_data(), 0, static_cast(m) * n * sizeof(DtypeOutput), stream); - if (err != cudaSuccess) { + if (cudaMemsetAsync(out->typed_data(), 0, static_cast(m) * n * sizeof(DtypeOutput), stream) != cudaSuccess) { return ffi::Error::Internal("Failed to zero output."); } - if (g_local == 0 || m == 0 || n == 0 || k == 0) { - return ffi::Error::Success(); - } + if (g_local == 0 || m == 0 || n == 0 || k == 0) return ffi::Error::Success(); + if (g_local > 1024) return ffi::Error::InvalidArgument("group_count must be <= 1024."); - if (g_local > 1024) { - return ffi::Error::InvalidArgument("group_count must be <= 1024."); - } - - size_t gl = static_cast(g_local); - size_t bytes = 7 * 16 + // alignment padding - sizeof(const DtypeA*) * gl + sizeof(const DtypeB*) * gl + sizeof(DtypeOutput*) * gl + - sizeof(StrideA) * gl + sizeof(StrideB) * gl + sizeof(StrideOutput) * gl + - sizeof(ProblemShapeType) * gl; - - auto slab_or = scratch.Allocate(bytes); - if (!slab_or.has_value()) { - return ffi::Error::Internal("Failed to allocate grouped GEMM slab from scratch."); - } - - char* p = reinterpret_cast(slab_or.value()); - auto d_A_ptrs = carve_aligned(p, gl); - auto d_B_ptrs = carve_aligned(p, gl); - auto d_out_ptrs = carve_aligned(p, gl); - auto d_stride_A = carve_aligned(p, gl); - auto d_stride_B = carve_aligned(p, gl); - auto d_stride_output = carve_aligned(p, gl); - auto d_problem_sizes = carve_aligned(p, gl); + auto data = GroupedGemmData::Allocate(scratch, g_local); + if (!data) return ffi::Error::Internal("Failed to allocate grouped GEMM slab."); prepare_grouped_gemm_data<<<1, g_local, 0, stream>>>( reinterpret_cast(lhs.typed_data()), reinterpret_cast(rhs.typed_data()), reinterpret_cast(out->typed_data()), group_offsets_cumsum.typed_data(), group_offset.typed_data(), k, n, - d_A_ptrs, d_B_ptrs, d_out_ptrs, - d_stride_A, d_stride_B, d_stride_output, d_problem_sizes); + data->A_ptrs, data->B_ptrs, data->out_ptrs, + data->stride_A, data->stride_B, data->stride_output, data->problem_sizes); - Gemm gemm; typename Gemm::Arguments args{ cutlass::gemm::GemmUniversalMode::kGrouped, - {g_local, d_problem_sizes, nullptr}, - {(const DtypeA**)d_A_ptrs, d_stride_A, (const DtypeB**)d_B_ptrs, d_stride_B}, - {{}, nullptr, d_stride_output, d_out_ptrs, d_stride_output}}; - + {g_local, data->problem_sizes, nullptr}, + {data->A_ptrs, data->stride_A, data->B_ptrs, data->stride_B}, + {{}, nullptr, data->stride_output, data->out_ptrs, data->stride_output}}; args.epilogue.thread.alpha = 1.0f; args.epilogue.thread.dAlpha = {cute::_0{}, cute::_0{}, 0}; - args.hw_info.sm_count = get_sm_count(); - if (gemm.can_implement(args) != cutlass::Status::kSuccess) { - return ffi::Error::Internal("cutlass cannot implement grouped gemm."); - } - - void* workspace = nullptr; - if (size_t workspace_size = Gemm::get_workspace_size(args)) { - auto workspace_or = scratch.Allocate(workspace_size); - if (!workspace_or.has_value()) { - return ffi::Error::Internal("Failed to allocate CUTLASS workspace from scratch."); - } - workspace = workspace_or.value(); - } - - if (gemm(args, workspace, stream) != cutlass::Status::kSuccess) { - return ffi::Error::Internal("cutlass grouped gemm failed."); - } - - return ffi::Error::Success(); + return RunGroupedGemm(stream, scratch, args); } XLA_FFI_DEFINE_HANDLER_SYMBOL( @@ -325,77 +334,33 @@ ffi::Error RaggedDotBwdCudaImpl( int32_t n = static_cast(grad_dims[1]); int32_t g_local = static_cast(d_rhs_dims[0]); - cudaError_t err = cudaMemsetAsync( - d_rhs->typed_data(), 0, static_cast(g_local) * k * n * sizeof(DtypeOutput), stream); - if (err != cudaSuccess) { + if (cudaMemsetAsync(d_rhs->typed_data(), 0, static_cast(g_local) * k * n * sizeof(DtypeOutput), stream) != cudaSuccess) { return ffi::Error::Internal("Failed to zero d_rhs output."); } - if (g_local == 0 || m == 0 || n == 0 || k == 0) { - return ffi::Error::Success(); - } - - if (g_local > 1024) { - return ffi::Error::InvalidArgument("group_count must be <= 1024."); - } - - size_t gl = static_cast(g_local); - size_t bytes = 7 * 16 + - sizeof(const DtypeA*) * gl + sizeof(const DtypeB*) * gl + sizeof(DtypeOutput*) * gl + - sizeof(StrideA_Bwd) * gl + sizeof(StrideB_Bwd) * gl + sizeof(StrideOutput_Bwd) * gl + - sizeof(ProblemShapeType) * gl; - - auto slab_or = scratch.Allocate(bytes); - if (!slab_or.has_value()) { - return ffi::Error::Internal("Failed to allocate grouped GEMM bwd slab from scratch."); - } + if (g_local == 0 || m == 0 || n == 0 || k == 0) return ffi::Error::Success(); + if (g_local > 1024) return ffi::Error::InvalidArgument("group_count must be <= 1024."); - char* p = reinterpret_cast(slab_or.value()); - auto d_A_ptrs = carve_aligned(p, gl); - auto d_B_ptrs = carve_aligned(p, gl); - auto d_out_ptrs = carve_aligned(p, gl); - auto d_stride_A = carve_aligned(p, gl); - auto d_stride_B = carve_aligned(p, gl); - auto d_stride_output = carve_aligned(p, gl); - auto d_problem_sizes = carve_aligned(p, gl); + auto data = GroupedGemmData::Allocate(scratch, g_local); + if (!data) return ffi::Error::Internal("Failed to allocate grouped GEMM slab."); prepare_grouped_gemm_bwd_data<<<1, g_local, 0, stream>>>( reinterpret_cast(lhs.typed_data()), reinterpret_cast(grad.typed_data()), reinterpret_cast(d_rhs->typed_data()), group_offsets_cumsum.typed_data(), group_offset.typed_data(), k, n, - d_A_ptrs, d_B_ptrs, d_out_ptrs, - d_stride_A, d_stride_B, d_stride_output, d_problem_sizes); + data->A_ptrs, data->B_ptrs, data->out_ptrs, + data->stride_A, data->stride_B, data->stride_output, data->problem_sizes); - Gemm_Bwd gemm; typename Gemm_Bwd::Arguments args{ cutlass::gemm::GemmUniversalMode::kGrouped, - {g_local, d_problem_sizes, nullptr}, - {(const DtypeA**)d_A_ptrs, d_stride_A, (const DtypeB**)d_B_ptrs, d_stride_B}, - {{}, nullptr, d_stride_output, d_out_ptrs, d_stride_output}}; - + {g_local, data->problem_sizes, nullptr}, + {data->A_ptrs, data->stride_A, data->B_ptrs, data->stride_B}, + {{}, nullptr, data->stride_output, data->out_ptrs, data->stride_output}}; args.epilogue.thread.alpha = 1.0f; args.epilogue.thread.dAlpha = {cute::_0{}, cute::_0{}, 0}; - args.hw_info.sm_count = get_sm_count(); - if (gemm.can_implement(args) != cutlass::Status::kSuccess) { - return ffi::Error::Internal("cutlass cannot implement grouped gemm bwd."); - } - - void* workspace = nullptr; - if (size_t workspace_size = Gemm_Bwd::get_workspace_size(args)) { - auto workspace_or = scratch.Allocate(workspace_size); - if (!workspace_or.has_value()) { - return ffi::Error::Internal("Failed to allocate CUTLASS bwd workspace from scratch."); - } - workspace = workspace_or.value(); - } - - if (gemm(args, workspace, stream) != cutlass::Status::kSuccess) { - return ffi::Error::Internal("cutlass grouped gemm bwd failed."); - } - - return ffi::Error::Success(); + return RunGroupedGemm(stream, scratch, args); } XLA_FFI_DEFINE_HANDLER_SYMBOL( From 0a184e4ac43f54259e8e8edda4e9374fcfe9ee3c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 00:21:12 -0800 Subject: [PATCH 057/106] unify code --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 113 +++++++++++------------------- 1 file changed, 42 insertions(+), 71 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 4fc8339e7..68f482b04 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -99,9 +99,6 @@ using CollectiveMainloop_Bwd = using GemmKernel_Bwd = cutlass::gemm::kernel::GemmUniversal; using Gemm_Bwd = cutlass::gemm::device::GemmUniversalAdapter; -using StrideA_Bwd = typename Gemm_Bwd::GemmKernel::InternalStrideA; -using StrideB_Bwd = typename Gemm_Bwd::GemmKernel::InternalStrideB; -using StrideOutput_Bwd = typename Gemm_Bwd::GemmKernel::InternalStrideD; template static T* carve_aligned(char*& p, size_t count) { @@ -144,6 +141,17 @@ struct GroupedGemmData { data.problem_sizes = carve_aligned(p, g); return data; } + + typename GemmT::Arguments MakeArgs(int32_t g_local) const { + typename GemmT::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {g_local, problem_sizes, nullptr}, + {A_ptrs, stride_A, B_ptrs, stride_B}, + {{}, nullptr, stride_output, out_ptrs, stride_output}}; + args.epilogue.thread.alpha = 1.0f; + args.epilogue.thread.dAlpha = {cute::_0{}, cute::_0{}, 0}; + return args; + } }; template @@ -172,74 +180,55 @@ ffi::Error RunGroupedGemm(cudaStream_t stream, ffi::ScratchAllocator& scratch, return ffi::Error::Success(); } -__global__ void prepare_grouped_gemm_data( - const DtypeA* A, - const DtypeB* B, - DtypeOutput* output, - const int32_t* group_offsets_cumsum, - const int32_t* group_offset_ptr, - int32_t k, - int32_t n, - const DtypeA** A_ptrs, - const DtypeB** B_ptrs, - DtypeOutput** output_ptrs, - StrideA* stride_A, - StrideB* stride_B, - StrideOutput* stride_output, - ProblemShapeType* problem_sizes) { +template +__global__ void prepare_grouped_gemm_fwd( + const DtypeA* A, const DtypeB* B, DtypeOutput* output, + const int32_t* group_offsets_cumsum, const int32_t* group_offset_ptr, + int32_t k, int32_t n, GroupedGemmData data) { + using Data = GroupedGemmData; int32_t tid = threadIdx.x; int32_t global = group_offset_ptr[0] + tid; - int32_t start = (global > 0) ? group_offsets_cumsum[global - 1] : 0; int32_t m = group_offsets_cumsum[global] - start; - A_ptrs[tid] = A + static_cast(start) * k; - B_ptrs[tid] = B + static_cast(tid) * n * k; - output_ptrs[tid] = output + static_cast(start) * n; - problem_sizes[tid] = ProblemShapeType(m, n, k); - - stride_A[tid] = cutlass::make_cute_packed_stride(StrideA{}, {k, k, 1}); - stride_B[tid] = cutlass::make_cute_packed_stride(StrideB{}, {n, n, 1}); - stride_output[tid] = cutlass::make_cute_packed_stride(StrideOutput{}, {m, n, 1}); + data.A_ptrs[tid] = A + static_cast(start) * k; + data.B_ptrs[tid] = B + static_cast(tid) * n * k; + data.out_ptrs[tid] = output + static_cast(start) * n; + data.problem_sizes[tid] = ProblemShapeType(m, n, k); + data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {k, k, 1}); + data.stride_B[tid] = cutlass::make_cute_packed_stride(typename Data::StrideB{}, {n, n, 1}); + data.stride_output[tid] = cutlass::make_cute_packed_stride(typename Data::StrideOutput{}, {m, n, 1}); } // Backward kernel for d_rhs: lhs.T @ grad -> d_rhs[G, K, N] // Memory pattern: A=lhs (ragged), B=grad (ragged), output=d_rhs (per-group) -__global__ void prepare_grouped_gemm_bwd_data( +template +__global__ void prepare_grouped_gemm_bwd( const DtypeA* lhs, // [M_total, K] row-major (ragged) const DtypeB* grad, // [M_total, N] row-major (ragged) DtypeOutput* d_rhs, // [G, K, N] row-major (per-group) - const int32_t* group_offsets_cumsum, - const int32_t* group_offset_ptr, - int32_t k, - int32_t n, - const DtypeA** A_ptrs, - const DtypeB** B_ptrs, - DtypeOutput** output_ptrs, - StrideA_Bwd* stride_A, - StrideB_Bwd* stride_B, - StrideOutput_Bwd* stride_output, - ProblemShapeType* problem_sizes) { + const int32_t* group_offsets_cumsum, const int32_t* group_offset_ptr, + int32_t k, int32_t n, GroupedGemmData data) { + using Data = GroupedGemmData; int32_t tid = threadIdx.x; int32_t global = group_offset_ptr[0] + tid; - int32_t start = (global > 0) ? group_offsets_cumsum[global - 1] : 0; int32_t m = group_offsets_cumsum[global] - start; // A = lhs slice, viewed as ColumnMajor [K, M_g] (transposed) - A_ptrs[tid] = lhs + static_cast(start) * k; + data.A_ptrs[tid] = lhs + static_cast(start) * k; // B = grad slice [M_g, N] - B_ptrs[tid] = grad + static_cast(start) * n; + data.B_ptrs[tid] = grad + static_cast(start) * n; // Output = d_rhs[tid] with shape [K, N] - output_ptrs[tid] = d_rhs + static_cast(tid) * k * n; + data.out_ptrs[tid] = d_rhs + static_cast(tid) * k * n; // GEMM: [K, M_g] @ [M_g, N] = [K, N] - problem_sizes[tid] = ProblemShapeType(k, n, m); + data.problem_sizes[tid] = ProblemShapeType(k, n, m); // ColumnMajor stride for A (lhs viewed as transposed) - stride_A[tid] = cutlass::make_cute_packed_stride(StrideA_Bwd{}, {m, m, 1}); - stride_B[tid] = cutlass::make_cute_packed_stride(StrideB_Bwd{}, {n, n, 1}); - stride_output[tid] = cutlass::make_cute_packed_stride(StrideOutput_Bwd{}, {k, n, 1}); + data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {m, m, 1}); + data.stride_B[tid] = cutlass::make_cute_packed_stride(typename Data::StrideB{}, {n, n, 1}); + data.stride_output[tid] = cutlass::make_cute_packed_stride(typename Data::StrideOutput{}, {k, n, 1}); } ffi::Error RaggedDotCudaImpl( @@ -275,22 +264,13 @@ ffi::Error RaggedDotCudaImpl( auto data = GroupedGemmData::Allocate(scratch, g_local); if (!data) return ffi::Error::Internal("Failed to allocate grouped GEMM slab."); - prepare_grouped_gemm_data<<<1, g_local, 0, stream>>>( + prepare_grouped_gemm_fwd<<<1, g_local, 0, stream>>>( reinterpret_cast(lhs.typed_data()), reinterpret_cast(rhs.typed_data()), reinterpret_cast(out->typed_data()), - group_offsets_cumsum.typed_data(), group_offset.typed_data(), k, n, - data->A_ptrs, data->B_ptrs, data->out_ptrs, - data->stride_A, data->stride_B, data->stride_output, data->problem_sizes); - - typename Gemm::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGrouped, - {g_local, data->problem_sizes, nullptr}, - {data->A_ptrs, data->stride_A, data->B_ptrs, data->stride_B}, - {{}, nullptr, data->stride_output, data->out_ptrs, data->stride_output}}; - args.epilogue.thread.alpha = 1.0f; - args.epilogue.thread.dAlpha = {cute::_0{}, cute::_0{}, 0}; + group_offsets_cumsum.typed_data(), group_offset.typed_data(), k, n, *data); + auto args = data->MakeArgs(g_local); return RunGroupedGemm(stream, scratch, args); } @@ -344,22 +324,13 @@ ffi::Error RaggedDotBwdCudaImpl( auto data = GroupedGemmData::Allocate(scratch, g_local); if (!data) return ffi::Error::Internal("Failed to allocate grouped GEMM slab."); - prepare_grouped_gemm_bwd_data<<<1, g_local, 0, stream>>>( + prepare_grouped_gemm_bwd<<<1, g_local, 0, stream>>>( reinterpret_cast(lhs.typed_data()), reinterpret_cast(grad.typed_data()), reinterpret_cast(d_rhs->typed_data()), - group_offsets_cumsum.typed_data(), group_offset.typed_data(), k, n, - data->A_ptrs, data->B_ptrs, data->out_ptrs, - data->stride_A, data->stride_B, data->stride_output, data->problem_sizes); - - typename Gemm_Bwd::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGrouped, - {g_local, data->problem_sizes, nullptr}, - {data->A_ptrs, data->stride_A, data->B_ptrs, data->stride_B}, - {{}, nullptr, data->stride_output, data->out_ptrs, data->stride_output}}; - args.epilogue.thread.alpha = 1.0f; - args.epilogue.thread.dAlpha = {cute::_0{}, cute::_0{}, 0}; + group_offsets_cumsum.typed_data(), group_offset.typed_data(), k, n, *data); + auto args = data->MakeArgs(g_local); return RunGroupedGemm(stream, scratch, args); } From 4292fc8e50483057b77c3b94d7860fb4b423e459 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 00:33:56 -0800 Subject: [PATCH 058/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 67 +++++++++++-------------------- 1 file changed, 23 insertions(+), 44 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 68f482b04..98f3a9123 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -80,9 +80,6 @@ using CollectiveMainloop = using GemmKernel = cutlass::gemm::kernel::GemmUniversal; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -using StrideA = typename Gemm::GemmKernel::InternalStrideA; -using StrideB = typename Gemm::GemmKernel::InternalStrideB; -using StrideOutput = typename Gemm::GemmKernel::InternalStrideD; using ProblemShapeType = ProblemShape::UnderlyingProblemShape; // Backward kernel for d_rhs: computes lhs.T @ grad per group @@ -180,33 +177,14 @@ ffi::Error RunGroupedGemm(cudaStream_t stream, ffi::ScratchAllocator& scratch, return ffi::Error::Success(); } -template -__global__ void prepare_grouped_gemm_fwd( - const DtypeA* A, const DtypeB* B, DtypeOutput* output, - const int32_t* group_offsets_cumsum, const int32_t* group_offset_ptr, - int32_t k, int32_t n, GroupedGemmData data) { - using Data = GroupedGemmData; - int32_t tid = threadIdx.x; - int32_t global = group_offset_ptr[0] + tid; - int32_t start = (global > 0) ? group_offsets_cumsum[global - 1] : 0; - int32_t m = group_offsets_cumsum[global] - start; +enum class GemmDir { Fwd, Bwd }; - data.A_ptrs[tid] = A + static_cast(start) * k; - data.B_ptrs[tid] = B + static_cast(tid) * n * k; - data.out_ptrs[tid] = output + static_cast(start) * n; - data.problem_sizes[tid] = ProblemShapeType(m, n, k); - data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {k, k, 1}); - data.stride_B[tid] = cutlass::make_cute_packed_stride(typename Data::StrideB{}, {n, n, 1}); - data.stride_output[tid] = cutlass::make_cute_packed_stride(typename Data::StrideOutput{}, {m, n, 1}); -} - -// Backward kernel for d_rhs: lhs.T @ grad -> d_rhs[G, K, N] -// Memory pattern: A=lhs (ragged), B=grad (ragged), output=d_rhs (per-group) -template -__global__ void prepare_grouped_gemm_bwd( - const DtypeA* lhs, // [M_total, K] row-major (ragged) - const DtypeB* grad, // [M_total, N] row-major (ragged) - DtypeOutput* d_rhs, // [G, K, N] row-major (per-group) +// Unified prepare kernel for forward and backward passes +// Fwd: A[M,K] @ B[G,K,N] -> out[M,N] (ragged by group) +// Bwd: A.T[K,M] @ B[M,N] -> out[G,K,N] (lhs transposed via ColumnMajor layout) +template +__global__ void prepare_grouped_gemm( + const DtypeA* A, const DtypeB* B, DtypeOutput* out, const int32_t* group_offsets_cumsum, const int32_t* group_offset_ptr, int32_t k, int32_t n, GroupedGemmData data) { using Data = GroupedGemmData; @@ -215,20 +193,21 @@ __global__ void prepare_grouped_gemm_bwd( int32_t start = (global > 0) ? group_offsets_cumsum[global - 1] : 0; int32_t m = group_offsets_cumsum[global] - start; - // A = lhs slice, viewed as ColumnMajor [K, M_g] (transposed) - data.A_ptrs[tid] = lhs + static_cast(start) * k; - // B = grad slice [M_g, N] - data.B_ptrs[tid] = grad + static_cast(start) * n; - // Output = d_rhs[tid] with shape [K, N] - data.out_ptrs[tid] = d_rhs + static_cast(tid) * k * n; - - // GEMM: [K, M_g] @ [M_g, N] = [K, N] - data.problem_sizes[tid] = ProblemShapeType(k, n, m); - - // ColumnMajor stride for A (lhs viewed as transposed) - data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {m, m, 1}); + data.A_ptrs[tid] = A + static_cast(start) * k; + if constexpr (Dir == GemmDir::Fwd) { + data.B_ptrs[tid] = B + static_cast(tid) * n * k; + data.out_ptrs[tid] = out + static_cast(start) * n; + data.problem_sizes[tid] = ProblemShapeType(m, n, k); + data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {k, k, 1}); + data.stride_output[tid] = cutlass::make_cute_packed_stride(typename Data::StrideOutput{}, {m, n, 1}); + } else { + data.B_ptrs[tid] = B + static_cast(start) * n; + data.out_ptrs[tid] = out + static_cast(tid) * k * n; + data.problem_sizes[tid] = ProblemShapeType(k, n, m); + data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {m, m, 1}); + data.stride_output[tid] = cutlass::make_cute_packed_stride(typename Data::StrideOutput{}, {k, n, 1}); + } data.stride_B[tid] = cutlass::make_cute_packed_stride(typename Data::StrideB{}, {n, n, 1}); - data.stride_output[tid] = cutlass::make_cute_packed_stride(typename Data::StrideOutput{}, {k, n, 1}); } ffi::Error RaggedDotCudaImpl( @@ -264,7 +243,7 @@ ffi::Error RaggedDotCudaImpl( auto data = GroupedGemmData::Allocate(scratch, g_local); if (!data) return ffi::Error::Internal("Failed to allocate grouped GEMM slab."); - prepare_grouped_gemm_fwd<<<1, g_local, 0, stream>>>( + prepare_grouped_gemm<<<1, g_local, 0, stream>>>( reinterpret_cast(lhs.typed_data()), reinterpret_cast(rhs.typed_data()), reinterpret_cast(out->typed_data()), @@ -324,7 +303,7 @@ ffi::Error RaggedDotBwdCudaImpl( auto data = GroupedGemmData::Allocate(scratch, g_local); if (!data) return ffi::Error::Internal("Failed to allocate grouped GEMM slab."); - prepare_grouped_gemm_bwd<<<1, g_local, 0, stream>>>( + prepare_grouped_gemm<<<1, g_local, 0, stream>>>( reinterpret_cast(lhs.typed_data()), reinterpret_cast(grad.typed_data()), reinterpret_cast(d_rhs->typed_data()), From 0392cacdba1ef68cfbf36d284562aa06039eee12 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 00:37:49 -0800 Subject: [PATCH 059/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 88 +++++++++++++++---------------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 98f3a9123..799a3c766 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -151,32 +151,6 @@ struct GroupedGemmData { } }; -template -ffi::Error RunGroupedGemm(cudaStream_t stream, ffi::ScratchAllocator& scratch, - typename GemmT::Arguments& args) { - GemmT gemm; - args.hw_info.sm_count = get_sm_count(); - - if (gemm.can_implement(args) != cutlass::Status::kSuccess) { - return ffi::Error::Internal("cutlass cannot implement grouped gemm."); - } - - void* workspace = nullptr; - if (size_t workspace_size = GemmT::get_workspace_size(args)) { - auto workspace_or = scratch.Allocate(workspace_size); - if (!workspace_or.has_value()) { - return ffi::Error::Internal("Failed to allocate CUTLASS workspace."); - } - workspace = workspace_or.value(); - } - - if (gemm(args, workspace, stream) != cutlass::Status::kSuccess) { - return ffi::Error::Internal("cutlass grouped gemm failed."); - } - - return ffi::Error::Success(); -} - enum class GemmDir { Fwd, Bwd }; // Unified prepare kernel for forward and backward passes @@ -210,6 +184,44 @@ __global__ void prepare_grouped_gemm( data.stride_B[tid] = cutlass::make_cute_packed_stride(typename Data::StrideB{}, {n, n, 1}); } +template +ffi::Error ExecuteGroupedGemm( + cudaStream_t stream, ffi::ScratchAllocator& scratch, + const DtypeA* A, const DtypeB* B, DtypeOutput* out, + const int32_t* group_offsets_cumsum, const int32_t* group_offset, + int32_t g_local, int32_t k, int32_t n) { + if (g_local > 1024) return ffi::Error::InvalidArgument("group_count must be <= 1024."); + + auto data = GroupedGemmData::Allocate(scratch, g_local); + if (!data) return ffi::Error::Internal("Failed to allocate grouped GEMM slab."); + + prepare_grouped_gemm<<<1, g_local, 0, stream>>>( + A, B, out, group_offsets_cumsum, group_offset, k, n, *data); + + GemmT gemm; + auto args = data->MakeArgs(g_local); + args.hw_info.sm_count = get_sm_count(); + + if (gemm.can_implement(args) != cutlass::Status::kSuccess) { + return ffi::Error::Internal("cutlass cannot implement grouped gemm."); + } + + void* workspace = nullptr; + if (size_t workspace_size = GemmT::get_workspace_size(args)) { + auto workspace_or = scratch.Allocate(workspace_size); + if (!workspace_or.has_value()) { + return ffi::Error::Internal("Failed to allocate CUTLASS workspace."); + } + workspace = workspace_or.value(); + } + + if (gemm(args, workspace, stream) != cutlass::Status::kSuccess) { + return ffi::Error::Internal("cutlass grouped gemm failed."); + } + + return ffi::Error::Success(); +} + ffi::Error RaggedDotCudaImpl( cudaStream_t stream, ffi::ScratchAllocator scratch, @@ -238,19 +250,13 @@ ffi::Error RaggedDotCudaImpl( } if (g_local == 0 || m == 0 || n == 0 || k == 0) return ffi::Error::Success(); - if (g_local > 1024) return ffi::Error::InvalidArgument("group_count must be <= 1024."); - - auto data = GroupedGemmData::Allocate(scratch, g_local); - if (!data) return ffi::Error::Internal("Failed to allocate grouped GEMM slab."); - prepare_grouped_gemm<<<1, g_local, 0, stream>>>( + return ExecuteGroupedGemm( + stream, scratch, reinterpret_cast(lhs.typed_data()), reinterpret_cast(rhs.typed_data()), reinterpret_cast(out->typed_data()), - group_offsets_cumsum.typed_data(), group_offset.typed_data(), k, n, *data); - - auto args = data->MakeArgs(g_local); - return RunGroupedGemm(stream, scratch, args); + group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); } XLA_FFI_DEFINE_HANDLER_SYMBOL( @@ -298,19 +304,13 @@ ffi::Error RaggedDotBwdCudaImpl( } if (g_local == 0 || m == 0 || n == 0 || k == 0) return ffi::Error::Success(); - if (g_local > 1024) return ffi::Error::InvalidArgument("group_count must be <= 1024."); - auto data = GroupedGemmData::Allocate(scratch, g_local); - if (!data) return ffi::Error::Internal("Failed to allocate grouped GEMM slab."); - - prepare_grouped_gemm<<<1, g_local, 0, stream>>>( + return ExecuteGroupedGemm( + stream, scratch, reinterpret_cast(lhs.typed_data()), reinterpret_cast(grad.typed_data()), reinterpret_cast(d_rhs->typed_data()), - group_offsets_cumsum.typed_data(), group_offset.typed_data(), k, n, *data); - - auto args = data->MakeArgs(g_local); - return RunGroupedGemm(stream, scratch, args); + group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); } XLA_FFI_DEFINE_HANDLER_SYMBOL( From eb5e004b36715a705104686d8dc7ba0cda105405 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 00:43:03 -0800 Subject: [PATCH 060/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 5 ----- 1 file changed, 5 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 799a3c766..4d6de7450 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -249,8 +249,6 @@ ffi::Error RaggedDotCudaImpl( return ffi::Error::Internal("Failed to zero output."); } - if (g_local == 0 || m == 0 || n == 0 || k == 0) return ffi::Error::Success(); - return ExecuteGroupedGemm( stream, scratch, reinterpret_cast(lhs.typed_data()), @@ -294,7 +292,6 @@ ffi::Error RaggedDotBwdCudaImpl( return ffi::Error::InvalidArgument("d_rhs shape must be [G, K, N]."); } - int32_t m = static_cast(lhs_dims[0]); int32_t k = static_cast(lhs_dims[1]); int32_t n = static_cast(grad_dims[1]); int32_t g_local = static_cast(d_rhs_dims[0]); @@ -303,8 +300,6 @@ ffi::Error RaggedDotBwdCudaImpl( return ffi::Error::Internal("Failed to zero d_rhs output."); } - if (g_local == 0 || m == 0 || n == 0 || k == 0) return ffi::Error::Success(); - return ExecuteGroupedGemm( stream, scratch, reinterpret_cast(lhs.typed_data()), From 1b19afadec155457d53192c4401d3662a7a5bb96 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 01:18:27 -0800 Subject: [PATCH 061/106] update build --- skyrl-tx/pyproject.toml | 15 +++-- skyrl-tx/tx/ffi/README.md | 31 ++++++---- skyrl-tx/tx/ffi/build.py | 76 +++++++++++++++++++++++++ skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh | 49 ---------------- 4 files changed, 106 insertions(+), 65 deletions(-) create mode 100644 skyrl-tx/tx/ffi/build.py delete mode 100644 skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh diff --git a/skyrl-tx/pyproject.toml b/skyrl-tx/pyproject.toml index 8ec58d243..d5a5eea50 100644 --- a/skyrl-tx/pyproject.toml +++ b/skyrl-tx/pyproject.toml @@ -1,6 +1,6 @@ [build-system] -requires = ["setuptools"] -build-backend = "setuptools.build_meta" +requires = ["hatchling"] +build-backend = "hatchling.build" [project] name = "skyrl-tx" @@ -67,11 +67,14 @@ dev = [ "alembic", ] -[tool.setuptools] -include-package-data = true +[tool.hatch.version] +path = "tx/__init__.py" -[tool.setuptools.dynamic] -version = {attr = "tx.__version__"} +[tool.hatch.build.hooks.custom] +path = "tx/ffi/build.py" [project.scripts] tx = "tx.run.main:app" + +[tool.uv.scripts] +build-ffi = "python -c 'from tx.ffi.build import build_ragged_dot; build_ragged_dot()'" diff --git a/skyrl-tx/tx/ffi/README.md b/skyrl-tx/tx/ffi/README.md index 681f34f9b..6d42445e0 100644 --- a/skyrl-tx/tx/ffi/README.md +++ b/skyrl-tx/tx/ffi/README.md @@ -1,15 +1,26 @@ -Build (Linux + CUDA) +# CUDA FFI Extensions -1) Build the shared library (requires CUDA nvcc with C++17 support): +## Building + +The CUDA extension is built automatically when creating a wheel: + +```bash +uv build ``` -export CUTLASS_DIR=/path/to/cutlass -export NVCC_ARCH=sm_90a # for H100, adjust if needed -tx/ffi/build_ragged_dot_ffi.sh + +For development, build manually: + +```bash +uv run build-ffi ``` -2) Make the shared library discoverable: -- Copy `tx/ffi/_build/libragged_dot_ffi.so` to `tx/ffi/libragged_dot_ffi.so`, or -- Set `TX_RAGGED_DOT_FFI_PATH=/path/to/libragged_dot_ffi.so`. +### Environment Variables + +- `CUTLASS_DIR` - Path to CUTLASS checkout (optional, clones automatically if not set) +- `NVCC_BIN` - Path to nvcc (default: `nvcc`) +- `NVCC_ARCH` - CUDA architecture (default: `90a` for H100) + +### Notes -Notes: -- The FFI kernel expects bfloat16 inputs/outputs and int32 group metadata. +- Requires CUDA nvcc with C++17 support +- The FFI kernel expects bfloat16 inputs/outputs and int32 group metadata diff --git a/skyrl-tx/tx/ffi/build.py b/skyrl-tx/tx/ffi/build.py new file mode 100644 index 000000000..07af3fe5d --- /dev/null +++ b/skyrl-tx/tx/ffi/build.py @@ -0,0 +1,76 @@ +import os +import subprocess +import sys +import tempfile +from pathlib import Path + +CUTLASS_REPO = "https://github.com/NVIDIA/cutlass.git" +CUTLASS_TAG = "v4.3.5" + + +def get_cutlass_dir(tmpdir): + if cutlass_dir := os.environ.get("CUTLASS_DIR"): + return Path(cutlass_dir) + + cutlass_dir = Path(tmpdir) / "cutlass" + print(f"Cloning CUTLASS {CUTLASS_TAG}...") + subprocess.run( + ["git", "clone", "--depth=1", f"--branch={CUTLASS_TAG}", CUTLASS_REPO, str(cutlass_dir)], + check=True, + ) + return cutlass_dir + + +def build_ragged_dot(): + try: + import jaxlib + jax_include_dir = Path(jaxlib.__file__).parent / "include" + except ImportError: + print("jaxlib not installed, skipping ragged_dot_ffi build", file=sys.stderr) + return + + nvcc_bin = os.environ.get("NVCC_BIN", "nvcc") + nvcc_arch = os.environ.get("NVCC_ARCH", "90a") + + try: + subprocess.run([nvcc_bin, "--version"], check=True, capture_output=True) + except (subprocess.CalledProcessError, FileNotFoundError): + print(f"nvcc not found at {nvcc_bin}, skipping ragged_dot_ffi build", file=sys.stderr) + return + + ffi_dir = Path(__file__).parent + source_file = ffi_dir / "ragged_dot_ffi.cu" + output_file = ffi_dir / "libragged_dot_ffi.so" + + with tempfile.TemporaryDirectory() as tmpdir: + cutlass_dir = get_cutlass_dir(tmpdir) + + cmd = [ + nvcc_bin, "-O3", "-std=c++17", f"-arch=sm_{nvcc_arch}", + "--expt-relaxed-constexpr", + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", + "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", + "-shared", "-Xcompiler", "-fPIC", + f"-I{jax_include_dir}", + f"-I{cutlass_dir}/include", + f"-I{cutlass_dir}/tools/util/include/", + str(source_file), "-o", str(output_file), + ] + + print(f"Building {output_file}...") + subprocess.run(cmd, check=True) + print(f"Built {output_file}") + + +try: + from hatchling.builders.hooks.plugin.interface import BuildHookInterface + + class CudaBuildHook(BuildHookInterface): + PLUGIN_NAME = "cuda_build" + + def initialize(self, version, build_data): + if self.target_name == "wheel": + build_ragged_dot() +except ImportError: + pass diff --git a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh b/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh deleted file mode 100644 index 5079c6fd8..000000000 --- a/skyrl-tx/tx/ffi/build_ragged_dot_ffi.sh +++ /dev/null @@ -1,49 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" - -if [[ -z "${CUTLASS_DIR:-}" ]]; then - echo "CUTLASS_DIR is not set. Point it to your CUTLASS checkout." >&2 - exit 1 -fi - -if [[ ! -d "${CUTLASS_DIR}" ]]; then - echo "CUTLASS_DIR does not exist: ${CUTLASS_DIR}" >&2 - exit 1 -fi - -JAX_INCLUDE_DIR="$(uv run --extra gpu - <<'PY' -import os -import jaxlib -print(os.path.join(os.path.dirname(jaxlib.__file__), "include")) -PY -)" - -NVCC_BIN="${NVCC_BIN:-nvcc}" -NVCC_ARCH="${NVCC_ARCH:-90a}" -if ! command -v "${NVCC_BIN}" >/dev/null 2>&1; then - echo "nvcc not found. Set NVCC_BIN or ensure CUDA is on PATH." >&2 - exit 1 -fi - -OUT_DIR="${SCRIPT_DIR}/_build" -mkdir -p "${OUT_DIR}" - -"${NVCC_BIN}" \ - -O3 \ - -std=c++17 \ - -arch=sm_${NVCC_ARCH} \ - --expt-relaxed-constexpr \ - -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1 \ - -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED \ - -DCUTLASS_ENABLE_GDC_FOR_SM90=1 \ - -shared \ - -Xcompiler -fPIC \ - -I"${JAX_INCLUDE_DIR}" \ - -I"${CUTLASS_DIR}/include" \ - -I"${CUTLASS_DIR}/tools/util/include/" \ - "${SCRIPT_DIR}/ragged_dot_ffi.cu" \ - -o "${OUT_DIR}/libragged_dot_ffi.so" - -echo "Built ${OUT_DIR}/libragged_dot_ffi.so" From 7aa2fe797b08aaa0fb9d8315a1fcb3392eafc020 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 01:23:58 -0800 Subject: [PATCH 062/106] update --- skyrl-tx/pyproject.toml | 3 --- skyrl-tx/tx/ffi/README.md | 6 ------ skyrl-tx/tx/ffi/build.py | 16 +++++++--------- 3 files changed, 7 insertions(+), 18 deletions(-) diff --git a/skyrl-tx/pyproject.toml b/skyrl-tx/pyproject.toml index d5a5eea50..a623bd5db 100644 --- a/skyrl-tx/pyproject.toml +++ b/skyrl-tx/pyproject.toml @@ -75,6 +75,3 @@ path = "tx/ffi/build.py" [project.scripts] tx = "tx.run.main:app" - -[tool.uv.scripts] -build-ffi = "python -c 'from tx.ffi.build import build_ragged_dot; build_ragged_dot()'" diff --git a/skyrl-tx/tx/ffi/README.md b/skyrl-tx/tx/ffi/README.md index 6d42445e0..c19665d12 100644 --- a/skyrl-tx/tx/ffi/README.md +++ b/skyrl-tx/tx/ffi/README.md @@ -8,12 +8,6 @@ The CUDA extension is built automatically when creating a wheel: uv build ``` -For development, build manually: - -```bash -uv run build-ffi -``` - ### Environment Variables - `CUTLASS_DIR` - Path to CUTLASS checkout (optional, clones automatically if not set) diff --git a/skyrl-tx/tx/ffi/build.py b/skyrl-tx/tx/ffi/build.py index 07af3fe5d..00685c9ee 100644 --- a/skyrl-tx/tx/ffi/build.py +++ b/skyrl-tx/tx/ffi/build.py @@ -63,14 +63,12 @@ def build_ragged_dot(): print(f"Built {output_file}") -try: - from hatchling.builders.hooks.plugin.interface import BuildHookInterface +from hatchling.builders.hooks.plugin.interface import BuildHookInterface - class CudaBuildHook(BuildHookInterface): - PLUGIN_NAME = "cuda_build" - def initialize(self, version, build_data): - if self.target_name == "wheel": - build_ragged_dot() -except ImportError: - pass +class CudaBuildHook(BuildHookInterface): + PLUGIN_NAME = "cuda_build" + + def initialize(self, version, build_data): + if self.target_name == "wheel": + build_ragged_dot() From be2129504b68ef0b9d87e949f733f2068f1bc93c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 01:24:52 -0800 Subject: [PATCH 063/106] update --- skyrl-tx/pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/skyrl-tx/pyproject.toml b/skyrl-tx/pyproject.toml index a623bd5db..151aec429 100644 --- a/skyrl-tx/pyproject.toml +++ b/skyrl-tx/pyproject.toml @@ -70,6 +70,9 @@ dev = [ [tool.hatch.version] path = "tx/__init__.py" +[tool.hatch.build.targets.wheel] +packages = ["tx"] + [tool.hatch.build.hooks.custom] path = "tx/ffi/build.py" From 5b67b4c06a4f7d7b4754bd806a397d9a0efa226d Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 01:26:09 -0800 Subject: [PATCH 064/106] update --- skyrl-tx/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/pyproject.toml b/skyrl-tx/pyproject.toml index 151aec429..88713d8c1 100644 --- a/skyrl-tx/pyproject.toml +++ b/skyrl-tx/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling", "jaxlib"] build-backend = "hatchling.build" [project] From a29c3efdba1561039ef9ca606bbca855ca8e6317 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 01:29:27 -0800 Subject: [PATCH 065/106] simplify --- skyrl-tx/tx/ffi/ragged_dot_ffi.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.py b/skyrl-tx/tx/ffi/ragged_dot_ffi.py index 6ad3b6e4e..16f8e88f7 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.py +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.py @@ -2,26 +2,20 @@ import ctypes import functools -import os from pathlib import Path import jax import jax.numpy as jnp +_LIB_PATH = Path(__file__).resolve().parent / "libragged_dot_ffi.so" + @functools.lru_cache(maxsize=1) def _ensure_registered() -> bool: - if env_path := os.environ.get("TX_RAGGED_DOT_FFI_PATH"): - lib_path = Path(env_path) - else: - here = Path(__file__).resolve().parent - lib_path = next((p for p in [here / "libragged_dot_ffi.so", here / "ragged_dot_ffi.so"] if p.exists()), None) - - if not lib_path or not lib_path.exists(): + if not _LIB_PATH.exists(): return False - try: - lib = ctypes.cdll.LoadLibrary(str(lib_path)) + lib = ctypes.cdll.LoadLibrary(str(_LIB_PATH)) jax.ffi.register_ffi_target("ragged_dot_cuda", jax.ffi.pycapsule(lib.RaggedDotCuda), platform="CUDA") jax.ffi.register_ffi_target("ragged_dot_bwd_cuda", jax.ffi.pycapsule(lib.RaggedDotBwdCuda), platform="CUDA") return True From ce866805a03279db4e9694cc73a9f1523035d4dc Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 01:52:24 -0800 Subject: [PATCH 066/106] update --- skyrl-tx/tx/ffi/build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tx/ffi/build.py b/skyrl-tx/tx/ffi/build.py index 00685c9ee..d188544c5 100644 --- a/skyrl-tx/tx/ffi/build.py +++ b/skyrl-tx/tx/ffi/build.py @@ -70,5 +70,5 @@ class CudaBuildHook(BuildHookInterface): PLUGIN_NAME = "cuda_build" def initialize(self, version, build_data): - if self.target_name == "wheel": + if self.target_name in ("wheel", "editable"): build_ragged_dot() From 8220f6e4bb95fbe67cca5f07f2fe44766a1afc53 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 02:00:07 -0800 Subject: [PATCH 067/106] update --- skyrl-tx/tx/ffi/build.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skyrl-tx/tx/ffi/build.py b/skyrl-tx/tx/ffi/build.py index d188544c5..54794f2a0 100644 --- a/skyrl-tx/tx/ffi/build.py +++ b/skyrl-tx/tx/ffi/build.py @@ -56,6 +56,7 @@ def build_ragged_dot(): f"-I{cutlass_dir}/include", f"-I{cutlass_dir}/tools/util/include/", str(source_file), "-o", str(output_file), + "-lcuda", ] print(f"Building {output_file}...") From 86e3ad9c8e1eb964490a80e83157ff301db683a7 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 02:33:36 -0800 Subject: [PATCH 068/106] optimize --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 4d6de7450..6b3189465 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -57,9 +57,9 @@ constexpr int AlignmentOutput = 8; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; using TileShape = cute::Shape; -using ClusterShape = cute::Shape; -using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; -using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; +using ClusterShape = cute::Shape; +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using ProblemShape = cutlass::gemm::GroupProblemShape>; using CollectiveEpilogue = From d16a8254a963941a2b05b02318748f0bb9763128 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 02:54:19 -0800 Subject: [PATCH 069/106] try masking --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 13 +++--------- skyrl-tx/tx/ffi/ragged_dot_ffi.py | 35 +++++++++++++++++++------------ 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 6b3189465..75fb7b89f 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -156,6 +156,7 @@ enum class GemmDir { Fwd, Bwd }; // Unified prepare kernel for forward and backward passes // Fwd: A[M,K] @ B[G,K,N] -> out[M,N] (ragged by group) // Bwd: A.T[K,M] @ B[M,N] -> out[G,K,N] (lhs transposed via ColumnMajor layout) +// group_offsets_cumsum has length G+1 with cumsum[0]=0 (include_initial=True) template __global__ void prepare_grouped_gemm( const DtypeA* A, const DtypeB* B, DtypeOutput* out, @@ -164,8 +165,8 @@ __global__ void prepare_grouped_gemm( using Data = GroupedGemmData; int32_t tid = threadIdx.x; int32_t global = group_offset_ptr[0] + tid; - int32_t start = (global > 0) ? group_offsets_cumsum[global - 1] : 0; - int32_t m = group_offsets_cumsum[global] - start; + int32_t start = group_offsets_cumsum[global]; + int32_t m = group_offsets_cumsum[global + 1] - start; data.A_ptrs[tid] = A + static_cast(start) * k; if constexpr (Dir == GemmDir::Fwd) { @@ -245,10 +246,6 @@ ffi::Error RaggedDotCudaImpl( int32_t g_local = static_cast(rhs_dims[0]); int32_t n = static_cast(rhs_dims[2]); - if (cudaMemsetAsync(out->typed_data(), 0, static_cast(m) * n * sizeof(DtypeOutput), stream) != cudaSuccess) { - return ffi::Error::Internal("Failed to zero output."); - } - return ExecuteGroupedGemm( stream, scratch, reinterpret_cast(lhs.typed_data()), @@ -296,10 +293,6 @@ ffi::Error RaggedDotBwdCudaImpl( int32_t n = static_cast(grad_dims[1]); int32_t g_local = static_cast(d_rhs_dims[0]); - if (cudaMemsetAsync(d_rhs->typed_data(), 0, static_cast(g_local) * k * n * sizeof(DtypeOutput), stream) != cudaSuccess) { - return ffi::Error::Internal("Failed to zero d_rhs output."); - } - return ExecuteGroupedGemm( stream, scratch, reinterpret_cast(lhs.typed_data()), diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.py b/skyrl-tx/tx/ffi/ragged_dot_ffi.py index 16f8e88f7..947948517 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.py +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.py @@ -60,6 +60,14 @@ def _ragged_dot_bwd_ffi_call( return call(lhs, grad, group_offset, group_offsets_cumsum) +def _apply_mask(result: jax.Array, cumsum: jax.Array, group_offset: jax.Array, g_local: int) -> jax.Array: + """Zero out tokens outside the local group range [offset, offset+g_local).""" + offset = group_offset[0] + token_idx = jnp.arange(result.shape[0]) + valid = (token_idx >= cumsum[offset]) & (token_idx < cumsum[offset + g_local]) + return jnp.where(valid[:, None], result, 0) + + @jax.custom_vjp def ragged_dot( lhs: jax.Array, @@ -67,8 +75,9 @@ def ragged_dot( group_sizes: jax.Array, group_offset: jax.Array, ) -> jax.Array: - group_offsets_cumsum = jnp.cumsum(group_sizes, dtype=jnp.int32) - return _ragged_dot_ffi_call(lhs, rhs, group_offset, group_offsets_cumsum) + cumsum = jnp.cumulative_sum(group_sizes, include_initial=True).astype(jnp.int32) + result = _ragged_dot_ffi_call(lhs, rhs, group_offset, cumsum) + return _apply_mask(result, cumsum, group_offset, rhs.shape[0]) def _ragged_dot_fwd( @@ -77,22 +86,22 @@ def _ragged_dot_fwd( group_sizes: jax.Array, group_offset: jax.Array, ): - group_offsets_cumsum = jnp.cumsum(group_sizes, dtype=jnp.int32) - y = _ragged_dot_ffi_call(lhs, rhs, group_offset, group_offsets_cumsum) - return y, (lhs, rhs, group_offset, group_offsets_cumsum) + cumsum = jnp.cumulative_sum(group_sizes, include_initial=True).astype(jnp.int32) + result = _ragged_dot_ffi_call(lhs, rhs, group_offset, cumsum) + y = _apply_mask(result, cumsum, group_offset, rhs.shape[0]) + return y, (lhs, rhs, group_offset, cumsum) def _ragged_dot_bwd(res, g): - lhs, rhs, group_offset, group_offsets_cumsum = res + lhs, rhs, group_offset, cumsum = res + g_local = rhs.shape[0] - # d_lhs: g @ rhs.T with ragged grouping - same structure as forward - # g: [M, N], rhs: [G, K, N] -> rhs.T: [G, N, K], d_lhs: [M, K] - rhs_t = jnp.swapaxes(rhs, 1, 2) # [G, N, K] - d_lhs = _ragged_dot_ffi_call(g, rhs_t, group_offset, group_offsets_cumsum) + # d_lhs: g @ rhs.T with ragged grouping + rhs_t = jnp.swapaxes(rhs, 1, 2) + d_lhs = _apply_mask(_ragged_dot_ffi_call(g, rhs_t, group_offset, cumsum), cumsum, group_offset, g_local) - # d_rhs: lhs.T @ g accumulated per group -> [G, K, N] - g_local = rhs.shape[0] - d_rhs = _ragged_dot_bwd_ffi_call(lhs, g, group_offset, group_offsets_cumsum, g_local) + # d_rhs: lhs.T @ g accumulated per group + d_rhs = _ragged_dot_bwd_ffi_call(lhs, g, group_offset, cumsum, g_local) return d_lhs, d_rhs, None, None From e346551080a834ee7fbecf58de0e7aa1090823c0 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 03:09:49 -0800 Subject: [PATCH 070/106] revert --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 8 ++++++++ skyrl-tx/tx/ffi/ragged_dot_ffi.py | 16 +++------------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 75fb7b89f..5e7af8dab 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -246,6 +246,10 @@ ffi::Error RaggedDotCudaImpl( int32_t g_local = static_cast(rhs_dims[0]); int32_t n = static_cast(rhs_dims[2]); + if (cudaMemsetAsync(out->typed_data(), 0, static_cast(m) * n * sizeof(DtypeOutput), stream) != cudaSuccess) { + return ffi::Error::Internal("Failed to zero output."); + } + return ExecuteGroupedGemm( stream, scratch, reinterpret_cast(lhs.typed_data()), @@ -293,6 +297,10 @@ ffi::Error RaggedDotBwdCudaImpl( int32_t n = static_cast(grad_dims[1]); int32_t g_local = static_cast(d_rhs_dims[0]); + if (cudaMemsetAsync(d_rhs->typed_data(), 0, static_cast(g_local) * k * n * sizeof(DtypeOutput), stream) != cudaSuccess) { + return ffi::Error::Internal("Failed to zero d_rhs output."); + } + return ExecuteGroupedGemm( stream, scratch, reinterpret_cast(lhs.typed_data()), diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.py b/skyrl-tx/tx/ffi/ragged_dot_ffi.py index 947948517..054a4bb71 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.py +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.py @@ -60,14 +60,6 @@ def _ragged_dot_bwd_ffi_call( return call(lhs, grad, group_offset, group_offsets_cumsum) -def _apply_mask(result: jax.Array, cumsum: jax.Array, group_offset: jax.Array, g_local: int) -> jax.Array: - """Zero out tokens outside the local group range [offset, offset+g_local).""" - offset = group_offset[0] - token_idx = jnp.arange(result.shape[0]) - valid = (token_idx >= cumsum[offset]) & (token_idx < cumsum[offset + g_local]) - return jnp.where(valid[:, None], result, 0) - - @jax.custom_vjp def ragged_dot( lhs: jax.Array, @@ -76,8 +68,7 @@ def ragged_dot( group_offset: jax.Array, ) -> jax.Array: cumsum = jnp.cumulative_sum(group_sizes, include_initial=True).astype(jnp.int32) - result = _ragged_dot_ffi_call(lhs, rhs, group_offset, cumsum) - return _apply_mask(result, cumsum, group_offset, rhs.shape[0]) + return _ragged_dot_ffi_call(lhs, rhs, group_offset, cumsum) def _ragged_dot_fwd( @@ -87,8 +78,7 @@ def _ragged_dot_fwd( group_offset: jax.Array, ): cumsum = jnp.cumulative_sum(group_sizes, include_initial=True).astype(jnp.int32) - result = _ragged_dot_ffi_call(lhs, rhs, group_offset, cumsum) - y = _apply_mask(result, cumsum, group_offset, rhs.shape[0]) + y = _ragged_dot_ffi_call(lhs, rhs, group_offset, cumsum) return y, (lhs, rhs, group_offset, cumsum) @@ -98,7 +88,7 @@ def _ragged_dot_bwd(res, g): # d_lhs: g @ rhs.T with ragged grouping rhs_t = jnp.swapaxes(rhs, 1, 2) - d_lhs = _apply_mask(_ragged_dot_ffi_call(g, rhs_t, group_offset, cumsum), cumsum, group_offset, g_local) + d_lhs = _ragged_dot_ffi_call(g, rhs_t, group_offset, cumsum) # d_rhs: lhs.T @ g accumulated per group d_rhs = _ragged_dot_bwd_ffi_call(lhs, g, group_offset, cumsum, g_local) From b758b2b23e468cb779d2cd34b63632959a209c10 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 03:21:00 -0800 Subject: [PATCH 071/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 5e7af8dab..39bfc6d41 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -56,7 +56,7 @@ constexpr int AlignmentOutput = 8; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; -using TileShape = cute::Shape; +using TileShape = cute::Shape; using ClusterShape = cute::Shape; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; From 2e2fff072a694785266afd96f52c0c6dd30e4148 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 03:30:05 -0800 Subject: [PATCH 072/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 39bfc6d41..6075776c0 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -56,7 +56,7 @@ constexpr int AlignmentOutput = 8; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; -using TileShape = cute::Shape; +using TileShape = cute::Shape; using ClusterShape = cute::Shape; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; From a2a390d3c7054dd31ae67221fa35822dc6b8e21e Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 14:03:19 -0800 Subject: [PATCH 073/106] add benchmark script --- skyrl-tx/benchmarks/bench_ragged_dot.py | 184 ++++++++++++++++++++++++ 1 file changed, 184 insertions(+) create mode 100644 skyrl-tx/benchmarks/bench_ragged_dot.py diff --git a/skyrl-tx/benchmarks/bench_ragged_dot.py b/skyrl-tx/benchmarks/bench_ragged_dot.py new file mode 100644 index 000000000..e747f18ec --- /dev/null +++ b/skyrl-tx/benchmarks/bench_ragged_dot.py @@ -0,0 +1,184 @@ +"""Benchmark ragged_dot CUTLASS kernel with Qwen3-30B-A3B MoE shapes.""" + +import argparse +import time + +import jax +import jax.numpy as jnp +from jax import lax + +from tx.ffi import ragged_dot_ffi, ragged_dot_ffi_available + + +def generate_group_sizes(num_tokens: int, num_experts: int, key: jax.Array) -> jax.Array: + """Generate random group sizes that sum to num_tokens.""" + # Random assignment of tokens to experts + assignments = jax.random.randint(key, (num_tokens,), 0, num_experts) + return jnp.bincount(assignments, length=num_experts).astype(jnp.int32) + + +def benchmark_forward( + num_tokens: int, + hidden_size: int, + intermediate_size: int, + num_experts: int, + num_warmup: int = 5, + num_iters: int = 20, + use_ffi: bool = True, +): + """Benchmark forward pass: lhs[M, K] @ rhs[G, K, N] -> out[M, N].""" + key = jax.random.PRNGKey(42) + k1, k2, k3 = jax.random.split(key, 3) + + lhs = jax.random.normal(k1, (num_tokens, hidden_size), dtype=jnp.bfloat16) + rhs = jax.random.normal(k2, (num_experts, hidden_size, intermediate_size), dtype=jnp.bfloat16) + group_sizes = generate_group_sizes(num_tokens, num_experts, k3) + group_offset = jnp.array([0], dtype=jnp.int32) + + if use_ffi: + fn = lambda: ragged_dot_ffi(lhs, rhs, group_sizes, group_offset) + else: + fn = lambda: lax.ragged_dot(lhs, rhs, group_sizes) + + # Warmup + for _ in range(num_warmup): + out = fn() + out.block_until_ready() + + # Benchmark + start = time.perf_counter() + for _ in range(num_iters): + out = fn() + out.block_until_ready() + elapsed = time.perf_counter() - start + + # FLOPs: 2 * M * K * N (matmul FLOPs) + flops = 2 * num_tokens * hidden_size * intermediate_size + tflops = (flops * num_iters / elapsed) / 1e12 + + return elapsed / num_iters, tflops + + +def benchmark_backward( + num_tokens: int, + hidden_size: int, + intermediate_size: int, + num_experts: int, + num_warmup: int = 5, + num_iters: int = 20, + use_ffi: bool = True, +): + """Benchmark backward pass through ragged_dot.""" + key = jax.random.PRNGKey(42) + k1, k2, k3 = jax.random.split(key, 3) + + lhs = jax.random.normal(k1, (num_tokens, hidden_size), dtype=jnp.bfloat16) + rhs = jax.random.normal(k2, (num_experts, hidden_size, intermediate_size), dtype=jnp.bfloat16) + group_sizes = generate_group_sizes(num_tokens, num_experts, k3) + group_offset = jnp.array([0], dtype=jnp.int32) + + if use_ffi: + def forward(lhs, rhs): + return ragged_dot_ffi(lhs, rhs, group_sizes, group_offset).sum() + else: + def forward(lhs, rhs): + return lax.ragged_dot(lhs, rhs, group_sizes).sum() + + grad_fn = jax.grad(forward, argnums=(0, 1)) + + # Warmup + for _ in range(num_warmup): + d_lhs, d_rhs = grad_fn(lhs, rhs) + d_lhs.block_until_ready() + d_rhs.block_until_ready() + + # Benchmark + start = time.perf_counter() + for _ in range(num_iters): + d_lhs, d_rhs = grad_fn(lhs, rhs) + d_lhs.block_until_ready() + d_rhs.block_until_ready() + elapsed = time.perf_counter() - start + + # Backward FLOPs: d_lhs = grad @ rhs.T (2*M*N*K) + d_rhs = lhs.T @ grad (2*K*M*N) + # Total: 4 * M * K * N + flops = 4 * num_tokens * hidden_size * intermediate_size + tflops = (flops * num_iters / elapsed) / 1e12 + + return elapsed / num_iters, tflops + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark ragged_dot CUTLASS kernel") + parser.add_argument("--num-tokens", type=int, default=8192, help="Number of tokens (M)") + parser.add_argument("--num-experts", type=int, default=128, help="Number of experts (G)") + parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden size (K)") + parser.add_argument("--intermediate-size", type=int, default=768, help="MoE intermediate size (N)") + parser.add_argument("--num-warmup", type=int, default=5, help="Warmup iterations") + parser.add_argument("--num-iters", type=int, default=20, help="Benchmark iterations") + parser.add_argument("--backward-only", action="store_true", help="Only benchmark backward pass") + parser.add_argument("--forward-only", action="store_true", help="Only benchmark forward pass") + args = parser.parse_args() + + print("Ragged Dot Benchmark (Qwen3-30B-A3B MoE shapes)") + print("=" * 60) + print(f"CUTLASS FFI available: {ragged_dot_ffi_available()}") + print(f"JAX backend: {jax.default_backend()}") + print(f"Devices: {jax.device_count()}") + print() + print(f"Config:") + print(f" num_tokens (M): {args.num_tokens}") + print(f" num_experts (G): {args.num_experts}") + print(f" hidden_size (K): {args.hidden_size}") + print(f" intermediate_size (N): {args.intermediate_size}") + print(f" warmup/iters: {args.num_warmup}/{args.num_iters}") + print() + + run_forward = not args.backward_only + run_backward = not args.forward_only + + if run_forward: + print("Forward Pass (lhs[M,K] @ rhs[G,K,N] -> out[M,N])") + print("-" * 60) + + if ragged_dot_ffi_available(): + ffi_time, ffi_tflops = benchmark_forward( + args.num_tokens, args.hidden_size, args.intermediate_size, + args.num_experts, args.num_warmup, args.num_iters, use_ffi=True + ) + print(f" CUTLASS FFI: {ffi_time*1000:8.3f} ms {ffi_tflops:8.2f} TFLOPS") + + jax_time, jax_tflops = benchmark_forward( + args.num_tokens, args.hidden_size, args.intermediate_size, + args.num_experts, args.num_warmup, args.num_iters, use_ffi=False + ) + print(f" JAX ragged: {jax_time*1000:8.3f} ms {jax_tflops:8.2f} TFLOPS") + + if ragged_dot_ffi_available(): + print(f" Speedup: {jax_time/ffi_time:.2f}x") + print() + + if run_backward: + print("Backward Pass (grad wrt lhs and rhs)") + print("-" * 60) + + if ragged_dot_ffi_available(): + ffi_time, ffi_tflops = benchmark_backward( + args.num_tokens, args.hidden_size, args.intermediate_size, + args.num_experts, args.num_warmup, args.num_iters, use_ffi=True + ) + print(f" CUTLASS FFI: {ffi_time*1000:8.3f} ms {ffi_tflops:8.2f} TFLOPS") + + jax_time, jax_tflops = benchmark_backward( + args.num_tokens, args.hidden_size, args.intermediate_size, + args.num_experts, args.num_warmup, args.num_iters, use_ffi=False + ) + print(f" JAX ragged: {jax_time*1000:8.3f} ms {jax_tflops:8.2f} TFLOPS") + + if ragged_dot_ffi_available(): + print(f" Speedup: {jax_time/ffi_time:.2f}x") + print() + + +if __name__ == "__main__": + main() From 729d816818e38932196d64c7fa2a4824a5bd2e4b Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 14:07:19 -0800 Subject: [PATCH 074/106] add tuning script --- skyrl-tx/benchmarks/sweep_tile_sizes.py | 165 ++++++++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 skyrl-tx/benchmarks/sweep_tile_sizes.py diff --git a/skyrl-tx/benchmarks/sweep_tile_sizes.py b/skyrl-tx/benchmarks/sweep_tile_sizes.py new file mode 100644 index 000000000..d091aa623 --- /dev/null +++ b/skyrl-tx/benchmarks/sweep_tile_sizes.py @@ -0,0 +1,165 @@ +"""Sweep tile sizes for ragged_dot CUTLASS kernel optimization.""" + +import subprocess +import sys +import re +from pathlib import Path + +CUDA_FILE = Path(__file__).parent.parent / "tx/ffi/ragged_dot_ffi.cu" +SO_FILE = Path(__file__).parent.parent / "tx/ffi/libragged_dot_ffi.so" + +# Tile configurations to test: (M, N, K) +# Constraints: dimensions should be powers of 2 or multiples that work with SM90 +TILE_CONFIGS = [ + (128, 256, 64), # current + (128, 128, 64), + (64, 128, 64), + (64, 256, 64), + (256, 128, 64), + (128, 64, 128), + (64, 64, 128), + (128, 128, 128), +] + + +def set_tile_shape(m: int, n: int, k: int) -> None: + """Update the TileShape in the CUDA file.""" + content = CUDA_FILE.read_text() + + # Replace the TileShape definition + pattern = r'using TileShape = cute::Shape;' + replacement = f'using TileShape = cute::Shape;' + + new_content = re.sub(pattern, replacement, content) + CUDA_FILE.write_text(new_content) + print(f"Set TileShape to ({m}, {n}, {k})") + + +def rebuild_kernel() -> bool: + """Rebuild the CUTLASS kernel.""" + # Remove old .so to force rebuild + if SO_FILE.exists(): + SO_FILE.unlink() + + # Rebuild using the build script directly + result = subprocess.run( + [sys.executable, "-c", "from tx.ffi.build import build_ragged_dot; build_ragged_dot()"], + capture_output=True, + text=True, + cwd=CUDA_FILE.parent.parent.parent, + ) + + if result.returncode != 0: + print(f"Build failed: {result.stderr}") + return False + + return SO_FILE.exists() + + +def run_benchmark(num_tokens: int = 8192) -> dict | None: + """Run the benchmark and parse results.""" + # Need to run in a fresh process to reload the .so + result = subprocess.run( + [sys.executable, "benchmarks/bench_ragged_dot.py", + "--num-tokens", str(num_tokens), + "--num-warmup", "3", + "--num-iters", "10"], + capture_output=True, + text=True, + cwd=CUDA_FILE.parent.parent.parent, + ) + + output = result.stdout + if result.returncode != 0: + print(f"Benchmark failed: {result.stderr}") + return None + + # Parse CUTLASS FFI results + results = {} + + # Forward pass + match = re.search(r'CUTLASS FFI:\s+([\d.]+)\s+ms\s+([\d.]+)\s+TFLOPS', output) + if match: + results['fwd_ms'] = float(match.group(1)) + results['fwd_tflops'] = float(match.group(2)) + + # Find backward pass (second occurrence) + matches = list(re.finditer(r'CUTLASS FFI:\s+([\d.]+)\s+ms\s+([\d.]+)\s+TFLOPS', output)) + if len(matches) >= 2: + results['bwd_ms'] = float(matches[1].group(1)) + results['bwd_tflops'] = float(matches[1].group(2)) + + return results if results else None + + +def main(): + import argparse + parser = argparse.ArgumentParser(description="Sweep tile sizes for CUTLASS kernel") + parser.add_argument("--num-tokens", type=int, default=8192, help="Number of tokens") + parser.add_argument("--dry-run", action="store_true", help="Only show configs, don't run") + args = parser.parse_args() + + print("CUTLASS Tile Size Sweep") + print("=" * 70) + print(f"Qwen3-30B-A3B shapes: K=2048, N=768, M={args.num_tokens}") + print() + + if args.dry_run: + print("Tile configurations to test:") + for m, n, k in TILE_CONFIGS: + print(f" ({m}, {n}, {k})") + return + + results = [] + + for m, n, k in TILE_CONFIGS: + print(f"\n{'='*70}") + print(f"Testing TileShape({m}, {n}, {k})") + print("-" * 70) + + set_tile_shape(m, n, k) + + if not rebuild_kernel(): + print(" FAILED to build") + results.append((m, n, k, None)) + continue + + bench_results = run_benchmark(args.num_tokens) + if bench_results: + print(f" Forward: {bench_results.get('fwd_ms', 'N/A'):>8.3f} ms {bench_results.get('fwd_tflops', 'N/A'):>8.2f} TFLOPS") + print(f" Backward: {bench_results.get('bwd_ms', 'N/A'):>8.3f} ms {bench_results.get('bwd_tflops', 'N/A'):>8.2f} TFLOPS") + results.append((m, n, k, bench_results)) + else: + print(" FAILED to benchmark") + results.append((m, n, k, None)) + + # Summary + print(f"\n{'='*70}") + print("SUMMARY") + print("=" * 70) + print(f"{'TileShape':<20} {'Fwd (ms)':>10} {'Fwd TFLOPS':>12} {'Bwd (ms)':>10} {'Bwd TFLOPS':>12}") + print("-" * 70) + + for m, n, k, res in results: + tile_str = f"({m}, {n}, {k})" + if res: + fwd_ms = f"{res.get('fwd_ms', 0):.3f}" + fwd_tf = f"{res.get('fwd_tflops', 0):.2f}" + bwd_ms = f"{res.get('bwd_ms', 0):.3f}" + bwd_tf = f"{res.get('bwd_tflops', 0):.2f}" + print(f"{tile_str:<20} {fwd_ms:>10} {fwd_tf:>12} {bwd_ms:>10} {bwd_tf:>12}") + else: + print(f"{tile_str:<20} {'FAILED':>10}") + + # Find best + valid_results = [(m, n, k, r) for m, n, k, r in results if r and 'fwd_tflops' in r] + if valid_results: + best_fwd = max(valid_results, key=lambda x: x[3]['fwd_tflops']) + best_bwd = max(valid_results, key=lambda x: x[3].get('bwd_tflops', 0)) + print() + print(f"Best forward: ({best_fwd[0]}, {best_fwd[1]}, {best_fwd[2]}) - {best_fwd[3]['fwd_tflops']:.2f} TFLOPS") + print(f"Best backward: ({best_bwd[0]}, {best_bwd[1]}, {best_bwd[2]}) - {best_bwd[3].get('bwd_tflops', 0):.2f} TFLOPS") + + +if __name__ == "__main__": + main() From 5c1bed3c93cb7729b0ecfeefda91ed8f926096ed Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 14:09:45 -0800 Subject: [PATCH 075/106] fix --- skyrl-tx/benchmarks/sweep_tile_sizes.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/benchmarks/sweep_tile_sizes.py b/skyrl-tx/benchmarks/sweep_tile_sizes.py index d091aa623..bc28bdc43 100644 --- a/skyrl-tx/benchmarks/sweep_tile_sizes.py +++ b/skyrl-tx/benchmarks/sweep_tile_sizes.py @@ -41,9 +41,10 @@ def rebuild_kernel() -> bool: if SO_FILE.exists(): SO_FILE.unlink() - # Rebuild using the build script directly + # Rebuild using uv with hatchling result = subprocess.run( - [sys.executable, "-c", "from tx.ffi.build import build_ragged_dot; build_ragged_dot()"], + ["uv", "run", "--with", "hatchling", "-c", + "from tx.ffi.build import build_ragged_dot; build_ragged_dot()"], capture_output=True, text=True, cwd=CUDA_FILE.parent.parent.parent, From 5cf35215d9ba41f7d359bfef89fc7897c87cf1db Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 14:10:24 -0800 Subject: [PATCH 076/106] fix --- skyrl-tx/benchmarks/sweep_tile_sizes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/benchmarks/sweep_tile_sizes.py b/skyrl-tx/benchmarks/sweep_tile_sizes.py index bc28bdc43..b81f38e89 100644 --- a/skyrl-tx/benchmarks/sweep_tile_sizes.py +++ b/skyrl-tx/benchmarks/sweep_tile_sizes.py @@ -43,7 +43,7 @@ def rebuild_kernel() -> bool: # Rebuild using uv with hatchling result = subprocess.run( - ["uv", "run", "--with", "hatchling", "-c", + ["uv", "run", "--with", "hatchling", "python", "-c", "from tx.ffi.build import build_ragged_dot; build_ragged_dot()"], capture_output=True, text=True, From 2fb40bc26ac410d998dac4800f32bb3f8ca4a52f Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 14:25:21 -0800 Subject: [PATCH 077/106] update tile size --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 6075776c0..dd6286208 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -56,7 +56,7 @@ constexpr int AlignmentOutput = 8; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; -using TileShape = cute::Shape; +using TileShape = cute::Shape; using ClusterShape = cute::Shape; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; From c23c9e742402dd3e6cba9232ed4b176530fadfda Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 14:30:48 -0800 Subject: [PATCH 078/106] fine grained sweeping (works before) --- skyrl-tx/benchmarks/sweep_tile_sizes.py | 80 +++++++++++++++---------- 1 file changed, 47 insertions(+), 33 deletions(-) diff --git a/skyrl-tx/benchmarks/sweep_tile_sizes.py b/skyrl-tx/benchmarks/sweep_tile_sizes.py index b81f38e89..fae374108 100644 --- a/skyrl-tx/benchmarks/sweep_tile_sizes.py +++ b/skyrl-tx/benchmarks/sweep_tile_sizes.py @@ -11,7 +11,7 @@ # Tile configurations to test: (M, N, K) # Constraints: dimensions should be powers of 2 or multiples that work with SM90 TILE_CONFIGS = [ - (128, 256, 64), # current + (128, 256, 64), (128, 128, 64), (64, 128, 64), (64, 256, 64), @@ -21,6 +21,12 @@ (128, 128, 128), ] +# Qwen3-30B-A3B MoE shapes +MOE_SHAPES = { + "gate_up": {"hidden_size": 2048, "intermediate_size": 768}, # K=2048, N=768 + "down": {"hidden_size": 768, "intermediate_size": 2048}, # K=768, N=2048 +} + def set_tile_shape(m: int, n: int, k: int) -> None: """Update the TileShape in the CUDA file.""" @@ -57,14 +63,17 @@ def rebuild_kernel() -> bool: return SO_FILE.exists() -def run_benchmark(num_tokens: int = 8192) -> dict | None: +def run_benchmark(num_tokens: int, hidden_size: int, intermediate_size: int) -> dict | None: """Run the benchmark and parse results.""" # Need to run in a fresh process to reload the .so result = subprocess.run( [sys.executable, "benchmarks/bench_ragged_dot.py", "--num-tokens", str(num_tokens), + "--hidden-size", str(hidden_size), + "--intermediate-size", str(intermediate_size), "--num-warmup", "3", - "--num-iters", "10"], + "--num-iters", "10", + "--forward-only"], capture_output=True, text=True, cwd=CUDA_FILE.parent.parent.parent, @@ -84,12 +93,6 @@ def run_benchmark(num_tokens: int = 8192) -> dict | None: results['fwd_ms'] = float(match.group(1)) results['fwd_tflops'] = float(match.group(2)) - # Find backward pass (second occurrence) - matches = list(re.finditer(r'CUTLASS FFI:\s+([\d.]+)\s+ms\s+([\d.]+)\s+TFLOPS', output)) - if len(matches) >= 2: - results['bwd_ms'] = float(matches[1].group(1)) - results['bwd_tflops'] = float(matches[1].group(2)) - return results if results else None @@ -102,7 +105,9 @@ def main(): print("CUTLASS Tile Size Sweep") print("=" * 70) - print(f"Qwen3-30B-A3B shapes: K=2048, N=768, M={args.num_tokens}") + print(f"Qwen3-30B-A3B MoE shapes, M={args.num_tokens}") + print(f" gate_up: K=2048, N=768") + print(f" down: K=768, N=2048") print() if args.dry_run: @@ -122,44 +127,53 @@ def main(): if not rebuild_kernel(): print(" FAILED to build") - results.append((m, n, k, None)) + results.append((m, n, k, None, None)) continue - bench_results = run_benchmark(args.num_tokens) - if bench_results: - print(f" Forward: {bench_results.get('fwd_ms', 'N/A'):>8.3f} ms {bench_results.get('fwd_tflops', 'N/A'):>8.2f} TFLOPS") - print(f" Backward: {bench_results.get('bwd_ms', 'N/A'):>8.3f} ms {bench_results.get('bwd_tflops', 'N/A'):>8.2f} TFLOPS") - results.append((m, n, k, bench_results)) - else: - print(" FAILED to benchmark") - results.append((m, n, k, None)) + tile_results = {} + for shape_name, shape_cfg in MOE_SHAPES.items(): + bench_result = run_benchmark( + args.num_tokens, + shape_cfg["hidden_size"], + shape_cfg["intermediate_size"], + ) + tile_results[shape_name] = bench_result + if bench_result: + print(f" {shape_name:>8}: {bench_result['fwd_ms']:>8.3f} ms {bench_result['fwd_tflops']:>8.2f} TFLOPS") + else: + print(f" {shape_name:>8}: FAILED") + + results.append((m, n, k, tile_results.get("gate_up"), tile_results.get("down"))) # Summary print(f"\n{'='*70}") print("SUMMARY") print("=" * 70) - print(f"{'TileShape':<20} {'Fwd (ms)':>10} {'Fwd TFLOPS':>12} {'Bwd (ms)':>10} {'Bwd TFLOPS':>12}") + print(f"{'TileShape':<20} {'gate_up ms':>12} {'TFLOPS':>8} {'down ms':>12} {'TFLOPS':>8} {'Total ms':>10}") print("-" * 70) - for m, n, k, res in results: + for m, n, k, gate_up, down in results: tile_str = f"({m}, {n}, {k})" - if res: - fwd_ms = f"{res.get('fwd_ms', 0):.3f}" - fwd_tf = f"{res.get('fwd_tflops', 0):.2f}" - bwd_ms = f"{res.get('bwd_ms', 0):.3f}" - bwd_tf = f"{res.get('bwd_tflops', 0):.2f}" - print(f"{tile_str:<20} {fwd_ms:>10} {fwd_tf:>12} {bwd_ms:>10} {bwd_tf:>12}") + if gate_up and down: + gu_ms = gate_up['fwd_ms'] + gu_tf = gate_up['fwd_tflops'] + dn_ms = down['fwd_ms'] + dn_tf = down['fwd_tflops'] + total_ms = gu_ms + dn_ms + print(f"{tile_str:<20} {gu_ms:>12.3f} {gu_tf:>8.2f} {dn_ms:>12.3f} {dn_tf:>8.2f} {total_ms:>10.3f}") else: - print(f"{tile_str:<20} {'FAILED':>10}") + print(f"{tile_str:<20} {'FAILED':>12}") # Find best - valid_results = [(m, n, k, r) for m, n, k, r in results if r and 'fwd_tflops' in r] + valid_results = [(m, n, k, gu, dn) for m, n, k, gu, dn in results if gu and dn] if valid_results: - best_fwd = max(valid_results, key=lambda x: x[3]['fwd_tflops']) - best_bwd = max(valid_results, key=lambda x: x[3].get('bwd_tflops', 0)) + best_gate_up = max(valid_results, key=lambda x: x[3]['fwd_tflops']) + best_down = max(valid_results, key=lambda x: x[4]['fwd_tflops']) + best_total = min(valid_results, key=lambda x: x[3]['fwd_ms'] + x[4]['fwd_ms']) print() - print(f"Best forward: ({best_fwd[0]}, {best_fwd[1]}, {best_fwd[2]}) - {best_fwd[3]['fwd_tflops']:.2f} TFLOPS") - print(f"Best backward: ({best_bwd[0]}, {best_bwd[1]}, {best_bwd[2]}) - {best_bwd[3].get('bwd_tflops', 0):.2f} TFLOPS") + print(f"Best gate_up: ({best_gate_up[0]}, {best_gate_up[1]}, {best_gate_up[2]}) - {best_gate_up[3]['fwd_tflops']:.2f} TFLOPS") + print(f"Best down: ({best_down[0]}, {best_down[1]}, {best_down[2]}) - {best_down[4]['fwd_tflops']:.2f} TFLOPS") + print(f"Best total: ({best_total[0]}, {best_total[1]}, {best_total[2]}) - {best_total[3]['fwd_ms'] + best_total[4]['fwd_ms']:.3f} ms") if __name__ == "__main__": From 2b1c3d6200bea5045e3cf3c41483069b2fc00e85 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 14:57:14 -0800 Subject: [PATCH 079/106] optimize --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 97 ++++++++++++++++++++----------- 1 file changed, 62 insertions(+), 35 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index dd6286208..72e89edb7 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -50,52 +50,62 @@ using DtypeAccum = float; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutOutput = cutlass::layout::RowMajor; +using LayoutA_Bwd = cutlass::layout::ColumnMajor; constexpr int AlignmentA = 8; constexpr int AlignmentB = 8; constexpr int AlignmentOutput = 8; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; -using TileShape = cute::Shape; using ClusterShape = cute::Shape; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using ProblemShape = cutlass::gemm::GroupProblemShape>; - -using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - DtypeAccum, DtypeAccum, void, LayoutOutput*, AlignmentOutput, - DtypeOutput, LayoutOutput*, AlignmentOutput, EpilogueSchedule, - cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; - -using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, DtypeA, LayoutA*, AlignmentA, - DtypeB, LayoutB*, AlignmentB, DtypeAccum, TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule>::CollectiveOp; - -using GemmKernel = cutlass::gemm::kernel::GemmUniversal; -using Gemm = cutlass::gemm::device::GemmUniversalAdapter; using ProblemShapeType = ProblemShape::UnderlyingProblemShape; -// Backward kernel for d_rhs: computes lhs.T @ grad per group -// Uses ColumnMajor for A to interpret row-major lhs as transposed -using LayoutA_Bwd = cutlass::layout::ColumnMajor; - -using CollectiveMainloop_Bwd = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, DtypeA, LayoutA_Bwd*, AlignmentA, - DtypeB, LayoutB*, AlignmentB, DtypeAccum, TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule>::CollectiveOp; +// Tile configurations optimized for different shapes: +// - SmallN (64, 64, 128): Best for K > N (e.g., gate_up: K=2048, N=768) +// - LargeN (64, 256, 64): Best for N > K (e.g., down: K=768, N=2048) +using TileShape_SmallN = cute::Shape; +using TileShape_LargeN = cute::Shape; + +// Forward kernel types for SmallN tile +template +struct GemmTypes { + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + DtypeAccum, DtypeAccum, void, LayoutOutput*, AlignmentOutput, + DtypeOutput, LayoutOutput*, AlignmentOutput, EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, DtypeA, LayoutA*, AlignmentA, + DtypeB, LayoutB*, AlignmentB, DtypeAccum, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using CollectiveMainloop_Bwd = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, DtypeA, LayoutA_Bwd*, AlignmentA, + DtypeB, LayoutB*, AlignmentB, DtypeAccum, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using GemmKernel_Bwd = cutlass::gemm::kernel::GemmUniversal; + using Gemm_Bwd = cutlass::gemm::device::GemmUniversalAdapter; +}; -using GemmKernel_Bwd = cutlass::gemm::kernel::GemmUniversal; -using Gemm_Bwd = cutlass::gemm::device::GemmUniversalAdapter; +// Instantiate kernel types for both tile configurations +using SmallN = GemmTypes; +using LargeN = GemmTypes; template static T* carve_aligned(char*& p, size_t count) { @@ -223,6 +233,23 @@ ffi::Error ExecuteGroupedGemm( return ffi::Error::Success(); } +// Dispatch to the best kernel based on K and N dimensions +template +ffi::Error DispatchGroupedGemm( + cudaStream_t stream, ffi::ScratchAllocator& scratch, + const DtypeA* A, const DtypeB* B, DtypeOutput* out, + const int32_t* group_offsets_cumsum, const int32_t* group_offset, + int32_t g_local, int32_t k, int32_t n) { + // Use LargeN tile (64, 256, 64) when N > K, otherwise SmallN tile (64, 64, 128) + if (n > k) { + return ExecuteGroupedGemm( + stream, scratch, A, B, out, group_offsets_cumsum, group_offset, g_local, k, n); + } else { + return ExecuteGroupedGemm( + stream, scratch, A, B, out, group_offsets_cumsum, group_offset, g_local, k, n); + } +} + ffi::Error RaggedDotCudaImpl( cudaStream_t stream, ffi::ScratchAllocator scratch, @@ -250,7 +277,7 @@ ffi::Error RaggedDotCudaImpl( return ffi::Error::Internal("Failed to zero output."); } - return ExecuteGroupedGemm( + return DispatchGroupedGemm( stream, scratch, reinterpret_cast(lhs.typed_data()), reinterpret_cast(rhs.typed_data()), @@ -301,7 +328,7 @@ ffi::Error RaggedDotBwdCudaImpl( return ffi::Error::Internal("Failed to zero d_rhs output."); } - return ExecuteGroupedGemm( + return DispatchGroupedGemm( stream, scratch, reinterpret_cast(lhs.typed_data()), reinterpret_cast(grad.typed_data()), From 22829588822ebcd1faea7cbcadd2faa2cadec35b Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 15:14:57 -0800 Subject: [PATCH 080/106] Revert to state at 2fb40bc2 (update tile size) Reverts changes from commits c23c9e74 and 2b1c3d62. Co-Authored-By: Claude Opus 4.5 --- skyrl-tx/benchmarks/sweep_tile_sizes.py | 80 +++++++++----------- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 97 +++++++++---------------- 2 files changed, 68 insertions(+), 109 deletions(-) diff --git a/skyrl-tx/benchmarks/sweep_tile_sizes.py b/skyrl-tx/benchmarks/sweep_tile_sizes.py index fae374108..b81f38e89 100644 --- a/skyrl-tx/benchmarks/sweep_tile_sizes.py +++ b/skyrl-tx/benchmarks/sweep_tile_sizes.py @@ -11,7 +11,7 @@ # Tile configurations to test: (M, N, K) # Constraints: dimensions should be powers of 2 or multiples that work with SM90 TILE_CONFIGS = [ - (128, 256, 64), + (128, 256, 64), # current (128, 128, 64), (64, 128, 64), (64, 256, 64), @@ -21,12 +21,6 @@ (128, 128, 128), ] -# Qwen3-30B-A3B MoE shapes -MOE_SHAPES = { - "gate_up": {"hidden_size": 2048, "intermediate_size": 768}, # K=2048, N=768 - "down": {"hidden_size": 768, "intermediate_size": 2048}, # K=768, N=2048 -} - def set_tile_shape(m: int, n: int, k: int) -> None: """Update the TileShape in the CUDA file.""" @@ -63,17 +57,14 @@ def rebuild_kernel() -> bool: return SO_FILE.exists() -def run_benchmark(num_tokens: int, hidden_size: int, intermediate_size: int) -> dict | None: +def run_benchmark(num_tokens: int = 8192) -> dict | None: """Run the benchmark and parse results.""" # Need to run in a fresh process to reload the .so result = subprocess.run( [sys.executable, "benchmarks/bench_ragged_dot.py", "--num-tokens", str(num_tokens), - "--hidden-size", str(hidden_size), - "--intermediate-size", str(intermediate_size), "--num-warmup", "3", - "--num-iters", "10", - "--forward-only"], + "--num-iters", "10"], capture_output=True, text=True, cwd=CUDA_FILE.parent.parent.parent, @@ -93,6 +84,12 @@ def run_benchmark(num_tokens: int, hidden_size: int, intermediate_size: int) -> results['fwd_ms'] = float(match.group(1)) results['fwd_tflops'] = float(match.group(2)) + # Find backward pass (second occurrence) + matches = list(re.finditer(r'CUTLASS FFI:\s+([\d.]+)\s+ms\s+([\d.]+)\s+TFLOPS', output)) + if len(matches) >= 2: + results['bwd_ms'] = float(matches[1].group(1)) + results['bwd_tflops'] = float(matches[1].group(2)) + return results if results else None @@ -105,9 +102,7 @@ def main(): print("CUTLASS Tile Size Sweep") print("=" * 70) - print(f"Qwen3-30B-A3B MoE shapes, M={args.num_tokens}") - print(f" gate_up: K=2048, N=768") - print(f" down: K=768, N=2048") + print(f"Qwen3-30B-A3B shapes: K=2048, N=768, M={args.num_tokens}") print() if args.dry_run: @@ -127,53 +122,44 @@ def main(): if not rebuild_kernel(): print(" FAILED to build") - results.append((m, n, k, None, None)) + results.append((m, n, k, None)) continue - tile_results = {} - for shape_name, shape_cfg in MOE_SHAPES.items(): - bench_result = run_benchmark( - args.num_tokens, - shape_cfg["hidden_size"], - shape_cfg["intermediate_size"], - ) - tile_results[shape_name] = bench_result - if bench_result: - print(f" {shape_name:>8}: {bench_result['fwd_ms']:>8.3f} ms {bench_result['fwd_tflops']:>8.2f} TFLOPS") - else: - print(f" {shape_name:>8}: FAILED") - - results.append((m, n, k, tile_results.get("gate_up"), tile_results.get("down"))) + bench_results = run_benchmark(args.num_tokens) + if bench_results: + print(f" Forward: {bench_results.get('fwd_ms', 'N/A'):>8.3f} ms {bench_results.get('fwd_tflops', 'N/A'):>8.2f} TFLOPS") + print(f" Backward: {bench_results.get('bwd_ms', 'N/A'):>8.3f} ms {bench_results.get('bwd_tflops', 'N/A'):>8.2f} TFLOPS") + results.append((m, n, k, bench_results)) + else: + print(" FAILED to benchmark") + results.append((m, n, k, None)) # Summary print(f"\n{'='*70}") print("SUMMARY") print("=" * 70) - print(f"{'TileShape':<20} {'gate_up ms':>12} {'TFLOPS':>8} {'down ms':>12} {'TFLOPS':>8} {'Total ms':>10}") + print(f"{'TileShape':<20} {'Fwd (ms)':>10} {'Fwd TFLOPS':>12} {'Bwd (ms)':>10} {'Bwd TFLOPS':>12}") print("-" * 70) - for m, n, k, gate_up, down in results: + for m, n, k, res in results: tile_str = f"({m}, {n}, {k})" - if gate_up and down: - gu_ms = gate_up['fwd_ms'] - gu_tf = gate_up['fwd_tflops'] - dn_ms = down['fwd_ms'] - dn_tf = down['fwd_tflops'] - total_ms = gu_ms + dn_ms - print(f"{tile_str:<20} {gu_ms:>12.3f} {gu_tf:>8.2f} {dn_ms:>12.3f} {dn_tf:>8.2f} {total_ms:>10.3f}") + if res: + fwd_ms = f"{res.get('fwd_ms', 0):.3f}" + fwd_tf = f"{res.get('fwd_tflops', 0):.2f}" + bwd_ms = f"{res.get('bwd_ms', 0):.3f}" + bwd_tf = f"{res.get('bwd_tflops', 0):.2f}" + print(f"{tile_str:<20} {fwd_ms:>10} {fwd_tf:>12} {bwd_ms:>10} {bwd_tf:>12}") else: - print(f"{tile_str:<20} {'FAILED':>12}") + print(f"{tile_str:<20} {'FAILED':>10}") # Find best - valid_results = [(m, n, k, gu, dn) for m, n, k, gu, dn in results if gu and dn] + valid_results = [(m, n, k, r) for m, n, k, r in results if r and 'fwd_tflops' in r] if valid_results: - best_gate_up = max(valid_results, key=lambda x: x[3]['fwd_tflops']) - best_down = max(valid_results, key=lambda x: x[4]['fwd_tflops']) - best_total = min(valid_results, key=lambda x: x[3]['fwd_ms'] + x[4]['fwd_ms']) + best_fwd = max(valid_results, key=lambda x: x[3]['fwd_tflops']) + best_bwd = max(valid_results, key=lambda x: x[3].get('bwd_tflops', 0)) print() - print(f"Best gate_up: ({best_gate_up[0]}, {best_gate_up[1]}, {best_gate_up[2]}) - {best_gate_up[3]['fwd_tflops']:.2f} TFLOPS") - print(f"Best down: ({best_down[0]}, {best_down[1]}, {best_down[2]}) - {best_down[4]['fwd_tflops']:.2f} TFLOPS") - print(f"Best total: ({best_total[0]}, {best_total[1]}, {best_total[2]}) - {best_total[3]['fwd_ms'] + best_total[4]['fwd_ms']:.3f} ms") + print(f"Best forward: ({best_fwd[0]}, {best_fwd[1]}, {best_fwd[2]}) - {best_fwd[3]['fwd_tflops']:.2f} TFLOPS") + print(f"Best backward: ({best_bwd[0]}, {best_bwd[1]}, {best_bwd[2]}) - {best_bwd[3].get('bwd_tflops', 0):.2f} TFLOPS") if __name__ == "__main__": diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 72e89edb7..dd6286208 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -50,62 +50,52 @@ using DtypeAccum = float; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutOutput = cutlass::layout::RowMajor; -using LayoutA_Bwd = cutlass::layout::ColumnMajor; constexpr int AlignmentA = 8; constexpr int AlignmentB = 8; constexpr int AlignmentOutput = 8; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; +using TileShape = cute::Shape; using ClusterShape = cute::Shape; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using ProblemShape = cutlass::gemm::GroupProblemShape>; + +using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + DtypeAccum, DtypeAccum, void, LayoutOutput*, AlignmentOutput, + DtypeOutput, LayoutOutput*, AlignmentOutput, EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + +using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, DtypeA, LayoutA*, AlignmentA, + DtypeB, LayoutB*, AlignmentB, DtypeAccum, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal; +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; using ProblemShapeType = ProblemShape::UnderlyingProblemShape; -// Tile configurations optimized for different shapes: -// - SmallN (64, 64, 128): Best for K > N (e.g., gate_up: K=2048, N=768) -// - LargeN (64, 256, 64): Best for N > K (e.g., down: K=768, N=2048) -using TileShape_SmallN = cute::Shape; -using TileShape_LargeN = cute::Shape; - -// Forward kernel types for SmallN tile -template -struct GemmTypes { - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - DtypeAccum, DtypeAccum, void, LayoutOutput*, AlignmentOutput, - DtypeOutput, LayoutOutput*, AlignmentOutput, EpilogueSchedule, - cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, DtypeA, LayoutA*, AlignmentA, - DtypeB, LayoutB*, AlignmentB, DtypeAccum, TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule>::CollectiveOp; - - using CollectiveMainloop_Bwd = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, DtypeA, LayoutA_Bwd*, AlignmentA, - DtypeB, LayoutB*, AlignmentB, DtypeAccum, TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, - KernelSchedule>::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - using GemmKernel_Bwd = cutlass::gemm::kernel::GemmUniversal; - using Gemm_Bwd = cutlass::gemm::device::GemmUniversalAdapter; -}; +// Backward kernel for d_rhs: computes lhs.T @ grad per group +// Uses ColumnMajor for A to interpret row-major lhs as transposed +using LayoutA_Bwd = cutlass::layout::ColumnMajor; -// Instantiate kernel types for both tile configurations -using SmallN = GemmTypes; -using LargeN = GemmTypes; +using CollectiveMainloop_Bwd = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, DtypeA, LayoutA_Bwd*, AlignmentA, + DtypeB, LayoutB*, AlignmentB, DtypeAccum, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + +using GemmKernel_Bwd = cutlass::gemm::kernel::GemmUniversal; +using Gemm_Bwd = cutlass::gemm::device::GemmUniversalAdapter; template static T* carve_aligned(char*& p, size_t count) { @@ -233,23 +223,6 @@ ffi::Error ExecuteGroupedGemm( return ffi::Error::Success(); } -// Dispatch to the best kernel based on K and N dimensions -template -ffi::Error DispatchGroupedGemm( - cudaStream_t stream, ffi::ScratchAllocator& scratch, - const DtypeA* A, const DtypeB* B, DtypeOutput* out, - const int32_t* group_offsets_cumsum, const int32_t* group_offset, - int32_t g_local, int32_t k, int32_t n) { - // Use LargeN tile (64, 256, 64) when N > K, otherwise SmallN tile (64, 64, 128) - if (n > k) { - return ExecuteGroupedGemm( - stream, scratch, A, B, out, group_offsets_cumsum, group_offset, g_local, k, n); - } else { - return ExecuteGroupedGemm( - stream, scratch, A, B, out, group_offsets_cumsum, group_offset, g_local, k, n); - } -} - ffi::Error RaggedDotCudaImpl( cudaStream_t stream, ffi::ScratchAllocator scratch, @@ -277,7 +250,7 @@ ffi::Error RaggedDotCudaImpl( return ffi::Error::Internal("Failed to zero output."); } - return DispatchGroupedGemm( + return ExecuteGroupedGemm( stream, scratch, reinterpret_cast(lhs.typed_data()), reinterpret_cast(rhs.typed_data()), @@ -328,7 +301,7 @@ ffi::Error RaggedDotBwdCudaImpl( return ffi::Error::Internal("Failed to zero d_rhs output."); } - return DispatchGroupedGemm( + return ExecuteGroupedGemm( stream, scratch, reinterpret_cast(lhs.typed_data()), reinterpret_cast(grad.typed_data()), From f4cfed296c662af47dc486ccebbba7b574441563 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 15:26:45 -0800 Subject: [PATCH 081/106] use kernel for everything --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 85 ++++++++++++++++++++++++++----- skyrl-tx/tx/layers/util.py | 8 +-- 2 files changed, 74 insertions(+), 19 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index dd6286208..50ce38fdc 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -97,6 +97,43 @@ using CollectiveMainloop_Bwd = using GemmKernel_Bwd = cutlass::gemm::kernel::GemmUniversal; using Gemm_Bwd = cutlass::gemm::device::GemmUniversalAdapter; +// Unaligned kernel variants (alignment=1) for small k/n dimensions (e.g., LoRA rank=1) +// Uses cooperative kernel schedule which supports alignment=1 +constexpr int AlignmentUnaligned = 1; +using TileShapeUnaligned = cute::Shape; +using KernelScheduleUnaligned = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; +using EpilogueScheduleUnaligned = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + +using CollectiveEpilogueUnaligned = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShapeUnaligned, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + DtypeAccum, DtypeAccum, void, LayoutOutput*, AlignmentUnaligned, + DtypeOutput, LayoutOutput*, AlignmentUnaligned, EpilogueScheduleUnaligned, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + +using CollectiveMainloopUnaligned = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, DtypeA, LayoutA*, AlignmentUnaligned, + DtypeB, LayoutB*, AlignmentUnaligned, DtypeAccum, TileShapeUnaligned, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogueUnaligned::SharedStorage))>, + KernelScheduleUnaligned>::CollectiveOp; + +using GemmKernelUnaligned = cutlass::gemm::kernel::GemmUniversal; +using GemmUnaligned = cutlass::gemm::device::GemmUniversalAdapter; + +using CollectiveMainloopUnaligned_Bwd = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, DtypeA, LayoutA_Bwd*, AlignmentUnaligned, + DtypeB, LayoutB*, AlignmentUnaligned, DtypeAccum, TileShapeUnaligned, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogueUnaligned::SharedStorage))>, + KernelScheduleUnaligned>::CollectiveOp; + +using GemmKernelUnaligned_Bwd = cutlass::gemm::kernel::GemmUniversal; +using GemmUnaligned_Bwd = cutlass::gemm::device::GemmUniversalAdapter; + template static T* carve_aligned(char*& p, size_t count) { T* out = reinterpret_cast((reinterpret_cast(p) + 15) & ~uintptr_t(15)); @@ -223,6 +260,10 @@ ffi::Error ExecuteGroupedGemm( return ffi::Error::Success(); } +static bool is_aligned(int32_t k, int32_t n) { + return (k % AlignmentA == 0) && (n % AlignmentB == 0); +} + ffi::Error RaggedDotCudaImpl( cudaStream_t stream, ffi::ScratchAllocator scratch, @@ -250,12 +291,21 @@ ffi::Error RaggedDotCudaImpl( return ffi::Error::Internal("Failed to zero output."); } - return ExecuteGroupedGemm( - stream, scratch, - reinterpret_cast(lhs.typed_data()), - reinterpret_cast(rhs.typed_data()), - reinterpret_cast(out->typed_data()), - group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); + if (is_aligned(k, n)) { + return ExecuteGroupedGemm( + stream, scratch, + reinterpret_cast(lhs.typed_data()), + reinterpret_cast(rhs.typed_data()), + reinterpret_cast(out->typed_data()), + group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); + } else { + return ExecuteGroupedGemm( + stream, scratch, + reinterpret_cast(lhs.typed_data()), + reinterpret_cast(rhs.typed_data()), + reinterpret_cast(out->typed_data()), + group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); + } } XLA_FFI_DEFINE_HANDLER_SYMBOL( @@ -301,12 +351,23 @@ ffi::Error RaggedDotBwdCudaImpl( return ffi::Error::Internal("Failed to zero d_rhs output."); } - return ExecuteGroupedGemm( - stream, scratch, - reinterpret_cast(lhs.typed_data()), - reinterpret_cast(grad.typed_data()), - reinterpret_cast(d_rhs->typed_data()), - group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); + // For backward pass: A is [M, K] viewed as [K, M] via ColumnMajor, B is [M, N] + // So we check alignment of K (output M dimension) and N (output N dimension) + if (is_aligned(k, n)) { + return ExecuteGroupedGemm( + stream, scratch, + reinterpret_cast(lhs.typed_data()), + reinterpret_cast(grad.typed_data()), + reinterpret_cast(d_rhs->typed_data()), + group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); + } else { + return ExecuteGroupedGemm( + stream, scratch, + reinterpret_cast(lhs.typed_data()), + reinterpret_cast(grad.typed_data()), + reinterpret_cast(d_rhs->typed_data()), + group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); + } } XLA_FFI_DEFINE_HANDLER_SYMBOL( diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index f489d792a..a8abddda0 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -35,11 +35,7 @@ def ragged_dot( preferred_element_type=preferred_element_type, ) - # CUTLASS kernel requires k and n dimensions divisible by 8 - k = lhs.shape[-1] - n = rhs.shape[-1] - cutlass_alignment = 8 - + # Use CUTLASS kernel when available (supports any k/n dimensions including k=1 for LoRA rank 1) if ( ragged_dot_ffi_available() and jax.default_backend() == "gpu" @@ -47,8 +43,6 @@ def ragged_dot( and rhs.dtype == jnp.bfloat16 and group_sizes.dtype == jnp.int32 and group_offset.dtype == jnp.int32 - and k % cutlass_alignment == 0 - and n % cutlass_alignment == 0 ): return ragged_dot_ffi(lhs, rhs, group_sizes, group_offset) From 3feab965584ef5e7c446439f1e0420336940ee32 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 15:33:24 -0800 Subject: [PATCH 082/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 64 +++++++++++++++---------------- skyrl-tx/tx/layers/util.py | 6 ++- 2 files changed, 36 insertions(+), 34 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 50ce38fdc..cc981100a 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -97,42 +97,40 @@ using CollectiveMainloop_Bwd = using GemmKernel_Bwd = cutlass::gemm::kernel::GemmUniversal; using Gemm_Bwd = cutlass::gemm::device::GemmUniversalAdapter; -// Unaligned kernel variants (alignment=1) for small k/n dimensions (e.g., LoRA rank=1) -// Uses cooperative kernel schedule which supports alignment=1 -constexpr int AlignmentUnaligned = 1; -using TileShapeUnaligned = cute::Shape; -using KernelScheduleUnaligned = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; -using EpilogueScheduleUnaligned = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; - -using CollectiveEpilogueUnaligned = +// Low-rank (LR) kernel variant for unaligned K dimension (e.g., LoRA rank=1) +// K alignment can be 1, but N (output) alignment must still be 8 for TMA +constexpr int AlignmentLR = 1; +using TileShapeLR = cute::Shape; + +using CollectiveEpilogueLR = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShapeUnaligned, ClusterShape, + ArchTag, OperatorClass, TileShapeLR, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - DtypeAccum, DtypeAccum, void, LayoutOutput*, AlignmentUnaligned, - DtypeOutput, LayoutOutput*, AlignmentUnaligned, EpilogueScheduleUnaligned, + DtypeAccum, DtypeAccum, void, LayoutOutput*, AlignmentOutput, + DtypeOutput, LayoutOutput*, AlignmentOutput, EpilogueSchedule, cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; -using CollectiveMainloopUnaligned = +using CollectiveMainloopLR = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, DtypeA, LayoutA*, AlignmentUnaligned, - DtypeB, LayoutB*, AlignmentUnaligned, DtypeAccum, TileShapeUnaligned, ClusterShape, + ArchTag, OperatorClass, DtypeA, LayoutA*, AlignmentLR, + DtypeB, LayoutB*, AlignmentLR, DtypeAccum, TileShapeLR, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogueUnaligned::SharedStorage))>, - KernelScheduleUnaligned>::CollectiveOp; + static_cast(sizeof(typename CollectiveEpilogueLR::SharedStorage))>, + KernelSchedule>::CollectiveOp; -using GemmKernelUnaligned = cutlass::gemm::kernel::GemmUniversal; -using GemmUnaligned = cutlass::gemm::device::GemmUniversalAdapter; +using GemmKernelLR = cutlass::gemm::kernel::GemmUniversal; +using GemmLR = cutlass::gemm::device::GemmUniversalAdapter; -using CollectiveMainloopUnaligned_Bwd = +using CollectiveMainloopLR_Bwd = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, DtypeA, LayoutA_Bwd*, AlignmentUnaligned, - DtypeB, LayoutB*, AlignmentUnaligned, DtypeAccum, TileShapeUnaligned, ClusterShape, + ArchTag, OperatorClass, DtypeA, LayoutA_Bwd*, AlignmentLR, + DtypeB, LayoutB*, AlignmentLR, DtypeAccum, TileShapeLR, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogueUnaligned::SharedStorage))>, - KernelScheduleUnaligned>::CollectiveOp; + static_cast(sizeof(typename CollectiveEpilogueLR::SharedStorage))>, + KernelSchedule>::CollectiveOp; -using GemmKernelUnaligned_Bwd = cutlass::gemm::kernel::GemmUniversal; -using GemmUnaligned_Bwd = cutlass::gemm::device::GemmUniversalAdapter; +using GemmKernelLR_Bwd = cutlass::gemm::kernel::GemmUniversal; +using GemmLR_Bwd = cutlass::gemm::device::GemmUniversalAdapter; template static T* carve_aligned(char*& p, size_t count) { @@ -260,8 +258,8 @@ ffi::Error ExecuteGroupedGemm( return ffi::Error::Success(); } -static bool is_aligned(int32_t k, int32_t n) { - return (k % AlignmentA == 0) && (n % AlignmentB == 0); +static bool is_k_aligned(int32_t k) { + return k % AlignmentA == 0; } ffi::Error RaggedDotCudaImpl( @@ -291,7 +289,7 @@ ffi::Error RaggedDotCudaImpl( return ffi::Error::Internal("Failed to zero output."); } - if (is_aligned(k, n)) { + if (is_k_aligned(k)) { return ExecuteGroupedGemm( stream, scratch, reinterpret_cast(lhs.typed_data()), @@ -299,7 +297,7 @@ ffi::Error RaggedDotCudaImpl( reinterpret_cast(out->typed_data()), group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); } else { - return ExecuteGroupedGemm( + return ExecuteGroupedGemm( stream, scratch, reinterpret_cast(lhs.typed_data()), reinterpret_cast(rhs.typed_data()), @@ -351,9 +349,9 @@ ffi::Error RaggedDotBwdCudaImpl( return ffi::Error::Internal("Failed to zero d_rhs output."); } - // For backward pass: A is [M, K] viewed as [K, M] via ColumnMajor, B is [M, N] - // So we check alignment of K (output M dimension) and N (output N dimension) - if (is_aligned(k, n)) { + // For backward pass: computes [K, N] = [K, M] @ [M, N] where M is the contraction dim + // K alignment affects the output, use LR kernel when K is not aligned + if (is_k_aligned(k)) { return ExecuteGroupedGemm( stream, scratch, reinterpret_cast(lhs.typed_data()), @@ -361,7 +359,7 @@ ffi::Error RaggedDotBwdCudaImpl( reinterpret_cast(d_rhs->typed_data()), group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); } else { - return ExecuteGroupedGemm( + return ExecuteGroupedGemm( stream, scratch, reinterpret_cast(lhs.typed_data()), reinterpret_cast(grad.typed_data()), diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index a8abddda0..e9be7e961 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -35,7 +35,10 @@ def ragged_dot( preferred_element_type=preferred_element_type, ) - # Use CUTLASS kernel when available (supports any k/n dimensions including k=1 for LoRA rank 1) + # Use CUTLASS kernel when available + # K dimension can be any value (LR kernel handles k=1 for LoRA), but N must be aligned for TMA + n = rhs.shape[-1] + cutlass_alignment = 8 if ( ragged_dot_ffi_available() and jax.default_backend() == "gpu" @@ -43,6 +46,7 @@ def ragged_dot( and rhs.dtype == jnp.bfloat16 and group_sizes.dtype == jnp.int32 and group_offset.dtype == jnp.int32 + and n % cutlass_alignment == 0 ): return ragged_dot_ffi(lhs, rhs, group_sizes, group_offset) From f03a1e00ca231bba76f237c6fea29d26eee3b20c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 15:37:23 -0800 Subject: [PATCH 083/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 22 ++++++++++++---------- skyrl-tx/tx/layers/util.py | 6 +----- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index cc981100a..22b4f24d6 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -98,16 +98,18 @@ using GemmKernel_Bwd = cutlass::gemm::kernel::GemmUniversal; // Low-rank (LR) kernel variant for unaligned K dimension (e.g., LoRA rank=1) -// K alignment can be 1, but N (output) alignment must still be 8 for TMA +// Uses CpAsync (non-TMA) schedule which has no alignment requirements constexpr int AlignmentLR = 1; using TileShapeLR = cute::Shape; +using KernelScheduleLR = cutlass::gemm::KernelPtrArrayCpAsyncWarpSpecializedCooperative; +using EpilogueScheduleLR = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; using CollectiveEpilogueLR = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShapeLR, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - DtypeAccum, DtypeAccum, void, LayoutOutput*, AlignmentOutput, - DtypeOutput, LayoutOutput*, AlignmentOutput, EpilogueSchedule, + DtypeAccum, DtypeAccum, void, LayoutOutput*, AlignmentLR, + DtypeOutput, LayoutOutput*, AlignmentLR, EpilogueScheduleLR, cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; using CollectiveMainloopLR = @@ -116,7 +118,7 @@ using CollectiveMainloopLR = DtypeB, LayoutB*, AlignmentLR, DtypeAccum, TileShapeLR, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout< static_cast(sizeof(typename CollectiveEpilogueLR::SharedStorage))>, - KernelSchedule>::CollectiveOp; + KernelScheduleLR>::CollectiveOp; using GemmKernelLR = cutlass::gemm::kernel::GemmUniversal; using GemmLR = cutlass::gemm::device::GemmUniversalAdapter; @@ -127,7 +129,7 @@ using CollectiveMainloopLR_Bwd = DtypeB, LayoutB*, AlignmentLR, DtypeAccum, TileShapeLR, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout< static_cast(sizeof(typename CollectiveEpilogueLR::SharedStorage))>, - KernelSchedule>::CollectiveOp; + KernelScheduleLR>::CollectiveOp; using GemmKernelLR_Bwd = cutlass::gemm::kernel::GemmUniversal; using GemmLR_Bwd = cutlass::gemm::device::GemmUniversalAdapter; @@ -258,8 +260,8 @@ ffi::Error ExecuteGroupedGemm( return ffi::Error::Success(); } -static bool is_k_aligned(int32_t k) { - return k % AlignmentA == 0; +static bool is_aligned(int32_t k, int32_t n) { + return (k % AlignmentA == 0) && (n % AlignmentB == 0); } ffi::Error RaggedDotCudaImpl( @@ -289,7 +291,7 @@ ffi::Error RaggedDotCudaImpl( return ffi::Error::Internal("Failed to zero output."); } - if (is_k_aligned(k)) { + if (is_aligned(k, n)) { return ExecuteGroupedGemm( stream, scratch, reinterpret_cast(lhs.typed_data()), @@ -350,8 +352,8 @@ ffi::Error RaggedDotBwdCudaImpl( } // For backward pass: computes [K, N] = [K, M] @ [M, N] where M is the contraction dim - // K alignment affects the output, use LR kernel when K is not aligned - if (is_k_aligned(k)) { + // Use LR kernel when K or N is not aligned + if (is_aligned(k, n)) { return ExecuteGroupedGemm( stream, scratch, reinterpret_cast(lhs.typed_data()), diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index e9be7e961..8f1e5ac6f 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -35,10 +35,7 @@ def ragged_dot( preferred_element_type=preferred_element_type, ) - # Use CUTLASS kernel when available - # K dimension can be any value (LR kernel handles k=1 for LoRA), but N must be aligned for TMA - n = rhs.shape[-1] - cutlass_alignment = 8 + # Use CUTLASS kernel when available (LR kernel handles any k/n including k=1 for LoRA rank=1) if ( ragged_dot_ffi_available() and jax.default_backend() == "gpu" @@ -46,7 +43,6 @@ def ragged_dot( and rhs.dtype == jnp.bfloat16 and group_sizes.dtype == jnp.int32 and group_offset.dtype == jnp.int32 - and n % cutlass_alignment == 0 ): return ragged_dot_ffi(lhs, rhs, group_sizes, group_offset) From 89b452fe1418f1b0c0022f210fc17fc378b4826b Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 15:41:16 -0800 Subject: [PATCH 084/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 22b4f24d6..1dadc1aff 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -97,11 +97,11 @@ using CollectiveMainloop_Bwd = using GemmKernel_Bwd = cutlass::gemm::kernel::GemmUniversal; using Gemm_Bwd = cutlass::gemm::device::GemmUniversalAdapter; -// Low-rank (LR) kernel variant for unaligned K dimension (e.g., LoRA rank=1) -// Uses CpAsync (non-TMA) schedule which has no alignment requirements +// Low-rank (LR) kernel variant for unaligned K/N dimensions (e.g., LoRA rank=1) +// Uses KernelPtrArrayMultistage which doesn't have TMA alignment requirements constexpr int AlignmentLR = 1; -using TileShapeLR = cute::Shape; -using KernelScheduleLR = cutlass::gemm::KernelPtrArrayCpAsyncWarpSpecializedCooperative; +using TileShapeLR = cute::Shape; +using KernelScheduleLR = cutlass::gemm::KernelPtrArrayMultistage; using EpilogueScheduleLR = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; using CollectiveEpilogueLR = From 2b997fa992159ed85f66038c2c785a4842ed6a9b Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 15:49:54 -0800 Subject: [PATCH 085/106] revert --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 85 +++++-------------------------- skyrl-tx/tx/layers/util.py | 8 ++- 2 files changed, 19 insertions(+), 74 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 1dadc1aff..dd6286208 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -97,43 +97,6 @@ using CollectiveMainloop_Bwd = using GemmKernel_Bwd = cutlass::gemm::kernel::GemmUniversal; using Gemm_Bwd = cutlass::gemm::device::GemmUniversalAdapter; -// Low-rank (LR) kernel variant for unaligned K/N dimensions (e.g., LoRA rank=1) -// Uses KernelPtrArrayMultistage which doesn't have TMA alignment requirements -constexpr int AlignmentLR = 1; -using TileShapeLR = cute::Shape; -using KernelScheduleLR = cutlass::gemm::KernelPtrArrayMultistage; -using EpilogueScheduleLR = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; - -using CollectiveEpilogueLR = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShapeLR, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - DtypeAccum, DtypeAccum, void, LayoutOutput*, AlignmentLR, - DtypeOutput, LayoutOutput*, AlignmentLR, EpilogueScheduleLR, - cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; - -using CollectiveMainloopLR = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, DtypeA, LayoutA*, AlignmentLR, - DtypeB, LayoutB*, AlignmentLR, DtypeAccum, TileShapeLR, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogueLR::SharedStorage))>, - KernelScheduleLR>::CollectiveOp; - -using GemmKernelLR = cutlass::gemm::kernel::GemmUniversal; -using GemmLR = cutlass::gemm::device::GemmUniversalAdapter; - -using CollectiveMainloopLR_Bwd = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, DtypeA, LayoutA_Bwd*, AlignmentLR, - DtypeB, LayoutB*, AlignmentLR, DtypeAccum, TileShapeLR, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogueLR::SharedStorage))>, - KernelScheduleLR>::CollectiveOp; - -using GemmKernelLR_Bwd = cutlass::gemm::kernel::GemmUniversal; -using GemmLR_Bwd = cutlass::gemm::device::GemmUniversalAdapter; - template static T* carve_aligned(char*& p, size_t count) { T* out = reinterpret_cast((reinterpret_cast(p) + 15) & ~uintptr_t(15)); @@ -260,10 +223,6 @@ ffi::Error ExecuteGroupedGemm( return ffi::Error::Success(); } -static bool is_aligned(int32_t k, int32_t n) { - return (k % AlignmentA == 0) && (n % AlignmentB == 0); -} - ffi::Error RaggedDotCudaImpl( cudaStream_t stream, ffi::ScratchAllocator scratch, @@ -291,21 +250,12 @@ ffi::Error RaggedDotCudaImpl( return ffi::Error::Internal("Failed to zero output."); } - if (is_aligned(k, n)) { - return ExecuteGroupedGemm( - stream, scratch, - reinterpret_cast(lhs.typed_data()), - reinterpret_cast(rhs.typed_data()), - reinterpret_cast(out->typed_data()), - group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); - } else { - return ExecuteGroupedGemm( - stream, scratch, - reinterpret_cast(lhs.typed_data()), - reinterpret_cast(rhs.typed_data()), - reinterpret_cast(out->typed_data()), - group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); - } + return ExecuteGroupedGemm( + stream, scratch, + reinterpret_cast(lhs.typed_data()), + reinterpret_cast(rhs.typed_data()), + reinterpret_cast(out->typed_data()), + group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); } XLA_FFI_DEFINE_HANDLER_SYMBOL( @@ -351,23 +301,12 @@ ffi::Error RaggedDotBwdCudaImpl( return ffi::Error::Internal("Failed to zero d_rhs output."); } - // For backward pass: computes [K, N] = [K, M] @ [M, N] where M is the contraction dim - // Use LR kernel when K or N is not aligned - if (is_aligned(k, n)) { - return ExecuteGroupedGemm( - stream, scratch, - reinterpret_cast(lhs.typed_data()), - reinterpret_cast(grad.typed_data()), - reinterpret_cast(d_rhs->typed_data()), - group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); - } else { - return ExecuteGroupedGemm( - stream, scratch, - reinterpret_cast(lhs.typed_data()), - reinterpret_cast(grad.typed_data()), - reinterpret_cast(d_rhs->typed_data()), - group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); - } + return ExecuteGroupedGemm( + stream, scratch, + reinterpret_cast(lhs.typed_data()), + reinterpret_cast(grad.typed_data()), + reinterpret_cast(d_rhs->typed_data()), + group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); } XLA_FFI_DEFINE_HANDLER_SYMBOL( diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index 8f1e5ac6f..f489d792a 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -35,7 +35,11 @@ def ragged_dot( preferred_element_type=preferred_element_type, ) - # Use CUTLASS kernel when available (LR kernel handles any k/n including k=1 for LoRA rank=1) + # CUTLASS kernel requires k and n dimensions divisible by 8 + k = lhs.shape[-1] + n = rhs.shape[-1] + cutlass_alignment = 8 + if ( ragged_dot_ffi_available() and jax.default_backend() == "gpu" @@ -43,6 +47,8 @@ def ragged_dot( and rhs.dtype == jnp.bfloat16 and group_sizes.dtype == jnp.int32 and group_offset.dtype == jnp.int32 + and k % cutlass_alignment == 0 + and n % cutlass_alignment == 0 ): return ragged_dot_ffi(lhs, rhs, group_sizes, group_offset) From ddf1ab7250cc717c5d1e5f2f545d382c2e263d33 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 16:46:07 -0800 Subject: [PATCH 086/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index dd6286208..cc18baabe 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -173,16 +173,17 @@ __global__ void prepare_grouped_gemm( data.B_ptrs[tid] = B + static_cast(tid) * n * k; data.out_ptrs[tid] = out + static_cast(start) * n; data.problem_sizes[tid] = ProblemShapeType(m, n, k); - data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {k, k, 1}); + data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {m, k, 1}); + data.stride_B[tid] = cutlass::make_cute_packed_stride(typename Data::StrideB{}, {k, n, 1}); data.stride_output[tid] = cutlass::make_cute_packed_stride(typename Data::StrideOutput{}, {m, n, 1}); } else { data.B_ptrs[tid] = B + static_cast(start) * n; data.out_ptrs[tid] = out + static_cast(tid) * k * n; data.problem_sizes[tid] = ProblemShapeType(k, n, m); - data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {m, m, 1}); + data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {k, m, 1}); + data.stride_B[tid] = cutlass::make_cute_packed_stride(typename Data::StrideB{}, {m, n, 1}); data.stride_output[tid] = cutlass::make_cute_packed_stride(typename Data::StrideOutput{}, {k, n, 1}); } - data.stride_B[tid] = cutlass::make_cute_packed_stride(typename Data::StrideB{}, {n, n, 1}); } template From fc8c75be06c0eba473bb3ec3e70fa82f60922eb3 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 16:59:35 -0800 Subject: [PATCH 087/106] fix --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index cc18baabe..e9959d190 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -173,17 +173,16 @@ __global__ void prepare_grouped_gemm( data.B_ptrs[tid] = B + static_cast(tid) * n * k; data.out_ptrs[tid] = out + static_cast(start) * n; data.problem_sizes[tid] = ProblemShapeType(m, n, k); - data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {m, k, 1}); - data.stride_B[tid] = cutlass::make_cute_packed_stride(typename Data::StrideB{}, {k, n, 1}); + data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {k, k, 1}); data.stride_output[tid] = cutlass::make_cute_packed_stride(typename Data::StrideOutput{}, {m, n, 1}); } else { data.B_ptrs[tid] = B + static_cast(start) * n; data.out_ptrs[tid] = out + static_cast(tid) * k * n; data.problem_sizes[tid] = ProblemShapeType(k, n, m); - data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {k, m, 1}); - data.stride_B[tid] = cutlass::make_cute_packed_stride(typename Data::StrideB{}, {m, n, 1}); + data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {k, k, 1}); data.stride_output[tid] = cutlass::make_cute_packed_stride(typename Data::StrideOutput{}, {k, n, 1}); } + data.stride_B[tid] = cutlass::make_cute_packed_stride(typename Data::StrideB{}, {n, n, 1}); } template From b0e14f352523de6289ba1dc6b30dc1b908079a17 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 17:57:07 -0800 Subject: [PATCH 088/106] update --- skyrl-tx/tx/tinker/backends/jax.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 720b760eb..4160f4dc9 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -60,7 +60,7 @@ class JaxBackendConfig(BaseModel, extra="forbid"): """Configuration specific to the JAX backend.""" max_lora_adapters: int = Field(default=32, description="Maximum number of LoRA adapters") - max_lora_rank: int = Field(default=32, description="Maximum LoRA rank") + max_lora_rank: int = Field(default=8, description="Maximum LoRA rank") tensor_parallel_size: int = Field(default=1, description="Tensor parallelism degree to use for the model") expert_parallel_size: int = Field(default=1, description="Expert parallelism degree for MoE layers") fully_sharded_data_parallel_size: int = Field( @@ -165,6 +165,12 @@ def __init__(self, base_model: str, config: JaxBackendConfig): shard_attention_heads=config.shard_attention_heads, ) + if config.max_lora_rank % 8 != 0: + logger.warning( + f"[bold yellow]max_lora_rank={config.max_lora_rank} is not divisible by 8. " + "This could lead to degraded performance for MoE models.[/bold yellow]" + ) + model_class = get_model_class(self.model_config) # Create model and load weights From 79b0d44cef6cee7f40f3e8c54cf0e16760b485b5 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 17:57:33 -0800 Subject: [PATCH 089/106] update --- skyrl-tx/tx/tinker/backends/jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 4160f4dc9..aecfa25e1 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -168,7 +168,7 @@ def __init__(self, base_model: str, config: JaxBackendConfig): if config.max_lora_rank % 8 != 0: logger.warning( f"[bold yellow]max_lora_rank={config.max_lora_rank} is not divisible by 8. " - "This could lead to degraded performance for MoE models.[/bold yellow]" + "This could lead to degraded performance.[/bold yellow]" ) model_class = get_model_class(self.model_config) From 1a7485fbe4c513809ea51be360ebfff7eed01bfb Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 18:06:47 -0800 Subject: [PATCH 090/106] add tests for lora --- skyrl-tx/benchmarks/bench_ragged_dot.py | 188 +++++++++++++++++------- skyrl-tx/benchmarks/sweep_tile_sizes.py | 125 ++++++++++------ 2 files changed, 212 insertions(+), 101 deletions(-) diff --git a/skyrl-tx/benchmarks/bench_ragged_dot.py b/skyrl-tx/benchmarks/bench_ragged_dot.py index e747f18ec..37d74ddb9 100644 --- a/skyrl-tx/benchmarks/bench_ragged_dot.py +++ b/skyrl-tx/benchmarks/bench_ragged_dot.py @@ -1,4 +1,4 @@ -"""Benchmark ragged_dot CUTLASS kernel with Qwen3-30B-A3B MoE shapes.""" +"""Benchmark ragged_dot CUTLASS kernel with Qwen3-30B-A3B MoE and LoRA shapes.""" import argparse import time @@ -10,18 +10,44 @@ from tx.ffi import ragged_dot_ffi, ragged_dot_ffi_available -def generate_group_sizes(num_tokens: int, num_experts: int, key: jax.Array) -> jax.Array: +# Preset configurations for different workloads +PRESETS = { + "moe": { + "description": "MoE expert layer (Qwen3-30B-A3B)", + "num_tokens": 8192, + "num_groups": 128, # num_experts + "k_dim": 2048, # hidden_size + "n_dim": 768, # intermediate_size + }, + "lora": { + "description": "LoRA adapter layer", + "num_tokens": 8192, + "num_groups": 32, # max_lora_adapters + "k_dim": 8, # lora_rank + "n_dim": 4096, # output features + }, + "lora-moe": { + "description": "LoRA on MoE experts (combined groups)", + "num_tokens": 8192, + "num_groups": 4096, # num_experts * max_lora_adapters (128 * 32) + "k_dim": 8, # lora_rank + "n_dim": 768, # intermediate_size + }, +} + + +def generate_group_sizes(num_tokens: int, num_groups: int, key: jax.Array) -> jax.Array: """Generate random group sizes that sum to num_tokens.""" - # Random assignment of tokens to experts - assignments = jax.random.randint(key, (num_tokens,), 0, num_experts) - return jnp.bincount(assignments, length=num_experts).astype(jnp.int32) + # Random assignment of tokens to groups + assignments = jax.random.randint(key, (num_tokens,), 0, num_groups) + return jnp.bincount(assignments, length=num_groups).astype(jnp.int32) def benchmark_forward( num_tokens: int, - hidden_size: int, - intermediate_size: int, - num_experts: int, + k_dim: int, + n_dim: int, + num_groups: int, num_warmup: int = 5, num_iters: int = 20, use_ffi: bool = True, @@ -30,9 +56,9 @@ def benchmark_forward( key = jax.random.PRNGKey(42) k1, k2, k3 = jax.random.split(key, 3) - lhs = jax.random.normal(k1, (num_tokens, hidden_size), dtype=jnp.bfloat16) - rhs = jax.random.normal(k2, (num_experts, hidden_size, intermediate_size), dtype=jnp.bfloat16) - group_sizes = generate_group_sizes(num_tokens, num_experts, k3) + lhs = jax.random.normal(k1, (num_tokens, k_dim), dtype=jnp.bfloat16) + rhs = jax.random.normal(k2, (num_groups, k_dim, n_dim), dtype=jnp.bfloat16) + group_sizes = generate_group_sizes(num_tokens, num_groups, k3) group_offset = jnp.array([0], dtype=jnp.int32) if use_ffi: @@ -53,7 +79,7 @@ def benchmark_forward( elapsed = time.perf_counter() - start # FLOPs: 2 * M * K * N (matmul FLOPs) - flops = 2 * num_tokens * hidden_size * intermediate_size + flops = 2 * num_tokens * k_dim * n_dim tflops = (flops * num_iters / elapsed) / 1e12 return elapsed / num_iters, tflops @@ -61,9 +87,9 @@ def benchmark_forward( def benchmark_backward( num_tokens: int, - hidden_size: int, - intermediate_size: int, - num_experts: int, + k_dim: int, + n_dim: int, + num_groups: int, num_warmup: int = 5, num_iters: int = 20, use_ffi: bool = True, @@ -72,9 +98,9 @@ def benchmark_backward( key = jax.random.PRNGKey(42) k1, k2, k3 = jax.random.split(key, 3) - lhs = jax.random.normal(k1, (num_tokens, hidden_size), dtype=jnp.bfloat16) - rhs = jax.random.normal(k2, (num_experts, hidden_size, intermediate_size), dtype=jnp.bfloat16) - group_sizes = generate_group_sizes(num_tokens, num_experts, k3) + lhs = jax.random.normal(k1, (num_tokens, k_dim), dtype=jnp.bfloat16) + rhs = jax.random.normal(k2, (num_groups, k_dim, n_dim), dtype=jnp.bfloat16) + group_sizes = generate_group_sizes(num_tokens, num_groups, k3) group_offset = jnp.array([0], dtype=jnp.int32) if use_ffi: @@ -102,55 +128,35 @@ def forward(lhs, rhs): # Backward FLOPs: d_lhs = grad @ rhs.T (2*M*N*K) + d_rhs = lhs.T @ grad (2*K*M*N) # Total: 4 * M * K * N - flops = 4 * num_tokens * hidden_size * intermediate_size + flops = 4 * num_tokens * k_dim * n_dim tflops = (flops * num_iters / elapsed) / 1e12 return elapsed / num_iters, tflops -def main(): - parser = argparse.ArgumentParser(description="Benchmark ragged_dot CUTLASS kernel") - parser.add_argument("--num-tokens", type=int, default=8192, help="Number of tokens (M)") - parser.add_argument("--num-experts", type=int, default=128, help="Number of experts (G)") - parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden size (K)") - parser.add_argument("--intermediate-size", type=int, default=768, help="MoE intermediate size (N)") - parser.add_argument("--num-warmup", type=int, default=5, help="Warmup iterations") - parser.add_argument("--num-iters", type=int, default=20, help="Benchmark iterations") - parser.add_argument("--backward-only", action="store_true", help="Only benchmark backward pass") - parser.add_argument("--forward-only", action="store_true", help="Only benchmark forward pass") - args = parser.parse_args() - - print("Ragged Dot Benchmark (Qwen3-30B-A3B MoE shapes)") - print("=" * 60) - print(f"CUTLASS FFI available: {ragged_dot_ffi_available()}") - print(f"JAX backend: {jax.default_backend()}") - print(f"Devices: {jax.device_count()}") - print() - print(f"Config:") - print(f" num_tokens (M): {args.num_tokens}") - print(f" num_experts (G): {args.num_experts}") - print(f" hidden_size (K): {args.hidden_size}") - print(f" intermediate_size (N): {args.intermediate_size}") - print(f" warmup/iters: {args.num_warmup}/{args.num_iters}") - print() - - run_forward = not args.backward_only - run_backward = not args.forward_only - +def run_benchmark_suite( + num_tokens: int, + k_dim: int, + n_dim: int, + num_groups: int, + num_warmup: int, + num_iters: int, + run_forward: bool, + run_backward: bool, +): + """Run the benchmark suite with the given configuration.""" if run_forward: print("Forward Pass (lhs[M,K] @ rhs[G,K,N] -> out[M,N])") print("-" * 60) if ragged_dot_ffi_available(): ffi_time, ffi_tflops = benchmark_forward( - args.num_tokens, args.hidden_size, args.intermediate_size, - args.num_experts, args.num_warmup, args.num_iters, use_ffi=True + num_tokens, k_dim, n_dim, num_groups, num_warmup, num_iters, use_ffi=True ) print(f" CUTLASS FFI: {ffi_time*1000:8.3f} ms {ffi_tflops:8.2f} TFLOPS") jax_time, jax_tflops = benchmark_forward( - args.num_tokens, args.hidden_size, args.intermediate_size, - args.num_experts, args.num_warmup, args.num_iters, use_ffi=False + num_tokens, k_dim, n_dim, num_groups, num_warmup, num_iters, use_ffi=False ) print(f" JAX ragged: {jax_time*1000:8.3f} ms {jax_tflops:8.2f} TFLOPS") @@ -164,14 +170,12 @@ def main(): if ragged_dot_ffi_available(): ffi_time, ffi_tflops = benchmark_backward( - args.num_tokens, args.hidden_size, args.intermediate_size, - args.num_experts, args.num_warmup, args.num_iters, use_ffi=True + num_tokens, k_dim, n_dim, num_groups, num_warmup, num_iters, use_ffi=True ) print(f" CUTLASS FFI: {ffi_time*1000:8.3f} ms {ffi_tflops:8.2f} TFLOPS") jax_time, jax_tflops = benchmark_backward( - args.num_tokens, args.hidden_size, args.intermediate_size, - args.num_experts, args.num_warmup, args.num_iters, use_ffi=False + num_tokens, k_dim, n_dim, num_groups, num_warmup, num_iters, use_ffi=False ) print(f" JAX ragged: {jax_time*1000:8.3f} ms {jax_tflops:8.2f} TFLOPS") @@ -180,5 +184,77 @@ def main(): print() +def main(): + parser = argparse.ArgumentParser(description="Benchmark ragged_dot CUTLASS kernel") + parser.add_argument("--preset", choices=list(PRESETS.keys()), help="Use a preset configuration (moe, lora, lora-moe)") + parser.add_argument("--all-presets", action="store_true", help="Run all preset configurations") + parser.add_argument("--num-tokens", type=int, default=8192, help="Number of tokens (M)") + parser.add_argument("--num-groups", type=int, default=128, help="Number of groups (G) - experts or adapters") + parser.add_argument("--k-dim", type=int, default=2048, help="K dimension - hidden_size (MoE) or lora_rank (LoRA)") + parser.add_argument("--n-dim", type=int, default=768, help="N dimension - output features") + parser.add_argument("--num-warmup", type=int, default=5, help="Warmup iterations") + parser.add_argument("--num-iters", type=int, default=20, help="Benchmark iterations") + parser.add_argument("--backward-only", action="store_true", help="Only benchmark backward pass") + parser.add_argument("--forward-only", action="store_true", help="Only benchmark forward pass") + args = parser.parse_args() + + print("Ragged Dot Benchmark") + print("=" * 60) + print(f"CUTLASS FFI available: {ragged_dot_ffi_available()}") + print(f"JAX backend: {jax.default_backend()}") + print(f"Devices: {jax.device_count()}") + print() + + run_forward = not args.backward_only + run_backward = not args.forward_only + + if args.all_presets: + # Run all presets + for preset_name, preset in PRESETS.items(): + print("=" * 60) + print(f"Preset: {preset_name} - {preset['description']}") + print("=" * 60) + print(f"Config:") + print(f" num_tokens (M): {preset['num_tokens']}") + print(f" num_groups (G): {preset['num_groups']}") + print(f" k_dim (K): {preset['k_dim']}") + print(f" n_dim (N): {preset['n_dim']}") + print(f" warmup/iters: {args.num_warmup}/{args.num_iters}") + print() + run_benchmark_suite( + preset['num_tokens'], preset['k_dim'], preset['n_dim'], preset['num_groups'], + args.num_warmup, args.num_iters, run_forward, run_backward + ) + elif args.preset: + # Use a specific preset + preset = PRESETS[args.preset] + print(f"Preset: {args.preset} - {preset['description']}") + print() + print(f"Config:") + print(f" num_tokens (M): {preset['num_tokens']}") + print(f" num_groups (G): {preset['num_groups']}") + print(f" k_dim (K): {preset['k_dim']}") + print(f" n_dim (N): {preset['n_dim']}") + print(f" warmup/iters: {args.num_warmup}/{args.num_iters}") + print() + run_benchmark_suite( + preset['num_tokens'], preset['k_dim'], preset['n_dim'], preset['num_groups'], + args.num_warmup, args.num_iters, run_forward, run_backward + ) + else: + # Use custom config from args + print(f"Config:") + print(f" num_tokens (M): {args.num_tokens}") + print(f" num_groups (G): {args.num_groups}") + print(f" k_dim (K): {args.k_dim}") + print(f" n_dim (N): {args.n_dim}") + print(f" warmup/iters: {args.num_warmup}/{args.num_iters}") + print() + run_benchmark_suite( + args.num_tokens, args.k_dim, args.n_dim, args.num_groups, + args.num_warmup, args.num_iters, run_forward, run_backward + ) + + if __name__ == "__main__": main() diff --git a/skyrl-tx/benchmarks/sweep_tile_sizes.py b/skyrl-tx/benchmarks/sweep_tile_sizes.py index b81f38e89..c40534a11 100644 --- a/skyrl-tx/benchmarks/sweep_tile_sizes.py +++ b/skyrl-tx/benchmarks/sweep_tile_sizes.py @@ -21,6 +21,22 @@ (128, 128, 128), ] +# Workload presets for benchmarking +WORKLOAD_PRESETS = { + "moe": { + "description": "MoE expert layer (Qwen3-30B-A3B)", + "args": ["--preset", "moe"], + }, + "lora": { + "description": "LoRA adapter layer (rank=8)", + "args": ["--preset", "lora"], + }, + "lora-moe": { + "description": "LoRA on MoE experts", + "args": ["--preset", "lora-moe"], + }, +} + def set_tile_shape(m: int, n: int, k: int) -> None: """Update the TileShape in the CUDA file.""" @@ -57,14 +73,20 @@ def rebuild_kernel() -> bool: return SO_FILE.exists() -def run_benchmark(num_tokens: int = 8192) -> dict | None: +def run_benchmark(workload: str = "moe", num_tokens: int = 8192) -> dict | None: """Run the benchmark and parse results.""" + # Build command with workload preset or custom args + cmd = [sys.executable, "benchmarks/bench_ragged_dot.py", "--num-warmup", "3", "--num-iters", "10"] + + if workload in WORKLOAD_PRESETS: + cmd.extend(WORKLOAD_PRESETS[workload]["args"]) + else: + # Legacy: custom num_tokens only + cmd.extend(["--num-tokens", str(num_tokens)]) + # Need to run in a fresh process to reload the .so result = subprocess.run( - [sys.executable, "benchmarks/bench_ragged_dot.py", - "--num-tokens", str(num_tokens), - "--num-warmup", "3", - "--num-iters", "10"], + cmd, capture_output=True, text=True, cwd=CUDA_FILE.parent.parent.parent, @@ -96,22 +118,30 @@ def run_benchmark(num_tokens: int = 8192) -> dict | None: def main(): import argparse parser = argparse.ArgumentParser(description="Sweep tile sizes for CUTLASS kernel") - parser.add_argument("--num-tokens", type=int, default=8192, help="Number of tokens") + parser.add_argument("--workload", choices=list(WORKLOAD_PRESETS.keys()), default="moe", + help="Workload preset to benchmark (moe, lora, lora-moe)") + parser.add_argument("--all-workloads", action="store_true", help="Sweep all workloads") + parser.add_argument("--num-tokens", type=int, default=8192, help="Number of tokens (for custom workload)") parser.add_argument("--dry-run", action="store_true", help="Only show configs, don't run") args = parser.parse_args() print("CUTLASS Tile Size Sweep") print("=" * 70) - print(f"Qwen3-30B-A3B shapes: K=2048, N=768, M={args.num_tokens}") - print() + + workloads = list(WORKLOAD_PRESETS.keys()) if args.all_workloads else [args.workload] if args.dry_run: print("Tile configurations to test:") for m, n, k in TILE_CONFIGS: print(f" ({m}, {n}, {k})") + print() + print("Workloads to test:") + for w in workloads: + print(f" {w}: {WORKLOAD_PRESETS[w]['description']}") return - results = [] + # Store results per workload + all_results: dict[str, list] = {w: [] for w in workloads} for m, n, k in TILE_CONFIGS: print(f"\n{'='*70}") @@ -122,44 +152,49 @@ def main(): if not rebuild_kernel(): print(" FAILED to build") - results.append((m, n, k, None)) + for w in workloads: + all_results[w].append((m, n, k, None)) continue - bench_results = run_benchmark(args.num_tokens) - if bench_results: - print(f" Forward: {bench_results.get('fwd_ms', 'N/A'):>8.3f} ms {bench_results.get('fwd_tflops', 'N/A'):>8.2f} TFLOPS") - print(f" Backward: {bench_results.get('bwd_ms', 'N/A'):>8.3f} ms {bench_results.get('bwd_tflops', 'N/A'):>8.2f} TFLOPS") - results.append((m, n, k, bench_results)) - else: - print(" FAILED to benchmark") - results.append((m, n, k, None)) - - # Summary - print(f"\n{'='*70}") - print("SUMMARY") - print("=" * 70) - print(f"{'TileShape':<20} {'Fwd (ms)':>10} {'Fwd TFLOPS':>12} {'Bwd (ms)':>10} {'Bwd TFLOPS':>12}") - print("-" * 70) - - for m, n, k, res in results: - tile_str = f"({m}, {n}, {k})" - if res: - fwd_ms = f"{res.get('fwd_ms', 0):.3f}" - fwd_tf = f"{res.get('fwd_tflops', 0):.2f}" - bwd_ms = f"{res.get('bwd_ms', 0):.3f}" - bwd_tf = f"{res.get('bwd_tflops', 0):.2f}" - print(f"{tile_str:<20} {fwd_ms:>10} {fwd_tf:>12} {bwd_ms:>10} {bwd_tf:>12}") - else: - print(f"{tile_str:<20} {'FAILED':>10}") - - # Find best - valid_results = [(m, n, k, r) for m, n, k, r in results if r and 'fwd_tflops' in r] - if valid_results: - best_fwd = max(valid_results, key=lambda x: x[3]['fwd_tflops']) - best_bwd = max(valid_results, key=lambda x: x[3].get('bwd_tflops', 0)) - print() - print(f"Best forward: ({best_fwd[0]}, {best_fwd[1]}, {best_fwd[2]}) - {best_fwd[3]['fwd_tflops']:.2f} TFLOPS") - print(f"Best backward: ({best_bwd[0]}, {best_bwd[1]}, {best_bwd[2]}) - {best_bwd[3].get('bwd_tflops', 0):.2f} TFLOPS") + for workload in workloads: + print(f"\n Workload: {workload} ({WORKLOAD_PRESETS[workload]['description']})") + bench_results = run_benchmark(workload) + if bench_results: + print(f" Forward: {bench_results.get('fwd_ms', 'N/A'):>8.3f} ms {bench_results.get('fwd_tflops', 'N/A'):>8.2f} TFLOPS") + print(f" Backward: {bench_results.get('bwd_ms', 'N/A'):>8.3f} ms {bench_results.get('bwd_tflops', 'N/A'):>8.2f} TFLOPS") + all_results[workload].append((m, n, k, bench_results)) + else: + print(" FAILED to benchmark") + all_results[workload].append((m, n, k, None)) + + # Summary per workload + for workload in workloads: + results = all_results[workload] + print(f"\n{'='*70}") + print(f"SUMMARY: {workload} ({WORKLOAD_PRESETS[workload]['description']})") + print("=" * 70) + print(f"{'TileShape':<20} {'Fwd (ms)':>10} {'Fwd TFLOPS':>12} {'Bwd (ms)':>10} {'Bwd TFLOPS':>12}") + print("-" * 70) + + for m, n, k, res in results: + tile_str = f"({m}, {n}, {k})" + if res: + fwd_ms = f"{res.get('fwd_ms', 0):.3f}" + fwd_tf = f"{res.get('fwd_tflops', 0):.2f}" + bwd_ms = f"{res.get('bwd_ms', 0):.3f}" + bwd_tf = f"{res.get('bwd_tflops', 0):.2f}" + print(f"{tile_str:<20} {fwd_ms:>10} {fwd_tf:>12} {bwd_ms:>10} {bwd_tf:>12}") + else: + print(f"{tile_str:<20} {'FAILED':>10}") + + # Find best + valid_results = [(m, n, k, r) for m, n, k, r in results if r and 'fwd_tflops' in r] + if valid_results: + best_fwd = max(valid_results, key=lambda x: x[3]['fwd_tflops']) + best_bwd = max(valid_results, key=lambda x: x[3].get('bwd_tflops', 0)) + print() + print(f"Best forward: ({best_fwd[0]}, {best_fwd[1]}, {best_fwd[2]}) - {best_fwd[3]['fwd_tflops']:.2f} TFLOPS") + print(f"Best backward: ({best_bwd[0]}, {best_bwd[1]}, {best_bwd[2]}) - {best_bwd[3].get('bwd_tflops', 0):.2f} TFLOPS") if __name__ == "__main__": From 8fa5c2e90efad1ecb3f17be50c0219fe943acce4 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 18:24:25 -0800 Subject: [PATCH 091/106] update --- skyrl-tx/benchmarks/bench_ragged_dot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/benchmarks/bench_ragged_dot.py b/skyrl-tx/benchmarks/bench_ragged_dot.py index 37d74ddb9..261a5b33c 100644 --- a/skyrl-tx/benchmarks/bench_ragged_dot.py +++ b/skyrl-tx/benchmarks/bench_ragged_dot.py @@ -29,7 +29,7 @@ "lora-moe": { "description": "LoRA on MoE experts (combined groups)", "num_tokens": 8192, - "num_groups": 4096, # num_experts * max_lora_adapters (128 * 32) + "num_groups": 1024, # num_experts * max_lora_adapters (128 * 8, capped at kernel limit) "k_dim": 8, # lora_rank "n_dim": 768, # intermediate_size }, From f6a6a92c19542f47212010272474973b9a8e0aae Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 18:42:41 -0800 Subject: [PATCH 092/106] update --- skyrl-tx/benchmarks/sweep_tile_sizes.py | 145 +++++++++++++++++------- 1 file changed, 106 insertions(+), 39 deletions(-) diff --git a/skyrl-tx/benchmarks/sweep_tile_sizes.py b/skyrl-tx/benchmarks/sweep_tile_sizes.py index c40534a11..ae3e60443 100644 --- a/skyrl-tx/benchmarks/sweep_tile_sizes.py +++ b/skyrl-tx/benchmarks/sweep_tile_sizes.py @@ -1,9 +1,10 @@ -"""Sweep tile sizes for ragged_dot CUTLASS kernel optimization.""" +"""Sweep tile sizes and kernel parameters for ragged_dot CUTLASS kernel optimization.""" import subprocess import sys import re from pathlib import Path +from itertools import product CUDA_FILE = Path(__file__).parent.parent / "tx/ffi/ragged_dot_ffi.cu" SO_FILE = Path(__file__).parent.parent / "tx/ffi/libragged_dot_ffi.so" @@ -11,16 +12,27 @@ # Tile configurations to test: (M, N, K) # Constraints: dimensions should be powers of 2 or multiples that work with SM90 TILE_CONFIGS = [ - (128, 256, 64), # current + (64, 256, 64), # best from previous sweep (128, 128, 64), (64, 128, 64), - (64, 256, 64), - (256, 128, 64), - (128, 64, 128), - (64, 64, 128), + (128, 256, 64), (128, 128, 128), ] +# Cluster shapes to test: (M, N, K) +CLUSTER_CONFIGS = [ + (1, 1, 1), # current + (2, 1, 1), + (1, 2, 1), + (2, 2, 1), +] + +# Kernel schedules to test +SCHEDULE_CONFIGS = [ + ("KernelPtrArrayTmaWarpSpecializedPingpong", "PtrArrayTmaWarpSpecializedPingpong"), # current + ("KernelPtrArrayTmaWarpSpecializedCooperative", "PtrArrayTmaWarpSpecializedCooperative"), +] + # Workload presets for benchmarking WORKLOAD_PRESETS = { "moe": { @@ -48,7 +60,44 @@ def set_tile_shape(m: int, n: int, k: int) -> None: new_content = re.sub(pattern, replacement, content) CUDA_FILE.write_text(new_content) - print(f"Set TileShape to ({m}, {n}, {k})") + + +def set_cluster_shape(m: int, n: int, k: int) -> None: + """Update the ClusterShape in the CUDA file.""" + content = CUDA_FILE.read_text() + + pattern = r'using ClusterShape = cute::Shape;' + replacement = f'using ClusterShape = cute::Shape;' + + new_content = re.sub(pattern, replacement, content) + CUDA_FILE.write_text(new_content) + + +def set_schedule(kernel_schedule: str, epilogue_schedule: str) -> None: + """Update the KernelSchedule and EpilogueSchedule in the CUDA file.""" + content = CUDA_FILE.read_text() + + # Replace KernelSchedule + pattern = r'using KernelSchedule = cutlass::gemm::\w+;' + replacement = f'using KernelSchedule = cutlass::gemm::{kernel_schedule};' + content = re.sub(pattern, replacement, content) + + # Replace EpilogueSchedule + pattern = r'using EpilogueSchedule = cutlass::epilogue::\w+;' + replacement = f'using EpilogueSchedule = cutlass::epilogue::{epilogue_schedule};' + content = re.sub(pattern, replacement, content) + + CUDA_FILE.write_text(content) + + +def set_kernel_config(tile: tuple, cluster: tuple, schedule: tuple) -> str: + """Set all kernel configuration parameters. Returns a config description string.""" + set_tile_shape(*tile) + set_cluster_shape(*cluster) + set_schedule(*schedule) + + schedule_name = schedule[0].replace("KernelPtrArrayTmaWarpSpecialized", "") + return f"Tile{tile} Cluster{cluster} {schedule_name}" def rebuild_kernel() -> bool: @@ -117,43 +166,62 @@ def run_benchmark(workload: str = "moe", num_tokens: int = 8192) -> dict | None: def main(): import argparse - parser = argparse.ArgumentParser(description="Sweep tile sizes for CUTLASS kernel") + parser = argparse.ArgumentParser(description="Sweep kernel parameters for CUTLASS ragged_dot") parser.add_argument("--workload", choices=list(WORKLOAD_PRESETS.keys()), default="moe", help="Workload preset to benchmark (moe, lora, lora-moe)") parser.add_argument("--all-workloads", action="store_true", help="Sweep all workloads") - parser.add_argument("--num-tokens", type=int, default=8192, help="Number of tokens (for custom workload)") + parser.add_argument("--sweep-tiles", action="store_true", help="Sweep tile shapes") + parser.add_argument("--sweep-clusters", action="store_true", help="Sweep cluster shapes") + parser.add_argument("--sweep-schedules", action="store_true", help="Sweep kernel schedules") + parser.add_argument("--sweep-all", action="store_true", help="Sweep all parameters") parser.add_argument("--dry-run", action="store_true", help="Only show configs, don't run") args = parser.parse_args() - print("CUTLASS Tile Size Sweep") - print("=" * 70) + # Default to sweeping tiles only if nothing specified + if not (args.sweep_tiles or args.sweep_clusters or args.sweep_schedules or args.sweep_all): + args.sweep_tiles = True + + if args.sweep_all: + args.sweep_tiles = args.sweep_clusters = args.sweep_schedules = True + + print("CUTLASS Kernel Parameter Sweep") + print("=" * 80) workloads = list(WORKLOAD_PRESETS.keys()) if args.all_workloads else [args.workload] + # Build parameter combinations + tiles = TILE_CONFIGS if args.sweep_tiles else [TILE_CONFIGS[0]] + clusters = CLUSTER_CONFIGS if args.sweep_clusters else [CLUSTER_CONFIGS[0]] + schedules = SCHEDULE_CONFIGS if args.sweep_schedules else [SCHEDULE_CONFIGS[0]] + + configs = list(product(tiles, clusters, schedules)) + print(f"Testing {len(configs)} configurations across {len(workloads)} workload(s)") + print() + if args.dry_run: - print("Tile configurations to test:") - for m, n, k in TILE_CONFIGS: - print(f" ({m}, {n}, {k})") + print("Configurations to test:") + for tile, cluster, schedule in configs: + schedule_name = schedule[0].replace("KernelPtrArrayTmaWarpSpecialized", "") + print(f" Tile{tile} Cluster{cluster} {schedule_name}") print() print("Workloads to test:") for w in workloads: print(f" {w}: {WORKLOAD_PRESETS[w]['description']}") return - # Store results per workload + # Store results: {workload: [(config_str, bench_results), ...]} all_results: dict[str, list] = {w: [] for w in workloads} - for m, n, k in TILE_CONFIGS: - print(f"\n{'='*70}") - print(f"Testing TileShape({m}, {n}, {k})") - print("-" * 70) - - set_tile_shape(m, n, k) + for tile, cluster, schedule in configs: + print(f"\n{'='*80}") + config_str = set_kernel_config(tile, cluster, schedule) + print(f"Testing: {config_str}") + print("-" * 80) if not rebuild_kernel(): print(" FAILED to build") for w in workloads: - all_results[w].append((m, n, k, None)) + all_results[w].append((config_str, None)) continue for workload in workloads: @@ -162,39 +230,38 @@ def main(): if bench_results: print(f" Forward: {bench_results.get('fwd_ms', 'N/A'):>8.3f} ms {bench_results.get('fwd_tflops', 'N/A'):>8.2f} TFLOPS") print(f" Backward: {bench_results.get('bwd_ms', 'N/A'):>8.3f} ms {bench_results.get('bwd_tflops', 'N/A'):>8.2f} TFLOPS") - all_results[workload].append((m, n, k, bench_results)) + all_results[workload].append((config_str, bench_results)) else: print(" FAILED to benchmark") - all_results[workload].append((m, n, k, None)) + all_results[workload].append((config_str, None)) # Summary per workload for workload in workloads: results = all_results[workload] - print(f"\n{'='*70}") + print(f"\n{'='*80}") print(f"SUMMARY: {workload} ({WORKLOAD_PRESETS[workload]['description']})") - print("=" * 70) - print(f"{'TileShape':<20} {'Fwd (ms)':>10} {'Fwd TFLOPS':>12} {'Bwd (ms)':>10} {'Bwd TFLOPS':>12}") - print("-" * 70) + print("=" * 80) + print(f"{'Configuration':<50} {'Fwd (ms)':>10} {'Fwd TF':>8} {'Bwd (ms)':>10} {'Bwd TF':>8}") + print("-" * 80) - for m, n, k, res in results: - tile_str = f"({m}, {n}, {k})" + for config_str, res in results: if res: fwd_ms = f"{res.get('fwd_ms', 0):.3f}" - fwd_tf = f"{res.get('fwd_tflops', 0):.2f}" + fwd_tf = f"{res.get('fwd_tflops', 0):.1f}" bwd_ms = f"{res.get('bwd_ms', 0):.3f}" - bwd_tf = f"{res.get('bwd_tflops', 0):.2f}" - print(f"{tile_str:<20} {fwd_ms:>10} {fwd_tf:>12} {bwd_ms:>10} {bwd_tf:>12}") + bwd_tf = f"{res.get('bwd_tflops', 0):.1f}" + print(f"{config_str:<50} {fwd_ms:>10} {fwd_tf:>8} {bwd_ms:>10} {bwd_tf:>8}") else: - print(f"{tile_str:<20} {'FAILED':>10}") + print(f"{config_str:<50} {'FAILED':>10}") # Find best - valid_results = [(m, n, k, r) for m, n, k, r in results if r and 'fwd_tflops' in r] + valid_results = [(cfg, r) for cfg, r in results if r and 'fwd_tflops' in r] if valid_results: - best_fwd = max(valid_results, key=lambda x: x[3]['fwd_tflops']) - best_bwd = max(valid_results, key=lambda x: x[3].get('bwd_tflops', 0)) + best_fwd = max(valid_results, key=lambda x: x[1]['fwd_tflops']) + best_bwd = max(valid_results, key=lambda x: x[1].get('bwd_tflops', 0)) print() - print(f"Best forward: ({best_fwd[0]}, {best_fwd[1]}, {best_fwd[2]}) - {best_fwd[3]['fwd_tflops']:.2f} TFLOPS") - print(f"Best backward: ({best_bwd[0]}, {best_bwd[1]}, {best_bwd[2]}) - {best_bwd[3].get('bwd_tflops', 0):.2f} TFLOPS") + print(f"Best forward: {best_fwd[0]} - {best_fwd[1]['fwd_tflops']:.2f} TFLOPS") + print(f"Best backward: {best_bwd[0]} - {best_bwd[1].get('bwd_tflops', 0):.2f} TFLOPS") if __name__ == "__main__": From 2db8415aa7e3a9e354361232e1d9afa96bce72fb Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 19:33:04 -0800 Subject: [PATCH 093/106] update --- skyrl-tx/benchmarks/sweep_tile_sizes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skyrl-tx/benchmarks/sweep_tile_sizes.py b/skyrl-tx/benchmarks/sweep_tile_sizes.py index ae3e60443..1e2e9fc4f 100644 --- a/skyrl-tx/benchmarks/sweep_tile_sizes.py +++ b/skyrl-tx/benchmarks/sweep_tile_sizes.py @@ -116,7 +116,6 @@ def rebuild_kernel() -> bool: ) if result.returncode != 0: - print(f"Build failed: {result.stderr}") return False return SO_FILE.exists() From ae19ae9b736d77ed6f8c878a38a339dbc19dbd88 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 20:01:41 -0800 Subject: [PATCH 094/106] update --- skyrl-tx/benchmarks/sweep_tile_sizes.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/benchmarks/sweep_tile_sizes.py b/skyrl-tx/benchmarks/sweep_tile_sizes.py index 1e2e9fc4f..62114029e 100644 --- a/skyrl-tx/benchmarks/sweep_tile_sizes.py +++ b/skyrl-tx/benchmarks/sweep_tile_sizes.py @@ -11,6 +11,7 @@ # Tile configurations to test: (M, N, K) # Constraints: dimensions should be powers of 2 or multiples that work with SM90 +# Note: Cooperative schedule requires M >= 128 TILE_CONFIGS = [ (64, 256, 64), # best from previous sweep (128, 128, 64), @@ -19,6 +20,16 @@ (128, 128, 128), ] +# Smaller tile configs that may work better for LoRA (small K dimension) +TILE_CONFIGS_LORA = [ + (128, 128, 32), # smaller K tile + (128, 256, 32), + (128, 64, 32), + (64, 128, 32), + (64, 64, 32), + (128, 128, 64), # reference +] + # Cluster shapes to test: (M, N, K) CLUSTER_CONFIGS = [ (1, 1, 1), # current @@ -173,6 +184,7 @@ def main(): parser.add_argument("--sweep-clusters", action="store_true", help="Sweep cluster shapes") parser.add_argument("--sweep-schedules", action="store_true", help="Sweep kernel schedules") parser.add_argument("--sweep-all", action="store_true", help="Sweep all parameters") + parser.add_argument("--lora-tiles", action="store_true", help="Use LoRA-optimized tile configs (smaller K)") parser.add_argument("--dry-run", action="store_true", help="Only show configs, don't run") args = parser.parse_args() @@ -189,7 +201,8 @@ def main(): workloads = list(WORKLOAD_PRESETS.keys()) if args.all_workloads else [args.workload] # Build parameter combinations - tiles = TILE_CONFIGS if args.sweep_tiles else [TILE_CONFIGS[0]] + base_tiles = TILE_CONFIGS_LORA if args.lora_tiles else TILE_CONFIGS + tiles = base_tiles if args.sweep_tiles else [base_tiles[0]] clusters = CLUSTER_CONFIGS if args.sweep_clusters else [CLUSTER_CONFIGS[0]] schedules = SCHEDULE_CONFIGS if args.sweep_schedules else [SCHEDULE_CONFIGS[0]] From 988cc062a84ed57a675b73bc4a4187f920bd9201 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 21:42:12 -0800 Subject: [PATCH 095/106] update tiles --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index e9959d190..742610ab4 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -56,7 +56,7 @@ constexpr int AlignmentOutput = 8; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; -using TileShape = cute::Shape; +using TileShape = cute::Shape; using ClusterShape = cute::Shape; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; From 01677444ca6098bc488dd97ed5174fd2380a43aa Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 22:18:45 -0800 Subject: [PATCH 096/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 742610ab4..2e455d226 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -174,13 +174,13 @@ __global__ void prepare_grouped_gemm( data.out_ptrs[tid] = out + static_cast(start) * n; data.problem_sizes[tid] = ProblemShapeType(m, n, k); data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {k, k, 1}); - data.stride_output[tid] = cutlass::make_cute_packed_stride(typename Data::StrideOutput{}, {m, n, 1}); + data.stride_output[tid] = cutlass::make_cute_packed_stride(typename Data::StrideOutput{}, {n, n, 1}); } else { data.B_ptrs[tid] = B + static_cast(start) * n; data.out_ptrs[tid] = out + static_cast(tid) * k * n; data.problem_sizes[tid] = ProblemShapeType(k, n, m); data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {k, k, 1}); - data.stride_output[tid] = cutlass::make_cute_packed_stride(typename Data::StrideOutput{}, {k, n, 1}); + data.stride_output[tid] = cutlass::make_cute_packed_stride(typename Data::StrideOutput{}, {n, n, 1}); } data.stride_B[tid] = cutlass::make_cute_packed_stride(typename Data::StrideB{}, {n, n, 1}); } From 3dfcc14ae6b24bbc6b9cb1f36c67852121ced0f7 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 22:32:16 -0800 Subject: [PATCH 097/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 2e455d226..f5fed34da 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -56,7 +56,7 @@ constexpr int AlignmentOutput = 8; using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; -using TileShape = cute::Shape; +using TileShape = cute::Shape; using ClusterShape = cute::Shape; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; From 50357b42cd3b869fd78e0891b52b3278f33d0fd7 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 19 Jan 2026 23:36:33 -0800 Subject: [PATCH 098/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index f5fed34da..71bd95ad5 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -225,7 +225,7 @@ ffi::Error ExecuteGroupedGemm( ffi::Error RaggedDotCudaImpl( cudaStream_t stream, - ffi::ScratchAllocator scratch, + ffi::ScratchAllocator& scratch, ffi::Buffer lhs, ffi::Buffer rhs, ffi::Buffer group_offset, @@ -273,7 +273,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( // Backward pass for d_rhs: computes lhs.T @ grad per group -> d_rhs[G, K, N] ffi::Error RaggedDotBwdCudaImpl( cudaStream_t stream, - ffi::ScratchAllocator scratch, + ffi::ScratchAllocator& scratch, ffi::Buffer lhs, ffi::Buffer grad, ffi::Buffer group_offset, From 58956a4cf1f0c886eed151d28751b3dc21a8ff87 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 20 Jan 2026 00:00:15 -0800 Subject: [PATCH 099/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 71bd95ad5..4a91da816 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -119,7 +119,8 @@ struct GroupedGemmData { ProblemShapeType* problem_sizes; static std::optional Allocate(ffi::ScratchAllocator& scratch, size_t g) { - size_t bytes = 7 * 16 + + // 128-byte alignment per array for TMA requirements on SM90 + size_t bytes = 7 * 128 + sizeof(const DtypeA*) * g + sizeof(const DtypeB*) * g + sizeof(DtypeOutput*) * g + sizeof(StrideA) * g + sizeof(StrideB) * g + sizeof(StrideOutput) * g + sizeof(ProblemShapeType) * g; @@ -225,7 +226,7 @@ ffi::Error ExecuteGroupedGemm( ffi::Error RaggedDotCudaImpl( cudaStream_t stream, - ffi::ScratchAllocator& scratch, + ffi::ScratchAllocator scratch, ffi::Buffer lhs, ffi::Buffer rhs, ffi::Buffer group_offset, @@ -273,7 +274,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( // Backward pass for d_rhs: computes lhs.T @ grad per group -> d_rhs[G, K, N] ffi::Error RaggedDotBwdCudaImpl( cudaStream_t stream, - ffi::ScratchAllocator& scratch, + ffi::ScratchAllocator scratch, ffi::Buffer lhs, ffi::Buffer grad, ffi::Buffer group_offset, From 447305d23e90408acf8070d92609a1146f6481eb Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 20 Jan 2026 00:51:19 -0800 Subject: [PATCH 100/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 4a91da816..170c9d316 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -97,9 +97,12 @@ using CollectiveMainloop_Bwd = using GemmKernel_Bwd = cutlass::gemm::kernel::GemmUniversal; using Gemm_Bwd = cutlass::gemm::device::GemmUniversalAdapter; +// 128-byte alignment for TMA requirements on SM90 +constexpr size_t kTmaAlignment = 128; + template static T* carve_aligned(char*& p, size_t count) { - T* out = reinterpret_cast((reinterpret_cast(p) + 15) & ~uintptr_t(15)); + T* out = reinterpret_cast((reinterpret_cast(p) + kTmaAlignment - 1) & ~uintptr_t(kTmaAlignment - 1)); p = reinterpret_cast(out + count); return out; } @@ -119,8 +122,7 @@ struct GroupedGemmData { ProblemShapeType* problem_sizes; static std::optional Allocate(ffi::ScratchAllocator& scratch, size_t g) { - // 128-byte alignment per array for TMA requirements on SM90 - size_t bytes = 7 * 128 + + size_t bytes = 7 * kTmaAlignment + sizeof(const DtypeA*) * g + sizeof(const DtypeB*) * g + sizeof(DtypeOutput*) * g + sizeof(StrideA) * g + sizeof(StrideB) * g + sizeof(StrideOutput) * g + sizeof(ProblemShapeType) * g; From c1f1ab1593c9607f756256cafe2093db5fdcfa03 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 20 Jan 2026 01:30:55 -0800 Subject: [PATCH 101/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 170c9d316..5497376de 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -1,6 +1,7 @@ #include #include +#include #include #include @@ -26,20 +27,15 @@ namespace ffi = xla::ffi; -static std::vector g_device_props; - static int get_sm_count() { - int device = 0; - if (cudaGetDevice(&device) != cudaSuccess || device < 0) { - return 0; - } - if (static_cast(device) >= g_device_props.size()) { - g_device_props.resize(device + 1); - } - cudaDeviceProp& props = g_device_props[device]; - if (!props.multiProcessorCount) { - cudaGetDeviceProperties(&props, device); - } + static std::once_flag flag; + static cudaDeviceProp props; + std::call_once(flag, [] { + int device = 0; + if (cudaGetDevice(&device) == cudaSuccess && device >= 0) { + cudaGetDeviceProperties(&props, device); + } + }); return props.multiProcessorCount; } From fa814f1b42e4a5f78fb9031512730772e8d2a00c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 20 Jan 2026 01:33:06 -0800 Subject: [PATCH 102/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 5497376de..b53eeb33e 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -28,15 +28,18 @@ namespace ffi = xla::ffi; static int get_sm_count() { - static std::once_flag flag; - static cudaDeviceProp props; - std::call_once(flag, [] { - int device = 0; - if (cudaGetDevice(&device) == cudaSuccess && device >= 0) { - cudaGetDeviceProperties(&props, device); - } + constexpr int kMaxDevices = 16; + static std::once_flag flags[kMaxDevices]; + static cudaDeviceProp props[kMaxDevices]; + + int device = 0; + if (cudaGetDevice(&device) != cudaSuccess || device < 0 || device >= kMaxDevices) { + return 0; + } + std::call_once(flags[device], [device] { + cudaGetDeviceProperties(&props[device], device); }); - return props.multiProcessorCount; + return props[device].multiProcessorCount; } using DtypeA = cutlass::bfloat16_t; From f72b28592eba7d5deb7351b5355bf2ebaa25a5c9 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 20 Jan 2026 01:37:23 -0800 Subject: [PATCH 103/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 36 ++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index b53eeb33e..5fec75da6 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -27,19 +28,38 @@ namespace ffi = xla::ffi; -static int get_sm_count() { - constexpr int kMaxDevices = 16; - static std::once_flag flags[kMaxDevices]; - static cudaDeviceProp props[kMaxDevices]; +namespace { + +int g_num_gpus = -1; +std::deque g_device_flags; +std::vector g_device_props; + +void initDeviceVectors() { + static bool init [[maybe_unused]] = [] { + if (cudaGetDeviceCount(&g_num_gpus) != cudaSuccess || g_num_gpus <= 0) { + g_num_gpus = 0; + } + g_device_flags.resize(g_num_gpus); + g_device_props.resize(g_num_gpus); + return true; + }(); +} +} // namespace + +static int get_sm_count() { int device = 0; - if (cudaGetDevice(&device) != cudaSuccess || device < 0 || device >= kMaxDevices) { + if (cudaGetDevice(&device) != cudaSuccess || device < 0) { + return 0; + } + initDeviceVectors(); + if (device >= g_num_gpus) { return 0; } - std::call_once(flags[device], [device] { - cudaGetDeviceProperties(&props[device], device); + std::call_once(g_device_flags[device], [device] { + cudaGetDeviceProperties(&g_device_props[device], device); }); - return props[device].multiProcessorCount; + return g_device_props[device].multiProcessorCount; } using DtypeA = cutlass::bfloat16_t; From 2ef0d7e8099f1994c98390c8fec13e94659c6266 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 20 Jan 2026 02:19:41 -0800 Subject: [PATCH 104/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 41 +++++++++---------------------- 1 file changed, 11 insertions(+), 30 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 5fec75da6..caba7ad73 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -1,8 +1,7 @@ #include #include -#include -#include +#include #include #include @@ -28,38 +27,20 @@ namespace ffi = xla::ffi; -namespace { - -int g_num_gpus = -1; -std::deque g_device_flags; -std::vector g_device_props; - -void initDeviceVectors() { - static bool init [[maybe_unused]] = [] { - if (cudaGetDeviceCount(&g_num_gpus) != cudaSuccess || g_num_gpus <= 0) { - g_num_gpus = 0; +static int get_sm_count() { + static std::vector device_props = [] { + int num_gpus = 0; + assert(cudaGetDeviceCount(&num_gpus) == cudaSuccess && num_gpus > 0); + std::vector props(num_gpus); + for (int i = 0; i < num_gpus; ++i) { + cudaGetDeviceProperties(&props[i], i); } - g_device_flags.resize(g_num_gpus); - g_device_props.resize(g_num_gpus); - return true; + return props; }(); -} -} // namespace - -static int get_sm_count() { int device = 0; - if (cudaGetDevice(&device) != cudaSuccess || device < 0) { - return 0; - } - initDeviceVectors(); - if (device >= g_num_gpus) { - return 0; - } - std::call_once(g_device_flags[device], [device] { - cudaGetDeviceProperties(&g_device_props[device], device); - }); - return g_device_props[device].multiProcessorCount; + assert(cudaGetDevice(&device) == cudaSuccess && device >= 0 && device < device_props.size()); + return device_props[device].multiProcessorCount; } using DtypeA = cutlass::bfloat16_t; From 59be86f35a476a482e2969d28c08c2abbef98277 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 20 Jan 2026 02:30:26 -0800 Subject: [PATCH 105/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index caba7ad73..641d779df 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -1,7 +1,6 @@ #include #include -#include #include #include @@ -27,10 +26,12 @@ namespace ffi = xla::ffi; -static int get_sm_count() { +static std::optional get_sm_count() { static std::vector device_props = [] { int num_gpus = 0; - assert(cudaGetDeviceCount(&num_gpus) == cudaSuccess && num_gpus > 0); + if (cudaGetDeviceCount(&num_gpus) != cudaSuccess || num_gpus <= 0) { + return std::vector{}; + } std::vector props(num_gpus); for (int i = 0; i < num_gpus; ++i) { cudaGetDeviceProperties(&props[i], i); @@ -39,7 +40,9 @@ static int get_sm_count() { }(); int device = 0; - assert(cudaGetDevice(&device) == cudaSuccess && device >= 0 && device < device_props.size()); + if (cudaGetDevice(&device) != cudaSuccess || device < 0 || device >= device_props.size()) { + return {}; + } return device_props[device].multiProcessorCount; } @@ -204,7 +207,9 @@ ffi::Error ExecuteGroupedGemm( GemmT gemm; auto args = data->MakeArgs(g_local); - args.hw_info.sm_count = get_sm_count(); + auto sm_count = get_sm_count(); + if (!sm_count) return ffi::Error::Internal("Failed to get SM count."); + args.hw_info.sm_count = *sm_count; if (gemm.can_implement(args) != cutlass::Status::kSuccess) { return ffi::Error::Internal("cutlass cannot implement grouped gemm."); From ce1eba9b136ce6af0f7be943558a36627caef263 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 20 Jan 2026 17:17:56 -0800 Subject: [PATCH 106/106] update --- skyrl-tx/tx/ffi/ragged_dot_ffi.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu index 641d779df..5cee55423 100644 --- a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -26,6 +26,7 @@ namespace ffi = xla::ffi; +// Cache device properties because cudaGetDeviceProperties is slow. static std::optional get_sm_count() { static std::vector device_props = [] { int num_gpus = 0;