Skip to content

Conversation

@bartekxk
Copy link
Contributor

Proposed changes

Add Grouped Conv Bwd Weight Direct Load implementation and instances

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

if constexpr(DirectLoad)
{
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<NPerBlock>{}, AK1Number),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that BK0, BK1Number and such are proper, since it's B layout

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, thanks

@aosewski aosewski requested a review from Copilot January 26, 2026 14:47
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for direct load implementation in grouped convolution backward weight operations, introducing a hardware-optimized memory transfer path for gfx950 devices. The implementation uses the ThreadGroupTensorSliceTransfer_DirectLoad mechanism with specific handling for F16 and BF16 data types.

Changes:

  • Added direct load instances for F16 and BF16 grouped conv backward weight operations
  • Introduced DirectLoad and LdsScalarLoad template parameters throughout the pipeline
  • Implemented device-specific validation to restrict direct load to gfx950 architecture

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load.cpp Instantiates F16 direct load device operations for grouped conv backward weight
device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load.cpp Instantiates BF16 direct load device operations for grouped conv backward weight
CMakeLists.txt Registers new F16 and BF16 direct load source files in build system
grouped_convolution_backward_weight_xdl.inc Declares F16 and BF16 direct load instance functions
grouped_convolution_backward_weight.hpp Integrates direct load instances into factory pattern
device_grouped_conv_bwd_weight_v3_xdl_instance.hpp Defines F16 and BF16 direct load instance configurations with true flag
gridwise_gemm_xdl_cshuffle_conv_v3.hpp Adds direct load conditional logic with specialized block descriptors and transfer handling
device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp Implements direct load support with gfx950 validation and group merging logic
thread_group_tensor_slice_transfer_direct_load.hpp Removes destination vector dimension constraint for direct load
blockwise_gemm_pipeline_xdlops_v1.hpp Adds LdsScalarLoad parameter to direct load pipeline
blockwise_gemm_pipeline_xdlops_selector.hpp Adds LdsScalarLoad selection logic with validation
blockwise_gemm_pipeline_xdlops_base.hpp Implements scalar load logic for LDS transfers when enabled

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +87 to +109
#if defined(__gfx950__)
DispatchSplitKHack<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
#endif
}
else
{
DispatchSplitKHack<GridwiseGemm,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference ? Both invocations seem to pass identical parameter set.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to disable this on other archs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants