29template <
typename ALayout,
37 typename GemmAccDataType,
38 typename CShuffleDataType,
39 typename AElementwiseOperation,
40 typename BElementwiseOperation,
41 typename CElementwiseOperation,
53 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
54 typename ABlockTransferThreadClusterArrangeOrder,
55 typename ABlockTransferSrcAccessOrder,
56 index_t ABlockTransferSrcVectorDim,
57 index_t ABlockTransferSrcScalarPerVector,
58 index_t ABlockTransferDstScalarPerVector_AK1,
60 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
61 typename BBlockTransferThreadClusterArrangeOrder,
62 typename BBlockTransferSrcAccessOrder,
63 index_t BBlockTransferSrcVectorDim,
64 index_t BBlockTransferSrcScalarPerVector,
65 index_t BBlockTransferDstScalarPerVector_BK1,
67 index_t CShuffleMXdlPerWavePerShuffle,
68 index_t CShuffleNXdlPerWavePerShuffle,
69 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
70 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
73 typename ReduceDataType = CDataType,
74 typename ComputeTypeA = CDataType,
75 typename ComputeTypeB = ComputeTypeA>
84 AElementwiseOperation,
85 BElementwiseOperation,
86 CElementwiseOperation>
97 template <index_t NXdlPerWave_>
107 AElementwiseOperation,
108 BElementwiseOperation,
121 ABlockTransferThreadClusterLengths_AK0_M_AK1,
122 ABlockTransferThreadClusterArrangeOrder,
123 ABlockTransferSrcAccessOrder,
124 ABlockTransferSrcVectorDim,
125 ABlockTransferSrcScalarPerVector,
126 ABlockTransferDstScalarPerVector_AK1,
129 BBlockTransferThreadClusterLengths_BK0_N_BK1,
130 BBlockTransferThreadClusterArrangeOrder,
131 BBlockTransferSrcAccessOrder,
132 BBlockTransferSrcVectorDim,
133 BBlockTransferSrcScalarPerVector,
134 BBlockTransferDstScalarPerVector_BK1,
137 CShuffleMXdlPerWavePerShuffle,
138 CShuffleNXdlPerWavePerShuffle,
139 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
140 CShuffleBlockTransferScalarPerVector_NPerBlock,
151 const BDataType* p_b_grid_,
152 const std::array<const void*, NumDTensor> p_ds_,
153 CDataType* p_c_grid_,
159 std::array<ck::index_t, NumDTensor> StrideDs_,
164 reinterpret_cast<ReduceDataType*>(p_c_grid_),
178 const std::array<const void*, NumDTensor>
p_ds;
188 if constexpr(std::is_same<CLayout, DLayout>::value)
206 CShuffleBlockTransferScalarPerVector_NPerBlock,
209 CShuffleBlockTransferScalarPerVector_NPerBlock,
210 CShuffleBlockTransferScalarPerVector_NPerBlock,
218 static constexpr index_t NumInDim = 3;
219 static constexpr index_t NumOutDim = 2;
221 std::array<ck::index_t, NumInDim> in_lengths = {arg.KBatch, arg.M, arg.N};
222 std::array<ck::index_t, NumOutDim> out_lengths = {arg.M, arg.N};
224 std::array<ck::index_t, NumInDim> in_strides;
225 std::array<ck::index_t, NumOutDim> out_strides;
226 if constexpr(std::is_same<CLayout, ck::tensor_layout::gemm::RowMajor>::value)
228 in_strides = {arg.M * arg.N, arg.N, 1};
229 out_strides = {arg.N, 1};
233 in_strides = {arg.M * arg.N, 1, arg.M};
234 out_strides = {1, arg.M};
237 std::array<int, 1> reduce_dims{0};
239 std::array<std::array<index_t, NumOutDim>,
NumDTensor> DsLengths;
240 std::array<std::array<index_t, NumOutDim>,
NumDTensor> DsStrides;
242 static_for<0, NumDTensor, 1>{}([&](
auto i) {
243 DsLengths[i] = out_lengths;
246 if constexpr(std::is_same<DLayout, ck::tensor_layout::gemm::RowMajor>::value)
248 DsStrides[i] = {arg.
StrideDs[i], 1};
252 DsStrides[i] = {1, arg.
StrideDs[i]};
271 auto invoker_ptr = reduce.MakeInvokerPointer();
275 if(reduce.IsSupportedArgument(argument_ptr.get()))
277 ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config);
281 throw std::runtime_error(
282 "The runtime parameters seems not supported by the device instance, exiting!");
288 template <
typename Gr
idwiseGemm>
291 auto arg = *
reinterpret_cast<const typename GridwiseGemm::Argument*
>(&arg_);
293 if(!(!(arg.IsReduceAdd() ||
NumDTensor > 0) &&
294 std::is_same<CDataType, ReduceDataType>::value))
296 if(arg.p_workspace_ ==
nullptr)
298 throw std::runtime_error(
"using reduce , but empty workspace!");
301 arg.p_c_grid =
reinterpret_cast<ReduceDataType*
>(arg.p_workspace_);
304 if(stream_config.log_level_ > 0)
309 if(!GridwiseGemm::CheckValidity(arg))
311 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
315 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
319 index_t k_grain = arg.KBatch * KPerBlock;
320 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
322 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
324 const auto Run = [&](
const auto& kernel) {
325 if(stream_config.flush_cache)
329 stream_config.rotating_count,
330 arg.M * arg.K *
sizeof(ADataType),
331 arg.K * arg.N *
sizeof(BDataType));
332 rotating_mem.Print();
334 auto run_flush_cache = [&]() {
353 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
357 constexpr index_t minimum_occupancy =
360 if(has_main_k_block_loop)
376 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::One)
386 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
397 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
399 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
411 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
425 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
439 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
453 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
455 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
467 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
485 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
508 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
544 if(!(!(arg.IsReduceAdd() ||
NumDTensor > 0) &&
545 std::is_same<CDataType, ReduceDataType>::value))
548 ave_time +=
RunReduce(arg_, stream_config);
559 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
595 reinterpret_cast<const typename GridwiseGemm32::Argument&
>(arg));
608 const BDataType* p_b,
609 const std::array<const void*, NumDTensor> p_ds,
616 std::array<ck::index_t, NumDTensor> StrideDs,
619 AElementwiseOperation,
620 BElementwiseOperation,
621 CElementwiseOperation)
623 return Argument{p_a, p_b, p_ds, p_c, M, N, K, StrideA, StrideB, StrideDs, StrideC, KBatch};
631 std::array<const void*, NumDTensor> p_ds,
638 std::array<ck::index_t, NumDTensor> StrideDs,
641 AElementwiseOperation,
642 BElementwiseOperation,
643 CElementwiseOperation)
override
645 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
646 static_cast<const BDataType*
>(p_b),
648 static_cast<CDataType*
>(p_c),
662 return std::make_unique<Invoker>(
Invoker{});
668 auto str = std::stringstream();
670 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
674 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
682 str <<
"DeviceGemmXdlUniversalReduce"
685 << std::string(ALayout::name)[0]
686 << std::string(BLayout::name)[0]
687 << std::string(CLayout::name)[0]
692 << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock <<
", "
694 << MPerXDL<<
"x"<<NPerXDL <<
", "
696 << MXdlPerWave<<
"x" << NXdlPerWave<<
", "
698 << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", "
699 <<
"BlkGemmPipelineScheduler: "
700 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
701 <<
"BlkGemmPipelineVersion: "
702 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
703 <<
"BlkGemmPipelinePrefetchStages: "
704 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
712 auto arg = *
dynamic_cast<const Argument*
>(p_arg);
714 if(!(!(arg.IsReduceAdd() ||
NumDTensor > 0) &&
715 std::is_same<CDataType, ReduceDataType>::value))
717 std::cout <<
"using workspace" << std::endl;
718 return arg.M * arg.N * arg.KBatch *
sizeof(ReduceDataType);
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
integral_constant< index_t, N > Number
Definition number.hpp:12
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
Definition ck/stream_config.hpp:10
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:247
ck::GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, ReduceDataType, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1202
Definition reduction_operator.hpp:37
Definition device_base.hpp:197
Definition device_gemm_xdl_cshuffle_v3r1.hpp:149
const std::array< const void *, NumDTensor > p_ds
Definition device_gemm_xdl_cshuffle_v3r1.hpp:178
std::array< ck::index_t, NumDTensor > StrideDs
Definition device_gemm_xdl_cshuffle_v3r1.hpp:179
Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, const std::array< const void *, NumDTensor > p_ds_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< ck::index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_)
Definition device_gemm_xdl_cshuffle_v3r1.hpp:150
Definition device_gemm_xdl_cshuffle_v3r1.hpp:215
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_cshuffle_v3r1.hpp:556
float RunReduce(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_cshuffle_v3r1.hpp:216
float RunImp(const Argument &arg_, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_cshuffle_v3r1.hpp:289
Definition device_gemm_xdl_cshuffle_v3r1.hpp:87
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition device_gemm_xdl_cshuffle_v3r1.hpp:629
static constexpr index_t NumDTensor
Definition device_gemm_xdl_cshuffle_v3r1.hpp:92
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_gemm_xdl_cshuffle_v3r1.hpp:710
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_cshuffle_v3r1.hpp:660
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_cshuffle_v3r1.hpp:90
DeviceReduceThreadWiseMultiD< ReduceDataType, DsDataType, GemmAccDataType, CDataType, 3, 1, ReduceAdd, PassThrough, OutElementwiseOperation, 256, CShuffleBlockTransferScalarPerVector_NPerBlock, 1, 0, CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, decltype(DsVectorLengthSequence)> DeviceReduceInstance
Definition device_gemm_xdl_cshuffle_v3r1.hpp:195
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_cshuffle_v3r1.hpp:563
static constexpr auto DsVectorLengthSequence
Definition device_gemm_xdl_cshuffle_v3r1.hpp:185
ck::reduce::Add ReduceAdd
Definition device_gemm_xdl_cshuffle_v3r1.hpp:182
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_cshuffle_v3r1.hpp:89
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_cshuffle_v3r1.hpp:569
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_cshuffle_v3r1.hpp:602
GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, ReduceDataType, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_gemm_xdl_cshuffle_v3r1.hpp:98
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const std::array< const void *, NumDTensor > p_ds, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition device_gemm_xdl_cshuffle_v3r1.hpp:607
static auto MakeInvoker()
Definition device_gemm_xdl_cshuffle_v3r1.hpp:626
CElementwiseOperation OutElementwiseOperation
Definition device_gemm_xdl_cshuffle_v3r1.hpp:183
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition device_gemm_xdl_cshuffle_v3r1.hpp:94
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_cshuffle_v3r1.hpp:146
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_cshuffle_v3r1.hpp:145
std::string GetTypeString() const override
Definition device_gemm_xdl_cshuffle_v3r1.hpp:666
Definition device_gemm_v2.hpp:57
Definition device_reduce_threadwise_multi_d.hpp:47
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > inLengths, const std::array< index_t, Rank > inStrides, const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsLengths, const std::array< std::array< index_t, NumDstDim >, NumDTensor > DsStrides, const std::array< index_t, NumDstDim > outLengths, const std::array< index_t, NumDstDim > outStrides, const std::array< int, NumReduceDim > reduceDims, const void *in_dev, const std::array< const void *, NumDTensor > ds_dev, void *out_dev, const InElementwiseOperation in_elementwise_op, const OutElementwiseOperation out_elementwise_op) override
Definition device_reduce_threadwise_multi_d.hpp:363
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition flush_cache.hpp:299