device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp Source File#
device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
Go to the documentation of this file.
141 : public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Default
Definition convolution_forward_specialization.hpp:16
__global__ void kernel_gemm_xdlops_v2r3_for_conv3d(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const index_t num_batches, const index_t a_batch_stride, const index_t b_batch_stride, const index_t c_batch_stride, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:43
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk_pad(const TensorDescriptor< In... > &in_grid_desc_n_di_hi_wi_c, const TensorDescriptor< Wei... > &wei_k_z_y_x_c_grid_desc, const TensorDescriptor< Out... > &out_n_do_ho_wo_k_grid_desc, const ConvStrides &conv_strides, const ConvDilations &conv_dilations, const InLeftPads &in_left_pads, const InRightPads &in_right_pads, Number< GemmK1Value >)
Definition transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp:28
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdlops_v2r3.hpp:142
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:322
index_t a_batch_stride_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:392
InElementwiseOperation in_element_op_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:402
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:398
OutDataType * p_c_grid_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:390
index_t M01_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:400
Argument(const InDataType *p_in, const WeiDataType *p_wei, OutDataType *p_out, const index_t N, const index_t K, const index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, index_t M01, index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:323
CGridDesc_M_N c_grid_desc_m_n_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:397
index_t N01_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:401
index_t b_batch_stride_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:393
WeiElementwiseOperation wei_element_op_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:403
OutElementwiseOperation out_element_op_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:404
const InDataType * p_a_grid_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:388
const WeiDataType * p_b_grid_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:389
Block2CTileMap block_2_ctile_map_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:399
index_t c_batch_stride_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:394
index_t num_subbatches_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:391
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:396
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:395
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:409
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:413
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:525
DeviceOp::Argument Argument
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:410
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:143
typename GridwiseGemm::DefaultBlock2CTileMap Block2CTileMap
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:318
DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K DeviceOp
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:144
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:269
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:270
static auto MakeArgument(const InDataType *p_in, const WeiDataType *p_wei, OutDataType *p_out, const index_t N, const index_t K, const index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:556
static bool IsSupportedArgument(const Argument &arg)
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:538
static constexpr auto NXdlPerWave32
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:148
static constexpr auto I3
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:159
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:314
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t N, const index_t K, const index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:213
InDataType ABDataType
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:154
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:271
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, InDataType, AccDataType, OutDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, Sequence< 1, 0, 2 >, Sequence< 1, 0, 2 >, 2, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, Sequence< 1, 0, 2 >, Sequence< 1, 0, 2 >, 2, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, Sequence< 2, 3, 0, 1, 7, 5, 4, 6 >, 7, CThreadTransferDstScalarPerVector > GridwiseGemmBase
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:274
InDataType ADataType
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:150
static constexpr auto I2
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:158
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:636
std::string GetTypeString() const override
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:641
static constexpr auto I1
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:157
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:313
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in, const void *p_wei, void *p_out, const index_t N, const index_t K, const index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) override
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:597
remove_cvref_t< decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( 1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}))> ABCGridDescs
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:266
static auto MakeInvoker()
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:593
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})) CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:316
WeiDataType BDataType
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:151
OutDataType CDataType
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:152
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:551
static constexpr auto I0
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:156
static index_t GetMaxAllowableSubBatchSize(const index_t N, const index_t K, const index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths)
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:166
static constexpr bool IsValidCompilationParameter()
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:532
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp:147
Definition device_conv_fwd.hpp:25