device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp Source File

device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp Source File#

Composable Kernel: device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp Source File
device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
11#include "ck/utility/env.hpp"
27
29
30namespace ck {
31namespace tensor_operation {
32namespace device {
33
34namespace {
35
36/*
37 * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
38 *
39 * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
40 * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
41 * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
42 * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
43 * limitations.
44 *
45 * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
46 * returns the 2D index of the tile that it computes. \see
47 * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
48 *
49 * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
50 * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
51 * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
52 * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
53 * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
54 * pointer offset into \p ComputePtrOffsetOfStridedBatch.
55 *
56 * MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this
57 * implementation we can avoid copy data to workspace before kernel launch since number of groups is
58 * runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then we run this
59 * kernel in the loop.
60 *
61 * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
62 * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
63 * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
64 *
65 */
66template <typename GridwiseGemm,
67 typename ABDataType,
68 typename DsPointer,
69 typename EDataType,
70 index_t MaxGroupedGemmGroupsNum,
71 typename GemmArgs,
72 typename AElementwiseOp,
73 typename BElementwiseOp,
74 typename CDEElementwiseOp,
75 typename ComputePtrOffsetOfBatch,
76 typename ComputePtrOffsetOfN,
77 InMemoryDataOperationEnum OutElementOp,
78 bool HasMainKBlockLoopInAllGemm,
79 bool NoMainKBlockLoopInAllGemm,
80 bool CTranspose>
81__global__ void
82#if CK_USE_LAUNCH_BOUNDS
84#endif
85 kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle(
86 const ABDataType* __restrict__ p_a_grid,
87 const ABDataType* __restrict__ p_b_grid,
88 DsPointer p_ds_grid,
89 EDataType* __restrict__ p_e_grid,
90 const std::array<GemmArgs, MaxGroupedGemmGroupsNum> gemm_kernel_args,
91 const index_t gemms_count,
92 const AElementwiseOp a_element_op,
93 const BElementwiseOp b_element_op,
94 const CDEElementwiseOp cde_element_op,
95 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
96 const ComputePtrOffsetOfN compute_ptr_offset_of_n,
97 const index_t KBatch)
98{
99#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
100 if constexpr(GridwiseGemm::template IsValidCompilationParameter<OutElementOp>())
101 {
102 // offset base pointer for each work-group
103 const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
104 const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
105 const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch);
106 const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch);
107
108 const long_index_t a_batch_offset =
109 CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))
110 : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
111 const long_index_t b_batch_offset =
112 CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))
113 : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
114 const long_index_t e_batch_offset =
115 amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
116
117 const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
118
119 const long_index_t a_n_offset =
120 CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
121 const long_index_t b_n_offset =
122 CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0;
123
124 const long_index_t e_n_offset =
125 amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
126
127 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
128
129 DsPointer p_ds_grid_grp;
130
131 static constexpr index_t NumDTensor = DsPointer::Size();
132
133 static_for<0, NumDTensor, 1>{}(
134 [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
135
136 index_t left = 0;
137 index_t right = gemms_count;
138 index_t group_id = index_t((left + right) / 2);
139 while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ &&
140 block_args_id < gemm_kernel_args[group_id].BlockEnd_)) &&
141 left <= right)
142 {
143 if(block_args_id < gemm_kernel_args[group_id].BlockStart_)
144 {
145 right = group_id;
146 }
147 else
148 {
149 left = group_id;
150 }
151 group_id = index_t((left + right) / 2);
152 }
153
154 if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm)
155 {
156 GridwiseGemm::template Run<HasMainKBlockLoopInAllGemm, OutElementOp>(
157 p_a_grid + a_batch_offset + a_n_offset,
158 p_b_grid + b_batch_offset + b_n_offset,
159 p_ds_grid_grp,
160 p_e_grid + e_batch_offset + e_n_offset,
161 p_shared,
162 a_element_op,
163 b_element_op,
164 cde_element_op,
165 gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
166 gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
167 gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
168 gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
169 gemm_kernel_args[group_id].block_2_ctile_map_,
170 KBatch,
171 k_idx);
172 }
173 else
174 {
175 if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
176 {
177 GridwiseGemm::template Run<true, OutElementOp>(
178 p_a_grid + a_batch_offset + a_n_offset,
179 p_b_grid + b_batch_offset + b_n_offset,
180 p_ds_grid_grp,
181 p_e_grid + e_batch_offset + e_n_offset,
182 p_shared,
183 a_element_op,
184 b_element_op,
185 cde_element_op,
186 gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
187 gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
188 gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
189 gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
190 gemm_kernel_args[group_id].block_2_ctile_map_,
191 KBatch,
192 k_idx);
193 }
194 else
195 {
196 GridwiseGemm::template Run<false, OutElementOp>(
197 p_a_grid + a_batch_offset + a_n_offset,
198 p_b_grid + b_batch_offset + b_n_offset,
199 p_ds_grid_grp,
200 p_e_grid + e_batch_offset + e_n_offset,
201 p_shared,
202 a_element_op,
203 b_element_op,
204 cde_element_op,
205 gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
206 gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
207 gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
208 gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
209 gemm_kernel_args[group_id].block_2_ctile_map_,
210 KBatch,
211 k_idx);
212 }
213 }
214 }
215#else
216 ignore = p_a_grid;
217 ignore = p_b_grid;
218 ignore = p_ds_grid;
219 ignore = p_e_grid;
220 ignore = gemm_kernel_args;
221 ignore = gemms_count;
222 ignore = a_element_op;
223 ignore = b_element_op;
224 ignore = cde_element_op;
225 ignore = compute_ptr_offset_of_batch;
226 ignore = compute_ptr_offset_of_n;
227 ignore = KBatch;
228#endif
229}
230
231} // namespace
232
233// Conv backward data multiple D:
234// input : output image A: [G, N, K, Ho, Wo]
235// input : weight B: [G, K, C, Y, X],
236// input : D0, D1, ... : [G, N, K, Ho, Wo]
237// output : input image E: [G, N, C, Hi, Wi]
238// C = a_op(A) * b_op(B)
239// E = cde_op(C, D0, D1, ...)
240template <index_t NDimSpatial,
241 typename ALayout, // output image
242 typename BLayout, // weight
243 typename DsLayout, // bias
244 typename ELayout, // input image
245 typename ADataType, // output image
246 typename BDataType, // weight
247 typename AccDataType,
248 typename CShuffleDataType,
249 typename DsDataType, // bias
250 typename EDataType, // input image
251 typename AElementwiseOp, // output image
252 typename BElementwiseOp, // weight
253 typename CDEElementwiseOp, // C, bias, and input image
254 ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
255 bool DoPadGemmM,
256 bool DoPadGemmN,
257 index_t NumGemmKPrefetchStage,
258 index_t BlockSize,
259 index_t MPerBlock,
260 index_t NPerBlock,
261 index_t KPerBlock,
262 index_t AK1,
263 index_t BK1,
264 index_t MPerXDL,
265 index_t NPerXDL,
266 index_t MXdlPerWave,
267 index_t NXdlPerWave,
268 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
269 typename ABlockTransferThreadClusterArrangeOrder,
270 typename ABlockTransferSrcAccessOrder,
271 index_t ABlockTransferSrcVectorDim,
272 index_t ABlockTransferSrcScalarPerVector,
273 index_t ABlockTransferDstScalarPerVector_AK1,
274 index_t ABlockLdsExtraM,
275 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
276 typename BBlockTransferThreadClusterArrangeOrder,
277 typename BBlockTransferSrcAccessOrder,
278 index_t BBlockTransferSrcVectorDim,
279 index_t BBlockTransferSrcScalarPerVector,
280 index_t BBlockTransferDstScalarPerVector_BK1,
281 index_t BBlockLdsExtraN,
282 index_t CShuffleMXdlPerWavePerShuffle,
283 index_t CShuffleNXdlPerWavePerShuffle,
284 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
285 index_t CDEBlockTransferScalarPerVector_NPerBlock,
287 typename AComputeType = ADataType,
288 typename BComputeType = AComputeType,
289 index_t MaxTransposeTransferInScalarPerVector = 1,
290 index_t MaxTransposeTransferOutScalarPerVector = 1>
292 : public DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
293 ALayout, // output image
294 BLayout, // weight
295 DsLayout, // bias
296 ELayout, // input image
297 ADataType, // output image
298 BDataType, // weight
299 DsDataType, // bias
300 EDataType, // input image
301 AElementwiseOp,
302 BElementwiseOp,
303 CDEElementwiseOp,
304 AComputeType,
305 BComputeType>
306{
307 // TODO: Extend support for more spatial dimensions.
308 static_assert(NDimSpatial == 2 || NDimSpatial == 3,
309 "wrong! only implemented for 2D and 3D now");
310
311 // MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this
312 // implementation we can avoid copy data to workspace before kernel launch since number of
313 // groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then
314 // we run this kernel in the loop.
316 ConvBackwardDataSpecialization ==
318 ? 1
319 : 32;
320
323 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
324 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
325
326 static constexpr index_t NumDTensor = DsDataType::Size();
328 static constexpr bool IsSplitKSupported =
329 (CDEBlockTransferScalarPerVector_NPerBlock % 2 == 0 || sizeof(EDataType) % 4 == 0) &&
331
332 // TODO: Add support for different A and B data types.
333 using ABDataType = ADataType;
334
335 static constexpr auto I0 = Number<0>{};
336 static constexpr auto I1 = Number<1>{};
337 static constexpr auto I2 = Number<2>{};
338 static constexpr auto I3 = Number<3>{};
339
340 static constexpr bool isATensorColMajor =
341 (ConvBackwardDataSpecialization ==
343 (ABlockTransferSrcVectorDim == 1) &&
346
347 static constexpr bool NeedTransposeKernel =
350
351 static constexpr bool CTranspose =
354
355 using ALayoutAfterTranspose = std::conditional_t<
358 std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
360 ALayout>>;
361 using BLayoutAfterTranspose = std::conditional_t<
364 std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ELayout, BLayout, ALayout>() &&
367 BLayout>>;
368 using ELayoutAfterTranspose = std::conditional_t<
371 std::conditional_t<is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
373 ELayout>>;
374
376 ConvBackwardDataSpecialization,
377 AK1,
378 BK1,
379 MPerBlock,
380 NPerBlock,
381 KPerBlock,
382 DoPadGemmM,
383 DoPadGemmN,
387 true, /*SplitConvN*/
389 EDataType,
390 1,
391 index_t,
392 CTranspose>;
393
394 static auto
396 {
397 const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1();
398
399 const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1();
400
401 const auto ds_grid_desc_m_n = generate_tuple(
402 [&](auto i) {
403 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
404 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
405 using ConvToGemmBwdDataTransformD =
407 ConvBackwardDataSpecialization,
408 AK1,
409 BK1,
410 MPerBlock,
411 NPerBlock,
412 KPerBlock,
413 DoPadGemmM,
414 DoPadGemmN,
416 BLayout,
417 DLayout,
418 true, /*SplitConvN*/
420 DDataType,
421 1, /*index_t NumGroupsToMerge = 1,*/
422 index_t, /* typename IndexType = */
423 CTranspose>;
424 return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N();
425 },
427
428 const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
429 if constexpr(CTranspose)
430 {
431 return make_tuple(
432 b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1, ds_grid_desc_m_n, e_grid_desc_m_n);
433 }
434 else
435 {
436 return make_tuple(
437 a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
438 }
439 }
440
441// GridwiseGemm
442#define GridwiseGemmMultiDTemplateParams \
443 ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
444 AElementwiseOp, BElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \
445 MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, \
446 ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
447 ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
448 ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
449 ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
450 BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
451 BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
452 BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
453 CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
454 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
455 CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
456
457#define GridwiseGemmCTransposeTemplateParameters \
458 ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
459 BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \
460 NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave_, MXdlPerWave, \
461 BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \
462 BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \
463 BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \
464 BBlockLdsExtraN, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
465 ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
466 ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
467 ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
468 CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
469 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
470 CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
471
472 template <index_t NXdlPerWave_>
474 template <index_t NXdlPerWave_>
479
481 std::conditional_t<CTranspose,
485 std::conditional_t<CTranspose, GridwiseGemmCTransposeBase<NXdlPerWave32>, GridwiseGemm32>;
486
487 template <typename EGridDesc_M_N>
488 static auto
490 {
491 return GridwiseGemmCTranspose64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
492 e_grid_desc_m_n);
493 }
494
495 template <typename Desc_K0_M_K1>
496 static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1)
497 {
498 const auto grid_desc_m_k = transform_tensor_descriptor(
499 desc_k0_m_k1,
500 make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)),
502 make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))),
505
506 return grid_desc_m_k;
507 }
508
509 // desc
512
517
520
522 decltype(GridwiseGemmCTranspose64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
523 DsGridDesc_M_N{}));
526
527 // block-to-e-tile map
529 decltype(GridwiseGemmCTranspose64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}));
530
532
533 struct GemmArgs
534 {
535 GemmArgs() = default;
536 GemmArgs(AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
537 BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
539 ds_grid_desc_mblock_mperblock_nblock_nperblock,
541 e_grid_desc_mblock_mperblock_nblock_nperblock,
542 GroupedGemmBlock2ETileMap block_2_ctile_map,
543 index_t BlockStart,
544 index_t BlockEnd,
545 bool HasMainKBlockLoop)
546 : a_grid_desc_ak0_m_ak1_(a_grid_desc_ak0_m_ak1),
547 b_grid_desc_bk0_n_bk1_(b_grid_desc_bk0_n_bk1),
548
550 ds_grid_desc_mblock_mperblock_nblock_nperblock),
551
553 e_grid_desc_mblock_mperblock_nblock_nperblock),
554
555 // block-to-e-tile map
556 block_2_ctile_map_(block_2_ctile_map),
557 BlockStart_(BlockStart),
558 BlockEnd_(BlockEnd),
559 HasMainKBlockLoop_(HasMainKBlockLoop)
560
561 {
562 }
563 // tensor descriptors for block/thread-wise copy
569
570 // block-to-e-tile map
574 };
577
579 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
581 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
582
583 static constexpr auto conv_ngchw_to_nhwgc_transformer =
585 BLayout,
586 ALayout,
587 NDimSpatial,
588 NPerBlock / ClusterLengthNPerBlock,
589 MPerBlock / ClusterLengthMPerBlock>{};
590
592 std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferInScalarPerVector);
594 std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferOutScalarPerVector);
595
598 .template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
601 .template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
604 .template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
607 .template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
608
610
619 NPerBlock,
620 MPerBlock,
621 NPerBlock / ClusterLengthNPerBlock,
622 MPerBlock / ClusterLengthMPerBlock,
626 I1,
627 I0>;
628
637 MPerBlock,
638 NPerBlock,
639 MPerBlock / ClusterLengthMPerBlock,
640 NPerBlock / ClusterLengthNPerBlock,
644 I0,
645 I1>;
646
655 NPerBlock,
656 MPerBlock,
657 NPerBlock / ClusterLengthNPerBlock,
658 MPerBlock / ClusterLengthMPerBlock,
662 I0,
663 I1>;
664 // Argument
665 struct Argument : public BaseArgument
666 {
667 Argument(const void* p_a, // output image
668 const void* p_b, // weight
669 const std::array<const void*, NumDTensor>& p_ds, // bias
670 void* p_e, // input image
671 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths,
672 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
673 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
674 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
675 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
676 ds_g_n_c_wis_lengths,
677 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
678 ds_g_n_c_wis_strides,
679 const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths,
680 const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides,
681 const std::array<index_t, NDimSpatial>& conv_filter_strides,
682 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
683 const std::array<index_t, NDimSpatial>& input_left_pads,
684 const std::array<index_t, NDimSpatial>& input_right_pads,
685 const AElementwiseOp& a_element_op,
686 const BElementwiseOp& b_element_op,
687 const CDEElementwiseOp& cde_element_op,
688 ck::index_t split_k = 1)
689 : p_a_grid_{static_cast<const ADataType*>(p_a)},
690 p_b_grid_{static_cast<const BDataType*>(p_b)},
691 p_ds_grid_{},
692 p_e_grid_{static_cast<EDataType*>(p_e)},
693 num_group_{a_g_n_k_wos_lengths[0]},
694 a_element_op_{a_element_op},
695 b_element_op_{b_element_op},
696 cde_element_op_{cde_element_op},
697 a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
698 b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
699 e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
700 conv_filter_strides_{conv_filter_strides},
701 input_left_pads_{input_left_pads},
702 input_right_pads_{input_right_pads},
703 k_batch_{split_k}
704 {
705 bool image_covered_dilation = true;
706 bool image_covered_strides = true;
707 for(index_t d = 0; d < NDimSpatial; d++)
708 {
709 // If dilation and stride is not equal to the we will have some empty places
710 image_covered_dilation &=
711 conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1;
712 // If stride is larger than windows size then we will have some empty places
713 image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3];
714 }
715 bool if_d_is_output_mem = false;
716 const void* out_mem_void = static_cast<const void*>(p_e);
717 static_for<0, NumDTensor, 1>{}([&](auto i) {
718 if(p_ds[i] == out_mem_void)
719 {
720 if_d_is_output_mem = true;
721 }
722 });
723
724 bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides;
725
726 // Temporary workaround untill prove/fix above conditions.
727 bwd_needs_zero_out = !if_d_is_output_mem;
730 e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
731 sizeof(EDataType);
732
733 std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
735 a_g_n_k_wos_lengths, a_g_n_k_wos_strides)
736 : a_g_n_k_wos_strides;
737 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_transposed =
739 b_g_k_c_xs_lengths, b_g_k_c_xs_strides)
740 : b_g_k_c_xs_strides;
741 std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_transposed =
743 e_g_n_c_wis_lengths, e_g_n_c_wis_strides)
744 : e_g_n_c_wis_strides;
745
746 // populate Ds pointer
747 static_for<0, NumDTensor, 1>{}([&](auto i) {
748 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
749
750 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
751 });
752
753 static_for<0, NumDTensor, 1>{}([&](auto i) {
754 compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0];
755 });
756
757 static constexpr auto NonSpatialDimsNum = Number<3>{};
758
759 static constexpr auto DIdx = Number<NonSpatialDimsNum>{};
760 static constexpr auto HIdx =
762 static constexpr auto WIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
764
765 static constexpr auto ZIdx = Number<NonSpatialDimsNum>{};
766 static constexpr auto YIdx =
768 static constexpr auto XIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
770
771 // problem definition
772 const index_t Z = b_g_k_c_xs_lengths[ZIdx];
773 const index_t Y = b_g_k_c_xs_lengths[YIdx];
774 const index_t X = b_g_k_c_xs_lengths[XIdx];
775
776 const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
777 const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
778 const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
779
780 const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
781 const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
782 const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
783
784 const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
785 const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
786 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
787
788 const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1;
789 const auto YTilde = ConvStrideH / GcdStrideDilationH;
790 const auto XTilde = ConvStrideW / GcdStrideDilationW;
791
792 index_t grid_size = 0;
793 // Allocate place for sets of gemms
794 gemm_kernel_args_.resize(
795 math::integer_divide_ceil(ZTilde * YTilde * XTilde, MaxGroupedGemmGroupsNum));
796
797 for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
798 {
799 for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
800 {
801 for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
802 {
803 // check slice is valid
804 const auto ZDotSlice =
805 NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1;
806 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
807 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
808
809 if(YDotSlice * XDotSlice * ZDotSlice <= 0)
810 {
811 continue;
812 }
813
814 std::array<index_t, NDimSpatial> tildes;
815 if constexpr(NDimSpatial == 2)
816 {
817 tildes = {i_ytilde, i_xtilde};
818 }
819 else if constexpr(NDimSpatial == 3)
820 {
821 tildes = {i_ztilde, i_ytilde, i_xtilde};
822 }
823 else
824 {
825 throw std::runtime_error("wrong! only implemented for 2D and 3D now");
826 }
827
828 ConvToGemmBwdDataTransform conv_to_gemm_transform_{
829 a_g_n_k_wos_lengths,
830 a_g_n_k_wos_strides_transposed,
831 b_g_k_c_xs_lengths,
832 b_g_k_c_xs_strides_transposed,
833 e_g_n_c_wis_lengths,
834 e_g_n_c_wis_strides_transposed,
835 conv_filter_strides,
836 conv_filter_dilations,
837 input_left_pads,
838 input_right_pads,
839 tildes,
840 k_batch_};
841
842 conv_N_per_block_ = conv_to_gemm_transform_.N_;
843
844 const auto a_grid_desc_ak0_m_ak1 = [&]() {
845 if constexpr(CTranspose)
846 {
847 return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
848 }
849 else
850 {
851 return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
852 }
853 }();
854
855 const auto b_grid_desc_bk0_n_bk1 = [&]() {
856 if constexpr(CTranspose)
857 {
858 return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
859 }
860 else
861 {
862 return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
863 }
864 }();
865 DsGridDesc_M_N ds_grid_desc_m_n;
866
867 // populate Ds desc
868 static_for<0, NumDTensor, 1>{}([&](auto i) {
869 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
870 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
871 using ConvToGemmBwdDataTransformD =
873 ConvBackwardDataSpecialization,
874 AK1,
875 BK1,
876 MPerBlock,
877 NPerBlock,
878 KPerBlock,
879 DoPadGemmM,
880 DoPadGemmN,
883 DLayout,
884 true, /*SplitConvN*/
886 DDataType,
887 1,
888 index_t,
889 CTranspose>;
890 ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{
891 a_g_n_k_wos_lengths,
892 a_g_n_k_wos_strides_transposed,
893 b_g_k_c_xs_lengths,
894 b_g_k_c_xs_strides_transposed,
895 ds_g_n_c_wis_lengths[i],
896 ds_g_n_c_wis_strides[i],
897 conv_filter_strides,
898 conv_filter_dilations,
899 input_left_pads,
900 input_right_pads,
901 tildes};
902
903 ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N();
904 });
905
906 const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N();
907
908 // desc for problem definition
909 const auto a_grid_desc_m_k =
910 transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1);
911 const auto b_grid_desc_n_k =
912 transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1);
913
914 a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k);
915 b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k);
916 ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
917 e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n);
918
919 const index_t grid_size_grp = Block2ETileMap::CalculateGridSize(
920 e_grid_desc_m_n.GetLength(I0), e_grid_desc_m_n.GetLength(I1));
921
922 const index_t BlockStart = grid_size;
923 const index_t BlockEnd = grid_size + grid_size_grp;
924
925 grid_size += grid_size_grp;
926
927 // block-to-e-tile map
928 const auto block_2_etile_map =
929 GroupedGemmBlock2ETileMap(Block2ETileMap(e_grid_desc_m_n.GetLength(I0),
930 e_grid_desc_m_n.GetLength(I1)),
931 BlockStart);
932
933 const auto GemmK = a_grid_desc_m_k.GetLength(I1);
934 const bool HasMainKBlockLoop =
935 GridwiseGemmCTranspose64::CalculateHasMainKBlockLoop(GemmK, k_batch_);
936
940 GemmArgs{a_grid_desc_ak0_m_ak1,
941 b_grid_desc_bk0_n_bk1,
942 GridwiseGemmCTranspose64::
943 MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
944 ds_grid_desc_m_n),
946 e_grid_desc_m_n),
947 block_2_etile_map,
948 BlockStart,
949 BlockEnd,
950 HasMainKBlockLoop};
951 gemms_count_++;
953 {
954 gemms_grid_size_.push_back(grid_size);
955 grid_size = 0;
956 }
957 }
958 }
959 }
960 gemm_kernel_args_.resize(
962 gemms_grid_size_.push_back(grid_size);
963
964 // A/B/Ds/E Batch Stride
965 compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
966 compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides_transposed[0];
967 compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides_transposed[0];
968
969 compute_ptr_offset_of_n_.BatchStrideA_ =
970 a_g_n_k_wos_strides_transposed[1] * conv_N_per_block_;
971 compute_ptr_offset_of_n_.BatchStrideE_ =
972 e_g_n_c_wis_strides_transposed[1] * conv_N_per_block_;
973
975
976 if constexpr(NeedTransposeKernel)
977 {
978 // Use not modified base strides
980 conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
981 a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_);
983 conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
984 a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_);
985
987 conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
988 b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
990 conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
991 b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
992
994 conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
995 e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_);
997 conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
998 e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_);
999
1001 a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
1003 b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
1005 e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
1006
1008 a_g_n_k_wos_strides[1] * conv_N_per_block_;
1010 e_g_n_c_wis_strides[1] * conv_N_per_block_;
1011 }
1012 }
1013
1015 {
1016 if constexpr(NeedTransposeKernel)
1017 {
1019 a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
1020 // Align to 128B
1021 return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128;
1022 }
1023 else
1024 {
1025 return 0;
1026 }
1027 }
1028
1030 {
1031 if constexpr(NeedTransposeKernel)
1032 {
1034 b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
1035 // Align to 128B
1036 return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128;
1037 }
1038 else
1039 {
1040 return 0;
1041 }
1042 }
1043
1045 {
1046 if constexpr(NeedTransposeKernel)
1047 {
1049 e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
1050 return sizeof(EDataType) * e_accum;
1051 }
1052 else
1053 {
1054 return 0;
1055 }
1056 }
1057
1063
1064 void Print() const
1065 {
1066 for(std::size_t i = 0; i < a_grid_desc_m_k_container_.size(); i++)
1067 {
1068 std::cout << "a_grid_desc_m_ak_container_" << a_grid_desc_m_k_container_[i]
1069 << std::endl;
1070
1071 std::cout << "b_grid_desc_n_bk_container_" << b_grid_desc_n_k_container_[i]
1072 << std::endl;
1073
1074 static_for<0, NumDTensor, 1>{}([&](auto j) {
1075 std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_"
1076 << ds_grid_desc_m_n_container_[i][j] << std::endl;
1077 });
1078
1079 std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_"
1080 << e_grid_desc_m_n_container_[i] << std::endl;
1081 }
1082 }
1083
1084 // pointers
1085 const ADataType* p_a_grid_;
1086 const BDataType* p_b_grid_;
1088 EDataType* p_e_grid_;
1089
1090 // tensor descriptor for problem definition
1093 std::vector<AGridDesc_M_K> a_grid_desc_m_k_container_;
1094 std::vector<BGridDesc_N_K> b_grid_desc_n_k_container_;
1095 std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_;
1096 std::vector<EGridDesc_M_N> e_grid_desc_m_n_container_;
1097
1098 // block-to-e-tile map
1102
1107
1108 // for computing batch offset
1109 ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
1110 ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
1111 ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_workspace_n_;
1112
1113 // element-wise op
1114 AElementwiseOp a_element_op_;
1115 BElementwiseOp b_element_op_;
1116 CDEElementwiseOp cde_element_op_;
1117
1118 std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
1119 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
1120 std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
1121 std::array<index_t, NDimSpatial> conv_filter_strides_;
1122 std::array<index_t, NDimSpatial> input_left_pads_;
1123 std::array<index_t, NDimSpatial> input_right_pads_;
1124
1127 std::vector<index_t> gemms_grid_size_;
1129 std::vector<std::array<GemmArgs, MaxGroupedGemmGroupsNum>> gemm_kernel_args_;
1130
1133 };
1134
1135 // Invoker
1136 struct Invoker : public BaseInvoker
1137 {
1139
1140 template <typename GridwiseGemm,
1141 typename GridwiseGemmCTranspose,
1142 InMemoryDataOperationEnum ElementOp>
1143 float RunMultiDGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
1144 {
1145 float ave_time = 0;
1146
1147 const index_t gdy = arg.num_group_;
1148 const index_t gdz = arg.num_workgroups_per_Conv_N_ * arg.k_batch_;
1149
1150 const ADataType* p_a_grid = arg.p_a_grid_;
1151 const BDataType* p_b_grid = arg.p_b_grid_;
1152 EDataType* p_e_grid = arg.p_e_grid_;
1153 if constexpr(NeedTransposeKernel)
1154 {
1157 {
1159 p_e_grid =
1162 sizeof(EDataType);
1163 }
1164
1167 {
1169 arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
1170 }
1171 }
1172 for(std::size_t gemm_set_id = 0; gemm_set_id < arg.gemm_kernel_args_.size();
1173 gemm_set_id++)
1174 {
1175 const index_t gdx = arg.gemms_grid_size_[gemm_set_id];
1176 const index_t gemms_count_for_set =
1177 gemm_set_id == arg.gemm_kernel_args_.size() - 1
1178 ? arg.gemms_count_ - MaxGroupedGemmGroupsNum * gemm_set_id
1180 const std::array<GemmArgs, MaxGroupedGemmGroupsNum>& gemm_kernel_args =
1181 arg.gemm_kernel_args_[gemm_set_id];
1182
1183 const auto clear_workspace = [&]() {
1184 if(arg.bwd_needs_zero_out && gemm_set_id == 0)
1185 {
1186 hip_check_error(hipMemsetAsync(
1187 p_e_grid, 0, arg.e_space_size_bytes, stream_config.stream_id_));
1188 }
1189 };
1190
1191 bool has_loop_in_all_gemm = true;
1192 bool no_loop_in_all_gemm = true;
1193 for(auto i = 0; i < gemms_count_for_set; i++)
1194 {
1195 has_loop_in_all_gemm &= gemm_kernel_args[i].HasMainKBlockLoop_;
1196 no_loop_in_all_gemm &= !gemm_kernel_args[i].HasMainKBlockLoop_;
1197 }
1198
1199 auto launch_kernel = [&](auto has_main_k_block_loop, auto no_main_k_block_loop) {
1200 constexpr bool has_main_loop = has_main_k_block_loop.value;
1201 constexpr bool no_main_loop = no_main_k_block_loop.value;
1202 if constexpr(CTranspose)
1203 {
1204 const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
1205 GridwiseGemmCTranspose,
1206 ADataType, // TODO: distiguish A/B datatype
1207 typename GridwiseGemm::DsGridPointer,
1208 EDataType,
1210 GemmArgs,
1211 BElementwiseOp,
1212 AElementwiseOp,
1213 CDEElementwiseOp,
1214 ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
1215 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1216 ElementOp,
1217 has_main_loop,
1218 no_main_loop,
1219 CTranspose>;
1220
1222 stream_config,
1223 clear_workspace,
1224 kernel,
1225 dim3(gdx, gdy, gdz),
1226 dim3(BlockSize),
1227 0,
1228 p_b_grid,
1229 p_a_grid,
1230 arg.p_ds_grid_,
1231 p_e_grid,
1232 gemm_kernel_args,
1233 gemms_count_for_set,
1234 arg.b_element_op_,
1235 arg.a_element_op_,
1236 arg.cde_element_op_,
1239 arg.k_batch_);
1240 }
1241 else
1242 {
1243 const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
1244 GridwiseGemm,
1245 ADataType, // TODO: distiguish A/B datatype
1246 typename GridwiseGemm::DsGridPointer,
1247 EDataType,
1249 GemmArgs,
1250 AElementwiseOp,
1251 BElementwiseOp,
1252 CDEElementwiseOp,
1253 ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
1254 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1255 ElementOp,
1256 has_main_loop,
1257 no_main_loop,
1258 CTranspose>;
1259
1261 stream_config,
1262 clear_workspace,
1263 kernel,
1264 dim3(gdx, gdy, gdz),
1265 dim3(BlockSize),
1266 0,
1267 p_a_grid,
1268 p_b_grid,
1269 arg.p_ds_grid_,
1270 p_e_grid,
1271 gemm_kernel_args,
1272 gemms_count_for_set,
1273 arg.a_element_op_,
1274 arg.b_element_op_,
1275 arg.cde_element_op_,
1278 arg.k_batch_);
1279 }
1280 };
1281 if(has_loop_in_all_gemm)
1282 {
1283 ave_time += launch_kernel(integral_constant<bool, true>{},
1284 integral_constant<bool, false>{});
1285 }
1286 else if(no_loop_in_all_gemm)
1287 {
1288 ave_time += launch_kernel(integral_constant<bool, false>{},
1289 integral_constant<bool, true>{});
1290 }
1291 else
1292 {
1293 ave_time += launch_kernel(integral_constant<bool, false>{},
1294 integral_constant<bool, false>{});
1295 }
1296 }
1297
1298 return ave_time;
1299 }
1300
1301 template <typename GridwiseGemm, typename GridwiseGemmCTranspose>
1302 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
1303 {
1304 float ave_time = 0;
1305
1306 if(stream_config.log_level_ > 0)
1307 {
1308 arg.Print();
1309 }
1310
1311 // Transpose from NGKHW to NHWGK
1312 if constexpr(NeedTransposeKernel)
1313 {
1314 EDataType* p_e_in_grid =
1317 sizeof(EDataType);
1318
1319 const auto clear_workspace = [&]() {
1320 hip_check_error(hipMemsetAsync(p_e_in_grid,
1321 0,
1323 stream_config.stream_id_));
1324 };
1325
1326 const index_t a_grid_size =
1327 arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
1330 const index_t b_grid_size =
1333 ? arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
1335 : 0; // Dont run transpose B if not needed
1336
1337 ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
1338 BDataType* p_b_out_grid = type_convert<BDataType*>(arg.p_workspace_) +
1339 arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
1340
1341 auto kernel_transpose =
1355 I1,
1356 I1,
1357 I1,
1358 I1>;
1359
1361 stream_config,
1362 clear_workspace,
1363 kernel_transpose,
1364 dim3(a_grid_size + b_grid_size),
1366 0,
1371 make_tuple(arg.p_a_grid_),
1372 make_tuple(arg.p_b_grid_),
1373 make_tuple(p_a_out_grid),
1374 make_tuple(p_b_out_grid),
1378 a_grid_size,
1380 I1, // B is not splited per N
1381 std::array<index_t, I1>{
1382 static_cast<index_t>(arg.compute_ptr_offset_of_workspace_n_.BatchStrideA_)},
1383 std::array<index_t, I1>{0},
1384 std::array<index_t, I1>{
1385 static_cast<index_t>(arg.compute_ptr_offset_of_n_.BatchStrideA_)},
1386 std::array<index_t, I1>{0});
1387 }
1388 if(arg.k_batch_ > 1)
1389 {
1390 if constexpr(IsSplitKSupported)
1391 {
1392 ave_time +=
1393 RunMultiDGemm<GridwiseGemm,
1394 GridwiseGemmCTranspose,
1395 InMemoryDataOperationEnum::AtomicAdd>(arg, stream_config);
1396 }
1397 }
1398 else
1399 {
1400 ave_time += RunMultiDGemm<GridwiseGemm,
1401 GridwiseGemmCTranspose,
1402 InMemoryDataOperationEnum::Set>(arg, stream_config);
1403 }
1404
1405 // Transpose from NHWGC to NGCHW
1406 if constexpr(NeedTransposeKernel)
1407 {
1408 const index_t grid_size =
1409 arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
1412
1413 const EDataType* p_e_in_grid =
1416 sizeof(EDataType);
1417
1418 EDataType* p_e_out_grid = arg.p_e_grid_;
1419
1420 auto kernel_transpose =
1422 ck::Tuple<NHWGCTransposeDescType>,
1423 ck::Tuple<NGCHWTransposeDescType>,
1424 ck::Tuple<const EDataType*>,
1425 ck::Tuple<EDataType*>,
1427 element_wise::PassThrough,
1428 I1,
1429 I1>;
1430
1431 ave_time += launch_and_time_kernel(
1432 stream_config,
1433 kernel_transpose,
1434 dim3(grid_size),
1436 0,
1439 make_tuple(p_e_in_grid),
1440 make_tuple(p_e_out_grid),
1442 element_wise::PassThrough{},
1444 std::array<index_t, I1>{
1445 static_cast<index_t>(arg.compute_ptr_offset_of_n_.BatchStrideE_)},
1446 std::array<index_t, I1>{static_cast<index_t>(
1447 arg.compute_ptr_offset_of_workspace_n_.BatchStrideE_)});
1448 }
1449
1450 return ave_time;
1451 }
1452 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
1453 {
1454 if(get_warp_size() == 64)
1455 {
1456 if constexpr(NXdlPerWave64 > 0)
1457 {
1458 return RunImp<GridwiseGemm64, GridwiseGemmCTranspose64>(arg, stream_config);
1459 }
1460 else
1461 {
1462 return 0;
1463 }
1464 }
1465 else
1466 {
1467 if constexpr(NXdlPerWave32 > 0)
1468 {
1469 return RunImp<GridwiseGemm32, GridwiseGemmCTranspose32>(arg, stream_config);
1470 }
1471 else
1472 {
1473 return 0;
1474 }
1475 }
1476 }
1477
1478 float Run(const BaseArgument* p_arg,
1479 const StreamConfig& stream_config = StreamConfig{}) override
1480 {
1481 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
1482 }
1483 };
1484
1485 static bool IsSupportedArgument(const Argument& arg)
1486 {
1487 // gfx11 doesn't support float atomic
1488 // Todo: Enable splitK for gfx12
1490 {
1491 return false;
1492 }
1494 {
1495 return false;
1496 }
1497 if(!is_bf16_atomic_supported() && std::is_same_v<EDataType, ck::bhalf_t> &&
1498 arg.k_batch_ > 1)
1499 {
1500 return false;
1501 }
1503 {
1504 if(!is_tf32_supported())
1505 {
1506 return false;
1507 }
1509 {
1510 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1511 {
1512 std::cout << "ComputeDataType for A and B should be same while using TF32"
1513 << std::endl;
1514 }
1515 return false;
1516 }
1517 }
1518
1519 if constexpr(!IsSplitKSupported)
1520 {
1521 if(arg.k_batch_ != 1)
1522 {
1523 return false;
1524 }
1525 }
1526 else
1527 {
1528 // Split-K autodeduction is not supported.
1529 if(arg.k_batch_ < 1)
1530 {
1531 return false;
1532 }
1533 }
1534
1535 const index_t ConvG = arg.b_g_k_c_xs_lengths_[0];
1536 const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
1537 const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];
1538 const index_t output_spatial_acum = ck::accumulate_n<index_t>(
1539 arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
1540 const index_t input_spatial_acum = ck::accumulate_n<index_t>(
1541 arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
1542 // Specifialization
1543 if constexpr(ConvBackwardDataSpecialization ==
1545 {
1546 // check if it's 1x1, stride=1 pad = 0 conv
1547 for(int i = 0; i < NDimSpatial; i++)
1548 {
1549 if(!(arg.b_g_k_c_xs_lengths_[3 + i] == 1 && arg.conv_filter_strides_[i] == 1 &&
1550 arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
1551 {
1552 return false;
1553 }
1554 }
1555 }
1556
1557 // vector load for A matrix from global memory to LDS
1562 {
1563 if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
1564 {
1565 return false;
1566 }
1567 }
1570 {
1571 static_assert(NeedTransposeKernel == false);
1572
1573 if constexpr(ABlockTransferSrcScalarPerVector != 1)
1574 {
1575 if(ABlockTransferSrcVectorDim != 1)
1576 {
1577 return false;
1578 }
1579 if(output_spatial_acum % ABlockTransferSrcScalarPerVector != 0)
1580 {
1581 return false;
1582 }
1583 }
1584 }
1585 else
1586 {
1587 return false;
1588 }
1589
1590 // vector load for B matrix from global memory to LDS
1595 {
1596 if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0))
1597 {
1598 return false;
1599 }
1600 }
1601 else
1602 {
1603 return false;
1604 }
1605
1606 // vector store for Ds
1607 bool ds_valid = true;
1608
1609 static_for<0, NumDTensor, 1>{}([&](auto i) {
1610 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
1611
1619 {
1620 if(CTranspose == false)
1621 {
1622 // vector load D matrix from global memory
1623 if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
1624 {
1625 ds_valid = false;
1626 }
1627 }
1628 else
1629 {
1630 if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1631 {
1632 ds_valid = false;
1633 }
1634 }
1635 }
1636 else
1637 {
1638 ds_valid = false;
1639 }
1640 });
1641
1642 if(!ds_valid)
1643 {
1644 return false;
1645 }
1646
1647 // vector store for E
1654 {
1655 if(CTranspose == false)
1656 {
1657 // vector store C matrix into global memory
1658 if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
1659 {
1660 return false;
1661 }
1662 }
1663 else
1664 {
1665 if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1666 {
1667 return false;
1668 }
1669 }
1670 }
1671 else
1672 {
1673 return false;
1674 }
1675
1676 // Gridwise GEMM size
1677 bool isWave64 = get_warp_size() == 64;
1678 for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++)
1679 {
1680 bool valid = true;
1681 if(isWave64)
1682 {
1683 if constexpr(NXdlPerWave64 > 0)
1684 {
1685 if(!GridwiseGemmCTranspose64::CheckValidity(
1692 .block_2_ctile_map_,
1693 arg.k_batch_))
1694 {
1695 valid = false;
1696 }
1697 }
1698 }
1699 else
1700 {
1701 if constexpr(NXdlPerWave32 > 0)
1702 {
1703 if(!GridwiseGemmCTranspose32::CheckValidity(
1710 .block_2_ctile_map_,
1711 arg.k_batch_))
1712 {
1713 valid = false;
1714 }
1715 }
1716 }
1717 if(!valid)
1718 {
1719 return false;
1720 }
1721 }
1722
1723 if constexpr(NeedTransposeKernel)
1724 {
1725 if((ConvG * ConvC) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1726 {
1727 return false;
1728 }
1729
1730 if((ConvG * ConvK) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
1731 {
1732 return false;
1733 }
1734
1735 const index_t a_spatial_acum = ck::accumulate_n<index_t>(
1736 arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
1737 const index_t e_spatial_acum = ck::accumulate_n<index_t>(
1738 arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
1739
1740 if(a_spatial_acum % TransposeTransferInScalarPerVectorAligned != 0)
1741 {
1742 return false;
1743 }
1744
1745 if(e_spatial_acum % TransposeTransferOutScalarPerVectorAligned != 0)
1746 {
1747 return false;
1748 }
1749
1750 if(!arg.p_workspace_)
1751 {
1752 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1753 {
1754 std::cout
1755 << "Warning: Workspace for "
1756 "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument is not "
1757 "allocated, use SetWorkSpacePointer."
1758 << std::endl;
1759 }
1760 return false;
1761 }
1762 }
1763
1764 return true;
1765 }
1766
1767 bool IsSupportedArgument(const BaseArgument* p_arg) override
1768 {
1769 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1770 }
1771
1772 static auto
1773 MakeArgument(const void* p_a, // output image
1774 const void* p_b, // weight
1775 const std::array<const void*, NumDTensor>& p_ds, // bias
1776 void* p_e, // input image
1777 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
1778 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
1779 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
1780 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
1781 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
1782 ds_g_n_c_wis_lengths, // bias
1783 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
1784 ds_g_n_c_wis_strides, // bias
1785 const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
1786 const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
1787 const std::array<index_t, NDimSpatial>& conv_filter_strides,
1788 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
1789 const std::array<index_t, NDimSpatial>& input_left_pads,
1790 const std::array<index_t, NDimSpatial>& input_right_pads,
1791 const AElementwiseOp& a_element_op,
1792 const BElementwiseOp& b_element_op,
1793 const CDEElementwiseOp& cde_element_op,
1794 const ck::index_t split_k = 1)
1795 {
1796 return Argument{p_a,
1797 p_b,
1798 p_ds,
1799 p_e,
1800 a_g_n_k_wos_lengths,
1801 a_g_n_k_wos_strides,
1802 b_g_k_c_xs_lengths,
1803 b_g_k_c_xs_strides,
1804 ds_g_n_c_wis_lengths,
1805 ds_g_n_c_wis_strides,
1806 e_g_n_c_wis_lengths,
1807 e_g_n_c_wis_strides,
1808 conv_filter_strides,
1809 conv_filter_dilations,
1810 input_left_pads,
1811 input_right_pads,
1812 a_element_op,
1813 b_element_op,
1814 cde_element_op,
1815 split_k};
1816 }
1817
1818 static auto MakeInvoker() { return Invoker{}; }
1819
1820 std::unique_ptr<BaseArgument> MakeArgumentPointer(
1821 const void* p_a, // output image
1822 const void* p_b, // weight
1823 const std::array<const void*, NumDTensor>& p_ds, // bias
1824 void* p_e, // input image
1825 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
1826 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
1827 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
1828 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
1829 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
1830 ds_g_n_c_wis_lengths, // bias
1831 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
1832 ds_g_n_c_wis_strides, // bias
1833 const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
1834 const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
1835 const std::array<index_t, NDimSpatial>& conv_filter_strides,
1836 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
1837 const std::array<index_t, NDimSpatial>& input_left_pads,
1838 const std::array<index_t, NDimSpatial>& input_right_pads,
1839 const AElementwiseOp& a_element_op,
1840 const BElementwiseOp& b_element_op,
1841 const CDEElementwiseOp& cde_element_op,
1842 const ck::index_t split_k = 1) override
1843 {
1844 return std::make_unique<Argument>(p_a,
1845 p_b,
1846 p_ds,
1847 p_e,
1848 a_g_n_k_wos_lengths,
1849 a_g_n_k_wos_strides,
1850 b_g_k_c_xs_lengths,
1851 b_g_k_c_xs_strides,
1852 ds_g_n_c_wis_lengths,
1853 ds_g_n_c_wis_strides,
1854 e_g_n_c_wis_lengths,
1855 e_g_n_c_wis_strides,
1856 conv_filter_strides,
1857 conv_filter_dilations,
1858 input_left_pads,
1859 input_right_pads,
1860 a_element_op,
1861 b_element_op,
1862 cde_element_op,
1863 split_k);
1864 }
1865
1866 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1867 {
1868 return std::make_unique<Invoker>(Invoker{});
1869 }
1870
1871 std::string GetTypeString() const override
1872 {
1873 auto str = std::stringstream();
1874
1875 // clang-format off
1876 str << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1"
1877 << "<"
1878 << BlockSize << ", "
1879 << MPerBlock << ", "
1880 << NPerBlock << ", "
1881 << KPerBlock << ", "
1882 << AK1 << ", "
1883 << BK1 << ", "
1884 << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", "
1885 << MPerXDL << ", "
1886 << NPerXDL << ", "
1887 << MXdlPerWave << ", "
1888 << NXdlPerWave << ", "
1889 << ABlockTransferSrcScalarPerVector << ", "
1890 << BBlockTransferSrcScalarPerVector << ", "
1891 << CShuffleMXdlPerWavePerShuffle << ", "
1892 << CShuffleNXdlPerWavePerShuffle;
1893
1896 str << ", TransposeTransferInScalarPerVectorAligned: "
1898 << "TransposeTransferOutScalarPerVectorAligned: " << TransposeTransferOutScalarPerVectorAligned;
1899 }
1900
1901
1902 str << ">";
1903
1904 return str.str();
1905 }
1906
1907 size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
1908 {
1909 auto arg = dynamic_cast<const Argument*>(p_arg);
1910 if(arg)
1911 {
1912 return arg->GetWorkspaceSizeBytes();
1913 }
1914 else
1915 throw std::runtime_error(
1916 "The argument pointer is not an object of "
1917 "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument structure!");
1918 }
1919
1921 void* p_workspace,
1922 const StreamConfig& = StreamConfig{}) const override
1923 {
1924 auto p_arg_ = dynamic_cast<Argument*>(p_arg);
1925 if(p_arg_)
1926 {
1927 p_arg_->p_workspace_ = p_workspace;
1928 }
1929 else
1930 throw std::runtime_error(
1931 "The argument pointer is not an object of "
1932 "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument structure!");
1933 }
1934};
1935
1936} // namespace device
1937} // namespace tensor_operation
1938} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:91
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
Definition utility/math.hpp:154
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MNKPadding
Definition gemm_specialization.hpp:20
constexpr bool is_NGCDHW_NGKDHW()
Definition device_grouped_conv_utils.hpp:112
constexpr bool is_NGCHW_GKCYX_NGKHW()
Definition device_grouped_conv_utils.hpp:64
std::string getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecialization &s)
Definition convolution_backward_data_specialization.hpp:17
constexpr bool is_NGCDHW_GKCZYX_NGKDHW()
Definition device_grouped_conv_utils.hpp:104
ConvolutionBackwardDataSpecialization
Definition convolution_backward_data_specialization.hpp:11
@ Filter1x1Stride1Pad0
Definition convolution_backward_data_specialization.hpp:13
constexpr bool is_NGCHW_NGKHW()
Definition device_grouped_conv_utils.hpp:72
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__global__ void kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op, const index_t batch_count, const std::array< index_t, NumInputs > input_batch_strides, const std::array< index_t, NumOutputs > output_batch_strides)
Definition gridwise_elementwise_2d.hpp:221
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
bool is_tf32_supported()
Definition host_utility/device_prop.hpp:132
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op) -> decltype(std::accumulate(first, std::next(first, count), init, op))
Definition library/utility/numeric.hpp:11
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__global__ void kernel_elementwise_batched_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size, const index_t batch_count_a, const index_t batch_count_b, const std::array< index_t, NumInputsA > input_batch_strides_a, const std::array< index_t, NumInputsB > input_batch_strides_b, const std::array< index_t, NumOutputsA > output_batch_strides_a, const std::array< index_t, NumOutputsB > output_batch_strides_b)
Definition gridwise_elementwise_2d.hpp:117
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:261
Definition gridwise_elementwise_2d.hpp:278
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:411
Definition block_to_ctile_map.hpp:872
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:117
Definition functional2.hpp:33
Definition tensor_operation/gpu/device/tensor_layout.hpp:238
Definition tensor_operation/gpu/device/tensor_layout.hpp:243
Definition tensor_operation/gpu/device/tensor_layout.hpp:135
Definition tensor_operation/gpu/device/tensor_layout.hpp:362
Definition tensor_operation/gpu/device/tensor_layout.hpp:130
Definition tensor_operation/gpu/device/tensor_layout.hpp:357
Definition transform_conv_bwd_data_to_gemm_v1.hpp:44
__host__ __device__ auto MakeADescriptor_AK0_M_AK1() const
Definition transform_conv_bwd_data_to_gemm_v1.hpp:659
IndexType N_
Definition transform_conv_bwd_data_to_gemm_v1.hpp:1508
__host__ __device__ auto MakeBDescriptor_BK0_N_BK1() const
Definition transform_conv_bwd_data_to_gemm_v1.hpp:943
__host__ __device__ auto MakeCDescriptor_M_N() const
Definition transform_conv_bwd_data_to_gemm_v1.hpp:1150
Definition transform_conv_ngchw_to_nhwgc.hpp:31
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:666
std::size_t GetWorkspaceETensorSizeBytes() const
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1044
index_t conv_N_per_block_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1092
index_t num_workgroups_per_Conv_N_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1126
std::vector< DsGridDesc_M_N > ds_grid_desc_m_n_container_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1095
NGCHWTransposeDescType e_out_transpose_desc_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1103
const index_t k_batch_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1125
index_t gemms_count_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1128
std::array< index_t, NDimSpatial+3 > a_g_n_k_wos_lengths_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1118
EDataType * p_e_grid_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1088
Argument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOp &a_element_op, const BElementwiseOp &b_element_op, const CDEElementwiseOp &cde_element_op, ck::index_t split_k=1)
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:667
GKCYXTransposeDescType b_in_transpose_desc_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1105
long_index_t e_space_size_bytes
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1132
const ADataType * p_a_grid_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1085
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1121
std::array< index_t, NDimSpatial+3 > e_g_n_c_wis_lengths_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1120
std::vector< std::array< GemmArgs, MaxGroupedGemmGroupsNum > > gemm_kernel_args_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1129
std::array< index_t, NDimSpatial > input_left_pads_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1122
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1087
Block2TileMapInOutElementwise elementwise_block_2_ctile_map_transpose_a_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1099
std::vector< AGridDesc_M_K > a_grid_desc_m_k_container_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1093
const BDataType * p_b_grid_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1086
Block2TileMapInOutElementwise elementwise_block_2_ctile_map_transpose_e_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1100
CDEElementwiseOp cde_element_op_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1116
ComputePtrOffsetOfStridedBatch< I1, I1, NumDTensor > compute_ptr_offset_of_batch_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1109
bool bwd_needs_zero_out
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1131
NHWGCTransposeDescType e_in_transpose_desc_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1104
AElementwiseOp a_element_op_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1114
std::vector< BGridDesc_N_K > b_grid_desc_n_k_container_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1094
void Print() const
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1064
std::vector< index_t > gemms_grid_size_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1127
NHWGCTransposeDescType a_out_transpose_desc_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1104
index_t num_group_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1091
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1119
std::array< index_t, NDimSpatial > input_right_pads_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1123
ComputePtrOffsetOfStridedBatch< I1, I1, I0 > compute_ptr_offset_of_workspace_n_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1111
std::vector< EGridDesc_M_N > e_grid_desc_m_n_container_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1096
std::size_t GetWorkspaceSizeBytes() const
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1058
std::size_t GetWorkspaceBTensorSizeBytes() const
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1029
NGCHWTransposeDescType a_in_transpose_desc_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1103
GKYXCTransposeDescType b_out_transpose_desc_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1106
std::size_t GetWorkspaceATensorSizeBytes() const
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1014
BElementwiseOp b_element_op_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1115
ComputePtrOffsetOfStridedBatch< I1, I1, I0 > compute_ptr_offset_of_n_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1110
Block2TileMapWeiElementwise elementwise_block_2_ctile_map_transpose_b_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1101
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:534
index_t BlockStart_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:572
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:567
index_t BlockEnd_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:572
bool HasMainKBlockLoop_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:573
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:568
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:564
GroupedGemmBlock2ETileMap block_2_ctile_map_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:571
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:565
GemmArgs(AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, GroupedGemmBlock2ETileMap block_2_ctile_map, index_t BlockStart, index_t BlockEnd, bool HasMainKBlockLoop)
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:536
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1137
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1302
float RunMultiDGemm(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1143
DeviceOp::Argument Argument
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1138
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1478
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1452
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:306
static constexpr auto NXdlPerWave32
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:324
static constexpr index_t ElementwiseBlocksize
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:609
static constexpr index_t TransposeTransferOutScalarPerVectorAligned
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:593
static constexpr auto conv_ngchw_to_nhwgc_transformer
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:583
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1866
ADataType ABDataType
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:333
static auto MakeArgument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOp &a_element_op, const BElementwiseOp &b_element_op, const CDEElementwiseOp &cde_element_op, const ck::index_t split_k=1)
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1773
static constexpr index_t ClusterLengthNPerBlock
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:580
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKYXCTransposeDesc< NDimSpatial >({}, {}))> GKYXCTransposeDescType
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:605
std::conditional_t< is_NGCHW_NGKHW< ELayout, BLayout, ALayout >() &&NeedTransposeKernel, tensor_layout::convolution::NHWGC, std::conditional_t< is_NGCDHW_NGKDHW< ELayout, BLayout, ALayout >() &&NeedTransposeKernel, tensor_layout::convolution::NDHWGC, ELayout > > ELayoutAfterTranspose
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:368
remove_cvref_t< tuple_element_t< 3, ABDsEGridDesc > > EGridDesc_M_N
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:516
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1907
remove_cvref_t< tuple_element_t< 2, ABDsEGridDesc > > DsGridDesc_M_N
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:515
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNGCHWTransposeDesc< NDimSpatial >({}, {}))> NGCHWTransposeDescType
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:596
decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})) AGridDesc_M_K
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:518
static constexpr ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:510
static constexpr auto I3
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:338
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1767
static constexpr bool NeedTransposeKernel
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:347
GridwiseGemmMultipleD_xdl_cshuffle< GridwiseGemmCTransposeTemplateParameters > GridwiseGemmCTransposeBase
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:475
std::conditional_t< is_NGCHW_NGKHW< ELayout, BLayout, ALayout >() &&NeedTransposeKernel, tensor_layout::convolution::NHWGK, std::conditional_t< is_NGCDHW_NGKDHW< ELayout, BLayout, ALayout >() &&NeedTransposeKernel, tensor_layout::convolution::NDHWGK, ALayout > > ALayoutAfterTranspose
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:355
static constexpr auto I1
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:336
static constexpr auto I0
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:335
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1485
decltype(GridwiseGemmCTranspose64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{})) DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:521
GridwiseElementwise< Tuple< NHWGCTransposeDescType >, Tuple< NGCHWTransposeDescType >, Tuple< const EDataType * >, Tuple< EDataType * >, Block2TileMapInOutElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, MPerBlock, NPerBlock/ClusterLengthNPerBlock, MPerBlock/ClusterLengthMPerBlock, Sequence< 1, 0 >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, Sequence< TransposeTransferOutScalarPerVectorAligned >, I0, I1 > GridwiseElementwiseOutputTranspose
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:647
decltype(GridwiseGemmCTranspose64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{})) Block2ETileMap
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:528
static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1 &desc_k0_m_k1)
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:496
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:477
BlockToCTileMap_M00_N0_M01Adapt< MPerBlock, NPerBlock > Block2TileMapWeiElementwise
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:576
std::conditional_t< is_NGCHW_GKCYX_NGKHW< ELayout, BLayout, ALayout >() &&NeedTransposeKernel, tensor_layout::convolution::GKYXC, std::conditional_t< is_NGCDHW_GKCZYX_NGKDHW< ELayout, BLayout, ALayout >() && NeedTransposeKernel, tensor_layout::convolution::GKZYXC, BLayout > > BLayoutAfterTranspose
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:361
remove_cvref_t< tuple_element_t< 0, ABDsEGridDesc > > AGridDesc_AK0_M_AK1
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:513
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 DeviceOp
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:321
static auto GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform &conv_to_gemm_transform)
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:395
static constexpr index_t ClusterLengthMPerBlock
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:578
static auto MakeInvoker()
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1818
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:323
static constexpr bool isATensorColMajor
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:340
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOp &a_element_op, const BElementwiseOp &b_element_op, const CDEElementwiseOp &cde_element_op, const ck::index_t split_k=1) override
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1820
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:478
std::string GetTypeString() const override
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1871
TransformConvBwdDataToGemm_v1< NDimSpatial, ConvBackwardDataSpecialization, AK1, BK1, MPerBlock, NPerBlock, KPerBlock, DoPadGemmM, DoPadGemmN, ALayoutAfterTranspose, BLayoutAfterTranspose, ELayoutAfterTranspose, true, ABDataType, EDataType, 1, index_t, CTranspose > ConvToGemmBwdDataTransform
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:375
decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})) EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:524
void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:1920
GridwiseElementwise< Tuple< NGCHWTransposeDescType >, Tuple< NHWGCTransposeDescType >, Tuple< const ADataType * >, Tuple< ADataType * >, Block2TileMapInOutElementwise, element_wise::PassThrough, ElementwiseBlocksize, NPerBlock, MPerBlock, NPerBlock/ClusterLengthNPerBlock, MPerBlock/ClusterLengthMPerBlock, Sequence< 1, 0 >, Sequence< TransposeTransferInScalarPerVectorAligned >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, I1, I0 > GridwiseElementwiseInputTranspose
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:611
GridwiseGemmMultipleD_xdl_cshuffle< GridwiseGemmMultiDTemplateParams > GridwiseGemmBase
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:473
OffsettedBlockToCTileMap< Block2ETileMap > GroupedGemmBlock2ETileMap
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:531
GridwiseElementwise< Tuple< GKCYXTransposeDescType >, Tuple< GKYXCTransposeDescType >, Tuple< const BDataType * >, Tuple< BDataType * >, Block2TileMapWeiElementwise, element_wise::PassThrough, ElementwiseBlocksize, MPerBlock, NPerBlock, MPerBlock/ClusterLengthMPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< 1 >, Sequence< CDEBlockTransferScalarPerVector_NPerBlock >, I0, I1 > GridwiseElementwiseWeightTranspose
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:629
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNHWGCTransposeDesc< NDimSpatial >({}, {}))> NHWGCTransposeDescType
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:599
static auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N e_grid_desc_m_n)
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:489
std::conditional_t< CTranspose, GridwiseGemmCTransposeBase< NXdlPerWave32 >, GridwiseGemm32 > GridwiseGemmCTranspose32
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:484
static constexpr index_t MaxGroupedGemmGroupsNum
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:315
static constexpr bool CTranspose
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:351
BlockToCTileMap_M00_N0_M01Adapt< NPerBlock, MPerBlock > Block2TileMapInOutElementwise
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:575
remove_cvref_t< tuple_element_t< 1, ABDsEGridDesc > > BGridDesc_BK0_N_BK1
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:514
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKCYXTransposeDesc< NDimSpatial >({}, {}))> GKCYXTransposeDescType
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:602
static constexpr index_t NumDTensor
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:326
std::conditional_t< CTranspose, GridwiseGemmCTransposeBase< math::max(NXdlPerWave64, 1)>, GridwiseGemm64 > GridwiseGemmCTranspose64
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:480
static constexpr auto I2
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:337
static constexpr index_t TransposeTransferInScalarPerVectorAligned
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:591
static constexpr bool IsSplitKSupported
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:328
decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform)) ABDsEGridDesc
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:511
decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})) BGridDesc_N_K
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:519
static constexpr GemmSpecialization GemmSpec
Definition device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp:327
Definition device_grouped_conv_bwd_data_multiple_d.hpp:36
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129