gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp Source File

gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp Source File#

Composable Kernel: gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp Source File
gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
19
20namespace ck {
21
22// GEMM:
23// input : A[M, K]
24// input : B[N, K]
25// input : D0[M, N], D1[M, N], ...
26// output : E[M, N]
27// output : F[M, N0], where N0 is number of blocks along N dimension
28// output : G[M, N0], where N0 is number of blocks along N dimension
29// C = a_op(A) * b_op(B)
30// E = cde_op(C, D0, D1, ...)
31// F, G = welford(E)
32// Assume:
33// D0, D1, ... and E have the same layout
34// Calculate mean & variance along N dimension for E
35template <typename ABDataType,
36 typename AccDataType,
37 typename CShuffleDataType,
38 typename DsDataType,
39 typename EMeanVarDataType,
40 typename AElementwiseOperation,
41 typename BElementwiseOperation,
42 typename CDEElementwiseOperation,
43 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
44 typename AGridDesc_M_K,
45 typename BGridDesc_N_K,
46 typename DsGridDesc_M_N,
47 typename EGridDesc_M_N,
48 typename MeanVarGridDesc_M_NBlock,
49 typename CountGridDesc_M_NBlock,
50 index_t NumGemmKPrefetchStage,
51 index_t BlockSize,
52 index_t MPerBlock,
53 index_t NPerBlock,
54 index_t KPerBlock,
55 index_t AK1Value,
56 index_t BK1Value,
57 index_t MPerXdl,
58 index_t NPerXdl,
59 index_t MXdlPerWave,
60 index_t NXdlPerWave,
61 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
62 typename ABlockTransferThreadClusterArrangeOrder,
63 typename ABlockTransferSrcAccessOrder,
64 index_t ABlockTransferSrcVectorDim,
65 index_t ABlockTransferSrcScalarPerVector,
66 index_t ABlockTransferDstScalarPerVector_AK1,
67 bool AThreadTransferSrcResetCoordinateAfterRun,
68 index_t ABlockLdsExtraM,
69 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
70 typename BBlockTransferThreadClusterArrangeOrder,
71 typename BBlockTransferSrcAccessOrder,
72 index_t BBlockTransferSrcVectorDim,
73 index_t BBlockTransferSrcScalarPerVector,
74 index_t BBlockTransferDstScalarPerVector_BK1,
75 bool BThreadTransferSrcResetCoordinateAfterRun,
76 index_t BBlockLdsExtraN,
77 index_t CShuffleMXdlPerWavePerShuffle,
78 index_t CShuffleNXdlPerWavePerShuffle,
79 typename PostShuffleThreadClusterSize_M_N,
80 index_t PostShuffleScalarPerVector,
81 LoopScheduler LoopSched,
84{
85 static constexpr index_t NumDTensor = DsDataType::Size();
86
87 static constexpr auto I0 = Number<0>{};
88 static constexpr auto I1 = Number<1>{};
89 static constexpr auto I2 = Number<2>{};
90 static constexpr auto I3 = Number<3>{};
91 static constexpr auto I4 = Number<4>{};
92 static constexpr auto I5 = Number<5>{};
93 static constexpr auto I6 = Number<6>{};
94 static constexpr auto I7 = Number<7>{};
95
96 // K1 should be Number<...>
97 static constexpr auto AK1 = Number<AK1Value>{};
98 static constexpr auto BK1 = Number<BK1Value>{};
99 static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
100 static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
101
103
106
107 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
108 {
109 // A matrix in LDS memory, dst of blockwise copy
113 }
114
115 __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
116 {
117 // B matrix in LDS memory, dst of blockwise copy
121 }
122
123 __host__ __device__ static constexpr auto
125 {
126 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
127 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
128
129 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
133 I1,
135
136 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
137 }
138
139 // ck::Tuple<const D0DataType*, const D1DataType*, ...>
140 static constexpr auto MakeDsGridPointer()
141 {
142 return generate_tuple(
143 [&](auto i) {
144 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
145
146 return static_cast<const DDataType*>(nullptr);
147 },
149 }
150
151 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
152 {
153 // LDS allocation for A and B: be careful of alignment
154 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
155 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
156
157 // lds max alignment
158 constexpr auto max_lds_align = math::lcm(AK1, BK1);
159
160 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
161 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
162
163 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
164 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
165
166 // LDS allocation for C shuffle in LDS
167 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
169
170 constexpr auto c_block_size =
171 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
172
173 return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
174 sizeof(ABDataType),
175 c_block_size * sizeof(CShuffleDataType));
176 }
177
178 // A desc for source in blockwise copy
179 __host__ __device__ static constexpr auto
180 MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
181 {
182 const auto M = a_grid_desc_m_k.GetLength(I0);
183 const auto K = a_grid_desc_m_k.GetLength(I1);
184
185 const auto AK0 = K / AK1;
186
187 return transform_tensor_descriptor(a_grid_desc_m_k,
192 }
193
194 // B desc for source in blockwise copy
195 __host__ __device__ static constexpr auto
196 MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
197 {
198 const auto N = b_grid_desc_n_k.GetLength(I0);
199 const auto K = b_grid_desc_n_k.GetLength(I1);
200
201 const auto BK0 = K / BK1;
202
203 return transform_tensor_descriptor(b_grid_desc_n_k,
208 }
209
210 // E desc for destination in blockwise copy
211 template <typename EGridDescriptor_M_N>
212 __host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
213 const EGridDescriptor_M_N& e_grid_desc_m_n)
214 {
215 const auto M = e_grid_desc_m_n.GetLength(I0);
216 const auto N = e_grid_desc_m_n.GetLength(I1);
217
218 const auto MBlock = M / MPerBlock;
219 const auto NBlock = N / NPerBlock;
220
221 const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
222 e_grid_desc_m_n,
227
228 return e_grid_desc_mblock_mperblock_nblock_nperblock;
229 }
230
231 // Ds desc for source in blockwise copy
232 template <typename DsGridDescriptor_M_N>
233 __host__ __device__ static constexpr auto
235 const DsGridDescriptor_M_N& ds_grid_desc_m_n)
236 {
237 return generate_tuple(
238 [&](auto i) {
240 },
242 }
243
244 template <typename GridDescriptor_M_N>
245 __host__ __device__ static constexpr auto
246 MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
247 {
248 const auto M = grid_desc_m_n.GetLength(I0);
249 const auto NBlock = grid_desc_m_n.GetLength(I1);
250 const auto MBlock = M / MPerBlock;
251
252 const auto grid_desc_mblock_mperblock_nblock = transform_tensor_descriptor(
253 grid_desc_m_n,
258
259 return grid_desc_mblock_mperblock_nblock;
260 }
261
262 // return block_id to E matrix tile idx (m0, n0) mapping
263 __host__ __device__ static constexpr auto
264 MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
265 {
267 e_grid_desc_m_n);
268 }
269
271
272 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
273 template <typename Block2ETileMap>
274 __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
275 const BGridDesc_N_K& b_grid_desc_n_k,
276 const DsGridDesc_M_N& ds_grid_desc_m_n,
277 const EGridDesc_M_N& e_grid_desc_m_n,
278 const Block2ETileMap& block_2_etile_map)
279 {
280 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
281 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
282 "Invalid tuning param!");
283
284 const auto M = a_grid_desc_m_k.GetLength(I0);
285 const auto N = b_grid_desc_n_k.GetLength(I0);
286 const auto K = a_grid_desc_m_k.GetLength(I1);
287
288 // check consistency of desc
289 if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
290 {
291 return false;
292 }
293
294 bool valid = true;
295
296 static_for<0, NumDTensor, 1>{}([&](auto i) {
297 valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
298 N == ds_grid_desc_m_n[i].GetLength(I1));
299 });
300
301 if(!valid)
302 {
303 return false;
304 }
305
306 // check tile size
307 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
308 {
309 return false;
310 }
311
312 // check gridwise gemm pipeline
313 const auto num_k_loop = K / KPerBlock;
314
315 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
316 {
317 return false;
318 }
319
320 // check block-to-E-tile
321 if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
322 {
323 return false;
324 }
325
326 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
327 // check tensor size: cannot be larger than 2GB each
328 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
329
330 if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
331 b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
332 e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EMeanVarDataType) <= TwoGB))
333 {
334 return false;
335 }
336
337 return true;
338 }
339
340 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
341 {
342 const index_t num_loop = K / KPerBlock;
343
344 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
345 }
346
348 remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
350 remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
353 EGridDesc_M_N{}))>;
356 MeanVarGridDesc_M_NBlock{}))>;
359 CountGridDesc_M_NBlock{}))>;
362 DsGridDesc_M_N{}))>;
363
365 remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
366
367 using DsGridPointer = decltype(MakeDsGridPointer());
368
369 template <bool HasMainKBlockLoop,
370 typename AGridDesc_AK0_M_AK1,
371 typename BGridDesc_BK0_N_BK1,
372 typename Block2ETileMap>
373 __device__ static void
374 Run(const ABDataType* __restrict__ p_a_grid,
375 const ABDataType* __restrict__ p_b_grid,
376 DsGridPointer p_ds_grid,
377 EMeanVarDataType* __restrict__ p_e_grid,
378 EMeanVarDataType* __restrict__ p_welford_mean_grid,
379 EMeanVarDataType* __restrict__ p_welford_var_grid,
380 int32_t* __restrict__ p_welford_count,
381 void* __restrict__ p_shared,
382 const AElementwiseOperation& a_element_op,
383 const BElementwiseOperation& b_element_op,
384 const CDEElementwiseOperation& cde_element_op,
385 const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
386 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
388 ds_grid_desc_mblock_mperblock_nblock_nperblock,
390 e_grid_desc_mblock_mperblock_nblock_nperblock,
392 mean_var_grid_desc_mblock_mperblock_nblock,
393 const CountGridDescriptor_MBlock_MPerBlock_NBlock& count_grid_desc_mblock_mperblock_nblock,
394 const Block2ETileMap& block_2_etile_map,
395 index_t NRaw)
396 {
397 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
398 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
399
400 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
401 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
402
403 const auto ds_grid_buf = generate_tuple(
404 [&](auto i) {
406 p_ds_grid[i],
407 ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
408 },
410
412 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
413
415 p_welford_mean_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
416
418 p_welford_var_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
419
420 auto welford_count_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
421 p_welford_count, count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
422
423 // divide block work by [M, N]
424 const auto block_work_idx =
425 block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
426
427 if(!block_2_etile_map.ValidCTileIndex(
428 block_work_idx,
429 make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
430 e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
431 {
432 return;
433 }
434
435 // HACK: this force m/n_block_data_idx_on_grid into SGPR
436 const index_t m_block_data_idx_on_grid =
437 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
438
439 const index_t n_block_data_idx_on_grid =
440 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
441
442 // lds max alignment
443 constexpr auto max_lds_align = math::lcm(AK1, BK1);
444
445 // A matrix in LDS memory, dst of blockwise copy
446 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
447
448 // B matrix in LDS memory, dst of blockwise copy
449 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
450
451 // A matrix blockwise copy
452 auto a_blockwise_copy =
454 AElementwiseOperation,
458 ABlockTransferThreadClusterLengths_AK0_M_AK1,
459 ABlockTransferThreadClusterArrangeOrder,
460 ABDataType,
461 ABDataType,
462 decltype(a_grid_desc_ak0_m_ak1),
463 decltype(a_block_desc_ak0_m_ak1),
464 ABlockTransferSrcAccessOrder,
466 ABlockTransferSrcVectorDim,
467 2,
468 ABlockTransferSrcScalarPerVector,
469 ABlockTransferDstScalarPerVector_AK1,
470 1,
471 1,
472 AThreadTransferSrcResetCoordinateAfterRun,
473 true,
474 NumGemmKPrefetchStage>(
475 a_grid_desc_ak0_m_ak1,
476 make_multi_index(0, m_block_data_idx_on_grid, 0),
477 a_element_op,
478 a_block_desc_ak0_m_ak1,
479 make_multi_index(0, 0, 0),
481
482 // B matrix blockwise copy
483 auto b_blockwise_copy =
485 BElementwiseOperation,
489 BBlockTransferThreadClusterLengths_BK0_N_BK1,
490 BBlockTransferThreadClusterArrangeOrder,
491 ABDataType,
492 ABDataType,
493 decltype(b_grid_desc_bk0_n_bk1),
494 decltype(b_block_desc_bk0_n_bk1),
495 BBlockTransferSrcAccessOrder,
497 BBlockTransferSrcVectorDim,
498 2,
499 BBlockTransferSrcScalarPerVector,
500 BBlockTransferDstScalarPerVector_BK1,
501 1,
502 1,
503 BThreadTransferSrcResetCoordinateAfterRun,
504 true,
505 NumGemmKPrefetchStage>(
506 b_grid_desc_bk0_n_bk1,
507 make_multi_index(0, n_block_data_idx_on_grid, 0),
508 b_element_op,
509 b_block_desc_bk0_n_bk1,
510 make_multi_index(0, 0, 0),
512
513 // GEMM definition
514 // c_mtx += transpose(a_mtx) * b_mtx
515 // a_mtx[K0PerBlock, MPerBlock] is in LDS
516 // b_mtx[K0PerBlock, NPerBlock] is in LDS
517 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
518 // register
519 // sanity check
520 constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
521 constexpr bool is_single_rate_mfma =
523 lcm_AK1_BK1 <= 4) ||
524 (is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
526 lcm_AK1_BK1 < 32))
527 ? true
528 : false;
529 constexpr auto is_scale_mfma = false;
530 constexpr index_t KPack = math::max(lcm_AK1_BK1,
531 MfmaSelector<ABDataType,
532 MPerXdl,
533 NPerXdl,
534 ABDataType,
535 is_single_rate_mfma,
536 is_scale_mfma>::selected_mfma.k_per_blk);
537
539 BlockSize,
540 ABDataType,
541 ABDataType,
542 AccDataType,
543 decltype(a_block_desc_ak0_m_ak1),
544 decltype(b_block_desc_bk0_n_bk1),
545 MPerXdl,
546 NPerXdl,
547 MXdlPerWave,
548 NXdlPerWave,
549 KPack,
550 LoopSched>();
551
552 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
553
554 // LDS allocation for A and B: be careful of alignment
555 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
556 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
557
559 static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
560
562 static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
563 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
564
565 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
566 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
567
568 // gridwise GEMM pipeline
569 const auto gridwise_gemm_pipeline =
571
572 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
573 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
574 KPerBlock);
575
576 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
577 a_block_desc_ak0_m_ak1,
578 a_blockwise_copy,
579 a_grid_buf,
580 a_block_buf,
581 a_block_slice_copy_step,
582 b_grid_desc_bk0_n_bk1,
583 b_block_desc_bk0_n_bk1,
584 b_blockwise_copy,
585 b_grid_buf,
586 b_block_buf,
587 b_block_slice_copy_step,
588 blockwise_gemm,
589 c_thread_buf,
590 num_k_block_main_loop);
591
592 // shuffle C, Welford and write out
593 {
594 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
595 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
596 "wrong!");
597
598 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
599 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
600
601 // TODO: hacky, fix it!
602 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
603 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
604
605 // TODO: hacky, fix it!
606 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
607 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
608 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
609
610 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
611 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
612 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
613 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
614 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
615 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
616 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
617 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
618
619 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
621
622 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
623 static_cast<CShuffleDataType*>(p_shared),
624 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
625
626 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
627 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
631 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
632 M1, // M1 = MWave
633 M2, // M2 * M3 * M4 = MPerXdl
634 M3,
635 M4)),
638 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
639 N1, // N1 = NWave
640 N2))), // N2 = NPerXdl
644
645 // calculate origin of thread output tensor on global memory
646 // blockwise GEMM c matrix starting index
647 const auto c_thread_mtx_on_block =
648 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
649
650 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
651 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
652
653 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
655 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
658
659 const auto m_thread_data_on_block_idx =
660 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
661 make_multi_index(m_thread_data_on_block));
662
663 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
668
669 const auto n_thread_data_on_block_idx =
670 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
671 make_multi_index(n_thread_data_on_block));
672
673 // shuffle: threadwise copy C from VGPR to LDS
674 auto c_thread_copy_vgpr_to_lds =
676 CShuffleDataType,
677 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
678 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
680 Sequence<CShuffleMXdlPerWavePerShuffle,
681 CShuffleNXdlPerWavePerShuffle,
682 I1,
683 I1,
684 M2,
685 I1,
686 M4,
687 I1>,
689 7,
690 1,
692 1,
693 true>{
694 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
696 0,
697 m_thread_data_on_block_idx[I1],
698 n_thread_data_on_block_idx[I1],
699 m_thread_data_on_block_idx[I2],
700 m_thread_data_on_block_idx[I3],
701 m_thread_data_on_block_idx[I4],
702 n_thread_data_on_block_idx[I2]),
704
705 // space filling curve for threadwise C in VGPR
706 constexpr auto sfc_c_vgpr =
709 Sequence<CShuffleMXdlPerWavePerShuffle,
710 CShuffleNXdlPerWavePerShuffle,
711 1,
712 1,
713 M2,
714 1,
715 M4,
716 1>,
717 false>{};
718
719 // space filling curve for shuffled blockwise C in global mem
720 constexpr auto sfc_der_global =
723 Sequence<1,
724 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
725 1,
726 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
727 false>{};
728
729 // LDS c_shuffle_block_desc_mperblock_nperblock
730 constexpr auto c_shuffle_block_desc_mperblock_nperblock = transform_tensor_descriptor(
731 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
735 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)),
738 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))),
741
742 static_assert(PostShuffleThreadClusterSize_M_N::At(I0) *
743 PostShuffleThreadClusterSize_M_N::At(I1) ==
744 BlockSize,
745 "wrong!");
746
747 static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) %
748 PostShuffleThreadClusterSize_M_N::At(I0) ==
749 0 &&
750 (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) %
751 PostShuffleThreadClusterSize_M_N::At(I1) ==
752 0,
753 "wrong!");
754
755 constexpr index_t PostShuffleThreadSliceSize_M =
756 (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) /
757 PostShuffleThreadClusterSize_M_N::At(I0);
758
759 constexpr index_t PostShuffleThreadSliceSize_N =
760 (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) /
761 PostShuffleThreadClusterSize_M_N::At(I1);
762
763 constexpr auto PostShuffleThreadSliceSize_M_N =
765
766 // VGPR post_shuffle_thread_desc_m_n
767 constexpr auto post_shuffle_thread_desc_m_n = make_naive_tensor_descriptor_packed(
770
772 post_shuffle_thread_desc_m_n.GetElementSpaceSize());
773
774 // To apply D0, D1, ... and Welford.
775 // threadwise copy from LDS to VGPR
776 constexpr auto post_shuffle_thread_cluster_desc =
777 make_cluster_descriptor(PostShuffleThreadClusterSize_M_N{}, Sequence<0, 1>{});
778
779 const auto post_shuffle_thread_cluster_idx =
780 post_shuffle_thread_cluster_desc.CalculateBottomIndex(
782
783 const auto post_shuffle_thread_data_idx_begin =
784 post_shuffle_thread_cluster_idx * PostShuffleThreadSliceSize_M_N;
785
786 // To apply D0, D1, ... and Welford.
787 // Copy c shuffle from LDS back to VGPR
788 auto post_shuffle_thread_copy_lds_to_vgpr =
789 ThreadwiseTensorSliceTransfer_v2<CShuffleDataType,
790 AccDataType,
791 decltype(c_shuffle_block_desc_mperblock_nperblock),
792 decltype(post_shuffle_thread_desc_m_n),
793 decltype(PostShuffleThreadSliceSize_M_N),
795 1,
796 PostShuffleScalarPerVector,
797 1,
798 true>{c_shuffle_block_desc_mperblock_nperblock,
799 post_shuffle_thread_data_idx_begin};
800
801 // D0, D1, ..., Dn
802 constexpr auto post_shuffle_thread_desc_I1_mperblock_I1_nperblock =
806 I1,
808
809 // FIXME: Decrease usage of VGPR
810 // Apply pointwise lambda function from multi-source (Global and LDS) into VGPR
811 auto ds_thread_buf = generate_tuple(
812 [&](auto) {
814 post_shuffle_thread_desc_I1_mperblock_I1_nperblock.GetElementSpaceSize());
815 },
817
818 // Copy D0, D1, ..., Dn from global to VGPR
819 auto ds_thread_copy_global_to_vgpr = generate_tuple(
820 [&](auto I) {
821 using DDataType = remove_cvref_t<tuple_element_t<I.value, DsDataType>>;
823 DDataType,
824 AccDataType,
825 decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]),
826 decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock),
827 Sequence<I1,
828 PostShuffleThreadSliceSize_M,
829 I1,
830 PostShuffleThreadSliceSize_N>,
832 3,
833 PostShuffleScalarPerVector,
834 1,
835 true>(
836 ds_grid_desc_mblock_mperblock_nblock_nperblock[I],
838 I0,
839 m_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I0],
840 I0,
841 n_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I1]));
842 },
844
845 auto e_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
846 AccDataType,
847 EMeanVarDataType,
848 decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock),
849 decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
851 Sequence<I1,
852 PostShuffleThreadSliceSize_M,
853 I1,
854 PostShuffleThreadSliceSize_N>, // SliceLengths
855 Sequence<0, 1, 2, 3>, // DimAccessOrder
856 3, // DstVectorDim
857 PostShuffleScalarPerVector,
859 1,
860 true>{
861 e_grid_desc_mblock_mperblock_nblock_nperblock,
863 m_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I0],
864 I0,
865 n_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I1]),
867
868 // Welford
869 constexpr auto thread_welford_src_desc_m_k = make_naive_tensor_descriptor_packed(
872
873 constexpr auto thread_welford_dst_desc_m = make_naive_tensor_descriptor_packed(
875
876 using ThreadwiseWelford = ThreadwiseWelford<AccDataType,
877 decltype(thread_welford_src_desc_m_k),
878 decltype(thread_welford_dst_desc_m)>;
879
880 using BlockwiseWelford = BlockwiseWelford<AccDataType,
881 BlockSize,
882 PostShuffleThreadClusterSize_M_N,
884 false>;
885
886 constexpr int num_shuffleM =
887 MPerBlock / (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl);
888
889 constexpr int num_shuffleN =
890 NPerBlock / (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl);
891
892 using mean_var_vgpr_type =
894 thread_welford_dst_desc_m.GetElementSpaceSize()));
895
896 using welford_count_vgpr_type =
898 thread_welford_dst_desc_m.GetElementSpaceSize()));
899
900 Array<ThreadwiseWelford, num_shuffleM> threadwise_welfords;
903 Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs;
904
905 int max_count = PostShuffleThreadSliceSize_N * num_shuffleN;
906 const auto nblock = mean_var_grid_desc_mblock_mperblock_nblock.GetLength(I2);
907
908 // tail block
909 if(block_work_idx[I1] % nblock == nblock - 1)
910 {
911 constexpr index_t NPerShuffleBlock =
912 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl;
913
914 int NPerBlockTail = NRaw - NPerBlock * (nblock - 1);
915 int thread_max_len =
916 PostShuffleThreadSliceSize_N * (post_shuffle_thread_cluster_idx[I1] + 1);
917 int shuffle_step = 0;
918 while(thread_max_len <= NPerBlockTail && shuffle_step < num_shuffleN)
919 {
920 ++shuffle_step;
921 thread_max_len += NPerShuffleBlock;
922 }
923
924 int delta = 0;
925 if(thread_max_len - NPerBlockTail > PostShuffleThreadSliceSize_N)
926 delta = 0;
927 else if(NPerBlockTail > thread_max_len)
928 delta = PostShuffleThreadSliceSize_N;
929 else
930 delta = PostShuffleThreadSliceSize_N - thread_max_len + NPerBlockTail;
931
932 max_count = shuffle_step * PostShuffleThreadSliceSize_N + delta;
933 }
934
935 static_for<0, num_shuffleM, 1>{}([&](auto i) {
936 threadwise_welfords(i).max_count_ = max_count;
938 thread_welford_dst_desc_m.GetElementSpaceSize());
939
941 thread_welford_dst_desc_m.GetElementSpaceSize());
942
943 welford_count_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
944 thread_welford_dst_desc_m.GetElementSpaceSize());
945
947 mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
948 var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
949 welford_count_thread_bufs(i)(j) = 0;
950 });
951 });
952
953 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
954
955 static_assert(num_access == sfc_der_global.GetNumOfAccess(), "wrong!");
956
957 int shuffleM_index = __builtin_amdgcn_readfirstlane(0);
958 static_for<0, num_access, 1>{}([&](auto access_id) {
959 // make sure it's safe to read from LDS
961
962 // each thread shuffle data from VGPR to LDS
963 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
964 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
965 c_thread_buf,
966 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
967 c_shuffle_block_buf);
968
969 // make sure it's safe to write to LDS
971
972 // Get shuffle data from LDS to VGPR
973 post_shuffle_thread_copy_lds_to_vgpr.Run(c_shuffle_block_desc_mperblock_nperblock,
974 c_shuffle_block_buf,
975 post_shuffle_thread_desc_m_n,
976 make_tuple(I0, I0),
977 e_thread_buf);
978
979 // Global read D0, D1, ...
980 static_for<0, NumDTensor, 1>{}([&](auto Id) {
981 auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(Id);
982 d_thread_copy_global_to_vgpr.Run(
983 ds_grid_desc_mblock_mperblock_nblock_nperblock[Id],
984 ds_grid_buf[Id],
985 post_shuffle_thread_desc_I1_mperblock_I1_nperblock,
986 make_tuple(I0, I0, I0, I0),
987 ds_thread_buf(Id));
988
989 if constexpr(access_id < num_access - 1)
990 {
991 // move on D0, D1, ...
992 constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id);
993 d_thread_copy_global_to_vgpr.MoveSrcSliceWindow(
994 ds_grid_desc_mblock_mperblock_nblock_nperblock[Id], de_global_step);
995 }
996 });
997
998 // cde_element_op(e, c, d0, d1, ...);
999 static_for<0, post_shuffle_thread_desc_m_n.GetElementSize(), 1>{}([&](auto i) {
1000 const auto c_ds_src_data_refs = concat_tuple_of_reference(
1001 tie(e_thread_buf[i]),
1002 generate_tie([&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; },
1004 auto e_dst_data_refs = tie(e_thread_buf(i));
1005 unpack2(cde_element_op, e_dst_data_refs, c_ds_src_data_refs);
1006 });
1007
1008 // Global write E
1009 e_thread_copy_vgpr_to_global.Run(post_shuffle_thread_desc_I1_mperblock_I1_nperblock,
1010 make_tuple(I0, I0, I0, I0),
1011 e_thread_buf,
1012 e_grid_desc_mblock_mperblock_nblock_nperblock,
1013 e_grid_buf);
1014
1015 if constexpr(access_id < num_access - 1)
1016 {
1017 // move on E
1018 constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id);
1019 e_thread_copy_vgpr_to_global.MoveDstSliceWindow(
1020 e_grid_desc_mblock_mperblock_nblock_nperblock, de_global_step);
1021 }
1022
1023 // Threadwise welford
1024 auto& threadwise_welford = threadwise_welfords(shuffleM_index);
1025 auto& mean_thread_buf = mean_thread_bufs(shuffleM_index);
1026 auto& var_thread_buf = var_thread_bufs(shuffleM_index);
1027
1028 threadwise_welford.Run(e_thread_buf, mean_thread_buf, var_thread_buf);
1029
1030 if constexpr(access_id < num_access - 1)
1031 {
1032 constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id);
1033 constexpr int shuffleMInc =
1034 de_global_step[I1] /
1035 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
1036 shuffleM_index = __builtin_amdgcn_readfirstlane(shuffleM_index + shuffleMInc);
1037 }
1038 }); // copy c, d, e + welford
1039
1040 // Blockwise welford and write out
1041 static_for<0, num_shuffleM, 1>{}([&](auto i) {
1042 auto& mean_thread_buf = mean_thread_bufs(i);
1043 auto& var_thread_buf = var_thread_bufs(i);
1044 auto& count_thread_buf = welford_count_thread_bufs(i);
1045
1048 count_thread_buf(j) = threadwise_welfords(i).cur_count_;
1050 mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j));
1051 });
1052
1053 if(post_shuffle_thread_cluster_idx[I1] == 0)
1054 {
1055 constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed(
1057
1058 constexpr int shuffleMPerBlock =
1059 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1);
1060
1061 auto mean_var_count_thread_copy_index = make_multi_index(
1062 block_work_idx[I0], // mblock
1063 shuffleMPerBlock * i + post_shuffle_thread_data_idx_begin[I0], // mperblock
1064 block_work_idx[I1]); // nblock
1065
1066 auto mean_var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
1067 AccDataType,
1068 EMeanVarDataType,
1069 decltype(thread_welford_desc_I_m_I),
1070 decltype(mean_var_grid_desc_mblock_mperblock_nblock),
1074 1,
1075 1,
1077 1,
1078 true>{mean_var_grid_desc_mblock_mperblock_nblock,
1079 mean_var_count_thread_copy_index,
1081
1082 mean_var_thread_copy_vgpr_to_global.Run(
1083 thread_welford_desc_I_m_I,
1084 make_tuple(I0, I0, I0),
1085 mean_thread_buf,
1086 mean_var_grid_desc_mblock_mperblock_nblock,
1087 mean_grid_buf); // write mean
1088
1089 mean_var_thread_copy_vgpr_to_global.Run(
1090 thread_welford_desc_I_m_I,
1091 make_tuple(I0, I0, I0),
1092 var_thread_buf,
1093 mean_var_grid_desc_mblock_mperblock_nblock,
1094 var_grid_buf); // write variance
1095
1096 // Stride of count is [0, 1]. Only the first row in count[0, 0:nblock] need
1097 // to be written.
1098 if(i == 0 && block_work_idx[I0] == 0 &&
1099 post_shuffle_thread_cluster_idx[I0] == 0)
1100 {
1101 auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
1102 int32_t,
1103 int32_t,
1104 decltype(thread_welford_desc_I_m_I),
1105 decltype(count_grid_desc_mblock_mperblock_nblock),
1109 1,
1110 1,
1112 1,
1113 false>{count_grid_desc_mblock_mperblock_nblock,
1114 mean_var_count_thread_copy_index,
1116
1117 count_thread_copy_vgpr_to_global.Run(
1118 thread_welford_desc_I_m_I,
1119 make_tuple(I0, I0, I0),
1120 count_thread_buf,
1121 count_grid_desc_mblock_mperblock_nblock,
1122 welford_count_grid_buf); // write count
1123 }
1124 }
1125 });
1126
1127 } // shuffle C + Ds + welford + write out
1128 } // run
1129};
1130
1131} // namespace ck
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition blockwise_gemm_xdlops.hpp:620
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
signed int int32_t
Definition stdint.h:123
Definition utility/array.hpp:14
Definition block_to_ctile_map.hpp:261
Definition blockwise_welford.hpp:25
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp:84
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const ADataType *__restrict__ p_b_grid, DsGridPointer p_ds_grid, EMeanVarDataType *__restrict__ p_e_grid, EMeanVarDataType *__restrict__ p_welford_mean_grid, EMeanVarDataType *__restrict__ p_welford_var_grid, int32_t *__restrict__ p_welford_count, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock &mean_var_grid_desc_mblock_mperblock_nblock, const CountGridDescriptor_MBlock_MPerBlock_NBlock &count_grid_desc_mblock_mperblock_nblock, const Block2ETileMap &block_2_etile_map, index_t NRaw)
Definition gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp:374
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition threadwise_welford.hpp:18
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340