From 20b056ded06b7ea0e86cedd27b8c0fd27ed45c89 Mon Sep 17 00:00:00 2001 From: Bartlomiej Kocot Date: Mon, 26 Jan 2026 10:21:39 +0000 Subject: [PATCH 1/8] Grouped Conv Bwd Weight Direct Load --- .../blockwise_gemm_pipeline_xdlops_base.hpp | 7 +- ...lockwise_gemm_pipeline_xdlops_selector.hpp | 11 +- .../blockwise_gemm_pipeline_xdlops_v1.hpp | 18 +- ...roup_tensor_slice_transfer_direct_load.hpp | 4 - ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 37 +- .../gridwise_gemm_xdl_cshuffle_conv_v3.hpp | 372 ++++++++++++------ ...rouped_conv_bwd_weight_v3_xdl_instance.hpp | 46 +++ .../grouped_convolution_backward_weight.hpp | 6 + ...rouped_convolution_backward_weight_xdl.inc | 24 ++ ...xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load.cpp | 40 ++ ..._xdl_nhwgc_gkyxc_nhwgk_f16_direct_load.cpp | 40 ++ 11 files changed, 462 insertions(+), 143 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load.cpp diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index 512a019ec80..3dc5d87630d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -30,7 +30,8 @@ template + bool TransposeC = false, + bool LdsScalarLoad = false> struct BlockwiseGemmXdlops_pipeline_base { static constexpr auto I0 = Number<0>{}; @@ -385,7 +386,7 @@ struct BlockwiseGemmXdlops_pipeline_base Sequence<1, 1, 1, KPack>, Sequence<0, 1, 2, 3>, 3, - A_K1, + LdsScalarLoad ? 1 : A_K1, A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, Sequence<0, 1, 2, 3>, 3, - B_K1, + LdsScalarLoad ? 1 : B_K1, B_K1>; AThreadCopy a_thread_copy_; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp index 7ddb4e9b749..9f797e70814 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp @@ -32,9 +32,15 @@ template + bool DirectLoad = false, + bool LdsScalarLoad = false> constexpr auto BlockGemmPipeline_Selector() { + // Supported for Direct Load and V1 + if constexpr(LdsScalarLoad) + { + static_assert(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1); + } if constexpr(DirectLoad) { if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) @@ -58,7 +64,8 @@ constexpr auto BlockGemmPipeline_Selector() NPerXDL, MRepeat, NRepeat, - KPack>{}; + KPack, + LdsScalarLoad>{}; } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp index 5604a31091e..cfe9fe5e107 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -758,7 +758,8 @@ template + index_t KPacks, + bool LdsScalarLoad = false> struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1 { }; @@ -781,9 +782,9 @@ template + bool LdsScalarLoad> struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1 + KPack, + LdsScalarLoad> : BlockwiseGemmXdlops_pipeline_base + KPack, + false /*TransposeC*/, + LdsScalarLoad> { using Base = BlockwiseGemmXdlops_pipeline_base; + KPack, + false /*TransposeC*/, + LdsScalarLoad>; using Base::I0; using Base::KRepeat; using Base::xdlops_gemm; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp index ad74ee847eb..a31c9101a16 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp @@ -140,10 +140,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad "Direct load transfer does not support datatypes conversion. Source and " "destination data types must be the same."); - static_assert( - DstVectorDim == nDim - 1, - "Direct load transfer requires the destination vector dimension to be the last one."); - static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim, "When loading more than one element per thread at once, the contiguous " "dimension must be the same between source and destination."); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 2121be00d17..cd07a108ef9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -236,7 +236,8 @@ template + typename ComputeTypeB = ComputeTypeA, + bool DirectLoad = false> struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 : public DeviceGroupedConvBwdWeight; using CGridDesc_M_N = remove_cvref_t; + static constexpr index_t ABlockTransferSrcScalarPerVectorAligned = + ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8 + ? 4 / sizeof(ADataType) + : ABlockTransferSrcScalarPerVector; + static constexpr index_t BBlockTransferSrcScalarPerVectorAligned = + BBlockTransferSrcScalarPerVector * sizeof(BDataType) == 8 + ? 4 / sizeof(BDataType) + : BBlockTransferSrcScalarPerVector; + template using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3< tensor_layout::gemm::RowMajor, @@ -399,7 +409,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, + DirectLoad ? ABlockTransferSrcScalarPerVectorAligned : ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, @@ -407,7 +417,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, + DirectLoad ? BBlockTransferSrcScalarPerVectorAligned : BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, @@ -418,7 +428,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, - ComputeTypeB>; + ComputeTypeB, + DirectLoad>; using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; @@ -1367,6 +1378,15 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } #endif + // check device + if constexpr(DirectLoad) + { + if(get_device_name() != "gfx950") + { + return false; + } + } + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); const index_t GemmK = @@ -1617,8 +1637,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 auto str = std::stringstream(); // clang-format off - str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3" - << "<" + str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; + + if constexpr(DirectLoad) { + str << "_DirectLoad"; + } + + str << "<" << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 8188c42ca5e..74d3cc9c866 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -13,6 +13,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" namespace ck { @@ -61,9 +62,14 @@ template + typename ComputeTypeB = ComputeTypeA, + bool DirectLoad = false> struct GridwiseGemm_xdl_cshuffle_conv_v3 { + static_assert((is_same_v && + is_same_v) || + !DirectLoad); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -79,6 +85,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 static constexpr auto AK1Number = Number{}; static constexpr auto BK1Number = Number{}; + static constexpr bool DirectLoadEnabled = DirectLoad; + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && @@ -279,7 +287,14 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 constexpr index_t ABlockLdsExtraM = ABlockLdsExtraMCustom; #endif // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM) + if constexpr(DirectLoad) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{}, I1, Number{})); + } + // A matrix in LDS memory, dst of blockwise copy + else if constexpr(ABlockLdsExtraM) { return make_naive_tensor_descriptor( make_tuple(AK0Number, Number{}, AK1Number), @@ -425,7 +440,13 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 constexpr index_t BBlockLdsExtraN = BBlockLdsExtraNCustom; #endif // B matrix in LDS memory, dst of blockwise copy - if constexpr(BBlockLdsExtraN) + if constexpr(DirectLoad) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{}, I1, Number{})); + } + else if constexpr(BBlockLdsExtraN) { return make_naive_tensor_descriptor( make_tuple(BK0Number, Number{}, BK1Number), @@ -573,6 +594,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) + static constexpr bool LdsScalarLoad = DirectLoad; using BlockwiseGemmPipe = remove_cvref_t())>; + KPack, + DirectLoad, + LdsScalarLoad>())>; __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { @@ -724,67 +748,119 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - a_grid_desc_ak0_m_ak1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + auto get_a_blockwise_copy = [&]() { + if constexpr(DirectLoad) + { + return ThreadGroupTensorSliceTransfer_DirectLoad< + ThisThreadBlock, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 1, + ABlockTransferSrcScalarPerVector>( + a_grid_desc_ak0_m_ak1, + make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); + } + else + { + return ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + } + }; // B matrix blockwise copy - auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - b_grid_desc_bk0_n_bk1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + auto get_b_blockwise_copy = [&]() { + if constexpr(DirectLoad) + { + return ThreadGroupTensorSliceTransfer_DirectLoad< + ThisThreadBlock, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 1, + BBlockTransferSrcScalarPerVector>( + b_grid_desc_bk0_n_bk1, + make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); + } + else + { + return ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + BElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + } + }; + + auto a_blockwise_copy = get_a_blockwise_copy(); + auto b_blockwise_copy = get_b_blockwise_copy(); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( @@ -1091,67 +1167,119 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - a_grid_desc_ak0_m_ak1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + auto get_a_blockwise_copy = [&]() { + if constexpr(DirectLoad) + { + return ThreadGroupTensorSliceTransfer_DirectLoad< + ThisThreadBlock, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 1, + ABlockTransferSrcScalarPerVector>( + a_grid_desc_ak0_m_ak1, + make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); + } + else + { + return ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + } + }; // B matrix blockwise copy - auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - b_grid_desc_bk0_n_bk1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + auto get_b_blockwise_copy = [&]() { + if constexpr(DirectLoad) + { + return ThreadGroupTensorSliceTransfer_DirectLoad< + ThisThreadBlock, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 1, + BBlockTransferSrcScalarPerVector>( + b_grid_desc_bk0_n_bk1, + make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); + } + else + { + return ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + BElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + } + }; + + auto a_blockwise_copy = get_a_blockwise_copy(); + auto b_blockwise_copy = get_b_blockwise_copy(); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp index 143d8573336..86b7e03af6a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp @@ -101,6 +101,52 @@ using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances = std::tuple // clang-format on >; +template +using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_direct_load_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| Compute| Compute| Direct| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| Data| Data| Load| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| Type| Type| | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | | + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 32, 64, 8, 16, 16, 1, 1, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 64, 64, 8, 16, 16, 1, 2, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 32, 64, 8, 32, 32, 2, 1, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Scheduler, PipelineVersion, F16, F16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 64, 8, 32, 32, 1, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Scheduler, PipelineVersion, F16, F16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Scheduler, PipelineVersion, F16, F16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 2, Scheduler, PipelineVersion, F16, F16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 32, 1, 4>, 4, Scheduler, PipelineVersion, F16, F16, true> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_direct_load_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| Compute| Compute| Direct| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| Data| Data| Load| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| Type| Type| | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | | + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 32, 64, 8, 16, 16, 1, 1, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, BF16, BF16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 64, 64, 8, 16, 16, 1, 2, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, BF16, BF16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 32, 64, 8, 32, 32, 2, 1, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Scheduler, PipelineVersion, BF16, BF16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 64, 8, 32, 32, 1, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Scheduler, PipelineVersion, BF16, BF16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Scheduler, PipelineVersion, BF16, BF16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 2, Scheduler, PipelineVersion, BF16, BF16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 32, 1, 4>, 4, Scheduler, PipelineVersion, BF16, BF16, true> + // clang-format on + >; + template >>& instances); +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instances( std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_direct_load_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load.cpp new file mode 100644 index 00000000000..7035a4cf71f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_direct_load_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 25cb0283ed01458d46b789d169ebeec5867bd3e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Mon, 26 Jan 2026 13:16:15 +0100 Subject: [PATCH 2/8] Update gridwise_gemm_xdl_cshuffle_conv_v3.hpp --- .../gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 74d3cc9c866..cd557632dcc 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -443,8 +443,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 if constexpr(DirectLoad) { return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(Number{}, I1, Number{})); + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{}, I1, Number{})); } else if constexpr(BBlockLdsExtraN) { From a9db3100b8c2a30b6a7266034951253092531ea5 Mon Sep 17 00:00:00 2001 From: "Graner, Johannes" Date: Thu, 22 Jan 2026 09:17:14 -0500 Subject: [PATCH 3/8] Implement group merging for bwd_weight and add instances --- ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 31 ++++++++++++++----- ...rouped_conv_bwd_weight_v3_xdl_instance.hpp | 3 ++ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index cd07a108ef9..cea62ef281a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -237,7 +237,8 @@ template + bool DirectLoad = false, + index_t NumGroupsToMerge = 1> struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 : public DeviceGroupedConvBwdWeight{}; template ::type = false> @@ -664,15 +665,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 if(split_k_offset_hack_) split_k_stride_b_ /= k_batch_; - // A/B/C Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0]; + // A/B/C Batch Stride (multiply by NumGroupsToMerge for group merging) + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0] * NumGroupsToMerge; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0] * NumGroupsToMerge; compute_ptr_offset_of_batch_.BatchStrideC_ = Conv_K_ * Conv_C_ * std::accumulate(begin(filter_spatial_lengths_), end(filter_spatial_lengths_), index_t{1}, - std::multiplies<>{}); + std::multiplies<>{}) * + NumGroupsToMerge; const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); @@ -754,7 +756,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 index_t gdx, gdy, gdz; std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( - gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_); + gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumGroupsToMerge); float ave_time = 0; @@ -1387,6 +1389,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } } + // Check that NumGroupsToMerge divides Conv_G evenly + if constexpr(NumGroupsToMerge > 1) + { + if(arg.Conv_G_ % NumGroupsToMerge != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported! Conv_G_ % NumGroupsToMerge != 0: Conv_G_=" + << arg.Conv_G_ << ", NumGroupsToMerge=" << NumGroupsToMerge + << std::endl; + } + return false; + } + } + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); const index_t GemmK = diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp index 86b7e03af6a..3a3dc156ec8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp @@ -115,9 +115,12 @@ using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_direct_load_instances //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| Type| Type| | //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | | DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 32, 64, 8, 16, 16, 1, 1, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 32, 64, 8, 16, 16, 1, 1, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true, 2>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 64, 64, 8, 16, 16, 1, 2, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 64, 64, 8, 16, 16, 1, 2, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true, 2>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 32, 64, 8, 32, 32, 2, 1, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Scheduler, PipelineVersion, F16, F16, true>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 64, 8, 32, 32, 1, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Scheduler, PipelineVersion, F16, F16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 64, 8, 32, 32, 1, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Scheduler, PipelineVersion, F16, F16, true, 2>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Scheduler, PipelineVersion, F16, F16, true>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 2, Scheduler, PipelineVersion, F16, F16, true>, DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 32, 1, 4>, 4, Scheduler, PipelineVersion, F16, F16, true> From b51d7aec7e04ac83f1071e1833b26447b034a1b8 Mon Sep 17 00:00:00 2001 From: "Graner, Johannes" Date: Mon, 26 Jan 2026 08:38:47 -0500 Subject: [PATCH 4/8] Link direct load instances --- .../gpu/grouped_conv2d_bwd_weight/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index ec9e7da3911..268835d5bfd 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -20,6 +20,8 @@ set(GROUPED_CONV2D_BWD_WEIGHT xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp + xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load.cpp + xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instance.cpp xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp From b4d3fda7a1010fe653e132b13d56eb7c3ece265d Mon Sep 17 00:00:00 2001 From: Bartlomiej Kocot Date: Tue, 27 Jan 2026 11:35:16 +0000 Subject: [PATCH 5/8] builder fixes --- .../ck_tile/builder/factory/conv_algorithms.hpp | 2 +- ..._grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 17 ++++++++++++++--- ...test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp | 3 ++- .../builder/test/impl/conv_algorithm_types.hpp | 3 ++- ...stance_string_bwd_weight_grp_conv_xdl_v3.cpp | 2 ++ 5 files changed, 21 insertions(+), 6 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 79b818555e8..c508126adb1 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -35,7 +35,7 @@ template concept BwdXdlV3AlgorithmBase = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && SpecifiesGridwiseBwdXdlGemm && SpecifiesBwdWeightConvSpecialization && - SpecifiesBlockGemm; + SpecifiesBlockGemm && SpecifiesNumGroupsToMerge; template concept BwdWmmaAlgorithmBase = diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 147028f9cfb..385691c8a2e 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -53,7 +53,9 @@ template + typename ComputeTypeB, + bool DirectLoad, + index_t NumGroupsToMerge> struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3; } // namespace ck::tensor_operation::device @@ -104,7 +106,9 @@ template + typename ComputeTypeB_, + bool DirectLoad, + index_t NumGroupsToMerge> struct InstanceTraits> + ComputeTypeB_, + DirectLoad, + NumGroupsToMerge>> { static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; @@ -213,6 +219,9 @@ struct InstanceTraits(); // 42. oss << "," << detail::type_name(); // 43. + oss << "," << kDirectLoad; // 44. + oss << "," << kNumGroupsToMerge; // 45. oss << ">"; return oss.str(); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp index a3f4a988ef1..80c6646211f 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp @@ -32,7 +32,8 @@ constexpr auto ALGORITHM = .with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) - .with_block_gemm(cku::BlockGemmDesc_v2_intrawave); + .with_block_gemm(cku::BlockGemmDesc_v2_intrawave) + .with_num_conv_groups_to_merge(1); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index b775505a265..f5b9bdc3b55 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -632,7 +632,8 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 = BwdXdlGemm_, Transfer_<>, ConvSpecializationBwdWeight_, - BlockGemm_>; + BlockGemm_, + GemmBatchOptions_>; using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl = ConvAlgorithmTemplate"; // Test describe() through base class pointer for XDL V3 variant From 82c17910b8cf21cc5b295b54689436f9527cea74 Mon Sep 17 00:00:00 2001 From: Bartlomiej Kocot Date: Tue, 27 Jan 2026 07:21:54 -0500 Subject: [PATCH 6/8] fix --- ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 59 +++++++++++++------ 1 file changed, 42 insertions(+), 17 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index cea62ef281a..75e1504add2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -82,23 +82,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - DispatchSplitKHack(karg.p_a_grid + a_batch_offset + split_k_offset_a, - karg.p_b_grid + b_batch_offset + split_k_offset_b, - karg.p_c_grid + e_batch_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx * num_k_per_block, - gridDim.y, - split_k_offset_hack); + if constexpr(GridwiseGemm::DirectLoadEnabled) + { +#if defined(__gfx950__) + DispatchSplitKHack(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_hack); +#endif + } + else + { + DispatchSplitKHack(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_hack); + } } #else ignore = karg; From ad99f5a8f79757b7380cebf0f2df2e36a0ab080a Mon Sep 17 00:00:00 2001 From: Bartlomiej Kocot Date: Tue, 27 Jan 2026 10:32:21 -0500 Subject: [PATCH 7/8] fixes --- .../gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp | 8 ++++---- .../block/blockwise_gemm_pipeline_xdlops_selector.hpp | 8 ++++---- .../gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp | 10 +++++----- .../device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 1 + ...e_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 1 + .../gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp | 4 ++-- 6 files changed, 17 insertions(+), 15 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index 3dc5d87630d..ffd728f2593 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -30,8 +30,8 @@ template + bool TransposeC = false, + bool LdsScalarLoadToVgpr = false> struct BlockwiseGemmXdlops_pipeline_base { static constexpr auto I0 = Number<0>{}; @@ -386,7 +386,7 @@ struct BlockwiseGemmXdlops_pipeline_base Sequence<1, 1, 1, KPack>, Sequence<0, 1, 2, 3>, 3, - LdsScalarLoad ? 1 : A_K1, + LdsScalarLoadToVgpr ? 1 : A_K1, A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, Sequence<0, 1, 2, 3>, 3, - LdsScalarLoad ? 1 : B_K1, + LdsScalarLoadToVgpr ? 1 : B_K1, B_K1>; AThreadCopy a_thread_copy_; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp index 9f797e70814..461ca513f97 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp @@ -32,12 +32,12 @@ template + bool DirectLoad = false, + bool LdsScalarLoadToVgpr = false> constexpr auto BlockGemmPipeline_Selector() { // Supported for Direct Load and V1 - if constexpr(LdsScalarLoad) + if constexpr(LdsScalarLoadToVgpr) { static_assert(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1); } @@ -65,7 +65,7 @@ constexpr auto BlockGemmPipeline_Selector() MRepeat, NRepeat, KPack, - LdsScalarLoad>{}; + LdsScalarLoadToVgpr>{}; } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp index cfe9fe5e107..ae4504d6baa 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -759,7 +759,7 @@ template + bool LdsScalarLoadToVgpr = false> struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1 { }; @@ -784,7 +784,7 @@ template + bool LdsScalarLoadToVgpr> struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1 + LdsScalarLoadToVgpr> : BlockwiseGemmXdlops_pipeline_base + LdsScalarLoadToVgpr> { using Base = BlockwiseGemmXdlops_pipeline_base; + LdsScalarLoadToVgpr>; using Base::I0; using Base::KRepeat; using Base::xdlops_gemm; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 75e1504add2..1c6ba50eaad 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -398,6 +398,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; + // Disable vector load = 4. It is not supported for Direct Load. Align to 2 in such case. static constexpr index_t ABlockTransferSrcScalarPerVectorAligned = ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8 ? 4 / sizeof(ADataType) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index f94311bc353..19a7228f78a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -567,6 +567,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 using DsGridDesc_M_N = remove_cvref_t; + // Disable vector load = 4. It is not supported for Direct Load. Align to 2 in such case. static constexpr index_t ABlockTransferSrcScalarPerVectorAligned = ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8 ? 4 / sizeof(ADataType) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index cd557632dcc..07a2f571119 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -594,7 +594,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) - static constexpr bool LdsScalarLoad = DirectLoad; + static constexpr bool LdsScalarLoadToVgpr = DirectLoad; using BlockwiseGemmPipe = remove_cvref_t())>; + LdsScalarLoadToVgpr>())>; __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { From 19d96f3dd6b543253ebf270ae64fb3c5c92e3b5f Mon Sep 17 00:00:00 2001 From: Bartlomiej Kocot Date: Tue, 27 Jan 2026 19:24:51 -0500 Subject: [PATCH 8/8] fix --- .../gridwise_gemm_xdl_cshuffle_conv_v3.hpp | 61 ++++++++++--------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index c472d031e5a..5289e209fbe 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -114,7 +114,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 static_assert((is_same_v && is_same_v) || !DirectLoad); - + using Base = GridwiseGemm_xdl_cshuffle_base< ALayout, BLayout, @@ -366,7 +366,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 return make_naive_tensor_descriptor( make_tuple(AK0Number, Number{}, AK1Number), make_tuple(Number{}, I1, Number{})); - } else if constexpr(is_same_v) + } + else if constexpr(is_same_v) { // Force use padded layout on gfx950 to reduce bank conflicts constexpr index_t ABlockLdsExtraM = 1; @@ -388,7 +389,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 return make_naive_tensor_descriptor( make_tuple(BK0Number, Number{}, BK1Number), make_tuple(Number{}, I1, Number{})); - } else if constexpr(is_same_v) + } + else if constexpr(is_same_v) { constexpr index_t BBlockLdsExtraN = 1; return make_naive_tensor_descriptor( @@ -403,33 +405,36 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) - using BlockwiseGemmPipe = remove_cvref_t< - decltype(BlockGemmPipeline_Selector< - BlkGemmPipelineVer, - BlkGemmPipeSched, - BlockSize, - ADataType, - BDataType, - ComputeTypeA, - AccDataType, - decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), - decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())), - decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( + // Disable vector load from lds to vgpr for direct load (backward weight store with continous M + // or N dimension) + static constexpr bool LdsScalarLoadToVgpr = DirectLoad; + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ADataType, + BDataType, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), - decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))), - ABlockTransferSrcScalarPerVector, - BBlockTransferSrcScalarPerVector, - MPerBlock, - NPerBlock, - KPerBlock, - MPerXdl, - NPerXdl, - MXdlPerWave, - NXdlPerWave, - KPack, - DirectLoad, - LdsScalarLoadToVgpr>())>; + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + DirectLoad, + LdsScalarLoadToVgpr>())>; template __device__ static constexpr index_t GetSharedMemoryNumberOfByte(DeviceArch)