gridwise_gemm_xdlops_v2r4r2.hpp Source File

gridwise_gemm_xdlops_v2r4r2.hpp Source File#

Composable Kernel: gridwise_gemm_xdlops_v2r4r2.hpp Source File
gridwise_gemm_xdlops_v2r4r2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck/utility/env.hpp"
19
20namespace ck {
21
22template <typename GridwiseGemm,
23 bool HasMainKBlockLoop,
24 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
25 typename Block2CTileMap,
26 typename AElementwiseOperation,
27 typename BElementwiseOperation,
28 typename CElementwiseOperation>
29__global__ void
30#if CK_USE_LAUNCH_BOUNDS
32#endif
33 kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg,
34 const Block2CTileMap& b2c_map,
35 const AElementwiseOperation a_element_op,
36 const BElementwiseOperation b_element_op,
37 const CElementwiseOperation c_element_op)
38{
39#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx12__)
40 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
41 {
42 constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
43
44 __shared__ uint8_t p_shared[shared_size];
45
46 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
47 karg, static_cast<void*>(p_shared), b2c_map, a_element_op, b_element_op, c_element_op);
48 }
49#else
50 ignore = karg;
51 ignore = b2c_map;
52 ignore = a_element_op;
53 ignore = b_element_op;
54 ignore = c_element_op;
55#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
56}
57
58template <index_t BlockSize,
59 typename FloatA,
60 typename FloatB,
61 typename FloatAcc,
62 typename FloatC,
63 typename ALayout,
64 typename BLayout,
65 typename CLayout,
66 typename AElementwiseOperation,
67 typename BElementwiseOperation,
68 typename CElementwiseOperation,
70 index_t NumGemmKPrefetchStage,
71 index_t MPerBlock,
72 index_t NPerBlock,
73 index_t K0PerBlock,
74 index_t MPerXdl,
75 index_t NPerXdl,
76 index_t K1Value,
77 index_t MRepeat,
78 index_t NRepeat,
79 typename ABlockTransferThreadClusterLengths_K0_M_K1,
80 typename ABlockTransferThreadClusterArrangeOrder,
81 typename ABlockTransferSrcAccessOrder,
82 index_t ABlockTransferSrcVectorDim,
83 index_t ABlockTransferSrcScalarPerVector,
84 index_t ABlockTransferDstScalarPerVector_K1,
85 bool AThreadTransferSrcResetCoordinateAfterRun,
86 bool ABlockLdsExtraM,
87 typename BBlockTransferThreadClusterLengths_K0_N_K1,
88 typename BBlockTransferThreadClusterArrangeOrder,
89 typename BBlockTransferSrcAccessOrder,
90 index_t BBlockTransferSrcVectorDim,
91 index_t BBlockTransferSrcScalarPerVector,
92 index_t BBlockTransferDstScalarPerVector_K1,
93 bool BThreadTransferSrcResetCoordinateAfterRun,
94 bool BBlockLdsExtraN,
95 index_t CShuffleMRepeatPerShuffle,
96 index_t CShuffleNRepeatPerShuffle,
97 index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
98 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
101 typename ComputeTypeA = FloatC,
102 typename ComputeTypeB = ComputeTypeA,
103 typename LDSTypeA = ComputeTypeA,
104 typename LDSTypeB = ComputeTypeB>
106{
107 static constexpr auto I0 = Number<0>{};
108 static constexpr auto I1 = Number<1>{};
109 static constexpr auto I2 = Number<2>{};
110 static constexpr auto I3 = Number<3>{};
111 static constexpr auto I4 = Number<4>{};
112 static constexpr auto I5 = Number<5>{};
113 static constexpr auto I6 = Number<6>{};
114 static constexpr auto I7 = Number<7>{};
115
116 // K1 should be Number<...>
117 static constexpr auto K1 = Number<K1Value>{};
118 static constexpr auto M01 = 1;
119 static constexpr auto N01 = 1;
120
121 static constexpr auto gemm_padder =
123 MPerBlock, NPerBlock, K1* K0PerBlock};
124
126
129
131 {
132 const FloatA* p_a_grid;
133 const FloatB* p_b_grid;
134 FloatC* p_c_grid;
146
147 Argument(const FloatA* p_a_grid_,
148 const FloatB* p_b_grid_,
149 FloatC* p_c_grid_,
150 index_t M_,
151 index_t N_,
152 index_t K_,
153 index_t StrideA_,
154 index_t StrideB_,
155 index_t StrideC_,
156 index_t MPadded_,
157 index_t NPadded_,
158 index_t KPadded_,
159 index_t K0Padded_,
160 index_t k_batch_)
161 : p_a_grid(p_a_grid_),
162 p_b_grid(p_b_grid_),
163 p_c_grid(p_c_grid_),
164 M(M_),
165 N(N_),
166 K(K_),
167 StrideA(StrideA_),
168 StrideB(StrideB_),
169 StrideC(StrideC_),
170 MPadded(MPadded_),
171 NPadded(NPadded_),
172 KPadded(KPadded_),
173 K0Padded(K0Padded_),
174 k_batch(k_batch_)
175 {
176 }
177
178 void Print() const
179 {
180 std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
181 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
182 << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
183 << "KP:" << KPadded << ", " << "K0Padded:" << K0Padded << ", "
184 << "KB:" << k_batch << "}" << std::endl;
185 }
186 };
187
188 __host__ __device__ static auto CalculateGridSize(const Argument& karg)
189 {
190 return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock),
191 math::integer_divide_ceil(karg.M, MPerBlock),
192 karg.k_batch);
193 }
194
195 // prefer this to be called on host
196 __host__ __device__ static auto CalculateMPadded(index_t M)
197 {
198 return math::integer_least_multiple(M, MPerBlock);
199 }
200
201 __host__ __device__ static auto CalculateNPadded(index_t N)
202 {
203 return math::integer_least_multiple(N, NPerBlock);
204 }
205
206 __host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1)
207 {
208 // k_batch * k0 * k0_per_block * k1
209 auto K_t = K_Batch * K0PerBlock * K1;
210 return (K + K_t - 1) / K_t * K0PerBlock;
211 }
212
213 __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
214 {
215 auto K0Padded = CalculateK0Padded(K, K_Batch);
216 return K_Batch * K0Padded * K1;
217 }
218
219 __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M,
220 index_t MPad,
221 index_t K,
222 index_t StrideA,
223 index_t KBatch,
224 index_t K0Padded,
225 index_t KPad)
226 {
227 const auto a_grid_desc_m_k = [&]() {
229 {
230 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
231 }
233 {
234 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
235 }
236 }();
237
242 {
243
244 const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
245 a_grid_desc_m_k,
249
250 // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
252 a_grid_desc_m_kpad,
254 make_right_pad_transform(M, MPad - M)),
257 }
258 else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
260 {
261 // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
263 a_grid_desc_m_k,
265 make_right_pad_transform(M, MPad - M)),
268 }
269 else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding)
270 {
271 const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
272 a_grid_desc_m_k,
276
278 a_grid_desc_m_kpad,
283 }
284 else
285 {
287 a_grid_desc_m_k,
292 }
293 }
294
295 __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K,
296 index_t NPad,
297 index_t N,
298 index_t StrideB,
299 index_t KBatch,
300 index_t K0Padded,
301 index_t KPad)
302 {
303 const auto b_grid_desc_k_n = [&]() {
305 {
306 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
307 }
309 {
310 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
311 }
312 }();
313
318 {
319
320 const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
321 b_grid_desc_k_n,
325
326 // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
328 b_grid_desc_kpad_n,
330 make_right_pad_transform(N, NPad - N)),
333 }
334 else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
336 {
337 // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
339 b_grid_desc_k_n,
341 make_right_pad_transform(N, NPad - N)),
344 }
345 else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding)
346 {
347 const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
348 b_grid_desc_k_n,
352
354 b_grid_desc_kpad_n,
359 }
360 else
361 {
363 b_grid_desc_k_n,
368 }
369 }
370
371 __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
372 {
373 const auto c_grid_desc_m_n = [&]() {
375 {
376 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
377 }
379 {
380 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
381 }
382 }();
383
384 return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n);
385 }
386
387 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
388 {
389 constexpr auto max_lds_align = K1;
390
391 // A matrix in LDS memory, dst of blockwise copy
392 constexpr auto a_k0_m_k1_block_desc = [&]() {
393 if constexpr(ABlockLdsExtraM)
394 {
398 }
399 else
400 {
402 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
403 }
404 }();
405
406 // B matrix in LDS memory, dst of blockwise copy
407 constexpr auto b_k0_n_k1_block_desc = [&]() {
408 if constexpr(BBlockLdsExtraN)
409 {
413 }
414 else
415 {
417 make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
418 }
419 }();
420
421 // LDS allocation for A and B: be careful of alignment
422 constexpr auto a_block_space_size =
423 math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
424
425 constexpr auto b_block_space_size =
426 math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
427
428 constexpr auto c_block_size =
430
431 return math::max(a_block_space_size * sizeof(LDSTypeA) +
432 b_block_space_size * sizeof(LDSTypeB),
433 c_block_size * sizeof(FloatC));
434 }
435
436 static constexpr auto MXdlPerWave = MRepeat;
437 static constexpr auto NXdlPerWave = NRepeat;
439
440 __host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
441 {
446 {
447 if(!(karg.M % MPerBlock == 0))
448 {
449 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
450 {
451 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
452 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
453 << std::endl;
454 }
455 return false;
456 }
457 }
458
463 {
464 if(!(karg.N % NPerBlock == 0))
465 {
466 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
467 {
468 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
469 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
470 << std::endl;
471 }
472 return false;
473 }
474 }
475
480 {
481
482 auto K_t = karg.k_batch * K0PerBlock * K1;
483 if(!(karg.K % K_t == 0))
484 {
485 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
486 {
487 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
488 << karg.K << " " << __FILE__ << ":" << __LINE__
489 << ", in function: " << __func__ << std::endl;
490 }
491 return false;
492 }
493 }
494
496 {
497 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
498 {
499 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
500 {
501 std::cout << "Arg K (" << karg.K
502 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
503 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
504 << __LINE__ << ", in function: " << __func__ << std::endl;
505 }
506 return false;
507 }
508 }
509 else
510 {
511 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
512 {
513 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
514 {
515 std::cout << "Arg M (" << karg.M
516 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
517 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
518 << __LINE__ << ", in function: " << __func__ << std::endl;
519 }
520 return false;
521 }
522 }
523
525 {
526 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
527 {
528 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
529 {
530 std::cout << "Arg N (" << karg.N
531 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
532 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
533 << __LINE__ << ", in function: " << __func__ << std::endl;
534 }
535 return false;
536 }
537 }
538 else
539 {
540 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
541 {
542 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
543 {
544 std::cout << "Arg K (" << karg.K
545 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
546 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
547 << __LINE__ << ", in function: " << __func__ << std::endl;
548 }
549 return false;
550 }
551 }
552
554 {
555 if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
556 {
557 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
558 {
559 std::cout << "Arg N (" << karg.N
560 << ") value is not a multiple of "
561 "CBlockTransferScalarPerVector_NWaveNPerXDL ("
562 << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__
563 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
564 }
565 return false;
566 }
567 }
568 else
569 {
570 if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
571 {
572 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
573 {
574 std::cout << "Arg M (" << karg.M
575 << ") value is not a multiple of "
576 "CBlockTransferScalarPerVector_NWaveNPerXDL ("
577 << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__
578 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
579 }
580 return false;
581 }
582 }
583
584 const auto num_k_loop = karg.K0Padded / K0PerBlock;
585 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
586 {
587 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
588 {
589 std::cout << "The number of k loops (" << num_k_loop
590 << ") value is not supported by GridwiseGemm Pipeline."
591 << " K0Padded: " << karg.K0Padded << ", K0PerBlock: " << K0PerBlock << " "
592 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
593 << std::endl;
594 }
595 return false;
596 }
597
598 return true;
599 }
600
601 __host__ __device__ static auto GetKPad(index_t K, index_t KBatch)
602 {
603 const index_t K0Padded =
604 math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock;
605 const index_t KPad = KBatch * K0Padded * K1;
606 return KPad;
607 }
608
609 __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded)
610 {
611 const index_t num_loop = K0Padded / K0PerBlock;
612 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
613 }
614
615 template <typename CGridDesc>
616 __host__ __device__ static constexpr auto
617 MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc)
618 {
619 const auto M = c_m_n_grid_desc.GetLength(I0);
620 const auto N = c_m_n_grid_desc.GetLength(I1);
621
622 const auto MBlock = M / MPerBlock;
623 const auto NBlock = N / NPerBlock;
624
626 c_m_n_grid_desc,
631 }
632
633 // return block_id to C matrix tile idx (m0, n0) mapping
634 template <typename CGridDesc>
635 __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor(
636 const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch)
637 {
639 c_m_n_grid_desc, 8, KBatch);
640 }
641
642 __host__ __device__ static constexpr auto
644 {
645 constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
646 constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
647
651 I1,
653 }
654
655 // return block_id to C matrix tile idx (m0, n0, k_split) mapping
656 __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
657 {
659 }
660
663
664 template <bool HasMainKBlockLoop,
665 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
666 typename Block2CTileMap>
667 __device__ static void Run(const Argument& karg,
668 void* __restrict__ p_shared_block,
669 const Block2CTileMap& block_2_ctile_map,
670 const AElementwiseOperation a_element_op = AElementwiseOperation{},
671 const BElementwiseOperation b_element_op = BElementwiseOperation{},
672 const CElementwiseOperation c_element_op = CElementwiseOperation{})
673 {
674 const FloatA* p_a_grid = karg.p_a_grid;
675 const FloatB* p_b_grid = karg.p_b_grid;
676 FloatC* p_c_grid = karg.p_c_grid;
677 const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
678 karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded);
679 const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
680 karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded);
681 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
682
683 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
685
686 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
687 p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize());
688 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
689 p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize());
691 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
692
693 // divide block work by [KBatch, M, N]
694 const auto block_work_idx =
695 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
696
697 if(!block_2_ctile_map.ValidCTileIndex(
698 block_work_idx,
699 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
700 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
701 {
702 return;
703 }
704
705 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
706 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]);
707 const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
708
709 // HACK: this force m/n_block_data_idx_on_grid into SGPR
710 const index_t m_block_data_idx_on_grid =
711 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
712
713 const index_t n_block_data_idx_on_grid =
714 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
715
716 // lds max alignment
717 constexpr auto max_lds_align = K1;
718
719 // A matrix in LDS memory, dst of blockwise copy
720 constexpr auto a_k0_m_k1_block_desc = [&]() {
721 if constexpr(ABlockLdsExtraM)
722 {
726 }
727 else
728 {
730 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
731 }
732 }();
733
734 constexpr auto a_b_k0_m_k1_block_desc = [&]() {
735 if constexpr(ABlockLdsExtraM)
736 {
741 K1,
742 I1));
743 }
744 else
745 {
748 max_lds_align);
749 }
750 }();
751 // B matrix in LDS memory, dst of blockwise copy
752 constexpr auto b_k0_n_k1_block_desc = [&]() {
753 if constexpr(BBlockLdsExtraN)
754 {
758 }
759 else
760 {
762 make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
763 }
764 }();
765
766 constexpr auto b_b_k0_n_k1_block_desc = [&]() {
767 if constexpr(BBlockLdsExtraN)
768 {
773 K1,
774 I1));
775 }
776 else
777 {
780 max_lds_align);
781 }
782 }();
783 // A matrix blockwise copy
784 auto a_blockwise_copy =
785 ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
786 AElementwiseOperation,
787 ck::tensor_operation::element_wise::PassThrough,
789 Sequence<1, K0PerBlock, MPerBlock, K1>,
790 ABlockTransferThreadClusterLengths_K0_M_K1,
791 ABlockTransferThreadClusterArrangeOrder,
792 FloatA,
793 LDSTypeA,
794 decltype(a_b_k0_m_k1_grid_desc),
795 decltype(a_b_k0_m_k1_block_desc),
796 ABlockTransferSrcAccessOrder,
797 Sequence<0, 2, 1, 3>,
798 ABlockTransferSrcVectorDim,
799 3,
800 ABlockTransferSrcScalarPerVector,
801 ABlockTransferDstScalarPerVector_K1,
802 1,
803 1,
804 AThreadTransferSrcResetCoordinateAfterRun,
805 true>(
806 a_b_k0_m_k1_grid_desc,
807 make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
808 a_element_op,
809 a_b_k0_m_k1_block_desc,
810 make_multi_index(0, 0, 0, 0),
811 ck::tensor_operation::element_wise::PassThrough{});
812
813 // B matrix blockwise copy
814 auto b_blockwise_copy =
815 ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
816 BElementwiseOperation,
817 ck::tensor_operation::element_wise::PassThrough,
819 Sequence<1, K0PerBlock, NPerBlock, K1>,
820 BBlockTransferThreadClusterLengths_K0_N_K1,
821 BBlockTransferThreadClusterArrangeOrder,
822 FloatB,
823 LDSTypeB,
824 decltype(b_b_k0_n_k1_grid_desc),
825 decltype(b_b_k0_n_k1_block_desc),
826 BBlockTransferSrcAccessOrder,
827 Sequence<0, 2, 1, 3>,
828 BBlockTransferSrcVectorDim,
829 3,
830 BBlockTransferSrcScalarPerVector,
831 BBlockTransferDstScalarPerVector_K1,
832 1,
833 1,
834 BThreadTransferSrcResetCoordinateAfterRun,
835 true>(
836 b_b_k0_n_k1_grid_desc,
837 make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
838 b_element_op,
839 b_b_k0_n_k1_block_desc,
840 make_multi_index(0, 0, 0, 0),
841 ck::tensor_operation::element_wise::PassThrough{});
842
843 // GEMM definition
844 // c_mtx += transpose(a_mtx) * b_mtx
845 // a_mtx[K0PerBlock, MPerBlock] is in LDS
846 // b_mtx[K0PerBlock, NPerBlock] is in LDS
847 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
848 // register
849 // sanity check
850
852 BlockSize,
853 LDSTypeA,
854 LDSTypeB,
855 FloatAcc,
856 decltype(a_k0_m_k1_block_desc),
857 decltype(b_k0_n_k1_block_desc),
858 MPerXdl,
859 NPerXdl,
860 MRepeat,
861 NRepeat,
862 K1,
863 LoopSched,
864 ComputeTypeA,
865 ComputeTypeB>();
866
867 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
868
869 // LDS allocation for A and B: be careful of alignment
870 constexpr auto a_block_space_size =
871 math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
872
873 auto p_a_block = reinterpret_cast<LDSTypeA*>(p_shared_block);
874 auto p_b_block = reinterpret_cast<LDSTypeB*>(p_a_block + a_block_space_size);
875
876 constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
877 constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
878
880 p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
882 p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
883
884 // gridwise GEMM pipeline
885 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
886 (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) /
887 (K0PerBlock * K1));
888
889 const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
890
891 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_b_k0_m_k1_grid_desc,
892 a_b_k0_m_k1_block_desc,
893 a_blockwise_copy,
894 a_grid_buf,
895 a_block_buf,
896 a_block_slice_copy_step,
897 b_b_k0_n_k1_grid_desc,
898 b_b_k0_n_k1_block_desc,
899 b_blockwise_copy,
900 b_grid_buf,
901 b_block_buf,
902 b_block_slice_copy_step,
903 blockwise_gemm,
904 c_thread_buf,
905 num_k_block_main_loop);
906
907 // output: register to global memory
908 {
909 constexpr index_t MWave = MPerBlock / (MRepeat * MPerXdl);
910 constexpr index_t NWave = NPerBlock / (NRepeat * NPerXdl);
911
912 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
913 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
914
915 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
916 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
917
918 constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0);
919 constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1);
920 constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2);
921 constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3);
922 constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4);
923 constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5);
924 constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6);
925 constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7);
926
927 constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock =
929
931 static_cast<FloatC*>(p_shared_block),
932 c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
933
934 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
935 c_block_desc_mblock_mperblock_nblock_nperblock,
937 make_freeze_transform(I0), // freeze mblock
938 make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle,
939 M1,
940 M2,
941 M3,
942 M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
943 make_freeze_transform(I0), // freeze nblock
944 make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle,
945 N1,
946 N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
947 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
949 Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
950
951 // calculate origin of thread output tensor on global memory
952 // blockwise GEMM c matrix starting index
953 const auto c_thread_mtx_on_block =
954 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
955
956 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
957 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
958
959 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
961 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
962 make_tuple(Sequence<0, 1, 2, 3, 4>{}),
963 make_tuple(Sequence<0>{}));
964
965 const auto m_thread_data_on_block_idx =
966 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
967 make_multi_index(m_thread_data_on_block));
968
969 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
972 make_tuple(Sequence<0, 1, 2>{}),
973 make_tuple(Sequence<0>{}));
974
975 const auto n_thread_data_on_block_idx =
976 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
977 make_multi_index(n_thread_data_on_block));
978
979 // VGPR to LDS
980 auto c_thread_copy_vgpr_to_lds =
981 ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
982 FloatC,
983 decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
984 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
985 ck::tensor_operation::element_wise::PassThrough,
986 Sequence<CShuffleMRepeatPerShuffle,
987 CShuffleNRepeatPerShuffle,
988 I1,
989 I1,
990 M2,
991 I1,
992 M4,
993 I1>,
994 Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
995 7,
996 1,
998 1,
999 true>{
1000 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1002 0,
1003 m_thread_data_on_block_idx[I1],
1004 n_thread_data_on_block_idx[I1],
1005 m_thread_data_on_block_idx[I2],
1006 m_thread_data_on_block_idx[I3],
1007 m_thread_data_on_block_idx[I4],
1008 n_thread_data_on_block_idx[I2]),
1009 ck::tensor_operation::element_wise::PassThrough{}};
1010
1011 // LDS to global
1012 auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1013 ThisThreadBlock, // index_t BlockSize,
1014 CElementwiseOperation, // ElementwiseOperation,
1015 CGlobalMemoryDataOperation, // DstInMemOp,
1016 Sequence<1,
1017 CShuffleMRepeatPerShuffle * MWave * MPerXdl,
1018 1,
1019 CShuffleNRepeatPerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1020 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1021 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1022 FloatC, // typename SrcData,
1023 FloatC, // typename DstData,
1024 decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
1025 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1026 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1027 3, // index_t VectorDim,
1028 CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
1029 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1030 false> // bool ThreadTransferDstResetCoordinateAfterRun
1031 {c_block_desc_mblock_mperblock_nblock_nperblock,
1032 make_multi_index(0, 0, 0, 0),
1033 c_grid_desc_mblock_mperblock_nblock_nperblock,
1034 make_multi_index(block_m_id, 0, block_n_id, 0),
1035 c_element_op};
1036
1037 constexpr auto mxdlperwave_forward_step =
1038 make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXdl, 0, 0);
1039 constexpr auto nxdlperwave_forward_step =
1040 make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXdl);
1041 constexpr auto nxdlperwave_backward_step =
1042 make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXdl);
1043
1044 static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) {
1045 constexpr auto mxdlperwave = mxdlperwave_iter;
1046
1047 static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) {
1048 constexpr bool nxdlperwave_forward_sweep =
1049 (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
1050
1051 constexpr index_t nxdlperwave_value =
1052 nxdlperwave_forward_sweep
1053 ? nxdlperwave_iter
1054 : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
1055
1056 constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
1057
1058 // make sure it's safe to do ds_write
1060
1061 // VGPR to LDS
1062 c_thread_copy_vgpr_to_lds.Run(
1063 c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
1064 make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
1065 c_thread_buf,
1066 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1067 c_block_buf);
1068
1069 // make sure it's safe to do ds_read
1071
1072 // LDS to global
1073 c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock,
1074 c_block_buf,
1075 c_grid_desc_mblock_mperblock_nblock_nperblock,
1076 c_grid_buf);
1077
1078 // move on nxdlperwave dimension
1079 if constexpr(nxdlperwave_forward_sweep &&
1080 (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
1081 {
1082 c_block_copy_lds_to_global.MoveDstSliceWindow(
1083 c_grid_desc_mblock_mperblock_nblock_nperblock,
1084 nxdlperwave_forward_step);
1085 }
1086 else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
1087 {
1088 c_block_copy_lds_to_global.MoveDstSliceWindow(
1089 c_grid_desc_mblock_mperblock_nblock_nperblock,
1090 nxdlperwave_backward_step);
1091 }
1092 });
1093
1094 // move on mxdlperwave dimension
1095 if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
1096 {
1097 c_block_copy_lds_to_global.MoveDstSliceWindow(
1098 c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step);
1099 }
1100 });
1101 }
1102 }
1103
1104 static std::string GetTypeString()
1105 {
1106 auto str = std::stringstream();
1107
1108 // clang-format off
1109 str << "GemmXdlSplitKCShuffle_"
1110 << getGemmSpecializationString(GemmSpec) << "_"
1111 << std::string(ALayout::name)[0]
1112 << std::string(BLayout::name)[0]
1113 << std::string(CLayout::name)[0]
1114 << "_"
1115 << "B" << BlockSize << "_"
1116 << "Vec" << ABlockTransferSrcScalarPerVector << "x"
1117 << BBlockTransferSrcScalarPerVector << "x"
1118 << CBlockTransferScalarPerVector_NWaveNPerXDL << "_"
1119 << MPerBlock << "x"
1120 << NPerBlock << "x"
1121 << K0PerBlock << "x"
1122 << K1 ;
1123 // clang-format on
1124
1125 return str.str();
1126 }
1127};
1128
1129} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition blockwise_gemm_xdlops.hpp:620
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
__global__ void kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, const Block2CTileMap &b2c_map, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:33
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition tensor_descriptor_helper.hpp:132
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
unsigned char uint8_t
Definition stdint.h:124
Simple tile mapping which creates 3D grid of block of threads.
Definition block_to_ctile_map.hpp:977
Definition block_to_ctile_map.hpp:541
index_t KPadded
Definition gridwise_gemm_xdlops_v2r4r2.hpp:143
Argument(const FloatA *p_a_grid_, const FloatB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t MPadded_, index_t NPadded_, index_t KPadded_, index_t K0Padded_, index_t k_batch_)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:147
index_t M
Definition gridwise_gemm_xdlops_v2r4r2.hpp:135
index_t K0Padded
Definition gridwise_gemm_xdlops_v2r4r2.hpp:144
index_t StrideA
Definition gridwise_gemm_xdlops_v2r4r2.hpp:138
index_t StrideC
Definition gridwise_gemm_xdlops_v2r4r2.hpp:140
void Print() const
Definition gridwise_gemm_xdlops_v2r4r2.hpp:178
index_t MPadded
Definition gridwise_gemm_xdlops_v2r4r2.hpp:141
index_t k_batch
Definition gridwise_gemm_xdlops_v2r4r2.hpp:145
FloatC * p_c_grid
Definition gridwise_gemm_xdlops_v2r4r2.hpp:134
index_t NPadded
Definition gridwise_gemm_xdlops_v2r4r2.hpp:142
const FloatA * p_a_grid
Definition gridwise_gemm_xdlops_v2r4r2.hpp:132
const FloatB * p_b_grid
Definition gridwise_gemm_xdlops_v2r4r2.hpp:133
index_t N
Definition gridwise_gemm_xdlops_v2r4r2.hpp:136
index_t StrideB
Definition gridwise_gemm_xdlops_v2r4r2.hpp:139
index_t K
Definition gridwise_gemm_xdlops_v2r4r2.hpp:137
Definition gridwise_gemm_xdlops_v2r4r2.hpp:106
__host__ static __device__ constexpr auto GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition gridwise_gemm_xdlops_v2r4r2.hpp:643
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition device_base.hpp:197
Definition matrix_padder.hpp:134
#define CK_ENV(name)
Definition utility/env.hpp:129