device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp Source File

device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp Source File#

Composable Kernel: device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp Source File
device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <functional>
7#include <iostream>
8#include <iterator>
9#include <numeric>
10#include <sstream>
11
14#include "ck/utility/env.hpp"
31#ifdef CK_EXPERIMENTAL_BUILDER
32#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
33#endif
34
35namespace ck {
36namespace tensor_operation {
37namespace device {
38
39namespace {
40
41/*
42 * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
43 *
44 * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
45 * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
46 * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
47 * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
48 * limitations.
49 *
50 * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
51 * returns the 2D index of the tile that it computes. \see
52 * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
53 *
54 * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
55 * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
56 * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
57 * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
58 * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
59 * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
60 *
61 * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
62 * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
63 * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
64 *
65 */
66template <typename GridwiseGemm,
67 typename AsPointer, // tuples if multi AB, pointers if no
68 typename BsPointer,
69 typename DsPointer,
70 typename EDataType,
71 typename AElementwiseOperation,
72 typename BElementwiseOperation,
73 typename CDEElementwiseOperation,
74 typename AGridDesc_AK0_M_AK1,
75 typename BGridDesc_BK0_N_BK1,
76 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
77 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
78 typename Block2ETileMap,
79 typename ComputePtrOffsetOfG,
80 typename ComputePtrOffsetOfN,
81 bool HasMainKBlockLoop,
82 bool isMultiA,
83 bool isMultiB,
84 bool CTranspose>
85__global__ void
86#if CK_USE_LAUNCH_BOUNDS
88#endif
89 kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
90 AsPointer p_as_grid,
91 BsPointer p_bs_grid,
92 DsPointer p_ds_grid,
93 EDataType* __restrict__ p_e_grid,
94 AElementwiseOperation a_element_op,
95 BElementwiseOperation b_element_op,
96 CDEElementwiseOperation cde_element_op,
97 const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
98 const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
99 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
100 ds_grid_desc_mblock_mperblock_nblock_nperblock,
101 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
102 e_grid_desc_mblock_mperblock_nblock_nperblock_,
103 const Block2ETileMap block_2_ctile_map,
104 const ComputePtrOffsetOfG compute_ptr_offset_of_groups,
105 const ComputePtrOffsetOfN compute_ptr_offset_of_n)
106{
107#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
108 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
109 {
110 // offset base pointer for each work-group
111 const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
112 const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
113
114 const long_index_t e_group_offset =
115 amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
116 const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
117 const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
118
119 const long_index_t e_n_offset =
120 amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
121
122 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
123
124 DsPointer p_ds_grid_grp;
125
126 static constexpr index_t NumDTensor =
127 DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
128
129 static_for<0, NumDTensor, 1>{}(
130 [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; });
131
132 if constexpr(isMultiA || isMultiB)
133 {
134 AsPointer p_as_grid_grp;
135 BsPointer p_bs_grid_grp;
136
137 const auto& as_group_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx);
138
139 // compute_ptr_offset_of_n_ not need BatchStrideB so
140 // in case of MultiA is false but isMultiB is true
141 // BatchStrideA_ is not tuple.
142 if constexpr(isMultiA)
143 {
144 const auto& as_n_offset = compute_ptr_offset_of_n.GetAsPtrOffset(n_idx);
145
146 static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
147 static_for<0, NumATensor, 1>{}([&](auto i) {
148 p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + as_n_offset[i];
149 });
150 }
151 else
152 {
153 const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
154 static_for<0, 1, 1>{}([&](auto i) {
155 p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + a_n_offset;
156 });
157 }
158
159 const auto& bs_group_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx);
160
161 static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size();
162 static_for<0, NumBTensor, 1>{}(
163 [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_group_offset[i]; });
164
165 GridwiseGemm::template Run<HasMainKBlockLoop>(
166 p_as_grid_grp,
167 p_bs_grid_grp,
168 p_ds_grid_grp,
169 p_e_grid + e_group_offset + e_n_offset,
170 p_shared,
171 a_element_op,
172 b_element_op,
173 cde_element_op,
174 a_grid_desc_k0_m_k1,
175 b_grid_desc_k0_n_k1,
176 ds_grid_desc_mblock_mperblock_nblock_nperblock,
177 e_grid_desc_mblock_mperblock_nblock_nperblock_,
178 block_2_ctile_map);
179 }
180 else
181 {
182 const long_index_t b_group_offset =
183 CTranspose
184 ? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx))
185 : amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
186 const long_index_t a_group_offset =
187 CTranspose
188 ? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx))
189 : amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
190 const long_index_t b_n_offset =
191 CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx))
192 : 0;
193 const long_index_t a_n_offset =
194 CTranspose ? 0
195 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
196
197 GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
198 p_as_grid + a_group_offset + a_n_offset,
199 p_bs_grid + b_group_offset + b_n_offset,
200 p_ds_grid_grp,
201 p_e_grid + e_group_offset + e_n_offset,
202 p_shared,
203 a_element_op,
204 b_element_op,
205 cde_element_op,
206 a_grid_desc_k0_m_k1,
207 b_grid_desc_k0_n_k1,
208 ds_grid_desc_mblock_mperblock_nblock_nperblock,
209 e_grid_desc_mblock_mperblock_nblock_nperblock_,
210 block_2_ctile_map);
211 }
212 }
213#else
214 ignore = p_as_grid;
215 ignore = p_bs_grid;
216 ignore = p_ds_grid;
217 ignore = p_e_grid;
218 ignore = a_grid_desc_k0_m_k1;
219 ignore = b_grid_desc_k0_n_k1;
220 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
221 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
222 ignore = a_element_op;
223 ignore = b_element_op;
224 ignore = cde_element_op;
225 ignore = compute_ptr_offset_of_groups;
226 ignore = compute_ptr_offset_of_n;
227 ignore = block_2_ctile_map;
228#endif
229}
230
231} // namespace
232#ifdef CK_CODE_GEN_RTC
233template <typename T>
234using is_tuple = decltype(ck::declval<T&>().IsTuple());
235#else
236template <typename T>
237using is_tuple = decltype(std::declval<T&>().IsTuple());
238#endif
239
240//
241// @brief Device Convolution operation.
242//
243// Supports:
244// @li Forward convolution with up to 3 spatial dimentions
245// @li Input tensor in GNWC data format
246// @li Weight tensor in GKXC data format
247// @li Output tensor in GNWK data format
248//
249// 1D:
250// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
251// 2D:
252// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
253// 3D:
254// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
255//
256template <index_t NDimSpatial,
257 typename ALayout,
258 typename BLayout,
259 typename DsLayout,
260 typename ELayout,
261 typename ADataType,
262 typename BDataType,
263 typename AccDataType,
264 typename CShuffleDataType,
265 typename DsDataType,
266 typename EDataType,
267 typename AElementwiseOperation,
268 typename BElementwiseOperation,
269 typename CDEElementwiseOperation,
270 ConvolutionForwardSpecialization ConvForwardSpecialization,
271 GemmSpecialization GemmSpec,
272 index_t NumGemmKPrefetchStage,
273 index_t BlockSize,
274 index_t MPerBlock,
275 index_t NPerBlock,
276 index_t KPerBlock,
277 index_t AK1,
278 index_t BK1,
279 index_t MPerXDL,
280 index_t NPerXDL,
281 index_t MXdlPerWave,
282 index_t NXdlPerWave,
283 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
284 typename ABlockTransferThreadClusterArrangeOrder,
285 typename ABlockTransferSrcAccessOrder,
286 index_t ABlockTransferSrcVectorDim,
287 index_t ABlockTransferSrcScalarPerVector,
288 index_t ABlockTransferDstScalarPerVector_AK1,
289 index_t ABlockLdsExtraM,
290 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
291 typename BBlockTransferThreadClusterArrangeOrder,
292 typename BBlockTransferSrcAccessOrder,
293 index_t BBlockTransferSrcVectorDim,
294 index_t BBlockTransferSrcScalarPerVector,
295 index_t BBlockTransferDstScalarPerVector_BK1,
296 index_t BBlockLdsExtraN,
297 index_t CShuffleMXdlPerWavePerShuffle,
298 index_t CShuffleNXdlPerWavePerShuffle,
299 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
300 index_t CDEBlockTransferScalarPerVector_NPerBlock,
301 typename AComputeDataType =
302 decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
303 Number<0>,
304 ADataType>()), // ComputeType is InputType by default (first
305 // in tuple for MultiAB), unpack if tuple was
306 // passed
307 typename BComputeDataType = AComputeDataType,
309 index_t NumGroupsToMerge = 1>
311 : public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
312 ALayout,
313 BLayout,
314 DsLayout,
315 ELayout,
316 ADataType,
317 BDataType,
318 DsDataType,
319 EDataType,
320 AElementwiseOperation,
321 BElementwiseOperation,
322 CDEElementwiseOperation,
323 AComputeDataType,
324 BComputeDataType>
325{
328 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
329 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
330
331 static_assert(NumGroupsToMerge >= 1);
332
335 static constexpr bool isMultiAB = isMultiA || isMultiB;
336
337 // NGCHW is not supported for multiAB
340 !(isMultiA || isMultiB));
341
342 static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
343 static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
344 static constexpr index_t NumDTensor = DsDataType::Size();
345
346 static constexpr bool DoElementwiseBeforeCShuffle =
349
350 static constexpr auto I0 = Number<0>{};
351 static constexpr auto I1 = Number<1>{};
352 static constexpr auto I2 = Number<2>{};
353 static constexpr auto I3 = Number<3>{};
354 static constexpr auto I4 = Number<4>{};
355 static constexpr auto I5 = Number<5>{};
356
357 static constexpr bool isATensorColMajor =
358 (ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) &&
359 (ABlockTransferSrcVectorDim == 1) && (NumGroupsToMerge == 1) &&
362
363 static constexpr bool NeedTransposeKernel =
366
367 static constexpr bool CTranspose = (NeedTransposeKernel == false) && (isMultiAB == false) &&
370
372 ConvForwardSpecialization,
373 true /*SplitN*/,
374 ADataType,
375 EDataType,
376 NumGroupsToMerge,
377 index_t,
378 CTranspose>;
379
381 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
382
383 static constexpr auto conv_ngchw_to_nhwgc_transformer =
385 BLayout,
386 ELayout,
387 NDimSpatial,
388 NPerBlock / ClusterLengthNPerBlock,
389 NPerBlock / ClusterLengthNPerBlock>{};
390
391 static constexpr auto matrix_padder =
392 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
393
394 template <typename ALay>
395 static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
396 {
397 namespace ctc = tensor_layout::convolution;
398 using Layout = std::conditional_t<
400 ctc::NHWGC,
401 std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
402 ctc::NDHWGC,
403 ALay>>;
404
405 const auto in_gemmmraw_gemmkraw_desc =
406 conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
407
408 const auto in_gemmm_gemmk_desc =
409 matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
410
411 return in_gemmm_gemmk_desc;
412 }
413
414 template <typename BLay>
415 static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
416 {
417 namespace ctc = tensor_layout::convolution;
418 using Layout = std::conditional_t<
420 ctc::GKYXC,
421 std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
422 ctc::GKZYXC,
423 BLay>>;
424
425 const auto wei_gemmnraw_gemmkraw_desc =
426 conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
427
428 const auto wei_gemmn_gemmk_desc =
429 matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
430
431 return wei_gemmn_gemmk_desc;
432 }
433
434 template <typename ELay>
435 static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
436 {
437 namespace ctc = tensor_layout::convolution;
438 using Layout = std::conditional_t<
440 ctc::NHWGK,
441 std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>() && NeedTransposeKernel,
442 ctc::NDHWGK,
443 ELay>>;
444
445 const auto out_gemmmraw_gemmnraw_desc =
446 conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
447 if constexpr(CTranspose)
448 {
449 constexpr auto matrix_padder_trans =
450 MatrixPadder<GemmSpec, index_t, index_t, index_t>{NPerBlock, MPerBlock, KPerBlock};
451 return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
452 }
453 else
454 {
455 return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
456 }
457 }
458
459 // Shape of Ds and E must be aligned. Strides can be different.
460 // Pass e_g_n_k_wos_lengths for logical broadcast.
461 static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
462 {
463 return generate_tuple(
464 [&](auto i) {
465 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
466
467 return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
468 },
470 }
471
472 // desc for problem definition
482
483 // If we are using multiAB and one of the template datatype parameters is not a tuple, convert
484 // it to it
485 using GemmADataType = std::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
486 using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
487
488#define GridwiseGemmMultiABDTemplateParameters \
489 GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
490 EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
491 InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
492 KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, \
493 ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
494 ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
495 ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
496 ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
497 BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
498 BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
499 BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
500 CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
501 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
502 CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
503 BComputeDataType
504
505#define GridwiseGemmTemplateParameters \
506 GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
507 EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
508 NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
509 NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
510 ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
511 ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
512 ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
513 BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \
514 BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \
515 BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \
516 BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
517 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
518 CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
519 BComputeDataType, DoElementwiseBeforeCShuffle
520
521#define GridwiseGemmCTransposeTemplateParameters \
522 GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
523 EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \
524 NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \
525 MPerXDL, NXdlPerWave_, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
526 BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
527 BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
528 BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
529 ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
530 ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
531 ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
532 ABlockLdsExtraM, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
533 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
534 CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
535 BComputeDataType, DoElementwiseBeforeCShuffle
536
537 // Use appropriate gridwise gemm
538 template <index_t NXdlPerWave_>
541 template <index_t NXdlPerWave_>
544 template <index_t NXdlPerWave_>
547
549 std::conditional_t<isMultiA || isMultiB,
552 using GridwiseGemm32 = std::conditional_t<isMultiA || isMultiB,
555
557 std::conditional_t<CTranspose,
561 std::conditional_t<CTranspose,
564
565 // If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers.
566 using APointers =
567 std::conditional_t<isMultiA, std::array<const void*, NumATensor>&, const void*>;
568 using BPointers =
569 std::conditional_t<isMultiB, std::array<const void*, NumBTensor>&, const void*>;
570 // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
571 // in initializer list what is required for single const pointer).
573 decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm64, ADataType > ())>;
575 decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm64, BDataType > ())>;
576
577 // desc for blockwise copy
579 remove_cvref_t<decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(
580 AGridDesc_M_K{}))>;
582 remove_cvref_t<decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(
583 BGridDesc_N_K{}))>;
585 decltype(GridwiseGemmCTranspose64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
586 DsGridDesc_M_N{}))>;
588 decltype(GridwiseGemmCTranspose64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
589 EGridDesc_M_N{}))>;
590
591 // block-to-e-tile map
593 remove_cvref_t<decltype(GridwiseGemmCTranspose64::MakeDefaultBlock2ETileMap(
594 EGridDesc_M_N{}))>;
596
599 .template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
602 .template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
603
606 .template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
609 .template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
610
612
613 // NPerBlock is used for the first and second dim which to use
614 // CDEBlockTransferScalarPerVector_NPerBlock for load and store during
615 // transposition. CBlockTransferScalarPerVector_NWaveNPerXdl is aligned to
616 // NPerBlock so it is more flexible to use this dim for load store dimension
617 // with such scalar per vector.
626 NPerBlock,
627 NPerBlock,
628 NPerBlock / ClusterLengthNPerBlock,
629 NPerBlock / ClusterLengthNPerBlock,
633 I1,
634 I0>;
635
644 NPerBlock,
645 NPerBlock,
646 NPerBlock / ClusterLengthNPerBlock,
647 NPerBlock / ClusterLengthNPerBlock,
651 I0,
652 I1>;
653
662 NPerBlock,
663 NPerBlock,
664 NPerBlock / ClusterLengthNPerBlock,
665 NPerBlock / ClusterLengthNPerBlock,
669 I0,
670 I1>;
671
672 // Argument
673 struct Argument : public BaseArgument
674 {
675 template <typename GridwiseGemm, typename GridwiseGemmCTranspose>
677 {
678 // populate desc for Ds/E
679 if constexpr(isMultiA || isMultiB)
680 {
681 const auto as_grid_desc_ak0_m_ak1 =
682 generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number<NumATensor>{});
683 const auto bs_grid_desc_bk0_n_bk1 =
684 generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number<NumBTensor>{});
685
686 if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1,
687 bs_grid_desc_bk0_n_bk1,
691 {
693 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
695
697 GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
699 }
700 }
701 else
702 {
703 bool valid = false;
704 if constexpr(CTranspose)
705 {
706 valid = GridwiseGemmCTranspose::CheckValidity(b_grid_desc_n_k_,
711 }
712 else
713 {
714 valid = GridwiseGemmCTranspose::CheckValidity(a_grid_desc_m_k_,
719 }
720 if(valid)
721 {
722 e_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose::
723 MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
724
725 ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose::
726 MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_);
727 }
728 }
729 };
730
732 BPointers p_bs,
733 const std::array<const void*, NumDTensor>& p_ds,
734 void* p_e,
735 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
736 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
737 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
738 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
739 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
740 ds_g_n_k_wos_lengths,
741 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
742 ds_g_n_k_wos_strides,
743 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
744 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
745 const std::array<index_t, NDimSpatial>& conv_filter_strides,
746 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
747 const std::array<index_t, NDimSpatial>& input_left_pads,
748 const std::array<index_t, NDimSpatial>& input_right_pads,
749 const AElementwiseOperation& a_element_op,
750 const BElementwiseOperation& b_element_op,
751 const CDEElementwiseOperation& cde_element_op)
752 : p_as_grid_{},
753 p_bs_grid_{},
754 p_ds_grid_{},
755 p_e_grid_{static_cast<EDataType*>(p_e)},
756 a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
758 ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
759 a_g_n_c_wis_lengths, a_g_n_c_wis_strides)
760 : a_g_n_c_wis_strides},
761 b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
763 ? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
764 b_g_k_c_xs_lengths, b_g_k_c_xs_strides)
765 : b_g_k_c_xs_strides},
766 ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
767 ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
768 e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
770 ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
771 e_g_n_k_wos_lengths, e_g_n_k_wos_strides)
772 : e_g_n_k_wos_strides},
773 conv_filter_strides_{conv_filter_strides},
774 conv_filter_dilations_{conv_filter_dilations},
775 input_left_pads_{input_left_pads},
776 input_right_pads_{input_right_pads},
797 GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
799 GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
803 GridwiseGemmCTranspose64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
806 a_element_op_{a_element_op},
807 b_element_op_{b_element_op},
808 cde_element_op_{cde_element_op}
809 {
810 // A/B/E Batch Stride
811 if constexpr(isMultiA || isMultiB)
812 {
813 static_for<0, NumATensor, 1>{}([&](auto i) {
814 // Init compute_ptr_offset_of_groups_ for multiple AB
815 compute_ptr_offset_of_groups_.BatchStrideA_(i) =
816 a_g_n_c_wis_strides_[0] * NumGroupsToMerge;
817
818 // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
819 // type is not tuple)
820 using DataType = remove_cvref_t<tuple_element_t<i.value, GemmADataType>>;
821 // It is possible that one of the AB is a pointer and one is a tuple.
822 // Then also use multiAB but we have to cast single pointer instead of tuple of
823 // pointer.
824 if constexpr(isMultiA)
825 {
826 // p_as is tuple
827 p_as_grid_(i) = static_cast<const DataType*>(p_as[i.value]);
828 // compute_ptr_offset_of_n_ not need BatchStrideB so
829 // in case of MultiA is false but isMultiB is true
830 // BatchStrideA_ is not tuple.
831 compute_ptr_offset_of_n_.BatchStrideA_(i) =
833 }
834 else
835 {
836 // if MultiB and not MultiA then p_as is single pointer
837 p_as_grid_(i) = static_cast<const DataType*>(p_as);
838 compute_ptr_offset_of_n_.BatchStrideA_ =
840 }
841 });
842 static_for<0, NumBTensor, 1>{}([&](auto i) {
843 // Init compute_ptr_offset_of_groups_ for multiple AB
844 compute_ptr_offset_of_groups_.BatchStrideB_(i) =
845 b_g_k_c_xs_strides_[0] * NumGroupsToMerge;
846
847 using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
848 // It is possible that one of the AB is a pointer and one is a tuple.
849 // Then also use multiAB but we have to cast single pointer instead of tuple of
850 // pointer.
851 if constexpr(isMultiB)
852 {
853 // p_bs is tuple
854 p_bs_grid_(i) = static_cast<const DataType*>(p_bs[i.value]);
855 }
856 else
857 {
858 // if MultiA and not MultiB then p_bs is single pointer
859 p_bs_grid_(i) = static_cast<const DataType*>(p_bs);
860 }
861 });
862 }
863 else
864 {
865 compute_ptr_offset_of_groups_.BatchStrideA_ =
866 a_g_n_c_wis_strides_[0] * NumGroupsToMerge;
867 compute_ptr_offset_of_groups_.BatchStrideB_ =
868 b_g_k_c_xs_strides_[0] * NumGroupsToMerge;
869 compute_ptr_offset_of_n_.BatchStrideA_ =
871
872 // p_as and p_bs are pointers
873 p_as_grid_(I0) = static_cast<const ADataType*>(p_as);
874 p_bs_grid_(I0) = static_cast<const BDataType*>(p_bs);
875 }
876
877 // populate pointer, batch stride, desc for Ds
878 static_for<0, NumDTensor, 1>{}([&](auto i) {
879 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
880 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
881
882 // D pointer
883 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
884
885 // D batch stride
886 compute_ptr_offset_of_groups_.BatchStrideDs_(i) =
887 ds_g_n_k_wos_strides_[i][0] * NumGroupsToMerge;
888 compute_ptr_offset_of_n_.BatchStrideDs_(i) =
890
891 ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths_,
901
902 // D desc
904 DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
905 });
906 compute_ptr_offset_of_groups_.BatchStrideE_ =
907 e_g_n_k_wos_strides_[0] * NumGroupsToMerge;
909
910 if(get_warp_size() == 64)
911 {
912 if constexpr(NXdlPerWave64 > 0)
913 {
915 }
916 }
917 else
918 {
919 if constexpr(NXdlPerWave32 > 0)
920 {
922 }
923 }
924 if constexpr(NeedTransposeKernel)
925 {
926 // Use not modified base strides
928 conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
929 a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
931 conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
932 a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
933
935 conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
936 b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
938 conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
939 b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
940
942 conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
943 e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
945 conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
946 e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
947
949 a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
951 b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
953 e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
954 }
955 }
956
958 {
959 if constexpr(NeedTransposeKernel)
960 {
962 a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
963 // Align to 128B
964 return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128;
965 }
966 else
967 {
968 return 0;
969 }
970 }
971
973 {
974 if constexpr(NeedTransposeKernel)
975 {
977 b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
978 // Align to 128B
979 return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128;
980 }
981 else
982 {
983 return 0;
984 }
985 }
986
988 {
989 if constexpr(NeedTransposeKernel)
990 {
992 e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
993 return sizeof(EDataType) * e_accum;
994 }
995 else
996 {
997 return 0;
998 }
999 }
1000
1006
1007 void Print() const
1008 {
1009 std::cout << "AComputeDataType: " << get_type_name<AComputeDataType>()
1010 << "; BComputeDataType: " << get_type_name<BComputeDataType>()
1011 << "; EDataType: " << get_type_name<EDataType>() << std::endl;
1012
1013 std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
1014 std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
1016 [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
1017 std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
1018
1019 std::cout << "a grid desc" << a_grid_desc_ak0_m_ak1_ << std::endl;
1020 std::cout << "b grid desc" << b_grid_desc_bk0_n_bk1_ << std::endl;
1021 std::cout << "e grid desc" << e_grid_desc_mblock_mperblock_nblock_nperblock_
1022 << std::endl;
1023 }
1024
1025 // private:
1026 // pointers (tuple if multi AB, pointer if no)
1029 typename GridwiseGemm64::DsGridPointer p_ds_grid_;
1030 EDataType* p_e_grid_;
1031
1032 // for checking IsSupportedArgument()
1033 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
1034 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
1035 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
1036 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
1037 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
1038 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
1039 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
1040 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
1041 std::array<index_t, NDimSpatial> conv_filter_strides_;
1042 std::array<index_t, NDimSpatial> conv_filter_dilations_;
1043 std::array<index_t, NDimSpatial> input_left_pads_;
1044 std::array<index_t, NDimSpatial> input_right_pads_;
1045
1046 // tensor descriptors for problem definiton
1048
1050
1052
1057
1058 // tensor descriptors for block/thread-wise copy
1064
1065 // block-to-e-tile map
1069
1074
1075 // for computing batch offset
1076 ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
1078 ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor> compute_ptr_offset_of_n_;
1079
1080 // element-wise op
1081 AElementwiseOperation a_element_op_;
1082 BElementwiseOperation b_element_op_;
1083 CDEElementwiseOperation cde_element_op_;
1084 };
1085
1086 // Invoker
1087 struct Invoker : public BaseInvoker
1088 {
1090
1091 template <typename GridwiseGemm, typename GridwiseGemmCTranspose>
1092 float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
1093 {
1094 if(stream_config.log_level_ > 0)
1095 {
1096 arg.Print();
1097 }
1098
1099 const index_t num_workgroups_per_Conv_N =
1101
1102 const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
1103 const index_t gdy = arg.num_group_ / NumGroupsToMerge;
1104 const index_t gdz = num_workgroups_per_Conv_N;
1105
1106 const auto K =
1107 arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
1108
1109 auto launch_kernel = [&](auto has_main_k_block_loop) {
1110 constexpr bool has_main_loop = has_main_k_block_loop.value;
1111
1112 if constexpr(isMultiA || isMultiB)
1113 {
1114 // Generate tuples with grid descriptors for each A and B
1115 const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
1116 [&](auto) { return arg.a_grid_desc_ak0_m_ak1_; }, Number<NumATensor>{});
1117 const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
1118 [&](auto) { return arg.b_grid_desc_bk0_n_bk1_; }, Number<NumBTensor>{});
1119
1120 const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
1121 GridwiseGemm,
1124 typename GridwiseGemm::DsGridPointer,
1125 EDataType,
1126 AElementwiseOperation,
1127 BElementwiseOperation,
1128 CDEElementwiseOperation,
1129 decltype(as_grid_desc_ak0_m_ak1),
1130 decltype(bs_grid_desc_bk0_n_bk1),
1134 ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
1135 ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
1136 has_main_loop,
1137 isMultiA,
1138 isMultiB,
1139 CTranspose>;
1140
1142 stream_config,
1143 kernel,
1144 dim3(gdx, gdy, gdz),
1145 dim3(BlockSize),
1146 0,
1147 arg.p_as_grid_,
1148 arg.p_bs_grid_,
1149 arg.p_ds_grid_,
1150 arg.p_e_grid_,
1151 arg.a_element_op_,
1152 arg.b_element_op_,
1153 arg.cde_element_op_,
1154 as_grid_desc_ak0_m_ak1,
1155 bs_grid_desc_bk0_n_bk1,
1161 }
1162 else
1163 {
1164 const ADataType* p_a_grid = arg.p_as_grid_.At(I0);
1165 const BDataType* p_b_grid = arg.p_bs_grid_.At(I0);
1166 EDataType* p_e_grid = arg.p_e_grid_;
1167 if constexpr(NeedTransposeKernel)
1168 {
1171 {
1174 arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
1175 p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
1178 sizeof(EDataType);
1179 }
1182 {
1184 p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
1187 sizeof(EDataType);
1188 }
1189 }
1190
1191 if constexpr(CTranspose)
1192 {
1193 const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
1194 GridwiseGemmCTranspose,
1195 const BDataType*,
1196 const ADataType*,
1197 typename GridwiseGemm::DsGridPointer,
1198 EDataType,
1199 BElementwiseOperation,
1200 AElementwiseOperation,
1201 CDEElementwiseOperation,
1207 ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
1208 ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
1209 has_main_loop,
1210 isMultiA,
1211 isMultiB,
1212 CTranspose>;
1214 stream_config,
1215 kernel,
1216 dim3(gdx, gdy, gdz),
1217 dim3(BlockSize),
1218 0,
1219 p_b_grid,
1220 p_a_grid,
1221 arg.p_ds_grid_,
1222 p_e_grid,
1223 arg.b_element_op_,
1224 arg.a_element_op_,
1225 arg.cde_element_op_,
1233 }
1234 else
1235 {
1236 const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
1237 GridwiseGemm,
1238 const ADataType*,
1239 const BDataType*,
1240 typename GridwiseGemm::DsGridPointer,
1241 EDataType,
1242 AElementwiseOperation,
1243 BElementwiseOperation,
1244 CDEElementwiseOperation,
1250 ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
1251 ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
1252 has_main_loop,
1253 isMultiA,
1254 isMultiB,
1255 CTranspose>;
1256
1258 stream_config,
1259 kernel,
1260 dim3(gdx, gdy, gdz),
1261 dim3(BlockSize),
1262 0,
1263 p_a_grid,
1264 p_b_grid,
1265 arg.p_ds_grid_,
1266 p_e_grid,
1267 arg.a_element_op_,
1268 arg.b_element_op_,
1269 arg.cde_element_op_,
1277 }
1278 }
1279 };
1280
1281 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
1282 {
1283 return launch_kernel(integral_constant<bool, true>{});
1284 }
1285 else
1286 {
1287 return launch_kernel(integral_constant<bool, false>{});
1288 }
1289 }
1290
1291 template <typename GridwiseGemm, typename GridwiseGemmCTranspose>
1292 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
1293 {
1294 float avg_time = 0.f;
1295 if constexpr(NeedTransposeKernel)
1296 {
1297 const index_t a_grid_size =
1298 arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
1300 const index_t b_grid_size =
1303 ? arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
1305 : 0; // Dont run transpose B if not needed
1306
1307 ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
1308 BDataType* p_b_out_grid = type_convert<BDataType*>(arg.p_workspace_) +
1309 arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
1310
1324
1325 avg_time += launch_and_time_kernel(stream_config,
1326 kernel_transpose,
1327 dim3(a_grid_size + b_grid_size),
1329 0,
1334 make_tuple(arg.p_as_grid_.At(I0)),
1335 make_tuple(arg.p_bs_grid_.At(I0)),
1336 make_tuple(p_a_out_grid),
1337 make_tuple(p_b_out_grid),
1341 a_grid_size);
1342 }
1343
1344 avg_time += RunGemm<GridwiseGemm, GridwiseGemmCTranspose>(arg, stream_config);
1345
1346 if constexpr(NeedTransposeKernel)
1347 {
1348 const index_t grid_size =
1349 arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
1351
1352 const EDataType* p_e_in_grid =
1355 sizeof(EDataType);
1356
1357 EDataType* p_e_out_grid = arg.p_e_grid_;
1358
1366
1367 avg_time += launch_and_time_kernel(stream_config,
1368 kernel_transpose,
1369 dim3(grid_size),
1371 0,
1374 make_tuple(p_e_in_grid),
1375 make_tuple(p_e_out_grid),
1378 }
1379
1380 return avg_time;
1381 }
1382
1383 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
1384 {
1385 if(get_warp_size() == 64)
1386 {
1387 if constexpr(NXdlPerWave64 > 0)
1388 {
1389 return RunImp<GridwiseGemm64, GridwiseGemmCTranspose64>(arg, stream_config);
1390 }
1391 else
1392 {
1393 return 0;
1394 }
1395 }
1396 else
1397 {
1398 if constexpr(NXdlPerWave32 > 0)
1399 {
1400 return RunImp<GridwiseGemm32, GridwiseGemmCTranspose32>(arg, stream_config);
1401 }
1402 else
1403 {
1404 return 0;
1405 }
1406 }
1407 }
1408 float Run(const BaseArgument* p_arg,
1409 const StreamConfig& stream_config = StreamConfig{}) override
1410 {
1411 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
1412 }
1413 };
1414
1415 static bool IsSupportedArgument(const Argument& arg)
1416 {
1417 namespace ctc = tensor_layout::convolution;
1418
1419 const index_t G = arg.b_g_k_c_xs_lengths_[I0];
1420 const index_t K = arg.b_g_k_c_xs_lengths_[I1];
1421 const index_t C = arg.b_g_k_c_xs_lengths_[I2];
1422 const index_t input_spatial_acum = ck::accumulate_n<index_t>(
1423 arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
1424
1425 // check device
1426 if(get_device_name() == "gfx908")
1427 {
1428 // FIXME: re-enable fp64 when SWDEV-335738 is fixed
1430 {
1431 return false;
1432 }
1433 }
1434
1436 {
1437 return false;
1438 }
1439
1440 // check ConvolutionForwardSpecialization
1441 if constexpr(ConvForwardSpecialization ==
1443 {
1444 // check if it's 1x1, stride=1 conv
1445 for(index_t i = 0; i < NDimSpatial; ++i)
1446 {
1447 const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
1448 const index_t ConvStride = arg.conv_filter_strides_[i];
1449 const index_t LeftPad = arg.input_left_pads_[i];
1450 const index_t RightPad = arg.input_right_pads_[i];
1451
1452 if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
1453 {
1454 return false;
1455 }
1456 }
1457 }
1458 else if constexpr(ConvForwardSpecialization ==
1460 {
1461 // check if it's 1x1 conv
1462 for(index_t i = 0; i < NDimSpatial; ++i)
1463 {
1464 const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
1465 const index_t LeftPad = arg.input_left_pads_[i];
1466 const index_t RightPad = arg.input_right_pads_[i];
1467
1468 if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
1469 {
1470 return false;
1471 }
1472 }
1473 }
1474 else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3)
1475 {
1476 if(C != 1)
1477 {
1478 return false;
1479 }
1480 for(index_t i = 0; i < NDimSpatial; ++i)
1481 {
1482 const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3];
1483
1484 if(filter_spatial_dim != I3)
1485 {
1486 return false;
1487 }
1488 }
1489 }
1490
1491 if constexpr(NumGroupsToMerge > 1)
1492 {
1493 if(!(C == 1))
1494 {
1495 return false;
1496 }
1497 if(G % NumGroupsToMerge != 0)
1498 {
1499 return false;
1500 }
1505 {
1506 return false;
1507 }
1508 }
1509
1510 // check vector access of A
1511 // FIXME: layout
1518 {
1519 // Check access per C
1520 if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
1521 {
1522 // If not possible, check access per G
1523 if(!(ABlockTransferSrcVectorDim == 1 && (C == 1 || NumGroupsToMerge == 1) &&
1527 G % ABlockTransferSrcScalarPerVector == 0))
1528 {
1529 return false;
1530 }
1531 }
1532 }
1534 {
1535 static_assert(NeedTransposeKernel == false);
1536 static_assert(NumGroupsToMerge == 1);
1537
1538 if constexpr(ABlockTransferSrcScalarPerVector != 1)
1539 {
1540 if(ABlockTransferSrcVectorDim != 1)
1541 {
1542 return false;
1543 }
1544 if(input_spatial_acum % ABlockTransferSrcScalarPerVector != 0)
1545 {
1546 return false;
1547 }
1548 }
1549 }
1550 else
1551 {
1552 return false;
1553 }
1554
1555 // check vector access of B
1556 // FIXME: layout
1563
1564 {
1565 if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
1566 {
1567 return false;
1568 }
1569 }
1570 else
1571 {
1572 return false;
1573 }
1574 // check vector access of Ds
1575 bool valid = true;
1576
1577 static_for<0, NumDTensor, 1>{}([&](auto i) {
1578 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
1579
1580 // FIXME: layout
1586 {
1587 if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
1588 {
1589 valid = false;
1590 }
1591
1592 if constexpr(is_same_v<DLayout, ctc::G_K>)
1593 {
1594 // G and K must be the same
1595 if(arg.ds_g_n_k_wos_lengths_[i][0] != arg.e_g_n_k_wos_lengths_[0] ||
1596 arg.ds_g_n_k_wos_lengths_[i][2] != arg.e_g_n_k_wos_lengths_[2])
1597 {
1598 valid = false;
1599 }
1600 }
1601 else
1602 {
1603 // E and D must have the same shape
1604 for(index_t d = 0; d < NDimSpatial + 3; d++)
1605 {
1606 if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d])
1607 {
1608 valid = false;
1609 }
1610 }
1611 }
1612 }
1613 else
1614 {
1615 valid = false;
1616 }
1617 });
1618
1619 if constexpr(NeedTransposeKernel)
1620 {
1621 if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1622 {
1623 return false;
1624 }
1625
1626 if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1627 {
1628 return false;
1629 }
1630
1631 const index_t output_spatial_acum = ck::accumulate_n<index_t>(
1632 arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
1633
1634 if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1635 {
1636 return false;
1637 }
1638
1639 if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1640 {
1641 return false;
1642 }
1643
1644 if(!arg.p_workspace_)
1645 {
1646 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1647 {
1648 std::cout << "Warning: Workspace for "
1649 "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument is not "
1650 "allocated, use SetWorkSpacePointer."
1651 << std::endl;
1652 }
1653 return false;
1654 }
1655
1656 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
1657 if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
1658 arg.e_in_transpose_desc_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
1659 {
1660 return false;
1661 }
1662 }
1663
1664 if(!valid)
1665 {
1666 return false;
1667 }
1668
1669 // check vector access of E
1676 {
1677 if(CTranspose == false)
1678 {
1679 if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
1680 {
1681 return false;
1682 }
1683 }
1684 else
1685 {
1686 const index_t output_spatial_acum = ck::accumulate_n<index_t>(
1687 arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
1688
1689 if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1690 {
1691 return false;
1692 }
1693 }
1694 }
1695 else
1696 {
1697 return false;
1698 }
1701 {
1702 if(!is_tf32_supported())
1703 {
1704 return false;
1705 }
1707 {
1708 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1709 {
1710 std::cout << "ComputeDataType for A and B should be same while using TF32"
1711 << std::endl;
1712 }
1713 return false;
1714 }
1715 }
1716 // check Gridwise GEMM
1717 if(get_warp_size() == 64)
1718 {
1719 if constexpr(NXdlPerWave64 > 0)
1720 {
1721 if constexpr(isMultiA || isMultiB)
1722 {
1723 // Genarate tuples with the same descriptors
1724 const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
1725 [&](auto) { return arg.a_grid_desc_m_k_; }, Number<NumATensor>{});
1726 const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
1727 [&](auto) { return arg.b_grid_desc_n_k_; }, Number<NumBTensor>{});
1728 return GridwiseGemm64::CheckValidity(as_grid_desc_ak0_m_ak1,
1729 bs_grid_desc_bk0_n_bk1,
1730 arg.ds_grid_desc_m_n_,
1731 arg.e_grid_desc_m_n_,
1732 arg.block_2_etile_map_);
1733 }
1734 else
1735 {
1736 if constexpr(CTranspose)
1737 {
1738 return GridwiseGemmCTranspose64::CheckValidity(arg.b_grid_desc_n_k_,
1739 arg.a_grid_desc_m_k_,
1740 arg.ds_grid_desc_m_n_,
1741 arg.e_grid_desc_m_n_,
1742 arg.block_2_etile_map_);
1743 }
1744 else
1745 {
1746 return GridwiseGemmCTranspose64::CheckValidity(arg.a_grid_desc_m_k_,
1747 arg.b_grid_desc_n_k_,
1748 arg.ds_grid_desc_m_n_,
1749 arg.e_grid_desc_m_n_,
1750 arg.block_2_etile_map_);
1751 }
1752 }
1753 }
1754 }
1755 else
1756 {
1757
1758 if constexpr(NXdlPerWave32 > 0)
1759 {
1760 if constexpr(isMultiA || isMultiB)
1761 {
1762 // Genarate tuples with the same descriptors
1763 const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
1764 [&](auto) { return arg.a_grid_desc_m_k_; }, Number<NumATensor>{});
1765 const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
1766 [&](auto) { return arg.b_grid_desc_n_k_; }, Number<NumBTensor>{});
1767 return GridwiseGemm32::CheckValidity(as_grid_desc_ak0_m_ak1,
1768 bs_grid_desc_bk0_n_bk1,
1769 arg.ds_grid_desc_m_n_,
1770 arg.e_grid_desc_m_n_,
1771 arg.block_2_etile_map_);
1772 }
1773 else
1774 {
1775 if constexpr(CTranspose)
1776 {
1777 return GridwiseGemmCTranspose32::CheckValidity(arg.b_grid_desc_n_k_,
1778 arg.a_grid_desc_m_k_,
1779 arg.ds_grid_desc_m_n_,
1780 arg.e_grid_desc_m_n_,
1781 arg.block_2_etile_map_);
1782 }
1783 else
1784 {
1785 return GridwiseGemmCTranspose32::CheckValidity(arg.a_grid_desc_m_k_,
1786 arg.b_grid_desc_n_k_,
1787 arg.ds_grid_desc_m_n_,
1788 arg.e_grid_desc_m_n_,
1789 arg.block_2_etile_map_);
1790 }
1791 }
1792 }
1793 }
1794
1795 return false;
1796 }
1797
1798 bool IsSupportedArgument(const BaseArgument* p_arg) override
1799 {
1800 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1801 }
1802
1803 static auto MakeArgument(
1804 APointers p_as,
1805 BPointers p_bs,
1806 const std::array<const void*, NumDTensor>& p_ds,
1807 void* p_e,
1808 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1809 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1810 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1811 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1812 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
1813 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
1814 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1815 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1816 const std::array<index_t, NDimSpatial>& conv_filter_strides,
1817 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
1818 const std::array<index_t, NDimSpatial>& input_left_pads,
1819 const std::array<index_t, NDimSpatial>& input_right_pads,
1820 const AElementwiseOperation& a_element_op,
1821 const BElementwiseOperation& b_element_op,
1822 const CDEElementwiseOperation& cde_element_op)
1823 {
1824 return Argument{p_as,
1825 p_bs,
1826 p_ds,
1827 p_e,
1828 a_g_n_c_wis_lengths,
1829 a_g_n_c_wis_strides,
1830 b_g_k_c_xs_lengths,
1831 b_g_k_c_xs_strides,
1832 ds_g_n_k_wos_lengths,
1833 ds_g_n_k_wos_strides,
1834 e_g_n_k_wos_lengths,
1835 e_g_n_k_wos_strides,
1836 conv_filter_strides,
1837 conv_filter_dilations,
1838 input_left_pads,
1839 input_right_pads,
1840 a_element_op,
1841 b_element_op,
1842 cde_element_op};
1843 }
1844
1845 static auto
1847 BPointers p_bs,
1848 const std::array<const void*, NumDTensor>& p_ds,
1849 void* p_e,
1850 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1851 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1852 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1853 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1854 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
1855 ds_g_n_k_wos_lengths,
1856 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
1857 ds_g_n_k_wos_strides,
1858 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1859 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1860 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
1861 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
1862 const std::array<long_index_t, NDimSpatial>& input_left_pads,
1863 const std::array<long_index_t, NDimSpatial>& input_right_pads,
1864 const AElementwiseOperation& a_element_op,
1865 const BElementwiseOperation& b_element_op,
1866 const CDEElementwiseOperation& cde_element_op)
1867 {
1868 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
1869 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
1870 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
1871 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
1872 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
1873 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
1874 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
1875 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
1876 std::array<index_t, NDimSpatial> conv_filter_strides_i32;
1877 std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
1878 std::array<index_t, NDimSpatial> input_left_pads_i32;
1879 std::array<index_t, NDimSpatial> input_right_pads_i32;
1880
1881 array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
1882 array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
1883 array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
1884 array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
1885 for(index_t d = 0; d < NumDTensor; d++)
1886 {
1887 array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
1888 array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
1889 }
1890 array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
1891 array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
1892 array_convert(conv_filter_strides_i32, conv_filter_strides);
1893 array_convert(conv_filter_dilations_i32, conv_filter_dilations);
1894 array_convert(input_left_pads_i32, input_left_pads);
1895 array_convert(input_right_pads_i32, input_right_pads);
1896
1897 return Argument{p_as,
1898 p_bs,
1899 p_ds,
1900 p_e,
1901 a_g_n_c_wis_lengths_i32,
1902 a_g_n_c_wis_strides_i32,
1903 b_g_k_c_xs_lengths_i32,
1904 b_g_k_c_xs_strides_i32,
1905 ds_g_n_k_wos_lengths_i32,
1906 ds_g_n_k_wos_strides_i32,
1907 e_g_n_k_wos_lengths_i32,
1908 e_g_n_k_wos_strides_i32,
1909 conv_filter_strides_i32,
1910 conv_filter_dilations_i32,
1911 input_left_pads_i32,
1912 input_right_pads_i32,
1913 a_element_op,
1914 b_element_op,
1915 cde_element_op};
1916 }
1917
1918 static auto MakeInvoker() { return Invoker{}; }
1919
1920 std::unique_ptr<BaseArgument> MakeArgumentPointer(
1921 APointers p_as,
1922 BPointers p_bs,
1923 const std::array<const void*, NumDTensor>& p_ds,
1924 void* p_e,
1925 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1926 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1927 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1928 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1929 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
1930 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
1931 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1932 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1933 const std::array<index_t, NDimSpatial>& conv_filter_strides,
1934 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
1935 const std::array<index_t, NDimSpatial>& input_left_pads,
1936 const std::array<index_t, NDimSpatial>& input_right_pads,
1937 const AElementwiseOperation& a_element_op,
1938 const BElementwiseOperation& b_element_op,
1939 const CDEElementwiseOperation& cde_element_op) override
1940 {
1941 return std::make_unique<Argument>(p_as,
1942 p_bs,
1943 p_ds,
1944 p_e,
1945 a_g_n_c_wis_lengths,
1946 a_g_n_c_wis_strides,
1947 b_g_k_c_xs_lengths,
1948 b_g_k_c_xs_strides,
1949 ds_g_n_k_wos_lengths,
1950 ds_g_n_k_wos_strides,
1951 e_g_n_k_wos_lengths,
1952 e_g_n_k_wos_strides,
1953 conv_filter_strides,
1954 conv_filter_dilations,
1955 input_left_pads,
1956 input_right_pads,
1957 a_element_op,
1958 b_element_op,
1959 cde_element_op);
1960 }
1961
1962 std::unique_ptr<BaseArgument>
1964 BPointers p_bs,
1965 const std::array<const void*, NumDTensor>& p_ds,
1966 void* p_e,
1967 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1968 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1969 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1970 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1971 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
1972 ds_g_n_k_wos_lengths,
1973 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
1974 ds_g_n_k_wos_strides,
1975 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1976 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1977 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
1978 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
1979 const std::array<long_index_t, NDimSpatial>& input_left_pads,
1980 const std::array<long_index_t, NDimSpatial>& input_right_pads,
1981 const AElementwiseOperation& a_element_op,
1982 const BElementwiseOperation& b_element_op,
1983 const CDEElementwiseOperation& cde_element_op) override
1984 {
1985
1986 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
1987 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
1988 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
1989 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
1990 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
1991 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
1992 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
1993 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
1994 std::array<index_t, NDimSpatial> conv_filter_strides_i32;
1995 std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
1996 std::array<index_t, NDimSpatial> input_left_pads_i32;
1997 std::array<index_t, NDimSpatial> input_right_pads_i32;
1998
1999 array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
2000 array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
2001 array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
2002 array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
2003 for(index_t d = 0; d < NumDTensor; d++)
2004 {
2005 array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
2006 array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
2007 }
2008 array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
2009 array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
2010 array_convert(conv_filter_strides_i32, conv_filter_strides);
2011 array_convert(conv_filter_dilations_i32, conv_filter_dilations);
2012 array_convert(input_left_pads_i32, input_left_pads);
2013 array_convert(input_right_pads_i32, input_right_pads);
2014
2015 return std::make_unique<Argument>(p_as,
2016 p_bs,
2017 p_ds,
2018 p_e,
2019 a_g_n_c_wis_lengths_i32,
2020 a_g_n_c_wis_strides_i32,
2021 b_g_k_c_xs_lengths_i32,
2022 b_g_k_c_xs_strides_i32,
2023 ds_g_n_k_wos_lengths_i32,
2024 ds_g_n_k_wos_strides_i32,
2025 e_g_n_k_wos_lengths_i32,
2026 e_g_n_k_wos_strides_i32,
2027 conv_filter_strides_i32,
2028 conv_filter_dilations_i32,
2029 input_left_pads_i32,
2030 input_right_pads_i32,
2031 a_element_op,
2032 b_element_op,
2033 cde_element_op);
2034 }
2035
2036 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
2037 {
2038 return std::make_unique<Invoker>(Invoker{});
2039 }
2040
2041 std::string GetTypeString() const override
2042 {
2043 auto str = std::stringstream();
2044
2045 // clang-format off
2046 str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
2047 << "<"
2048 << BlockSize << ", "
2049 << MPerBlock << ", "
2050 << NPerBlock << ", "
2051 << KPerBlock << ", "
2052 << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
2053 << MPerXDL << ", "
2054 << NPerXDL << ", "
2055 << MXdlPerWave << ", "
2056 << NXdlPerWave << ", "
2057 << ABlockTransferSrcScalarPerVector << ", "
2058 << BBlockTransferSrcScalarPerVector << ", "
2059 << CDEBlockTransferScalarPerVector_NPerBlock << ", "
2060 << CShuffleMXdlPerWavePerShuffle << ", "
2061 << CShuffleNXdlPerWavePerShuffle << ", "
2062 << NumGroupsToMerge
2063 << ">";
2064 // clang-format on
2065
2066 return str.str();
2067 }
2068
2069#ifdef CK_EXPERIMENTAL_BUILDER
2070 std::string GetInstanceString() const override
2071 {
2072 static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
2073 "Specialization of instance_traits not found. Please check that a "
2074 "specialization exists in file "
2075 "ck_tile/builder/reflect/"
2076 "instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp "
2077 "for the given template parameters.");
2078 return ck_tile::reflect::instance_string<DeviceOp>();
2079 }
2080#endif
2081
2082 size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
2083 {
2084 auto arg = dynamic_cast<const Argument*>(p_arg);
2085 if(arg)
2086 {
2087 return arg->GetWorkspaceSizeBytes();
2088 }
2089 else
2090 throw std::runtime_error(
2091 "The argument pointer is not an object of "
2092 "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
2093 }
2094
2096 void* p_workspace,
2097 const StreamConfig& = StreamConfig{}) const override
2098 {
2099 auto p_arg_ = dynamic_cast<Argument*>(p_arg);
2100 if(p_arg_)
2101 {
2102 p_arg_->p_workspace_ = p_workspace;
2103 }
2104 else
2105 throw std::runtime_error(
2106 "The argument pointer is not an object of "
2107 "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
2108 }
2109};
2110
2111} // namespace device
2112} // namespace tensor_operation
2113} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition tensor_operation/gpu/device/tensor_layout.hpp:42
Definition convolution_backward_data_specialization.hpp:8
constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK()
Definition device_grouped_conv_utils.hpp:119
constexpr bool is_NGCSpatial_GKSpatial_NGKSpatial()
Definition device_grouped_conv_utils.hpp:135
GemmSpecialization
Definition gemm_specialization.hpp:11
constexpr bool is_NGCHW_GKYXC_NGKHW()
Definition device_grouped_conv_utils.hpp:56
constexpr bool is_NGCDHW_NGKDHW()
Definition device_grouped_conv_utils.hpp:112
constexpr bool is_NGCHW_GKCYX_NGKHW()
Definition device_grouped_conv_utils.hpp:64
decltype(std::declval< T & >().IsTuple()) is_tuple
Definition device_grouped_conv_fwd_multiple_abd.hpp:23
constexpr bool is_NGCDHW_GKZYXC_NGKDHW()
Definition device_grouped_conv_utils.hpp:96
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter3x3
Definition convolution_forward_specialization.hpp:20
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization &s)
Definition convolution_forward_specialization.hpp:24
constexpr bool is_NGCDHW_GKCZYX_NGKDHW()
Definition device_grouped_conv_utils.hpp:104
constexpr bool is_NGCHW_NGKHW()
Definition device_grouped_conv_utils.hpp:72
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
const char * get_type_name()
Definition data_type.hpp:468
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
bool is_tf32_supported()
Definition host_utility/device_prop.hpp:132
__host__ __device__ void array_convert(std::array< Y, NumElems > &y, const std::array< X, NumElems > &x)
Definition utility/type_convert.hpp:2466
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op) -> decltype(std::accumulate(first, std::next(first, count), init, op))
Definition library/utility/numeric.hpp:11
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__global__ void kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size)
Definition gridwise_elementwise_2d.hpp:61
int64_t long_index_t
Definition ck.hpp:300
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition gridwise_elementwise_2d.hpp:29
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Layout wrapper that performs the tensor descriptor logic.
Definition layout.hpp:24
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:261
Definition gridwise_elementwise_2d.hpp:278
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:77
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
Definition multi_index_transform.hpp:196
Definition multi_index_transform.hpp:284
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:117
Definition functional2.hpp:33
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:25
Definition transform_conv_ngchw_to_nhwgc.hpp:31
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
virtual std::string GetInstanceString() const
Definition device_base.hpp:230
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:674
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1035
Argument(APointers p_as, BPointers p_bs, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:731
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1036
EGridDesc_M_N e_grid_desc_m_n_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1056
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1040
NGCHWTransposeDescType a_in_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1070
std::array< index_t, NDimSpatial > input_right_pads_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1044
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1060
AElementwiseOperation a_element_op_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1081
ComputePtrOffsetOfStridedBatch< NumATensor, I1, NumDTensor > compute_ptr_offset_of_n_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1078
void Print() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1007
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1038
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_e_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1068
BGridPointer p_bs_grid_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1028
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1034
CDEElementwiseOperation cde_element_op_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1083
std::size_t GetWorkspaceETensorSizeBytes() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:987
NHWGCTransposeDescType e_in_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1071
AGridDesc_M_K a_grid_desc_m_k_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1053
std::size_t GetWorkspaceSizeBytes() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1001
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_lengths_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1033
GKYXCTransposeDescType b_out_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1073
ComputePtrOffsetOfStridedBatch< NumATensor, NumBTensor, NumDTensor > compute_ptr_offset_of_groups_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1077
index_t num_group_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1047
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1059
BElementwiseOperation b_element_op_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1082
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1041
AGridPointer p_as_grid_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1027
NGCHWTransposeDescType e_out_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1070
std::size_t GetWorkspaceBTensorSizeBytes() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:972
void InitGridDesc()
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:676
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1029
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1063
Block2ETileMap block_2_etile_map_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1066
GKCYXTransposeDescType b_in_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1072
NHWGCTransposeDescType a_out_transpose_desc_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1071
std::array< index_t, NDimSpatial > input_left_pads_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1043
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1055
BGridDesc_N_K b_grid_desc_n_k_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1054
std::array< index_t, NDimSpatial > conv_filter_dilations_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1042
index_t conv_N_per_block_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1051
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1039
ConvToGemmFwdTransformer conv_to_gemm_transformer_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1049
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1062
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_b_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1068
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1037
EDataType * p_e_grid_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1030
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1067
std::size_t GetWorkspaceATensorSizeBytes() const
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:957
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1088
DeviceOp::Argument Argument
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1089
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1292
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1383
float RunGemm(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1092
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1408
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:325
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1415
BlockToCTileMap_M00_N0_M01Adapt< NPerBlock, NPerBlock > Block2TileMapElementwise
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:595
std::unique_ptr< BaseArgument > MakeArgumentPointer(APointers p_as, BPointers p_bs, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1963
remove_cvref_t< decltype(MakeBGridDescriptor_N_K< BLayout >(dummy_conv_to_gemm_transformer))> BGridDesc_N_K
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:476
remove_cvref_t< decltype(GridwiseGemmCTranspose64::MakeDefaultBlock2ETileMap( EGridDesc_M_N{}))> Block2ETileMap
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:592
std::conditional_t<!isMultiA &&isMultiB, Tuple< ADataType >, ADataType > GemmADataType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:485
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNGCHWTransposeDesc< NDimSpatial >({}, {}))> NGCHWTransposeDescType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:597
remove_cvref_t< decltype(GetBGridPointer< isMultiA||isMultiB, GridwiseGemm64, BDataType >())> BGridPointer
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:574
static auto MakeArgument(APointers p_as, BPointers p_bs, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1846
TransformConvFwdToGemm< NDimSpatial, ConvForwardSpecialization, true, ADataType, EDataType, NumGroupsToMerge, index_t, CTranspose > ConvToGemmFwdTransformer
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:371
GridwiseElementwise< Tuple< GKCYXTransposeDescType >, Tuple< GKYXCTransposeDescType >, Tuple< const BDataType * >, Tuple< BDataType * >, Block2TileMapElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, NPerBlock, NPerBlock/ClusterLengthNPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< 1 >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, I0, I1 > GridwiseElementwiseWeightTranspose
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:636
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:435
remove_cvref_t< decltype(GetAGridPointer< isMultiA||isMultiB, GridwiseGemm64, ADataType >())> AGridPointer
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:572
std::conditional_t<!isMultiB &&isMultiA, Tuple< BDataType >, BDataType > GemmBDataType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:486
GridwiseGemmMultipleABD_xdl_cshuffle< GridwiseGemmMultiABDTemplateParameters > GridwiseGemmMultipleABDBase
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:539
std::conditional_t< CTranspose, GridwiseGemmMultipleDCTransposeBase< math::max(NXdlPerWave64, 1)>, GridwiseGemm64 > GridwiseGemmCTranspose64
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:556
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:415
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:2082
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:461
remove_cvref_t< decltype(MakeAGridDescriptor_M_K< ALayout >(dummy_conv_to_gemm_transformer))> AGridDesc_M_K
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:474
remove_cvref_t< decltype(GridwiseGemmCTranspose64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:587
std::conditional_t< isMultiA||isMultiB, GridwiseGemmMultipleABDBase< math::max(NXdlPerWave64, 1)>, GridwiseGemmMultipleDBase< math::max(NXdlPerWave64, 1)> > GridwiseGemm64
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:548
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:2036
GridwiseGemmMultipleD_xdl_cshuffle< GridwiseGemmTemplateParameters > GridwiseGemmMultipleDBase
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:542
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle DeviceOp
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:326
static auto MakeInvoker()
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1918
std::conditional_t< isMultiA, std::array< const void *, NumATensor > &, const void * > APointers
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:566
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:581
std::conditional_t< isMultiA||isMultiB, GridwiseGemmMultipleABDBase< NXdlPerWave32 >, GridwiseGemmMultipleDBase< NXdlPerWave32 > > GridwiseGemm32
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:552
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))> DsGridDesc_M_N
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:478
remove_cvref_t< decltype(MakeEGridDescriptor_M_N< ELayout >(dummy_conv_to_gemm_transformer))> EGridDesc_M_N
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:480
GridwiseElementwise< Tuple< NGCHWTransposeDescType >, Tuple< NHWGCTransposeDescType >, Tuple< const ADataType * >, Tuple< ADataType * >, Block2TileMapElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, NPerBlock, NPerBlock/ClusterLengthNPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, I1, I0 > GridwiseElementwiseInputTranspose
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:618
static auto MakeArgument(APointers p_as, BPointers p_bs, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1803
GridwiseElementwise< Tuple< NHWGCTransposeDescType >, Tuple< NGCHWTransposeDescType >, Tuple< const EDataType * >, Tuple< EDataType * >, Block2TileMapElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, NPerBlock, NPerBlock/ClusterLengthNPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, I0, I1 > GridwiseElementwiseOutputTranspose
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:654
std::conditional_t< CTranspose, GridwiseGemmMultipleDCTransposeBase< NXdlPerWave32 >, GridwiseGemm32 > GridwiseGemmCTranspose32
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:560
remove_cvref_t< decltype(GridwiseGemmCTranspose64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:584
std::unique_ptr< BaseArgument > MakeArgumentPointer(APointers p_as, BPointers p_bs, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1920
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKCYXTransposeDesc< NDimSpatial >({}, {}))> GKCYXTransposeDescType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:604
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKYXCTransposeDesc< NDimSpatial >({}, {}))> GKYXCTransposeDescType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:607
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:578
std::string GetTypeString() const override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:2041
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:1798
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:395
void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:2095
std::conditional_t< isMultiB, std::array< const void *, NumBTensor > &, const void * > BPointers
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:568
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNHWGCTransposeDesc< NDimSpatial >({}, {}))> NHWGCTransposeDescType
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:600
GridwiseGemmMultipleD_xdl_cshuffle< GridwiseGemmCTransposeTemplateParameters > GridwiseGemmMultipleDCTransposeBase
Definition device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp:545
Grouped Convolution Forward.
Definition device_grouped_conv_fwd_multiple_abd.hpp:73
__host__ __device__ constexpr auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition matrix_padder.hpp:163
Definition matrix_padder.hpp:180
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129