device_grouped_conv_bwd_weight_wmma_cshuffle.hpp Source File#
device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
ConvolutionBackwardWeightSpecialization
Definition convolution_backward_weight_specialization.hpp:13
@ Filter1x1Stride1Pad0
Definition convolution_backward_weight_specialization.hpp:15
constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
Definition device_grouped_conv_utils.hpp:88
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
Definition device_grouped_conv_utils.hpp:80
std::string getConvBackwardWeightSpecializationString(const ConvolutionBackwardWeightSpecialization &s)
Definition convolution_backward_weight_specialization.hpp:21
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__global__ void kernel_grouped_conv_multiple_d_wmma_cshuffle(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const index_t batch_count, const AGridDesc_AK0_M_AK1 a_grid_desc, const BGridDesc_BK0_N_BK1 b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_, const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:40
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:326
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:809
__host__ static __device__ constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N_ &e_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:819
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &e_grid_desc_m_n, index_t, index_t)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:850
__host__ static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N_ &ds_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:840
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc, const BGridDesc_K0_N_K1 &b_grid_desc, const CGridDesc_M_N &e_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:608
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:186
Definition utility/integral_constant.hpp:20
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:454
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:557
OutElementwiseOperation a_element_op_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:564
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:562
const index_t k_batch_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:579
InElementwiseOperation b_element_op_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:565
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:555
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:554
std::array< index_t, NDimSpatial > output_spatial_lengths_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:575
const std::array< index_t, NDimSpatial > & conv_filter_strides_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:576
CGridDesc_M_N c_grid_desc_m_n_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:556
std::array< index_t, NDimSpatial > filter_spatial_lengths_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:574
std::array< index_t, NDimSpatial > input_spatial_lengths_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:573
const index_t Conv_K_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:571
CDataType * p_c_grid_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:553
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle::Argument::block_2_ctile_map_
Block2CTileMap block_2_ctile_map_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:559
const std::array< index_t, NDimSpatial > & input_right_pads_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:578
const BDataType * p_b_grid_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:552
const ADataType * p_a_grid_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:551
const index_t Conv_C_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:572
Argument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, index_t split_k)
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:455
const index_t Conv_G_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:569
const index_t Conv_N_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:570
WeiElementwiseOperation c_element_op_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:566
const std::array< index_t, NDimSpatial > & input_left_pads_
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:577
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:584
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:679
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:603
void Print(const Argument &arg)
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:587
DeviceOp::Argument Argument
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:585
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:79
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:383
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:382
OutElementwiseOperation AElementwiseOperation
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:86
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:692
static constexpr auto make_wei_grid_desc(const index_t K, const index_t Z, const index_t Y, const index_t X, const index_t C, const std::array< index_t, NDimSpatial+3 > &weights_strides)
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:148
static auto MakeArgument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, const index_t split_k)
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:760
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_grid, void *p_wei_grid, const void *p_out_grid, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, const index_t split_k) override
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:800
static constexpr auto make_in_grid_desc(const index_t N, const index_t Di, const index_t Hi, const index_t Wi, const index_t C, const std::array< index_t, NDimSpatial+3 > &input_strides)
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:120
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:686
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( CGridDesc_M_N{})) CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:446
GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, Tuple<>, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWMMA, NPerWMMA, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, true, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, true, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, NumGemmKPrefetchStage, LoopSched, PipelineVer > GridwiseGemm
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:388
DeviceGroupedConvBwdWeight_Wmma_CShuffle DeviceOp
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:80
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap( CGridDesc_M_N{}, I1, I1)) Block2CTileMap
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:450
static constexpr index_t KPerBlock
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:101
static constexpr auto I3
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:96
InDataType BDataType
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:83
InElementwiseOperation BElementwiseOperation
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:87
static constexpr auto I4
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:97
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:754
static constexpr auto I0
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:93
AccDataType CShuffleDataType
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:386
InDataType ABDataType
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:91
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:837
static constexpr auto I5
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:98
static constexpr auto I2
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:95
static auto MakeInvoker()
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:797
static constexpr auto make_out_grid_desc(const index_t N, const index_t Do, const index_t Ho, const index_t Wo, const index_t K, const std::array< index_t, NDimSpatial+3 > &output_strides)
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:105
WeiElementwiseOperation CElementwiseOperation
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:88
static constexpr auto GemmK1Number
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:100
static constexpr auto I1
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:94
decltype(GetABCGridDesc< NDimSpatial >()) ABCGridDescs
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:380
static auto GetABCGridDesc()
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:359
std::string GetTypeString() const override
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:842
WeiDataType CDataType
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:84
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t N, const index_t K, const index_t C, const std::array< index_t, NDimSpatial > &input_spatial_lengths, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, NDimSpatial+3 > &input_strides, const std::array< index_t, NDimSpatial+3 > &weights_strides, const std::array< index_t, NDimSpatial+3 > &output_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads)
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:162
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(Tuple<>{})) DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:443
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:384
OutDataType ADataType
Definition device_grouped_conv_bwd_weight_wmma_cshuffle.hpp:82
Definition device_grouped_conv_bwd_weight.hpp:29