device_gemm_wmma_cshuffle_v3r1.hpp Source File#
device_gemm_wmma_cshuffle_v3r1.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__global__ void kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:40
STL namespace.
Definition ck/stream_config.hpp:10
static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:624
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:273
static __host__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:837
static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:852
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:233
Definition multi_index_transform.hpp:13
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:186
Definition utility/tuple.hpp:117
Definition reduction_operator.hpp:37
Definition device_base.hpp:197
BaseInvoker()=default
Definition device_gemm_wmma_cshuffle_v3r1.hpp:149
Argument(std::array< const void *, 1 > p_a_grid_, std::array< const void *, 1 > p_b_grid_, const ::std::array< const void *, NumDTensor > p_ds_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, std::array< index_t, 1 > StrideA_, std::array< index_t, 1 > StrideB_, const ::std::array< index_t, NumDTensor > stride_ds_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition device_gemm_wmma_cshuffle_v3r1.hpp:150
CDataType * p_c_grid
Definition device_gemm_wmma_cshuffle_v3r1.hpp:188
const ::std::array< const void *, NumDTensor > p_ds
Definition device_gemm_wmma_cshuffle_v3r1.hpp:190
CElementwiseOperation c_element_op
Definition device_gemm_wmma_cshuffle_v3r1.hpp:189
::std::array< index_t, NumDTensor > StrideDs
Definition device_gemm_wmma_cshuffle_v3r1.hpp:191
Definition device_gemm_wmma_cshuffle_v3r1.hpp:226
float RunReduce(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_wmma_cshuffle_v3r1.hpp:227
float Run(const Argument &arg_, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_wmma_cshuffle_v3r1.hpp:299
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_wmma_cshuffle_v3r1.hpp:368
Definition device_gemm_wmma_cshuffle_v3r1.hpp:91
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition device_gemm_wmma_cshuffle_v3r1.hpp:94
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_wmma_cshuffle_v3r1.hpp:375
static size_t GetSharedMemoryNumberOfByte()
Definition device_gemm_wmma_cshuffle_v3r1.hpp:412
static constexpr index_t NumDTensor
Definition device_gemm_wmma_cshuffle_v3r1.hpp:92
GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, GemmAccDataType, ReduceDataType, Tuple<>, ReduceDataType, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, false, false > GridwiseGemm
Definition device_gemm_wmma_cshuffle_v3r1.hpp:96
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const ::std::array< const void *, NumDTensor > p_ds, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, const ::std::array< index_t, NumDTensor > stride_ds, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_wmma_cshuffle_v3r1.hpp:417
static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition device_gemm_wmma_cshuffle_v3r1.hpp:405
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_wmma_cshuffle_v3r1.hpp:381
static constexpr auto DsVectorLengthSequence
Definition device_gemm_wmma_cshuffle_v3r1.hpp:197
static auto MakeInvoker()
Definition device_gemm_wmma_cshuffle_v3r1.hpp:450
::std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, ::std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, ::std::array< index_t, NumDTensor > DsStrides, index_t StrideC, index_t KSplit, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_wmma_cshuffle_v3r1.hpp:459
ck::reduce::Add ReduceAdd
Definition device_gemm_wmma_cshuffle_v3r1.hpp:194
CElementwiseOperation OutElementwiseOperation
Definition device_gemm_wmma_cshuffle_v3r1.hpp:195
DeviceReduceThreadWiseMultiD< ReduceDataType, DsDataType, GemmAccDataType, CDataType, 3, 1, ReduceAdd, PassThrough, OutElementwiseOperation, 256, CShuffleBlockTransferScalarPerVector_NPerBlock, 1, 0, CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, decltype(DsVectorLengthSequence)> DeviceReduceInstance
Definition device_gemm_wmma_cshuffle_v3r1.hpp:207
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_wmma_cshuffle_v3r1.hpp:400
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_gemm_wmma_cshuffle_v3r1.hpp:546
::std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_wmma_cshuffle_v3r1.hpp:453
static constexpr index_t GetBlockSize()
Definition device_gemm_wmma_cshuffle_v3r1.hpp:410
::std::string GetTypeString() const override
Definition device_gemm_wmma_cshuffle_v3r1.hpp:492
Definition device_gemm_v2.hpp:57
Definition device_reduce_threadwise_multi_d.hpp:47
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > inLengths, const std::array< index_t, Rank > inStrides, const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsLengths, const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsStrides, const std::array< index_t, NumDstDim > outLengths, const std::array< index_t, NumDstDim > outStrides, const std::array< int, NumReduceDim > reduceDims, const void *in_dev, const std::array< const void *, NumDTensor > ds_dev, void *out_dev, const InElementwiseOperation in_elementwise_op, const OutElementwiseOperation out_elementwise_op) override
Definition device_reduce_threadwise_multi_d.hpp:363
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340