Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ template <typename T>
concept BwdXdlV3AlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
SpecifiesBlockGemm<T>;
SpecifiesBlockGemm<T> && SpecifiesNumGroupsToMerge<T>;

template <typename T>
concept BwdWmmaAlgorithmBase =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ template <ck::index_t NDimSpatial,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
typename ComputeTypeB>
typename ComputeTypeB,
bool DirectLoad,
index_t NumGroupsToMerge>
struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3;

} // namespace ck::tensor_operation::device
Expand Down Expand Up @@ -104,7 +106,9 @@ template <ck::index_t NDimSpatial,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA_,
typename ComputeTypeB_>
typename ComputeTypeB_,
bool DirectLoad,
index_t NumGroupsToMerge>
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<
NDimSpatial,
InLayout_,
Expand Down Expand Up @@ -148,7 +152,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA_,
ComputeTypeB_>>
ComputeTypeB_,
DirectLoad,
NumGroupsToMerge>>
{
static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";

Expand Down Expand Up @@ -213,6 +219,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
using ComputeTypeA = ComputeTypeA_;
using ComputeTypeB = ComputeTypeB_;

static constexpr bool kDirectLoad = DirectLoad;
static constexpr index_t kNumGroupsToMerge = NumGroupsToMerge;

// Static member function to generate instance string
static std::string instance_string()
{
Expand Down Expand Up @@ -274,6 +283,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
oss << "," << detail::pipeline_version_name(kBlkGemmPipelineVer); // 41.
oss << "," << detail::type_name<ComputeTypeA>(); // 42.
oss << "," << detail::type_name<ComputeTypeB>(); // 43.
oss << "," << kDirectLoad; // 44.
oss << "," << kNumGroupsToMerge; // 45.
oss << ">";

return oss.str();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ constexpr auto ALGORITHM =
.with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0)
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave);
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave)
.with_num_conv_groups_to_merge(1);

using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
Expand Down
3 changes: 2 additions & 1 deletion experimental/builder/test/impl/conv_algorithm_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,8 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 =
BwdXdlGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
BlockGemm_>;
BlockGemm_,
GemmBatchOptions_>;

using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl =
ConvAlgorithmTemplate<ThreadBlock_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ std::string expected_str =
",v1" // BlkGemmPipelineVer
",fp16" // ComputeTypeA
",fp16" // ComputeTypeB
",0" // DirectLoad
",1" // NumGroupsToMerge
">";

// Test describe() through base class pointer for XDL V3 variant
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeC = false>
bool TransposeC = false,
bool LdsScalarLoadToVgpr = false>
struct BlockwiseGemmXdlops_pipeline_base
{
static constexpr auto I0 = Number<0>{};
Expand Down Expand Up @@ -385,7 +386,7 @@ struct BlockwiseGemmXdlops_pipeline_base
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
LdsScalarLoadToVgpr ? 1 : A_K1,
A_K1>;

using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
Expand All @@ -395,7 +396,7 @@ struct BlockwiseGemmXdlops_pipeline_base
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
LdsScalarLoadToVgpr ? 1 : B_K1,
B_K1>;

AThreadCopy a_thread_copy_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,15 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool DirectLoad = false>
bool DirectLoad = false,
bool LdsScalarLoadToVgpr = false>
constexpr auto BlockGemmPipeline_Selector()
{
// Supported for Direct Load and V1
if constexpr(LdsScalarLoadToVgpr)
{
static_assert(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1);
}
if constexpr(DirectLoad)
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
Expand All @@ -58,7 +64,8 @@ constexpr auto BlockGemmPipeline_Selector()
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
KPack,
LdsScalarLoadToVgpr>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,8 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPacks>
index_t KPacks,
bool LdsScalarLoadToVgpr = false>
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1
{
};
Expand All @@ -781,9 +782,9 @@ template <index_t BlockSize,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack
index_t KPack,
// ,bool TransposeC //disable transposec right now...
>
bool LdsScalarLoadToVgpr>
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
Expand All @@ -803,7 +804,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
NPerXDL,
MRepeat,
NRepeat,
KPack>
KPack,
LdsScalarLoadToVgpr>
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
Expand All @@ -822,7 +824,9 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
NPerXDL,
MRepeat,
NRepeat,
KPack>
KPack,
false /*TransposeC*/,
LdsScalarLoadToVgpr>

{
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
Expand All @@ -843,7 +847,9 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
NPerXDL,
MRepeat,
NRepeat,
KPack>;
KPack,
false /*TransposeC*/,
LdsScalarLoadToVgpr>;
using Base::I0;
using Base::KRepeat;
using Base::xdlops_gemm;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"Direct load transfer does not support datatypes conversion. Source and "
"destination data types must be the same.");

static_assert(
DstVectorDim == nDim - 1,
"Direct load transfer requires the destination vector dimension to be the last one.");

static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim,
"When loading more than one element per thread at once, the contiguous "
"dimension must be the same between source and destination.");
Expand Down
Loading