Skip to content

Conversation

@AviralGoelAMD
Copy link
Collaborator

This PR implements split-K support for BQuantGrouped mode across all data types (fp8, bf8, fp8i4, bf8i4), with proper handling of packed data types.

Key Changes

1. Split-K Offset Calculation for Packed Types

  • Correctly advances pointers by (KRead / BPackedSize) bytes for int4 types

2. Validation Constraints

  • Added check: KRead % BPackedSize == 0 to ensure split-K batches align on byte boundaries
  • Ensures KRead % QuantGroupSize::kK == 0 for proper scale application

ASCII Visualization: pk_int4_t Split-K Alignment

GEMM with BQuant: C = A × B_quantized × BQ_scales

K dimension (example: K=512, QuantGroupSize::kK=128, split_k=4):
┌─────────────────────────────────────────────────────────────────┐
│        Batch 0       │       Batch 1       │   Batch 2  │ Batch 3│
│    KRead=128         │    KRead=128        │  KRead=128 │KRead=128│
└─────────────────────────────────────────────────────────────────┘
0                    128                   256          384        512

Quantization Groups (kK=128):
┌──────────────────┬──────────────────┬──────────────────┬─────────┐
│   Group 0        │   Group 1        │   Group 2        │ Group 3 │
│   BQ_scale[0]    │   BQ_scale[1]    │   BQ_scale[2]    │BQ_scale[3]│
└──────────────────┴──────────────────┴──────────────────┴─────────┘
0                 128                256                384        512

────────────────────────────────────────────────────────────────────
CONSTRAINT 1: KRead % QuantGroupSize::kK == 0
────────────────────────────────────────────────────────────────────
Each split-K batch applies scales based on *local* K index, not global.
If KRead doesn't align with group boundaries, wrong scales are applied.

✓ VALID (KRead=128, kK=128):
  Batch 0: K[0:128]     → uses local_k ∈ [0,128)   → BQ[0]  ✓
  Batch 1: K[128:256]   → uses local_k ∈ [0,128)   → BQ[1]  ✓

✗ INVALID (KRead=192, kK=128):
  Batch 0: K[0:192]     → uses local_k ∈ [0,192)   → BQ[0], BQ[1]  ✓
  Batch 1: K[192:384]   → uses local_k ∈ [0,192)   → BQ[0], BQ[1]  ✗
                          (should use BQ[1], BQ[2] but pipeline doesn't know!)

────────────────────────────────────────────────────────────────────
CONSTRAINT 2: KRead % BPackedSize == 0  (for pk_int4_t types)
────────────────────────────────────────────────────────────────────
pk_int4_t packs 2 elements per byte. Split-K must advance by whole bytes.

B Matrix bytes:   [byte 0][byte 1][byte 2][byte 3]...
B Elements:       │K0│K1 ││K2│K3 ││K4│K5 ││K6│K7 │...
                  └──┴───┘└──┴───┘└──┴───┘└──┴───┘

✓ VALID (KRead=4, BPackedSize=2):
  Batch 0: K[0:4]   → byte offset = 4/2 = 2 bytes  ✓
  Batch 1: K[4:8]   → byte offset = 4/2 = 2 bytes  ✓

✗ INVALID (KRead=3, BPackedSize=2):
  Batch 0: K[0:3]   → byte offset = 3/2 = 1.5 bytes  ✗ (fractional!)

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements split-K support for BQuantGrouped quantized GEMM operations, enabling parallel processing of the K dimension across multiple workgroups. The implementation handles packed data types (fp8i4, bf8i4) and enforces alignment constraints to ensure correct quantization scale application.

Changes:

  • Adds split-K offset calculation for packed data types with proper byte-boundary alignment
  • Implements BQ (quantization scale) pointer offsetting for split-K batches
  • Adds validation to ensure split-K batch sizes align with quantization groups and packed element boundaries
  • Updates example configurations from Prefill (128x128 tiles) to Decode (16x64 tiles) for better split-K performance

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.

File Description
include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp Core kernel changes: split-K offset calculation, BQ pointer offsetting, tensor view creation, and validation logic
example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc Removes premature split-K rejection, delegates validation to kernel
example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_*.cpp Changes config from Prefill to Decode for fp8, bf8, fp8i4, bf8i4 variants

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +216 to +217
// Split-K validation is handled by Kernel::IsSupportedArgument
// Split-K is only supported for BQuantGrouped without preshuffle
Copy link

Copilot AI Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR adds split-K support for BQuantGrouped mode but does not add any tests that exercise this functionality with k_batch > 1. The existing tests in test/ck_tile/gemm_block_scale only test with k_batch=1 (the default). Tests should be added to verify split-K correctness for various configurations, especially with packed types (pk_int4_t) and different quantization group sizes to ensure the alignment constraints work correctly.

Copilot uses AI. Check for mistakes.
@AviralGoelAMD AviralGoelAMD marked this pull request as draft January 26, 2026 23:26
@AviralGoelAMD AviralGoelAMD self-assigned this Jan 27, 2026
@ThomasNing
Copy link
Contributor

@AviralGoelAMD LGTM, please solve the CI

@AviralGoelAMD AviralGoelAMD marked this pull request as ready for review January 27, 2026 19:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants