-
Notifications
You must be signed in to change notification settings - Fork 270
Grouped Conv Bwd Weight Direct Load #3648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
| if constexpr(DirectLoad) | ||
| { | ||
| return make_naive_tensor_descriptor( | ||
| make_tuple(AK0Number, Number<NPerBlock>{}, AK1Number), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that BK0, BK1Number and such are proper, since it's B layout
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for direct load implementation in grouped convolution backward weight operations, introducing a hardware-optimized memory transfer path for gfx950 devices. The implementation uses the ThreadGroupTensorSliceTransfer_DirectLoad mechanism with specific handling for F16 and BF16 data types.
Changes:
- Added direct load instances for F16 and BF16 grouped conv backward weight operations
- Introduced
DirectLoadandLdsScalarLoadtemplate parameters throughout the pipeline - Implemented device-specific validation to restrict direct load to gfx950 architecture
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load.cpp | Instantiates F16 direct load device operations for grouped conv backward weight |
| device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load.cpp | Instantiates BF16 direct load device operations for grouped conv backward weight |
| CMakeLists.txt | Registers new F16 and BF16 direct load source files in build system |
| grouped_convolution_backward_weight_xdl.inc | Declares F16 and BF16 direct load instance functions |
| grouped_convolution_backward_weight.hpp | Integrates direct load instances into factory pattern |
| device_grouped_conv_bwd_weight_v3_xdl_instance.hpp | Defines F16 and BF16 direct load instance configurations with true flag |
| gridwise_gemm_xdl_cshuffle_conv_v3.hpp | Adds direct load conditional logic with specialized block descriptors and transfer handling |
| device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | Implements direct load support with gfx950 validation and group merging logic |
| thread_group_tensor_slice_transfer_direct_load.hpp | Removes destination vector dimension constraint for direct load |
| blockwise_gemm_pipeline_xdlops_v1.hpp | Adds LdsScalarLoad parameter to direct load pipeline |
| blockwise_gemm_pipeline_xdlops_selector.hpp | Adds LdsScalarLoad selection logic with validation |
| blockwise_gemm_pipeline_xdlops_base.hpp | Implements scalar load logic for LDS transfers when enabled |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp
Show resolved
Hide resolved
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp
Show resolved
Hide resolved
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
Outdated
Show resolved
Hide resolved
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
Outdated
Show resolved
Hide resolved
| #if defined(__gfx950__) | ||
| DispatchSplitKHack<GridwiseGemm, | ||
| AGridDesc_AK0_M_K1, | ||
| BGridDesc_BK0_N_K1, | ||
| CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, | ||
| HasMainKBlockLoop, | ||
| CGlobalMemoryDataOperation, | ||
| TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a, | ||
| karg.p_b_grid + b_batch_offset + split_k_offset_b, | ||
| karg.p_c_grid + e_batch_offset, | ||
| p_shared, | ||
| karg, | ||
| a_grid_desc_ak0_m_ak1, | ||
| b_grid_desc_bk0_n_bk1, | ||
| c_grid_desc_mblock_mperblock_nblock_nperblock, | ||
| k_idx * num_k_per_block, | ||
| gridDim.y, | ||
| split_k_offset_hack); | ||
| #endif | ||
| } | ||
| else | ||
| { | ||
| DispatchSplitKHack<GridwiseGemm, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the difference ? Both invocations seem to pass identical parameter set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to disable this on other archs
Proposed changes
Add Grouped Conv Bwd Weight Direct Load implementation and instances
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered