diff --git a/skyrl-tx/benchmarks/bench_ragged_dot.py b/skyrl-tx/benchmarks/bench_ragged_dot.py new file mode 100644 index 000000000..261a5b33c --- /dev/null +++ b/skyrl-tx/benchmarks/bench_ragged_dot.py @@ -0,0 +1,260 @@ +"""Benchmark ragged_dot CUTLASS kernel with Qwen3-30B-A3B MoE and LoRA shapes.""" + +import argparse +import time + +import jax +import jax.numpy as jnp +from jax import lax + +from tx.ffi import ragged_dot_ffi, ragged_dot_ffi_available + + +# Preset configurations for different workloads +PRESETS = { + "moe": { + "description": "MoE expert layer (Qwen3-30B-A3B)", + "num_tokens": 8192, + "num_groups": 128, # num_experts + "k_dim": 2048, # hidden_size + "n_dim": 768, # intermediate_size + }, + "lora": { + "description": "LoRA adapter layer", + "num_tokens": 8192, + "num_groups": 32, # max_lora_adapters + "k_dim": 8, # lora_rank + "n_dim": 4096, # output features + }, + "lora-moe": { + "description": "LoRA on MoE experts (combined groups)", + "num_tokens": 8192, + "num_groups": 1024, # num_experts * max_lora_adapters (128 * 8, capped at kernel limit) + "k_dim": 8, # lora_rank + "n_dim": 768, # intermediate_size + }, +} + + +def generate_group_sizes(num_tokens: int, num_groups: int, key: jax.Array) -> jax.Array: + """Generate random group sizes that sum to num_tokens.""" + # Random assignment of tokens to groups + assignments = jax.random.randint(key, (num_tokens,), 0, num_groups) + return jnp.bincount(assignments, length=num_groups).astype(jnp.int32) + + +def benchmark_forward( + num_tokens: int, + k_dim: int, + n_dim: int, + num_groups: int, + num_warmup: int = 5, + num_iters: int = 20, + use_ffi: bool = True, +): + """Benchmark forward pass: lhs[M, K] @ rhs[G, K, N] -> out[M, N].""" + key = jax.random.PRNGKey(42) + k1, k2, k3 = jax.random.split(key, 3) + + lhs = jax.random.normal(k1, (num_tokens, k_dim), dtype=jnp.bfloat16) + rhs = jax.random.normal(k2, (num_groups, k_dim, n_dim), dtype=jnp.bfloat16) + group_sizes = generate_group_sizes(num_tokens, num_groups, k3) + group_offset = jnp.array([0], dtype=jnp.int32) + + if use_ffi: + fn = lambda: ragged_dot_ffi(lhs, rhs, group_sizes, group_offset) + else: + fn = lambda: lax.ragged_dot(lhs, rhs, group_sizes) + + # Warmup + for _ in range(num_warmup): + out = fn() + out.block_until_ready() + + # Benchmark + start = time.perf_counter() + for _ in range(num_iters): + out = fn() + out.block_until_ready() + elapsed = time.perf_counter() - start + + # FLOPs: 2 * M * K * N (matmul FLOPs) + flops = 2 * num_tokens * k_dim * n_dim + tflops = (flops * num_iters / elapsed) / 1e12 + + return elapsed / num_iters, tflops + + +def benchmark_backward( + num_tokens: int, + k_dim: int, + n_dim: int, + num_groups: int, + num_warmup: int = 5, + num_iters: int = 20, + use_ffi: bool = True, +): + """Benchmark backward pass through ragged_dot.""" + key = jax.random.PRNGKey(42) + k1, k2, k3 = jax.random.split(key, 3) + + lhs = jax.random.normal(k1, (num_tokens, k_dim), dtype=jnp.bfloat16) + rhs = jax.random.normal(k2, (num_groups, k_dim, n_dim), dtype=jnp.bfloat16) + group_sizes = generate_group_sizes(num_tokens, num_groups, k3) + group_offset = jnp.array([0], dtype=jnp.int32) + + if use_ffi: + def forward(lhs, rhs): + return ragged_dot_ffi(lhs, rhs, group_sizes, group_offset).sum() + else: + def forward(lhs, rhs): + return lax.ragged_dot(lhs, rhs, group_sizes).sum() + + grad_fn = jax.grad(forward, argnums=(0, 1)) + + # Warmup + for _ in range(num_warmup): + d_lhs, d_rhs = grad_fn(lhs, rhs) + d_lhs.block_until_ready() + d_rhs.block_until_ready() + + # Benchmark + start = time.perf_counter() + for _ in range(num_iters): + d_lhs, d_rhs = grad_fn(lhs, rhs) + d_lhs.block_until_ready() + d_rhs.block_until_ready() + elapsed = time.perf_counter() - start + + # Backward FLOPs: d_lhs = grad @ rhs.T (2*M*N*K) + d_rhs = lhs.T @ grad (2*K*M*N) + # Total: 4 * M * K * N + flops = 4 * num_tokens * k_dim * n_dim + tflops = (flops * num_iters / elapsed) / 1e12 + + return elapsed / num_iters, tflops + + +def run_benchmark_suite( + num_tokens: int, + k_dim: int, + n_dim: int, + num_groups: int, + num_warmup: int, + num_iters: int, + run_forward: bool, + run_backward: bool, +): + """Run the benchmark suite with the given configuration.""" + if run_forward: + print("Forward Pass (lhs[M,K] @ rhs[G,K,N] -> out[M,N])") + print("-" * 60) + + if ragged_dot_ffi_available(): + ffi_time, ffi_tflops = benchmark_forward( + num_tokens, k_dim, n_dim, num_groups, num_warmup, num_iters, use_ffi=True + ) + print(f" CUTLASS FFI: {ffi_time*1000:8.3f} ms {ffi_tflops:8.2f} TFLOPS") + + jax_time, jax_tflops = benchmark_forward( + num_tokens, k_dim, n_dim, num_groups, num_warmup, num_iters, use_ffi=False + ) + print(f" JAX ragged: {jax_time*1000:8.3f} ms {jax_tflops:8.2f} TFLOPS") + + if ragged_dot_ffi_available(): + print(f" Speedup: {jax_time/ffi_time:.2f}x") + print() + + if run_backward: + print("Backward Pass (grad wrt lhs and rhs)") + print("-" * 60) + + if ragged_dot_ffi_available(): + ffi_time, ffi_tflops = benchmark_backward( + num_tokens, k_dim, n_dim, num_groups, num_warmup, num_iters, use_ffi=True + ) + print(f" CUTLASS FFI: {ffi_time*1000:8.3f} ms {ffi_tflops:8.2f} TFLOPS") + + jax_time, jax_tflops = benchmark_backward( + num_tokens, k_dim, n_dim, num_groups, num_warmup, num_iters, use_ffi=False + ) + print(f" JAX ragged: {jax_time*1000:8.3f} ms {jax_tflops:8.2f} TFLOPS") + + if ragged_dot_ffi_available(): + print(f" Speedup: {jax_time/ffi_time:.2f}x") + print() + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark ragged_dot CUTLASS kernel") + parser.add_argument("--preset", choices=list(PRESETS.keys()), help="Use a preset configuration (moe, lora, lora-moe)") + parser.add_argument("--all-presets", action="store_true", help="Run all preset configurations") + parser.add_argument("--num-tokens", type=int, default=8192, help="Number of tokens (M)") + parser.add_argument("--num-groups", type=int, default=128, help="Number of groups (G) - experts or adapters") + parser.add_argument("--k-dim", type=int, default=2048, help="K dimension - hidden_size (MoE) or lora_rank (LoRA)") + parser.add_argument("--n-dim", type=int, default=768, help="N dimension - output features") + parser.add_argument("--num-warmup", type=int, default=5, help="Warmup iterations") + parser.add_argument("--num-iters", type=int, default=20, help="Benchmark iterations") + parser.add_argument("--backward-only", action="store_true", help="Only benchmark backward pass") + parser.add_argument("--forward-only", action="store_true", help="Only benchmark forward pass") + args = parser.parse_args() + + print("Ragged Dot Benchmark") + print("=" * 60) + print(f"CUTLASS FFI available: {ragged_dot_ffi_available()}") + print(f"JAX backend: {jax.default_backend()}") + print(f"Devices: {jax.device_count()}") + print() + + run_forward = not args.backward_only + run_backward = not args.forward_only + + if args.all_presets: + # Run all presets + for preset_name, preset in PRESETS.items(): + print("=" * 60) + print(f"Preset: {preset_name} - {preset['description']}") + print("=" * 60) + print(f"Config:") + print(f" num_tokens (M): {preset['num_tokens']}") + print(f" num_groups (G): {preset['num_groups']}") + print(f" k_dim (K): {preset['k_dim']}") + print(f" n_dim (N): {preset['n_dim']}") + print(f" warmup/iters: {args.num_warmup}/{args.num_iters}") + print() + run_benchmark_suite( + preset['num_tokens'], preset['k_dim'], preset['n_dim'], preset['num_groups'], + args.num_warmup, args.num_iters, run_forward, run_backward + ) + elif args.preset: + # Use a specific preset + preset = PRESETS[args.preset] + print(f"Preset: {args.preset} - {preset['description']}") + print() + print(f"Config:") + print(f" num_tokens (M): {preset['num_tokens']}") + print(f" num_groups (G): {preset['num_groups']}") + print(f" k_dim (K): {preset['k_dim']}") + print(f" n_dim (N): {preset['n_dim']}") + print(f" warmup/iters: {args.num_warmup}/{args.num_iters}") + print() + run_benchmark_suite( + preset['num_tokens'], preset['k_dim'], preset['n_dim'], preset['num_groups'], + args.num_warmup, args.num_iters, run_forward, run_backward + ) + else: + # Use custom config from args + print(f"Config:") + print(f" num_tokens (M): {args.num_tokens}") + print(f" num_groups (G): {args.num_groups}") + print(f" k_dim (K): {args.k_dim}") + print(f" n_dim (N): {args.n_dim}") + print(f" warmup/iters: {args.num_warmup}/{args.num_iters}") + print() + run_benchmark_suite( + args.num_tokens, args.k_dim, args.n_dim, args.num_groups, + args.num_warmup, args.num_iters, run_forward, run_backward + ) + + +if __name__ == "__main__": + main() diff --git a/skyrl-tx/benchmarks/sweep_tile_sizes.py b/skyrl-tx/benchmarks/sweep_tile_sizes.py new file mode 100644 index 000000000..62114029e --- /dev/null +++ b/skyrl-tx/benchmarks/sweep_tile_sizes.py @@ -0,0 +1,280 @@ +"""Sweep tile sizes and kernel parameters for ragged_dot CUTLASS kernel optimization.""" + +import subprocess +import sys +import re +from pathlib import Path +from itertools import product + +CUDA_FILE = Path(__file__).parent.parent / "tx/ffi/ragged_dot_ffi.cu" +SO_FILE = Path(__file__).parent.parent / "tx/ffi/libragged_dot_ffi.so" + +# Tile configurations to test: (M, N, K) +# Constraints: dimensions should be powers of 2 or multiples that work with SM90 +# Note: Cooperative schedule requires M >= 128 +TILE_CONFIGS = [ + (64, 256, 64), # best from previous sweep + (128, 128, 64), + (64, 128, 64), + (128, 256, 64), + (128, 128, 128), +] + +# Smaller tile configs that may work better for LoRA (small K dimension) +TILE_CONFIGS_LORA = [ + (128, 128, 32), # smaller K tile + (128, 256, 32), + (128, 64, 32), + (64, 128, 32), + (64, 64, 32), + (128, 128, 64), # reference +] + +# Cluster shapes to test: (M, N, K) +CLUSTER_CONFIGS = [ + (1, 1, 1), # current + (2, 1, 1), + (1, 2, 1), + (2, 2, 1), +] + +# Kernel schedules to test +SCHEDULE_CONFIGS = [ + ("KernelPtrArrayTmaWarpSpecializedPingpong", "PtrArrayTmaWarpSpecializedPingpong"), # current + ("KernelPtrArrayTmaWarpSpecializedCooperative", "PtrArrayTmaWarpSpecializedCooperative"), +] + +# Workload presets for benchmarking +WORKLOAD_PRESETS = { + "moe": { + "description": "MoE expert layer (Qwen3-30B-A3B)", + "args": ["--preset", "moe"], + }, + "lora": { + "description": "LoRA adapter layer (rank=8)", + "args": ["--preset", "lora"], + }, + "lora-moe": { + "description": "LoRA on MoE experts", + "args": ["--preset", "lora-moe"], + }, +} + + +def set_tile_shape(m: int, n: int, k: int) -> None: + """Update the TileShape in the CUDA file.""" + content = CUDA_FILE.read_text() + + # Replace the TileShape definition + pattern = r'using TileShape = cute::Shape;' + replacement = f'using TileShape = cute::Shape;' + + new_content = re.sub(pattern, replacement, content) + CUDA_FILE.write_text(new_content) + + +def set_cluster_shape(m: int, n: int, k: int) -> None: + """Update the ClusterShape in the CUDA file.""" + content = CUDA_FILE.read_text() + + pattern = r'using ClusterShape = cute::Shape;' + replacement = f'using ClusterShape = cute::Shape;' + + new_content = re.sub(pattern, replacement, content) + CUDA_FILE.write_text(new_content) + + +def set_schedule(kernel_schedule: str, epilogue_schedule: str) -> None: + """Update the KernelSchedule and EpilogueSchedule in the CUDA file.""" + content = CUDA_FILE.read_text() + + # Replace KernelSchedule + pattern = r'using KernelSchedule = cutlass::gemm::\w+;' + replacement = f'using KernelSchedule = cutlass::gemm::{kernel_schedule};' + content = re.sub(pattern, replacement, content) + + # Replace EpilogueSchedule + pattern = r'using EpilogueSchedule = cutlass::epilogue::\w+;' + replacement = f'using EpilogueSchedule = cutlass::epilogue::{epilogue_schedule};' + content = re.sub(pattern, replacement, content) + + CUDA_FILE.write_text(content) + + +def set_kernel_config(tile: tuple, cluster: tuple, schedule: tuple) -> str: + """Set all kernel configuration parameters. Returns a config description string.""" + set_tile_shape(*tile) + set_cluster_shape(*cluster) + set_schedule(*schedule) + + schedule_name = schedule[0].replace("KernelPtrArrayTmaWarpSpecialized", "") + return f"Tile{tile} Cluster{cluster} {schedule_name}" + + +def rebuild_kernel() -> bool: + """Rebuild the CUTLASS kernel.""" + # Remove old .so to force rebuild + if SO_FILE.exists(): + SO_FILE.unlink() + + # Rebuild using uv with hatchling + result = subprocess.run( + ["uv", "run", "--with", "hatchling", "python", "-c", + "from tx.ffi.build import build_ragged_dot; build_ragged_dot()"], + capture_output=True, + text=True, + cwd=CUDA_FILE.parent.parent.parent, + ) + + if result.returncode != 0: + return False + + return SO_FILE.exists() + + +def run_benchmark(workload: str = "moe", num_tokens: int = 8192) -> dict | None: + """Run the benchmark and parse results.""" + # Build command with workload preset or custom args + cmd = [sys.executable, "benchmarks/bench_ragged_dot.py", "--num-warmup", "3", "--num-iters", "10"] + + if workload in WORKLOAD_PRESETS: + cmd.extend(WORKLOAD_PRESETS[workload]["args"]) + else: + # Legacy: custom num_tokens only + cmd.extend(["--num-tokens", str(num_tokens)]) + + # Need to run in a fresh process to reload the .so + result = subprocess.run( + cmd, + capture_output=True, + text=True, + cwd=CUDA_FILE.parent.parent.parent, + ) + + output = result.stdout + if result.returncode != 0: + print(f"Benchmark failed: {result.stderr}") + return None + + # Parse CUTLASS FFI results + results = {} + + # Forward pass + match = re.search(r'CUTLASS FFI:\s+([\d.]+)\s+ms\s+([\d.]+)\s+TFLOPS', output) + if match: + results['fwd_ms'] = float(match.group(1)) + results['fwd_tflops'] = float(match.group(2)) + + # Find backward pass (second occurrence) + matches = list(re.finditer(r'CUTLASS FFI:\s+([\d.]+)\s+ms\s+([\d.]+)\s+TFLOPS', output)) + if len(matches) >= 2: + results['bwd_ms'] = float(matches[1].group(1)) + results['bwd_tflops'] = float(matches[1].group(2)) + + return results if results else None + + +def main(): + import argparse + parser = argparse.ArgumentParser(description="Sweep kernel parameters for CUTLASS ragged_dot") + parser.add_argument("--workload", choices=list(WORKLOAD_PRESETS.keys()), default="moe", + help="Workload preset to benchmark (moe, lora, lora-moe)") + parser.add_argument("--all-workloads", action="store_true", help="Sweep all workloads") + parser.add_argument("--sweep-tiles", action="store_true", help="Sweep tile shapes") + parser.add_argument("--sweep-clusters", action="store_true", help="Sweep cluster shapes") + parser.add_argument("--sweep-schedules", action="store_true", help="Sweep kernel schedules") + parser.add_argument("--sweep-all", action="store_true", help="Sweep all parameters") + parser.add_argument("--lora-tiles", action="store_true", help="Use LoRA-optimized tile configs (smaller K)") + parser.add_argument("--dry-run", action="store_true", help="Only show configs, don't run") + args = parser.parse_args() + + # Default to sweeping tiles only if nothing specified + if not (args.sweep_tiles or args.sweep_clusters or args.sweep_schedules or args.sweep_all): + args.sweep_tiles = True + + if args.sweep_all: + args.sweep_tiles = args.sweep_clusters = args.sweep_schedules = True + + print("CUTLASS Kernel Parameter Sweep") + print("=" * 80) + + workloads = list(WORKLOAD_PRESETS.keys()) if args.all_workloads else [args.workload] + + # Build parameter combinations + base_tiles = TILE_CONFIGS_LORA if args.lora_tiles else TILE_CONFIGS + tiles = base_tiles if args.sweep_tiles else [base_tiles[0]] + clusters = CLUSTER_CONFIGS if args.sweep_clusters else [CLUSTER_CONFIGS[0]] + schedules = SCHEDULE_CONFIGS if args.sweep_schedules else [SCHEDULE_CONFIGS[0]] + + configs = list(product(tiles, clusters, schedules)) + print(f"Testing {len(configs)} configurations across {len(workloads)} workload(s)") + print() + + if args.dry_run: + print("Configurations to test:") + for tile, cluster, schedule in configs: + schedule_name = schedule[0].replace("KernelPtrArrayTmaWarpSpecialized", "") + print(f" Tile{tile} Cluster{cluster} {schedule_name}") + print() + print("Workloads to test:") + for w in workloads: + print(f" {w}: {WORKLOAD_PRESETS[w]['description']}") + return + + # Store results: {workload: [(config_str, bench_results), ...]} + all_results: dict[str, list] = {w: [] for w in workloads} + + for tile, cluster, schedule in configs: + print(f"\n{'='*80}") + config_str = set_kernel_config(tile, cluster, schedule) + print(f"Testing: {config_str}") + print("-" * 80) + + if not rebuild_kernel(): + print(" FAILED to build") + for w in workloads: + all_results[w].append((config_str, None)) + continue + + for workload in workloads: + print(f"\n Workload: {workload} ({WORKLOAD_PRESETS[workload]['description']})") + bench_results = run_benchmark(workload) + if bench_results: + print(f" Forward: {bench_results.get('fwd_ms', 'N/A'):>8.3f} ms {bench_results.get('fwd_tflops', 'N/A'):>8.2f} TFLOPS") + print(f" Backward: {bench_results.get('bwd_ms', 'N/A'):>8.3f} ms {bench_results.get('bwd_tflops', 'N/A'):>8.2f} TFLOPS") + all_results[workload].append((config_str, bench_results)) + else: + print(" FAILED to benchmark") + all_results[workload].append((config_str, None)) + + # Summary per workload + for workload in workloads: + results = all_results[workload] + print(f"\n{'='*80}") + print(f"SUMMARY: {workload} ({WORKLOAD_PRESETS[workload]['description']})") + print("=" * 80) + print(f"{'Configuration':<50} {'Fwd (ms)':>10} {'Fwd TF':>8} {'Bwd (ms)':>10} {'Bwd TF':>8}") + print("-" * 80) + + for config_str, res in results: + if res: + fwd_ms = f"{res.get('fwd_ms', 0):.3f}" + fwd_tf = f"{res.get('fwd_tflops', 0):.1f}" + bwd_ms = f"{res.get('bwd_ms', 0):.3f}" + bwd_tf = f"{res.get('bwd_tflops', 0):.1f}" + print(f"{config_str:<50} {fwd_ms:>10} {fwd_tf:>8} {bwd_ms:>10} {bwd_tf:>8}") + else: + print(f"{config_str:<50} {'FAILED':>10}") + + # Find best + valid_results = [(cfg, r) for cfg, r in results if r and 'fwd_tflops' in r] + if valid_results: + best_fwd = max(valid_results, key=lambda x: x[1]['fwd_tflops']) + best_bwd = max(valid_results, key=lambda x: x[1].get('bwd_tflops', 0)) + print() + print(f"Best forward: {best_fwd[0]} - {best_fwd[1]['fwd_tflops']:.2f} TFLOPS") + print(f"Best backward: {best_bwd[0]} - {best_bwd[1].get('bwd_tflops', 0):.2f} TFLOPS") + + +if __name__ == "__main__": + main() diff --git a/skyrl-tx/pyproject.toml b/skyrl-tx/pyproject.toml index 587d9c19e..8b48edd25 100644 --- a/skyrl-tx/pyproject.toml +++ b/skyrl-tx/pyproject.toml @@ -1,6 +1,6 @@ [build-system] -requires = ["setuptools"] -build-backend = "setuptools.build_meta" +requires = ["hatchling", "jaxlib"] +build-backend = "hatchling.build" [project] name = "skyrl-tx" @@ -83,11 +83,14 @@ dev = [ "alembic", ] -[tool.setuptools] -include-package-data = true +[tool.hatch.version] +path = "tx/__init__.py" -[tool.setuptools.dynamic] -version = {attr = "tx.__version__"} +[tool.hatch.build.targets.wheel] +packages = ["tx"] + +[tool.hatch.build.hooks.custom] +path = "tx/ffi/build.py" [project.scripts] tx = "tx.run.main:app" diff --git a/skyrl-tx/tx/ffi/README.md b/skyrl-tx/tx/ffi/README.md new file mode 100644 index 000000000..c19665d12 --- /dev/null +++ b/skyrl-tx/tx/ffi/README.md @@ -0,0 +1,20 @@ +# CUDA FFI Extensions + +## Building + +The CUDA extension is built automatically when creating a wheel: + +```bash +uv build +``` + +### Environment Variables + +- `CUTLASS_DIR` - Path to CUTLASS checkout (optional, clones automatically if not set) +- `NVCC_BIN` - Path to nvcc (default: `nvcc`) +- `NVCC_ARCH` - CUDA architecture (default: `90a` for H100) + +### Notes + +- Requires CUDA nvcc with C++17 support +- The FFI kernel expects bfloat16 inputs/outputs and int32 group metadata diff --git a/skyrl-tx/tx/ffi/__init__.py b/skyrl-tx/tx/ffi/__init__.py new file mode 100644 index 000000000..3d2794504 --- /dev/null +++ b/skyrl-tx/tx/ffi/__init__.py @@ -0,0 +1,4 @@ +from tx.ffi.ragged_dot_ffi import is_available as ragged_dot_ffi_available +from tx.ffi.ragged_dot_ffi import ragged_dot as ragged_dot_ffi + +__all__ = ["ragged_dot_ffi", "ragged_dot_ffi_available"] diff --git a/skyrl-tx/tx/ffi/build.py b/skyrl-tx/tx/ffi/build.py new file mode 100644 index 000000000..54794f2a0 --- /dev/null +++ b/skyrl-tx/tx/ffi/build.py @@ -0,0 +1,75 @@ +import os +import subprocess +import sys +import tempfile +from pathlib import Path + +CUTLASS_REPO = "https://github.com/NVIDIA/cutlass.git" +CUTLASS_TAG = "v4.3.5" + + +def get_cutlass_dir(tmpdir): + if cutlass_dir := os.environ.get("CUTLASS_DIR"): + return Path(cutlass_dir) + + cutlass_dir = Path(tmpdir) / "cutlass" + print(f"Cloning CUTLASS {CUTLASS_TAG}...") + subprocess.run( + ["git", "clone", "--depth=1", f"--branch={CUTLASS_TAG}", CUTLASS_REPO, str(cutlass_dir)], + check=True, + ) + return cutlass_dir + + +def build_ragged_dot(): + try: + import jaxlib + jax_include_dir = Path(jaxlib.__file__).parent / "include" + except ImportError: + print("jaxlib not installed, skipping ragged_dot_ffi build", file=sys.stderr) + return + + nvcc_bin = os.environ.get("NVCC_BIN", "nvcc") + nvcc_arch = os.environ.get("NVCC_ARCH", "90a") + + try: + subprocess.run([nvcc_bin, "--version"], check=True, capture_output=True) + except (subprocess.CalledProcessError, FileNotFoundError): + print(f"nvcc not found at {nvcc_bin}, skipping ragged_dot_ffi build", file=sys.stderr) + return + + ffi_dir = Path(__file__).parent + source_file = ffi_dir / "ragged_dot_ffi.cu" + output_file = ffi_dir / "libragged_dot_ffi.so" + + with tempfile.TemporaryDirectory() as tmpdir: + cutlass_dir = get_cutlass_dir(tmpdir) + + cmd = [ + nvcc_bin, "-O3", "-std=c++17", f"-arch=sm_{nvcc_arch}", + "--expt-relaxed-constexpr", + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", + "-DCUTLASS_ENABLE_GDC_FOR_SM90=1", + "-shared", "-Xcompiler", "-fPIC", + f"-I{jax_include_dir}", + f"-I{cutlass_dir}/include", + f"-I{cutlass_dir}/tools/util/include/", + str(source_file), "-o", str(output_file), + "-lcuda", + ] + + print(f"Building {output_file}...") + subprocess.run(cmd, check=True) + print(f"Built {output_file}") + + +from hatchling.builders.hooks.plugin.interface import BuildHookInterface + + +class CudaBuildHook(BuildHookInterface): + PLUGIN_NAME = "cuda_build" + + def initialize(self, version, build_data): + if self.target_name in ("wheel", "editable"): + build_ragged_dot() diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.cu b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu new file mode 100644 index 000000000..5cee55423 --- /dev/null +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.cu @@ -0,0 +1,331 @@ +#include +#include + +#include +#include + +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#if !defined(CUTLASS_MAJOR) || CUTLASS_MAJOR < 3 +#error "This kernel requires CUTLASS >= 3.x (SM90 grouped GEMM)." +#endif + +namespace ffi = xla::ffi; + +// Cache device properties because cudaGetDeviceProperties is slow. +static std::optional get_sm_count() { + static std::vector device_props = [] { + int num_gpus = 0; + if (cudaGetDeviceCount(&num_gpus) != cudaSuccess || num_gpus <= 0) { + return std::vector{}; + } + std::vector props(num_gpus); + for (int i = 0; i < num_gpus; ++i) { + cudaGetDeviceProperties(&props[i], i); + } + return props; + }(); + + int device = 0; + if (cudaGetDevice(&device) != cudaSuccess || device < 0 || device >= device_props.size()) { + return {}; + } + return device_props[device].multiProcessorCount; +} + +using DtypeA = cutlass::bfloat16_t; +using DtypeB = cutlass::bfloat16_t; +using DtypeOutput = cutlass::bfloat16_t; +using DtypeAccum = float; +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::RowMajor; +using LayoutOutput = cutlass::layout::RowMajor; +constexpr int AlignmentA = 8; +constexpr int AlignmentB = 8; +constexpr int AlignmentOutput = 8; + +using ArchTag = cutlass::arch::Sm90; +using OperatorClass = cutlass::arch::OpClassTensorOp; +using TileShape = cute::Shape; +using ClusterShape = cute::Shape; +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; +using ProblemShape = cutlass::gemm::GroupProblemShape>; + +using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + DtypeAccum, DtypeAccum, void, LayoutOutput*, AlignmentOutput, + DtypeOutput, LayoutOutput*, AlignmentOutput, EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + +using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, DtypeA, LayoutA*, AlignmentA, + DtypeB, LayoutB*, AlignmentB, DtypeAccum, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal; +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +using ProblemShapeType = ProblemShape::UnderlyingProblemShape; + +// Backward kernel for d_rhs: computes lhs.T @ grad per group +// Uses ColumnMajor for A to interpret row-major lhs as transposed +using LayoutA_Bwd = cutlass::layout::ColumnMajor; + +using CollectiveMainloop_Bwd = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, DtypeA, LayoutA_Bwd*, AlignmentA, + DtypeB, LayoutB*, AlignmentB, DtypeAccum, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + +using GemmKernel_Bwd = cutlass::gemm::kernel::GemmUniversal; +using Gemm_Bwd = cutlass::gemm::device::GemmUniversalAdapter; + +// 128-byte alignment for TMA requirements on SM90 +constexpr size_t kTmaAlignment = 128; + +template +static T* carve_aligned(char*& p, size_t count) { + T* out = reinterpret_cast((reinterpret_cast(p) + kTmaAlignment - 1) & ~uintptr_t(kTmaAlignment - 1)); + p = reinterpret_cast(out + count); + return out; +} + +template +struct GroupedGemmData { + using StrideA = typename GemmT::GemmKernel::InternalStrideA; + using StrideB = typename GemmT::GemmKernel::InternalStrideB; + using StrideOutput = typename GemmT::GemmKernel::InternalStrideD; + + const DtypeA** A_ptrs; + const DtypeB** B_ptrs; + DtypeOutput** out_ptrs; + StrideA* stride_A; + StrideB* stride_B; + StrideOutput* stride_output; + ProblemShapeType* problem_sizes; + + static std::optional Allocate(ffi::ScratchAllocator& scratch, size_t g) { + size_t bytes = 7 * kTmaAlignment + + sizeof(const DtypeA*) * g + sizeof(const DtypeB*) * g + sizeof(DtypeOutput*) * g + + sizeof(StrideA) * g + sizeof(StrideB) * g + sizeof(StrideOutput) * g + + sizeof(ProblemShapeType) * g; + auto slab_or = scratch.Allocate(bytes); + if (!slab_or.has_value()) { + return std::nullopt; + } + GroupedGemmData data; + char* p = reinterpret_cast(slab_or.value()); + data.A_ptrs = carve_aligned(p, g); + data.B_ptrs = carve_aligned(p, g); + data.out_ptrs = carve_aligned(p, g); + data.stride_A = carve_aligned(p, g); + data.stride_B = carve_aligned(p, g); + data.stride_output = carve_aligned(p, g); + data.problem_sizes = carve_aligned(p, g); + return data; + } + + typename GemmT::Arguments MakeArgs(int32_t g_local) const { + typename GemmT::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {g_local, problem_sizes, nullptr}, + {A_ptrs, stride_A, B_ptrs, stride_B}, + {{}, nullptr, stride_output, out_ptrs, stride_output}}; + args.epilogue.thread.alpha = 1.0f; + args.epilogue.thread.dAlpha = {cute::_0{}, cute::_0{}, 0}; + return args; + } +}; + +enum class GemmDir { Fwd, Bwd }; + +// Unified prepare kernel for forward and backward passes +// Fwd: A[M,K] @ B[G,K,N] -> out[M,N] (ragged by group) +// Bwd: A.T[K,M] @ B[M,N] -> out[G,K,N] (lhs transposed via ColumnMajor layout) +// group_offsets_cumsum has length G+1 with cumsum[0]=0 (include_initial=True) +template +__global__ void prepare_grouped_gemm( + const DtypeA* A, const DtypeB* B, DtypeOutput* out, + const int32_t* group_offsets_cumsum, const int32_t* group_offset_ptr, + int32_t k, int32_t n, GroupedGemmData data) { + using Data = GroupedGemmData; + int32_t tid = threadIdx.x; + int32_t global = group_offset_ptr[0] + tid; + int32_t start = group_offsets_cumsum[global]; + int32_t m = group_offsets_cumsum[global + 1] - start; + + data.A_ptrs[tid] = A + static_cast(start) * k; + if constexpr (Dir == GemmDir::Fwd) { + data.B_ptrs[tid] = B + static_cast(tid) * n * k; + data.out_ptrs[tid] = out + static_cast(start) * n; + data.problem_sizes[tid] = ProblemShapeType(m, n, k); + data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {k, k, 1}); + data.stride_output[tid] = cutlass::make_cute_packed_stride(typename Data::StrideOutput{}, {n, n, 1}); + } else { + data.B_ptrs[tid] = B + static_cast(start) * n; + data.out_ptrs[tid] = out + static_cast(tid) * k * n; + data.problem_sizes[tid] = ProblemShapeType(k, n, m); + data.stride_A[tid] = cutlass::make_cute_packed_stride(typename Data::StrideA{}, {k, k, 1}); + data.stride_output[tid] = cutlass::make_cute_packed_stride(typename Data::StrideOutput{}, {n, n, 1}); + } + data.stride_B[tid] = cutlass::make_cute_packed_stride(typename Data::StrideB{}, {n, n, 1}); +} + +template +ffi::Error ExecuteGroupedGemm( + cudaStream_t stream, ffi::ScratchAllocator& scratch, + const DtypeA* A, const DtypeB* B, DtypeOutput* out, + const int32_t* group_offsets_cumsum, const int32_t* group_offset, + int32_t g_local, int32_t k, int32_t n) { + if (g_local > 1024) return ffi::Error::InvalidArgument("group_count must be <= 1024."); + + auto data = GroupedGemmData::Allocate(scratch, g_local); + if (!data) return ffi::Error::Internal("Failed to allocate grouped GEMM slab."); + + prepare_grouped_gemm<<<1, g_local, 0, stream>>>( + A, B, out, group_offsets_cumsum, group_offset, k, n, *data); + + GemmT gemm; + auto args = data->MakeArgs(g_local); + auto sm_count = get_sm_count(); + if (!sm_count) return ffi::Error::Internal("Failed to get SM count."); + args.hw_info.sm_count = *sm_count; + + if (gemm.can_implement(args) != cutlass::Status::kSuccess) { + return ffi::Error::Internal("cutlass cannot implement grouped gemm."); + } + + void* workspace = nullptr; + if (size_t workspace_size = GemmT::get_workspace_size(args)) { + auto workspace_or = scratch.Allocate(workspace_size); + if (!workspace_or.has_value()) { + return ffi::Error::Internal("Failed to allocate CUTLASS workspace."); + } + workspace = workspace_or.value(); + } + + if (gemm(args, workspace, stream) != cutlass::Status::kSuccess) { + return ffi::Error::Internal("cutlass grouped gemm failed."); + } + + return ffi::Error::Success(); +} + +ffi::Error RaggedDotCudaImpl( + cudaStream_t stream, + ffi::ScratchAllocator scratch, + ffi::Buffer lhs, + ffi::Buffer rhs, + ffi::Buffer group_offset, + ffi::Buffer group_offsets_cumsum, + ffi::ResultBuffer out) { + auto lhs_dims = lhs.dimensions(); + auto rhs_dims = rhs.dimensions(); + + if (lhs_dims.size() != 2 || rhs_dims.size() != 3 || group_offset.dimensions().size() != 1) { + return ffi::Error::InvalidArgument("Unexpected ragged_dot dimensions."); + } + if (lhs_dims[1] != rhs_dims[1]) { + return ffi::Error::InvalidArgument("lhs/rhs K dimension mismatch."); + } + + int32_t m = static_cast(lhs_dims[0]); + int32_t k = static_cast(lhs_dims[1]); + int32_t g_local = static_cast(rhs_dims[0]); + int32_t n = static_cast(rhs_dims[2]); + + if (cudaMemsetAsync(out->typed_data(), 0, static_cast(m) * n * sizeof(DtypeOutput), stream) != cudaSuccess) { + return ffi::Error::Internal("Failed to zero output."); + } + + return ExecuteGroupedGemm( + stream, scratch, + reinterpret_cast(lhs.typed_data()), + reinterpret_cast(rhs.typed_data()), + reinterpret_cast(out->typed_data()), + group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + RaggedDotCuda, + RaggedDotCudaImpl, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Arg>() // lhs + .Arg>() // rhs + .Arg>() // group_offset + .Arg>() // group_offsets_cumsum + .Ret>()); // out + +// Backward pass for d_rhs: computes lhs.T @ grad per group -> d_rhs[G, K, N] +ffi::Error RaggedDotBwdCudaImpl( + cudaStream_t stream, + ffi::ScratchAllocator scratch, + ffi::Buffer lhs, + ffi::Buffer grad, + ffi::Buffer group_offset, + ffi::Buffer group_offsets_cumsum, + ffi::ResultBuffer d_rhs) { + auto lhs_dims = lhs.dimensions(); + auto grad_dims = grad.dimensions(); + auto d_rhs_dims = d_rhs->dimensions(); + + if (lhs_dims.size() != 2 || grad_dims.size() != 2 || d_rhs_dims.size() != 3) { + return ffi::Error::InvalidArgument("Unexpected ragged_dot_bwd dimensions."); + } + if (lhs_dims[0] != grad_dims[0]) { + return ffi::Error::InvalidArgument("lhs/grad M dimension mismatch."); + } + if (d_rhs_dims[1] != lhs_dims[1] || d_rhs_dims[2] != grad_dims[1]) { + return ffi::Error::InvalidArgument("d_rhs shape must be [G, K, N]."); + } + + int32_t k = static_cast(lhs_dims[1]); + int32_t n = static_cast(grad_dims[1]); + int32_t g_local = static_cast(d_rhs_dims[0]); + + if (cudaMemsetAsync(d_rhs->typed_data(), 0, static_cast(g_local) * k * n * sizeof(DtypeOutput), stream) != cudaSuccess) { + return ffi::Error::Internal("Failed to zero d_rhs output."); + } + + return ExecuteGroupedGemm( + stream, scratch, + reinterpret_cast(lhs.typed_data()), + reinterpret_cast(grad.typed_data()), + reinterpret_cast(d_rhs->typed_data()), + group_offsets_cumsum.typed_data(), group_offset.typed_data(), g_local, k, n); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + RaggedDotBwdCuda, + RaggedDotBwdCudaImpl, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Arg>() // lhs + .Arg>() // grad + .Arg>() // group_offset + .Arg>() // group_offsets_cumsum + .Ret>()); // d_rhs diff --git a/skyrl-tx/tx/ffi/ragged_dot_ffi.py b/skyrl-tx/tx/ffi/ragged_dot_ffi.py new file mode 100644 index 000000000..054a4bb71 --- /dev/null +++ b/skyrl-tx/tx/ffi/ragged_dot_ffi.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import ctypes +import functools +from pathlib import Path + +import jax +import jax.numpy as jnp + +_LIB_PATH = Path(__file__).resolve().parent / "libragged_dot_ffi.so" + + +@functools.lru_cache(maxsize=1) +def _ensure_registered() -> bool: + if not _LIB_PATH.exists(): + return False + try: + lib = ctypes.cdll.LoadLibrary(str(_LIB_PATH)) + jax.ffi.register_ffi_target("ragged_dot_cuda", jax.ffi.pycapsule(lib.RaggedDotCuda), platform="CUDA") + jax.ffi.register_ffi_target("ragged_dot_bwd_cuda", jax.ffi.pycapsule(lib.RaggedDotBwdCuda), platform="CUDA") + return True + except Exception: + return False + + +def is_available() -> bool: + return _ensure_registered() + + +def _ragged_dot_ffi_call( + lhs: jax.Array, + rhs: jax.Array, + group_offset: jax.Array, + group_offsets_cumsum: jax.Array, +) -> jax.Array: + if not _ensure_registered(): + raise RuntimeError("ragged_dot_ffi is not available. Build and load the shared library first.") + + out = jax.ShapeDtypeStruct((lhs.shape[0], rhs.shape[2]), lhs.dtype) + call = jax.ffi.ffi_call("ragged_dot_cuda", out, vmap_method=None) + return call(lhs, rhs, group_offset, group_offsets_cumsum) + + +def _ragged_dot_bwd_ffi_call( + lhs: jax.Array, + grad: jax.Array, + group_offset: jax.Array, + group_offsets_cumsum: jax.Array, + g_local: int, +) -> jax.Array: + """Backward pass for d_rhs: computes lhs.T @ grad per group -> [G, K, N].""" + if not _ensure_registered(): + raise RuntimeError("ragged_dot_ffi is not available. Build and load the shared library first.") + + k = lhs.shape[1] + n = grad.shape[1] + + out = jax.ShapeDtypeStruct((g_local, k, n), lhs.dtype) + call = jax.ffi.ffi_call("ragged_dot_bwd_cuda", out, vmap_method=None) + return call(lhs, grad, group_offset, group_offsets_cumsum) + + +@jax.custom_vjp +def ragged_dot( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, + group_offset: jax.Array, +) -> jax.Array: + cumsum = jnp.cumulative_sum(group_sizes, include_initial=True).astype(jnp.int32) + return _ragged_dot_ffi_call(lhs, rhs, group_offset, cumsum) + + +def _ragged_dot_fwd( + lhs: jax.Array, + rhs: jax.Array, + group_sizes: jax.Array, + group_offset: jax.Array, +): + cumsum = jnp.cumulative_sum(group_sizes, include_initial=True).astype(jnp.int32) + y = _ragged_dot_ffi_call(lhs, rhs, group_offset, cumsum) + return y, (lhs, rhs, group_offset, cumsum) + + +def _ragged_dot_bwd(res, g): + lhs, rhs, group_offset, cumsum = res + g_local = rhs.shape[0] + + # d_lhs: g @ rhs.T with ragged grouping + rhs_t = jnp.swapaxes(rhs, 1, 2) + d_lhs = _ragged_dot_ffi_call(g, rhs_t, group_offset, cumsum) + + # d_rhs: lhs.T @ g accumulated per group + d_rhs = _ragged_dot_bwd_ffi_call(lhs, g, group_offset, cumsum, g_local) + + return d_lhs, d_rhs, None, None + + +ragged_dot.defvjp(_ragged_dot_fwd, _ragged_dot_bwd) diff --git a/skyrl-tx/tx/layers/util.py b/skyrl-tx/tx/layers/util.py index 0030c604d..f489d792a 100644 --- a/skyrl-tx/tx/layers/util.py +++ b/skyrl-tx/tx/layers/util.py @@ -4,6 +4,14 @@ from jax import numpy as jnp from jax.sharding import get_abstract_mesh, PartitionSpec +try: + from tx.ffi import ragged_dot_ffi, ragged_dot_ffi_available +except Exception: # pragma: no cover - optional GPU extension + ragged_dot_ffi = None + + def ragged_dot_ffi_available() -> bool: + return False + def ragged_dot( lhs: jax.Array, @@ -27,6 +35,23 @@ def ragged_dot( preferred_element_type=preferred_element_type, ) + # CUTLASS kernel requires k and n dimensions divisible by 8 + k = lhs.shape[-1] + n = rhs.shape[-1] + cutlass_alignment = 8 + + if ( + ragged_dot_ffi_available() + and jax.default_backend() == "gpu" + and lhs.dtype == jnp.bfloat16 + and rhs.dtype == jnp.bfloat16 + and group_sizes.dtype == jnp.int32 + and group_offset.dtype == jnp.int32 + and k % cutlass_alignment == 0 + and n % cutlass_alignment == 0 + ): + return ragged_dot_ffi(lhs, rhs, group_sizes, group_offset) + assert group_offset.shape == (1,), "group_offset must have shape (1,)" offset = group_offset[0] m = lhs.shape[0] @@ -104,7 +129,13 @@ def shard_map_ep(module: nnx.Module, func, *args): ) in_specs = (state_specs,) + (PartitionSpec(),) * len(args) - @jax.shard_map(mesh=get_abstract_mesh(), in_specs=in_specs, out_specs=PartitionSpec(), axis_names={"ep"}) + @jax.shard_map( + mesh=get_abstract_mesh(), + in_specs=in_specs, + out_specs=PartitionSpec(), + axis_names={"ep"}, + check_vma=False, + ) def _body(state, *fn_args): module_shard = nnx.merge(graphdef, state) return func(module_shard, *fn_args) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index 99ae33327..5fe480400 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -60,7 +60,7 @@ class JaxBackendConfig(BaseModel, extra="forbid"): """Configuration specific to the JAX backend.""" max_lora_adapters: int = Field(default=32, description="Maximum number of LoRA adapters") - max_lora_rank: int = Field(default=32, description="Maximum LoRA rank") + max_lora_rank: int = Field(default=8, description="Maximum LoRA rank") tensor_parallel_size: int = Field(default=1, description="Tensor parallelism degree to use for the model") expert_parallel_size: int = Field(default=1, description="Expert parallelism degree for MoE layers") fully_sharded_data_parallel_size: int = Field( @@ -171,6 +171,12 @@ def __init__(self, base_model: str, config: JaxBackendConfig): gradient_checkpointing=config.gradient_checkpointing, ) + if config.max_lora_rank % 8 != 0: + logger.warning( + f"[bold yellow]max_lora_rank={config.max_lora_rank} is not divisible by 8. " + "This could lead to degraded performance.[/bold yellow]" + ) + model_class = get_model_class(self.model_config) # Create model and load weights