diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp index 79b818555e8..c508126adb1 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_algorithms.hpp @@ -35,7 +35,7 @@ template concept BwdXdlV3AlgorithmBase = ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesTileTransferParameters3D && SpecifiesGridwiseBwdXdlGemm && SpecifiesBwdWeightConvSpecialization && - SpecifiesBlockGemm; + SpecifiesBlockGemm && SpecifiesNumGroupsToMerge; template concept BwdWmmaAlgorithmBase = diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 147028f9cfb..385691c8a2e 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -53,7 +53,9 @@ template + typename ComputeTypeB, + bool DirectLoad, + index_t NumGroupsToMerge> struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3; } // namespace ck::tensor_operation::device @@ -104,7 +106,9 @@ template + typename ComputeTypeB_, + bool DirectLoad, + index_t NumGroupsToMerge> struct InstanceTraits> + ComputeTypeB_, + DirectLoad, + NumGroupsToMerge>> { static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; @@ -213,6 +219,9 @@ struct InstanceTraits(); // 42. oss << "," << detail::type_name(); // 43. + oss << "," << kDirectLoad; // 44. + oss << "," << kNumGroupsToMerge; // 45. oss << ">"; return oss.str(); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp index a3f4a988ef1..80c6646211f 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp @@ -32,7 +32,8 @@ constexpr auto ALGORITHM = .with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave) .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0) - .with_block_gemm(cku::BlockGemmDesc_v2_intrawave); + .with_block_gemm(cku::BlockGemmDesc_v2_intrawave) + .with_num_conv_groups_to_merge(1); using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index b775505a265..f5b9bdc3b55 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -632,7 +632,8 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 = BwdXdlGemm_, Transfer_<>, ConvSpecializationBwdWeight_, - BlockGemm_>; + BlockGemm_, + GemmBatchOptions_>; using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl = ConvAlgorithmTemplate"; // Test describe() through base class pointer for XDL V3 variant diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index 512a019ec80..ffd728f2593 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -30,7 +30,8 @@ template + bool TransposeC = false, + bool LdsScalarLoadToVgpr = false> struct BlockwiseGemmXdlops_pipeline_base { static constexpr auto I0 = Number<0>{}; @@ -385,7 +386,7 @@ struct BlockwiseGemmXdlops_pipeline_base Sequence<1, 1, 1, KPack>, Sequence<0, 1, 2, 3>, 3, - A_K1, + LdsScalarLoadToVgpr ? 1 : A_K1, A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, Sequence<0, 1, 2, 3>, 3, - B_K1, + LdsScalarLoadToVgpr ? 1 : B_K1, B_K1>; AThreadCopy a_thread_copy_; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp index 7ddb4e9b749..461ca513f97 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp @@ -32,9 +32,15 @@ template + bool DirectLoad = false, + bool LdsScalarLoadToVgpr = false> constexpr auto BlockGemmPipeline_Selector() { + // Supported for Direct Load and V1 + if constexpr(LdsScalarLoadToVgpr) + { + static_assert(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1); + } if constexpr(DirectLoad) { if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) @@ -58,7 +64,8 @@ constexpr auto BlockGemmPipeline_Selector() NPerXDL, MRepeat, NRepeat, - KPack>{}; + KPack, + LdsScalarLoadToVgpr>{}; } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp index 5604a31091e..ae4504d6baa 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -758,7 +758,8 @@ template + index_t KPacks, + bool LdsScalarLoadToVgpr = false> struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1 { }; @@ -781,9 +782,9 @@ template + bool LdsScalarLoadToVgpr> struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1 + KPack, + LdsScalarLoadToVgpr> : BlockwiseGemmXdlops_pipeline_base + KPack, + false /*TransposeC*/, + LdsScalarLoadToVgpr> { using Base = BlockwiseGemmXdlops_pipeline_base; + KPack, + false /*TransposeC*/, + LdsScalarLoadToVgpr>; using Base::I0; using Base::KRepeat; using Base::xdlops_gemm; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp index ad74ee847eb..a31c9101a16 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp @@ -140,10 +140,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad "Direct load transfer does not support datatypes conversion. Source and " "destination data types must be the same."); - static_assert( - DstVectorDim == nDim - 1, - "Direct load transfer requires the destination vector dimension to be the last one."); - static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim, "When loading more than one element per thread at once, the contiguous " "dimension must be the same between source and destination."); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 175b4625bac..26cf586017a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -82,23 +82,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; - DispatchSplitKHack(karg.p_a_grid + a_batch_offset + split_k_offset_a, - karg.p_b_grid + b_batch_offset + split_k_offset_b, - karg.p_c_grid + e_batch_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx * num_k_per_block, - gridDim.y, - split_k_offset_hack); + if constexpr(GridwiseGemm::DirectLoadEnabled) + { +#if defined(__gfx950__) + DispatchSplitKHack(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_hack); +#endif + } + else + { + DispatchSplitKHack(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_hack); + } } #else ignore = karg; @@ -236,7 +261,9 @@ template + typename ComputeTypeB = ComputeTypeA, + bool DirectLoad = false, + index_t NumGroupsToMerge = 1> struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 : public DeviceGroupedConvBwdWeight{}; template ::type = false> @@ -371,6 +398,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; + // Disable vector load = 4. It is not supported for Direct Load. Align to 2 in such case. + static constexpr index_t ABlockTransferSrcScalarPerVectorAligned = + ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8 + ? 4 / sizeof(ADataType) + : ABlockTransferSrcScalarPerVector; + static constexpr index_t BBlockTransferSrcScalarPerVectorAligned = + BBlockTransferSrcScalarPerVector * sizeof(BDataType) == 8 + ? 4 / sizeof(BDataType) + : BBlockTransferSrcScalarPerVector; + template using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3< tensor_layout::gemm::RowMajor, @@ -399,7 +436,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, + DirectLoad ? ABlockTransferSrcScalarPerVectorAligned : ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, @@ -407,7 +444,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, + DirectLoad ? BBlockTransferSrcScalarPerVectorAligned : BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, @@ -418,7 +455,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, - ComputeTypeB>; + ComputeTypeB, + DirectLoad>; using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; @@ -653,15 +691,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 if(split_k_offset_hack_) split_k_stride_b_ /= k_batch_; - // A/B/C Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0]; + // A/B/C Batch Stride (multiply by NumGroupsToMerge for group merging) + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0] * NumGroupsToMerge; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0] * NumGroupsToMerge; compute_ptr_offset_of_batch_.BatchStrideC_ = Conv_K_ * Conv_C_ * std::accumulate(begin(filter_spatial_lengths_), end(filter_spatial_lengths_), index_t{1}, - std::multiplies<>{}); + std::multiplies<>{}) * + NumGroupsToMerge; const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); @@ -743,7 +782,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 index_t gdx, gdy, gdz; std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( - gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_); + gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumGroupsToMerge); float ave_time = 0; @@ -1367,6 +1406,30 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } #endif + // check device + if constexpr(DirectLoad) + { + if(get_device_name() != "gfx950") + { + return false; + } + } + + // Check that NumGroupsToMerge divides Conv_G evenly + if constexpr(NumGroupsToMerge > 1) + { + if(arg.Conv_G_ % NumGroupsToMerge != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported! Conv_G_ % NumGroupsToMerge != 0: Conv_G_=" + << arg.Conv_G_ << ", NumGroupsToMerge=" << NumGroupsToMerge + << std::endl; + } + return false; + } + } + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); const index_t GemmK = @@ -1617,8 +1680,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 auto str = std::stringstream(); // clang-format off - str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3" - << "<" + str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; + + if constexpr(DirectLoad) { + str << "_DirectLoad"; + } + + str << "<" << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index d8cb5f4a8cf..15a5e088035 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -567,6 +567,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 using DsGridDesc_M_N = remove_cvref_t; + // Disable vector load = 4. It is not supported for Direct Load. Align to 2 in such case. static constexpr index_t ABlockTransferSrcScalarPerVectorAligned = ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8 ? 4 / sizeof(ADataType) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index cfbfaf3262a..5289e209fbe 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -12,6 +12,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp" namespace ck { @@ -61,7 +62,8 @@ template + typename ComputeTypeB = ComputeTypeA, + bool DirectLoad = false> struct GridwiseGemm_xdl_cshuffle_conv_v3 : public GridwiseGemm_xdl_cshuffle_base< ALayout, @@ -109,6 +111,10 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 ComputeTypeB, false> // ForceNaiveLayout { + static_assert((is_same_v && + is_same_v) || + !DirectLoad); + using Base = GridwiseGemm_xdl_cshuffle_base< ALayout, BLayout, @@ -164,6 +170,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 using Base::I2; using ThisThreadBlock = typename Base::ThisThreadBlock; + static constexpr bool DirectLoadEnabled = DirectLoad; + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); static constexpr bool is_single_rate_mfma = (((is_same::value || is_same::value) && @@ -353,7 +361,13 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 template __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(DeviceArch) { - if constexpr(is_same_v) + if constexpr(DirectLoad) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{}, I1, Number{})); + } + else if constexpr(is_same_v) { // Force use padded layout on gfx950 to reduce bank conflicts constexpr index_t ABlockLdsExtraM = 1; @@ -370,7 +384,13 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 template __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(DeviceArch) { - if constexpr(is_same_v) + if constexpr(DirectLoad) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{}, I1, Number{})); + } + else if constexpr(is_same_v) { constexpr index_t BBlockLdsExtraN = 1; return make_naive_tensor_descriptor( @@ -385,31 +405,36 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType) - using BlockwiseGemmPipe = remove_cvref_t< - decltype(BlockGemmPipeline_Selector< - BlkGemmPipelineVer, - BlkGemmPipeSched, - BlockSize, - ADataType, - BDataType, - ComputeTypeA, - AccDataType, - decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), - decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())), - decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( + // Disable vector load from lds to vgpr for direct load (backward weight store with continous M + // or N dimension) + static constexpr bool LdsScalarLoadToVgpr = DirectLoad; + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ADataType, + BDataType, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), - decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))), - ABlockTransferSrcScalarPerVector, - BBlockTransferSrcScalarPerVector, - MPerBlock, - NPerBlock, - KPerBlock, - MPerXdl, - NPerXdl, - MXdlPerWave, - NXdlPerWave, - KPack>())>; + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + DirectLoad, + LdsScalarLoadToVgpr>())>; template __device__ static constexpr index_t GetSharedMemoryNumberOfByte(DeviceArch) @@ -539,67 +564,119 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); - // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - a_grid_desc_ak0_m_ak1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + auto get_a_blockwise_copy = [&]() { + if constexpr(DirectLoad) + { + return ThreadGroupTensorSliceTransfer_DirectLoad< + ThisThreadBlock, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 1, + ABlockTransferSrcScalarPerVector>( + a_grid_desc_ak0_m_ak1, + make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); + } + else + { + return ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + } + }; // B matrix blockwise copy - auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - b_grid_desc_bk0_n_bk1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + auto get_b_blockwise_copy = [&]() { + if constexpr(DirectLoad) + { + return ThreadGroupTensorSliceTransfer_DirectLoad< + ThisThreadBlock, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 1, + BBlockTransferSrcScalarPerVector>( + b_grid_desc_bk0_n_bk1, + make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); + } + else + { + return ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + BElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + } + }; + + auto a_blockwise_copy = get_a_blockwise_copy(); + auto b_blockwise_copy = get_b_blockwise_copy(); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( @@ -722,67 +799,119 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()); - // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - a_grid_desc_ak0_m_ak1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + auto get_a_blockwise_copy = [&]() { + if constexpr(DirectLoad) + { + return ThreadGroupTensorSliceTransfer_DirectLoad< + ThisThreadBlock, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 1, + ABlockTransferSrcScalarPerVector>( + a_grid_desc_ak0_m_ak1, + make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); + } + else + { + return ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + } + }; // B matrix blockwise copy - auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - b_grid_desc_bk0_n_bk1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + auto get_b_blockwise_copy = [&]() { + if constexpr(DirectLoad) + { + return ThreadGroupTensorSliceTransfer_DirectLoad< + ThisThreadBlock, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 1, + BBlockTransferSrcScalarPerVector>( + b_grid_desc_bk0_n_bk1, + make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); + } + else + { + return ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + BElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + } + }; + + auto a_blockwise_copy = get_a_blockwise_copy(); + auto b_blockwise_copy = get_b_blockwise_copy(); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp index 143d8573336..3a3dc156ec8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp @@ -101,6 +101,55 @@ using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances = std::tuple // clang-format on >; +template +using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_direct_load_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| Compute| Compute| Direct| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| Data| Data| Load| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| Type| Type| | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | | + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 32, 64, 8, 16, 16, 1, 1, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 32, 64, 8, 16, 16, 1, 1, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true, 2>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 64, 64, 8, 16, 16, 1, 2, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 64, 64, 8, 16, 16, 1, 2, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true, 2>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 32, 64, 8, 32, 32, 2, 1, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Scheduler, PipelineVersion, F16, F16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 64, 8, 32, 32, 1, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Scheduler, PipelineVersion, F16, F16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 64, 8, 32, 32, 1, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Scheduler, PipelineVersion, F16, F16, true, 2>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Scheduler, PipelineVersion, F16, F16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 2, Scheduler, PipelineVersion, F16, F16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 32, 1, 4>, 4, Scheduler, PipelineVersion, F16, F16, true> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_direct_load_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| Compute| Compute| Direct| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| Data| Data| Load| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| Type| Type| | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | | + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 32, 64, 8, 16, 16, 1, 1, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, BF16, BF16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 64, 64, 8, 16, 16, 1, 2, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, BF16, BF16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 32, 64, 8, 32, 32, 2, 1, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Scheduler, PipelineVersion, BF16, BF16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 64, 8, 32, 32, 1, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Scheduler, PipelineVersion, BF16, BF16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Scheduler, PipelineVersion, BF16, BF16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 2, Scheduler, PipelineVersion, BF16, BF16, true>, + DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 32, 1, 4>, 4, Scheduler, PipelineVersion, BF16, BF16, true> + // clang-format on + >; + template >>& instances); +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instances( std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_direct_load_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load.cpp new file mode 100644 index 00000000000..7035a4cf71f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_direct_load_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck