gridwise_gemm_xdl_cshuffle_v3.hpp Source File

gridwise_gemm_xdl_cshuffle_v3.hpp Source File#

Composable Kernel: gridwise_gemm_xdl_cshuffle_v3.hpp Source File
gridwise_gemm_xdl_cshuffle_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck/utility/env.hpp"
17
18namespace ck {
19
20// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
21// kernel function Blockers:
22// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
23// two lds chunks.
24// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
25// buffer when we declare __shared__ inside blkgemmpipe
26template <typename GridwiseGemm,
27 bool HasMainKBlockLoop,
28 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
29 index_t MinimumOccupancy = 1,
31__global__ void
32#if CK_USE_LAUNCH_BOUNDS
33__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
34#endif
35 // __attribute__((amdgpu_waves_per_eu(1, 1)))
36 kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
37{
38#if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
39 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
40 {
41 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
42
43 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
44
45 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
46 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
47 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
48 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
49 p_shared,
50 karg);
51 }
52#else
53 ignore = karg;
54#endif // end of if (defined(__gfx9__))
55}
56
57template <typename GridwiseGemm,
58 bool HasMainKBlockLoop,
59 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
60 index_t MinimumOccupancy = 1,
62__global__ void
63#if CK_USE_LAUNCH_BOUNDS
64__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
65#endif
66 // __attribute__((amdgpu_waves_per_eu(1, 1)))
67 kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
68{
69#if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
70 // Pass two lds pointer is the key to tell compiler that ds_read/write
71 // operate on different lds chunk at same time without order dependecy
72 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
73 {
74 __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
75 __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
76
77 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
78
79 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
80 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
81 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
82 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
83 p_shared_0,
84 p_shared_1,
85 karg);
86 }
87#else
88 ignore = karg;
89#endif // end of if (defined(__gfx9__))
90}
91
197template <typename ALayout,
198 typename BLayout,
199 typename CLayout,
200 typename ADataType,
201 typename BDataType,
202 typename AccDataType,
203 typename CShuffleDataType,
204 typename CDataType,
205 typename AElementwiseOperation,
206 typename BElementwiseOperation,
207 typename CElementwiseOperation,
209 index_t BlockSize,
210 index_t MPerBlock,
211 index_t NPerBlock,
212 index_t KPerBlock,
213 index_t AK1Value,
214 index_t BK1Value,
215 index_t MPerXdl,
216 index_t NPerXdl,
217 index_t MXdlPerWave,
218 index_t NXdlPerWave,
219 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
220 typename ABlockTransferThreadClusterArrangeOrder,
221 typename ABlockTransferSrcAccessOrder,
222 index_t ABlockTransferSrcVectorDim,
223 index_t ABlockTransferSrcScalarPerVector,
224 index_t ABlockTransferDstScalarPerVector_AK1,
225 bool AThreadTransferSrcResetCoordinateAfterRun,
226 index_t ABlockLdsExtraM,
227 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
228 typename BBlockTransferThreadClusterArrangeOrder,
229 typename BBlockTransferSrcAccessOrder,
230 index_t BBlockTransferSrcVectorDim,
231 index_t BBlockTransferSrcScalarPerVector,
232 index_t BBlockTransferDstScalarPerVector_BK1,
233 bool BThreadTransferSrcResetCoordinateAfterRun,
234 index_t BBlockLdsExtraN,
235 index_t CShuffleMXdlPerWavePerShuffle,
236 index_t CShuffleNXdlPerWavePerShuffle,
237 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
238 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
241 typename ComputeTypeA = CDataType,
242 typename ComputeTypeB = ComputeTypeA,
243 bool PermuteA = false,
244 bool PermuteB = false,
245 bool DoElementwiseBeforeCShuffle = false>
247{
248 static constexpr auto I0 = Number<0>{};
249 static constexpr auto I1 = Number<1>{};
250 static constexpr auto I2 = Number<2>{};
251 static constexpr auto I3 = Number<3>{};
252 static constexpr auto I4 = Number<4>{};
253 static constexpr auto I5 = Number<5>{};
254 static constexpr auto I6 = Number<6>{};
255 static constexpr auto I7 = Number<7>{};
256
257 // K1 should be Number<...>
258 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
259 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
260 static constexpr auto AK1Number = Number<AK1Value>{};
261 static constexpr auto BK1Number = Number<BK1Value>{};
262
263 static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
264 static constexpr bool is_single_rate_mfma =
266 lcm_AK1_BK1 <= 4) ||
268 // gfx950 double rate mfma16x16 require at least 128 KPerBlock to consume
270 KPerBlock < 128 && MPerXdl == 16))
271 ? true
272 : false;
273 static constexpr auto is_scale_mfma = false;
274 static constexpr index_t KPack =
276 MfmaSelector<ComputeTypeA,
277 MPerXdl,
278 NPerXdl,
279 ComputeTypeA,
281 is_scale_mfma>::selected_mfma.k_per_blk);
282
284
285 static constexpr index_t APackedSize = []() {
287 return 2;
288 else
289 return 1;
290 }();
291
292 static constexpr index_t BPackedSize = []() {
294 return 2;
295 else
296 return 1;
297 }();
298
299 __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
300 {
301 return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
302 }
303
304 __host__ static auto CalculateMPadded(index_t M)
305 {
306 return math::integer_least_multiple(M, MPerBlock);
307 }
308
309 __host__ static auto CalculateNPadded(index_t N)
310 {
311 return math::integer_least_multiple(N, NPerBlock);
312 }
313
314 __host__ static auto CalculateKPadded(index_t K)
315 {
316 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
317 }
318
319 __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
320 {
321 auto K_t = K_Batch * KPerBlock;
322 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
323 }
324
325 __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
326 {
327 auto K_t = K_Batch * KPerBlock;
328 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
329 }
330
331 __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
332 {
333 auto K_t = K_Batch * KPerBlock;
334 return (K + K_t - 1) / K_t * KPerBlock;
335 }
336
337 __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
338 {
339 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
340 auto K_t = K_Batch * KReadVec;
341 return (K + K_t - 1) / K_t * KReadVec;
342 }
343
344 __host__ static auto CalculateMBlock(index_t M)
345 {
346 return math::integer_divide_ceil(M, MPerBlock);
347 }
348
349 __host__ static auto CalculateNBlock(index_t N)
350 {
351 return math::integer_divide_ceil(N, NPerBlock);
352 }
353
354 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
355 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
356 {
357 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
358 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
359
361 TileDesc_K0_MN_K1{},
367 }
368
369 __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
370 index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
371 {
372 const auto a_grid_desc_mraw_kraw = [&]() {
374 {
375 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
376 }
378 {
379 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
380 }
381 }();
382
383 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
384
385 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
386 GemmSpec == GemmSpecialization::MNKPadding)
387 {
388 // pad both M and K
389 const auto a_grid_desc_m_k =
390 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
392 make_right_pad_transform(K, KPad - K)),
395
396 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
397 a_grid_desc_m_k,
402
403 return a_grid_desc_ak0_m_ak1;
404 }
405 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
406 GemmSpec == GemmSpecialization::MNPadding)
407 {
408 // pad M, but not K
409 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
410 a_grid_desc_mraw_kraw,
412 make_right_pad_transform(M, MPad - M)),
415
416 return a_grid_desc_ak0_m_ak1;
417 }
418 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
419 GemmSpec == GemmSpecialization::NKPadding)
420 {
421 // pad K, but not M
422 const auto a_grid_desc_m_k = transform_tensor_descriptor(
423 a_grid_desc_mraw_kraw,
427
428 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
429 a_grid_desc_m_k,
434
435 return a_grid_desc_ak0_m_ak1;
436 }
437 else
438 {
439 // not pad M or K
440 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
441 a_grid_desc_mraw_kraw,
446
447 return a_grid_desc_ak0_m_ak1;
448 }
449 }
450
451 __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
452 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
453 {
454 const auto b_grid_desc_nraw_kraw = [&]() {
456 {
457 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
458 }
460 {
461 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
462 }
463 }();
464
465 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
466
468 GemmSpec != GemmSpecialization::Default),
469 "pk_i4_t does not support padding");
470
471 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
472 GemmSpec == GemmSpecialization::MNKPadding)
473 {
474 // pad both N and K
475 const auto b_grid_desc_n_k =
476 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
478 make_right_pad_transform(K, KPad - K)),
481
482 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
483 b_grid_desc_n_k,
488
489 return b_grid_desc_bk0_n_bk1;
490 }
491 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
492 GemmSpec == GemmSpecialization::MNPadding)
493 {
494 // pad N, but not K
495 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
496 b_grid_desc_nraw_kraw,
498 make_right_pad_transform(N, NPad - N)),
501
502 return b_grid_desc_bk0_n_bk1;
503 }
504 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
505 GemmSpec == GemmSpecialization::MKPadding)
506 {
507 // pad K, but not N
508 const auto b_grid_desc_n_k = transform_tensor_descriptor(
509 b_grid_desc_nraw_kraw,
513
514 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
515 b_grid_desc_n_k,
520
521 return b_grid_desc_bk0_n_bk1;
522 }
523 else
524 {
525 if constexpr(!PermuteB)
526 {
527 // not pad N or K
528 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
529 b_grid_desc_nraw_kraw,
534
535 return b_grid_desc_bk0_n_bk1;
536 }
537 else
538 {
539 // Pre-shuffled Weight
540 // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
541 constexpr index_t BK01 = KPerBlock / BK1Value;
542 const index_t BK0_ = StrideB / BK1Value;
543 const index_t BK00 = BK0_ / BK01;
544
545 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
546 make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
547
548 const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor(
549 b_grid_desc_bk00_n_bk01_bk1_permute,
555
556 return b_grid_desc_bk0_n_bk1_permute;
557 }
558 }
559 }
560
561 template <typename ABlockDesc_AK0_M_AK1>
562 __host__ __device__ static constexpr auto
563 MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
564 {
565 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
566
567 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
568 }
569
570 template <typename BBlockDesc_BK0_N_BK1>
571 __host__ __device__ static constexpr auto
572 MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
573 {
574 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
575
576 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
577 }
578
579 __host__ __device__ static auto
581 {
582 const auto c_grid_desc_mraw_nraw = [&]() {
584 {
585 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
586 }
588 {
589 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
590 }
591 }();
592
593 // pad M and N
594 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
596 make_right_pad_transform(N, NPad - N)),
599#if 0
600 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
601
602 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
603 GemmSpec == GemmSpecialization::MNKPadding)
604 {
605 // pad M and N
606 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
608 make_right_pad_transform(N, NPad - N)),
611 }
612 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
613 GemmSpec == GemmSpecialization::MKPadding)
614 {
615 // pad M, but not N
617 c_grid_desc_mraw_nraw,
621 }
622 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
623 GemmSpec == GemmSpecialization::NKPadding)
624 {
625 // pad N, but not M
627 c_grid_desc_mraw_nraw,
631 }
632 else
633 {
634 // not pad M or N
635 return c_grid_desc_mraw_nraw;
636 }
637#endif
638 }
639
640 struct Problem
641 {
642 __host__ Problem(index_t M_,
643 index_t N_,
644 index_t K_,
645 index_t StrideA_,
646 index_t StrideB_,
647 index_t StrideC_,
648 index_t KBatch_,
649 AElementwiseOperation a_element_op,
650 BElementwiseOperation b_element_op,
651 CElementwiseOperation c_element_op)
652 : M{M_},
653 N{N_},
654 K{K_},
655 StrideA{StrideA_},
656 StrideB{StrideB_},
657 StrideC{StrideC_},
658 KBatch{KBatch_},
661 KRead{CalculateKRead(K_, KBatch_)},
662 KPadded{CalculateKPadded(K_, KBatch_)},
663 AK0{CalculateAK0Padded(K_, KBatch_)},
664 BK0{CalculateBK0Padded(K_, KBatch_)},
667 a_element_op_{a_element_op},
668 b_element_op_{b_element_op},
669 c_element_op_{c_element_op}
670 {
671 }
672
673 __host__ void Print() const
674 {
675 // clang-format off
676 std::cout << "problem {"
677 << "M:" << M << ", "
678 << "N:" << N << ", "
679 << "K:" << K << ", "
680 << "SA:" << StrideA << ", "
681 << "SB:" << StrideB << ", "
682 << "SC:" << StrideC << ", "
683 << "MP:" << MPadded << ", "
684 << "NP:" << NPadded << ", "
685 << "KRead:" << KRead << ", "
686 << "KP:" << KPadded << ", "
687 << "AK0:" << AK0 << ", "
688 << "BK0:" << BK0 << ", "
689 << "MBlock: " << MBlock << ", "
690 << "NBlock: " << NBlock << "}" << std::endl;
691 // clang-format off
692 }
693
709 AElementwiseOperation a_element_op_;
710 BElementwiseOperation b_element_op_;
711 CElementwiseOperation c_element_op_;
712 };
713
714 // Argument
716 {
717 __host__ Argument(const ADataType* p_a_grid_,
718 const BDataType* p_b_grid_,
719 CDataType* p_c_grid_,
720 index_t M_,
721 index_t N_,
722 index_t K_,
723 index_t StrideA_,
724 index_t StrideB_,
725 index_t StrideC_,
726 index_t k_batch_,
727 bool is_reduce_ = false,
728 AElementwiseOperation a_element_op = AElementwiseOperation{},
729 BElementwiseOperation b_element_op = BElementwiseOperation{},
730 CElementwiseOperation c_element_op = CElementwiseOperation{})
731 : Problem{M_,
732 N_,
733 K_,
734 StrideA_,
735 StrideB_,
736 StrideC_,
737 k_batch_,
741 p_a_grid{p_a_grid_},
742 p_b_grid{p_b_grid_},
743 p_c_grid{p_c_grid_},
744 is_reduce(is_reduce_)
745 {
746 }
747
748 __host__ __device__ inline bool IsReduceAdd() const
749 {
750 return (Problem::KBatch > 1) && is_reduce;
751 }
752
753 __host__ __device__ inline bool IsAtomicAdd() const
754 {
755 return (Problem::KBatch > 1) && (!is_reduce);
756 }
757
758 const ADataType* p_a_grid;
759 const BDataType* p_b_grid;
760 CDataType* p_c_grid;
762 };
763
765 {
766
767 __device__ SplitKBatchOffset(Argument& karg)
768 {
770 {
771 a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
772 }
774 {
775 a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
776 }
777
779 {
780 b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
781 }
783 {
784 if constexpr(!PermuteB)
785 {
786 b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
787 }
788 else
789 {
790 const int k0_offset = karg.KRead * karg.N;
791 b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
792 }
793 }
794
795 if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
796 {
797 karg.K = karg.KRead;
798 }
799 else
800 {
801 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
802 }
803
804 if(karg.IsReduceAdd())
805 {
806 c_reduce_offset = blockIdx.z * karg.M * karg.N;
807 }
808 else
809 {
810 c_reduce_offset = 0;
811 }
812 }
813
817 };
818
819 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
820 {
821 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
822 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
823 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
824 // A matrix in LDS memory, dst of blockwise copy
825 if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
826 {
827 // bank conflict when writting the data into LDS, but don't worry, we have whole entire
828 // loop to hide it in v4. it may give you some benefit from less valu in compute address
832 }
833 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
834 // in some cases.
836 {
837 constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize;
838 constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
839 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
841 AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
843
844 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
845 a_lds_block_desc,
851
852 constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
853 a_lds_block_desc_permuted,
859
860 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
861 a_lds_block_desc_ak0_mldslayer_m_ak1,
868
869 return a_lds_block_desc_ak0_m_ak1;
870 }
871 else // ColumnMajor A
872 {
873 // kfold and mpair dimension is not always required.
874 // more dimension in merge_transform increase the difficulty of generating immarg offset
875 // for compiler.
876 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
877 constexpr auto M1 = MPerBlock / M0;
878
879 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
880 constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
881 constexpr auto KThreadRead = WaveSize / MPerXdl;
882 constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
883
884 constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
885 ? 1
886 : 128 / (AK1Number * M0 * sizeof(ADataType));
887 constexpr auto KThreadReadPerm =
888 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
889 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
890 : KThreadRead;
891
892 // 1<=mpair<=n0
893 constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
894 ? 1
895 : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
896 ? M0
897 : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
898
899 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
903 Number<kfold * M0 / mpair>{},
905 AK1Number));
906
907 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
908 a_lds_block_desc,
913 make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
920
921 constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
922 a_lds_block_desc_permuted,
931 Sequence<1>{},
932 Sequence<2>{},
933 Sequence<3>{},
934 Sequence<4>{},
935 Sequence<5>{}),
937 Sequence<2>{},
940 Sequence<6>{},
941 Sequence<7>{}));
942
943 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
944 a_lds_block_desc_unmerged,
947 Number<KThreadWrite / kfold / KThreadReadPerm>{},
955
956 return a_lds_block_desc_ak0_m_ak1;
957 }
958 }
959
960 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
961 {
962 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
963 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
964 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
965 // B matrix in LDS memory, dst of blockwise copy
966 if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
967 {
968 // bank conflict when writting the data into LDS, but don't worry, we have whole entire
969 // loop to hide it in v4. it may give you some benefit from less valu in compute address
973 }
975 {
976 // NLdsLayer * K0 as logical Bank
977 constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize;
978 constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
979 constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
981 BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
983
984 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
985 b_lds_block_desc,
991
992 constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
993 b_lds_block_desc_permuted,
999
1000 constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1001 b_lds_block_desc_bk0_nldslayer_n_bk1,
1008
1009 return b_lds_block_desc_bk0_n_bk1;
1010 }
1011 else // RowMajor B
1012 {
1013 constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
1014 constexpr auto N1 = NPerBlock / N0;
1015
1016 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
1017 constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
1018 constexpr auto KThreadRead = WaveSize / NPerXdl;
1019 constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
1020
1021 constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
1022 ? 1
1023 : 128 / (BK1Number * N0 * sizeof(BDataType));
1024 constexpr auto KThreadReadPerm =
1025 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1026 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1027 : KThreadRead;
1028
1029 // 1<=npair<=n0
1030 constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
1031 ? 1
1032 : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
1033 ? N0
1034 : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
1035
1036 constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
1040 Number<kfold * N0 / npair>{},
1041 Number<npair>{},
1042 BK1Number));
1043
1044 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
1045 b_lds_block_desc,
1046 make_tuple(
1050 make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
1053 make_tuple(
1055 make_tuple(
1057
1058 constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
1059 b_lds_block_desc_permuted,
1060 make_tuple(
1068 Sequence<1>{},
1069 Sequence<2>{},
1070 Sequence<3>{},
1071 Sequence<4>{},
1072 Sequence<5>{}),
1074 Sequence<2>{},
1077 Sequence<6>{},
1078 Sequence<7>{}));
1079
1080 constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1081 b_lds_block_desc_unmerged,
1084 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1085 Number<kfold>{},
1092
1093 return b_lds_block_desc_bk0_n_bk1;
1094 }
1095 }
1096
1098 {
1099 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1100 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1101
1102 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1104 make_tuple(I1,
1106 I1,
1108
1109 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1110 }
1111
1114 BlkGemmPipelineVer,
1115 BlkGemmPipeSched,
1116 BlockSize,
1117 ADataType,
1118 BDataType,
1119 ComputeTypeA,
1120 AccDataType,
1127 ABlockTransferSrcScalarPerVector,
1128 BBlockTransferSrcScalarPerVector,
1129 MPerBlock,
1130 NPerBlock,
1131 KPerBlock,
1132 MPerXdl,
1133 NPerXdl,
1134 MXdlPerWave,
1135 NXdlPerWave,
1136 KPack>())>;
1137
1138 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1139 {
1140 // LDS allocation for A and B: be careful of alignment
1141 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1142 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1143
1144 // lds max alignment
1145 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1146
1147 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1148 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1149
1150 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1151 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1152
1153 // LDS allocation for C shuffle in LDS
1154 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1156
1157 constexpr auto c_block_size =
1158 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1159
1160 return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize +
1161 b_block_space_size_aligned * sizeof(BDataType) / BPackedSize),
1162 c_block_size * sizeof(CShuffleDataType));
1163 }
1164
1165 template <InMemoryDataOperationEnum CGlobalMemoryDataOperation>
1166 __device__ static bool constexpr IsValidCompilationParameter()
1167 {
1168 enum struct Arch : bool
1169 {
1170#if defined(__gfx950__)
1171 is_gfx950_build = true,
1172#else
1173 is_gfx950_build = false,
1174#endif
1175 };
1176
1177 // skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950
1178 if constexpr(static_cast<bool>(Arch::is_gfx950_build) ||
1179 (AK1Number < 32 && BK1Number < 32) ||
1180 (AK1Number >= 32 && APackedSize == 2) ||
1181 (BK1Number >= 32 && BPackedSize == 2))
1182 {
1183
1184 }
1185 else
1186 {
1187 return false;
1188 }
1189
1190 return ck::tensor_operation::device::IsValidGemmCompilationParameter<
1191 BlockSize,
1192 MPerBlock,
1193 NPerBlock,
1194 MPerXdl,
1195 NPerXdl,
1196 MXdlPerWave,
1197 NXdlPerWave,
1198 CDataType,
1199 CGlobalMemoryDataOperation>();
1200 }
1201 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1202 __host__ static constexpr bool CheckValidity(const Argument& karg)
1203 {
1204 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1205 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1206 "Invalid tuning param!");
1207
1208 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1213 {
1214 if(!(karg.M % MPerBlock == 0))
1215 {
1216 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1217 {
1218 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1219 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1220 << std::endl;
1221 }
1222 return false;
1223 }
1224 }
1225
1226 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1231 {
1232 if(!(karg.N % NPerBlock == 0))
1233 {
1234 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1235 {
1236 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1237 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1238 << std::endl;
1239 }
1240 return false;
1241 }
1242 }
1243
1244 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1248 {
1249
1250 auto K_t = karg.KBatch * KPerBlock;
1251 if(!(karg.K % K_t == 0))
1252 {
1253 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1254 {
1255 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1256 << karg.K << " " << __FILE__ << ":" << __LINE__
1257 << ", in function: " << __func__ << std::endl;
1258 }
1259 return false;
1260 }
1261 }
1262 else
1263 {
1264 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1265 auto K_t = karg.KBatch * KReadVec;
1266 auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1267 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1268 {
1269 return false;
1270 }
1271 }
1272
1274 {
1275 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1276 {
1277 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1278 {
1279 std::cout << "Arg K (" << karg.K
1280 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1281 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1282 << __LINE__ << ", in function: " << __func__ << std::endl;
1283 }
1284 return false;
1285 }
1286 }
1287 else
1288 {
1289 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1290 {
1291 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1292 {
1293 std::cout << "Arg M (" << karg.M
1294 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1295 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1296 << __LINE__ << ", in function: " << __func__ << std::endl;
1297 }
1298 return false;
1299 }
1300 }
1301
1303 {
1304 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1305 {
1306 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1307 {
1308 std::cout << "Arg N (" << karg.N
1309 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1310 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1311 << __LINE__ << ", in function: " << __func__ << std::endl;
1312 }
1313 return false;
1314 }
1315 }
1316 else
1317 {
1318 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1319 {
1320 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1321 {
1322 std::cout << "Arg K (" << karg.K
1323 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1324 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1325 << __LINE__ << ", in function: " << __func__ << std::endl;
1326 }
1327 return false;
1328 }
1329 }
1330
1332 {
1333 if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1334 {
1335 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1336 {
1337 std::cout << "Arg N (" << karg.N
1338 << ") value is not a multiple of "
1339 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1340 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1341 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1342 << std::endl;
1343 }
1344 return false;
1345 }
1346 }
1347 else
1348 {
1349 if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1350 {
1351 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1352 {
1353 std::cout << "Arg M (" << karg.M
1354 << ") value is not a multiple of "
1355 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1356 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1357 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1358 << std::endl;
1359 }
1360 return false;
1361 }
1362 }
1363
1368 {
1369 if(!karg.IsReduceAdd())
1370 {
1371 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1372 {
1373 std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
1374 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1375 }
1376 if(karg.KBatch > 1)
1377 {
1378 return false;
1379 }
1380 }
1381 }
1382
1383 // check gridwise gemm pipeline
1384 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1385
1386 if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
1387 {
1388 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1389 {
1390 return false;
1391 }
1392 }
1393
1394 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1395 return true;
1396 }
1397
1398 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1399 {
1400 const index_t num_loop = K / KPerBlock;
1401
1402 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1403 }
1404
1405 __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1406 {
1407 const index_t num_loop = K / KPerBlock;
1408
1409 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1410 }
1411
1412 template <typename CGridDesc>
1413 __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1414 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1415 {
1416 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1417 c_grid_desc_m_n,
1422
1423 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1424 }
1425
1426 // return block_id to C matrix tile idx (m0, n0) mapping
1427 // if arch = gfx942
1429 // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
1430
1431 template <typename AGridDesc_AK0_M_K1,
1432 typename BGridDesc_BK0_N_K1,
1433 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1434 bool HasMainKBlockLoop,
1435 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1436 TailNumber TailNum = TailNumber::Odd>
1437 __device__ static void Run(const ADataType* p_a_grid,
1438 const BDataType* p_b_grid,
1439 CDataType* p_c_grid,
1440 void* p_shared,
1441 const Problem& problem,
1442 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1443 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1444 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1445 c_grid_desc_mblock_mperblock_nblock_nperblock)
1446 {
1447 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1448 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1449 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1450 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1452 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1453
1454 // divide block work by [M, N]
1455 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1456
1457 const auto block_work_idx =
1458 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1459
1460 if(!block_2_ctile_map.ValidCTileIndex(
1461 block_work_idx,
1462 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1463 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1464 {
1465 return;
1466 }
1467
1468 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1469 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1470
1471 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1472 const index_t m_block_data_idx_on_grid =
1473 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1474
1475 const index_t n_block_data_idx_on_grid =
1476 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1477
1478 // lds max alignment
1479 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1480
1481 // A matrix in LDS memory, dst of blockwise copy
1482 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1483
1484 // B matrix in LDS memory, dst of blockwise copy
1485 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1486
1487 // A matrix blockwise copy
1488 auto a_blockwise_copy =
1490 AElementwiseOperation,
1494 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1495 ABlockTransferThreadClusterArrangeOrder,
1496 ADataType,
1497 ADataType,
1498 decltype(a_grid_desc_ak0_m_ak1),
1499 decltype(a_block_desc_ak0_m_ak1),
1500 ABlockTransferSrcAccessOrder,
1502 ABlockTransferSrcVectorDim,
1503 2,
1504 ABlockTransferSrcScalarPerVector,
1505 ABlockTransferDstScalarPerVector_AK1,
1506 1,
1507 1,
1508 AThreadTransferSrcResetCoordinateAfterRun,
1509 true,
1510 BlockwiseGemmPipe::GlobalBufferNum>(
1511 a_grid_desc_ak0_m_ak1,
1512 make_multi_index(0, m_block_data_idx_on_grid, 0),
1513 problem.a_element_op_,
1514 a_block_desc_ak0_m_ak1,
1515 make_multi_index(0, 0, 0),
1517
1518 // B matrix blockwise copy
1519 auto b_blockwise_copy =
1521 BElementwiseOperation,
1525 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1526 BBlockTransferThreadClusterArrangeOrder,
1527 BDataType,
1528 BDataType,
1529 decltype(b_grid_desc_bk0_n_bk1),
1530 decltype(b_block_desc_bk0_n_bk1),
1531 BBlockTransferSrcAccessOrder,
1533 BBlockTransferSrcVectorDim,
1534 2,
1535 BBlockTransferSrcScalarPerVector,
1536 BBlockTransferDstScalarPerVector_BK1,
1537 1,
1538 1,
1539 BThreadTransferSrcResetCoordinateAfterRun,
1540 true,
1541 BlockwiseGemmPipe::GlobalBufferNum>(
1542 b_grid_desc_bk0_n_bk1,
1543 make_multi_index(0, n_block_data_idx_on_grid, 0),
1544 problem.b_element_op_,
1545 b_block_desc_bk0_n_bk1,
1546 make_multi_index(0, 0, 0),
1548
1549 // LDS allocation for A and B: be careful of alignment
1550 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1551 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1552
1553 // Cast after lds
1555 static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1556
1558 reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
1559 sizeof(ADataType) /
1560 APackedSize),
1561 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1562
1563 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1564 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1565
1566 // Blockwise GEMM pipeline
1567 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1568 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1569 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1570
1571 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1572 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1573 KPerBlock);
1574
1575 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1576 a_block_desc_ak0_m_ak1,
1577 a_blockwise_copy,
1578 a_grid_buf,
1579 a_block_buf,
1580 a_block_slice_copy_step,
1581 b_grid_desc_bk0_n_bk1,
1582 b_block_desc_bk0_n_bk1,
1583 b_blockwise_copy,
1584 b_grid_buf,
1585 b_block_buf,
1586 b_block_slice_copy_step,
1587 c_thread_buf,
1588 num_k_block_main_loop);
1589
1590 // shuffle C and write out
1591 {
1592 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1593 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1594 "wrong!");
1595
1596 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1597 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1598
1599 // TODO: hacky, fix it!
1600 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1601 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1602
1603 // TODO: hacky, fix it!
1604 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1605 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1606 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1607
1608 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1609 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1610 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1611 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1612 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1613 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1614 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1615 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1616
1617 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1619
1620 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1621 static_cast<CShuffleDataType*>(p_shared),
1622 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1623
1624 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1625 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1626 make_tuple(
1629 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1630 M1, // M1 = MWave
1631 M2, // M2 * M3 * M4 = MPerXdl
1632 M3,
1633 M4)),
1636 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1637 N1, // N1 = NWave
1638 N2))), // N2 = NPerXdl
1640 make_tuple(
1642
1643 // calculate origin of thread output tensor on global memory
1644 // blockwise GEMM c matrix starting index
1645 const auto c_thread_mtx_on_block =
1646 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1647
1648 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1649 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1650
1651 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1653 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1656
1657 const auto m_thread_data_on_block_idx =
1658 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1659 make_multi_index(m_thread_data_on_block));
1660
1661 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1666
1667 const auto n_thread_data_on_block_idx =
1668 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1669 make_multi_index(n_thread_data_on_block));
1670
1672 const auto& vpgr_to_lds_element_op = [&] {
1673 if constexpr(DoElementwiseBeforeCShuffle)
1674 {
1675 return problem.c_element_op_;
1676 }
1677 else
1678 {
1679 return pass_through;
1680 }
1681 };
1682 const auto& lds_to_global_element_op = [&] {
1683 if constexpr(!DoElementwiseBeforeCShuffle)
1684 {
1685 return problem.c_element_op_;
1686 }
1687 else
1688 {
1689 return pass_through;
1690 }
1691 };
1692
1693 // shuffle: threadwise copy C from VGPR to LDS
1694 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1695 AccDataType,
1696 CShuffleDataType,
1697 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1698 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1699 conditional_t<DoElementwiseBeforeCShuffle,
1700 CElementwiseOperation,
1702 Sequence<CShuffleMXdlPerWavePerShuffle,
1703 CShuffleNXdlPerWavePerShuffle,
1704 I1,
1705 I1,
1706 M2,
1707 I1,
1708 M4,
1709 I1>,
1711 7,
1712 1,
1714 1,
1715 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1717 0,
1718 m_thread_data_on_block_idx[I1],
1719 n_thread_data_on_block_idx[I1],
1720 m_thread_data_on_block_idx[I2],
1721 m_thread_data_on_block_idx[I3],
1722 m_thread_data_on_block_idx[I4],
1723 n_thread_data_on_block_idx[I2]),
1724 vpgr_to_lds_element_op()};
1725
1726 // shuffle: blockwise copy C from LDS to global
1727 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1728 ThisThreadBlock, // ThreadGroup
1729 conditional_t<!DoElementwiseBeforeCShuffle,
1730 CElementwiseOperation,
1732 CGlobalMemoryDataOperation, // DstInMemOp,
1733 Sequence<1,
1734 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1735 1,
1736 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1737 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1738 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1739 CShuffleDataType, // typename SrcData,
1740 CDataType, // typename DstData,
1741 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1742 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1743 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1744 3, // index_t VectorDim,
1745 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1746 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1747 false> // bool ThreadTransferDstResetCoordinateAfterRun>
1748 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1749 make_multi_index(0, 0, 0, 0),
1750 c_grid_desc_mblock_mperblock_nblock_nperblock,
1751 make_multi_index(block_m_id, 0, block_n_id, 0),
1752 lds_to_global_element_op()};
1753
1754 // space filling curve for threadwise C in VGPR
1755 constexpr auto sfc_c_vgpr =
1758 Sequence<CShuffleMXdlPerWavePerShuffle,
1759 CShuffleNXdlPerWavePerShuffle,
1760 1,
1761 1,
1762 M2,
1763 1,
1764 M4,
1765 1>>{};
1766
1767 // space filling curve for shuffled blockwise C in global mem
1768 constexpr auto sfc_c_global =
1771 Sequence<1,
1772 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1773 1,
1774 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1775
1776 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1777
1778 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1779
1780 static_for<0, num_access, 1>{}([&](auto access_id) {
1781 // make sure it's safe to write to LDS
1783
1784 // each thread write its data from VGPR to LDS
1785 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1786 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1787 c_thread_buf,
1788 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1789 c_shuffle_block_buf);
1790
1791 // make sure it's safe to read from LDS
1793
1794 // each block copy its data from LDS to global
1795 c_shuffle_block_copy_lds_to_global.Run(
1796 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1797 c_shuffle_block_buf,
1798 c_grid_desc_mblock_mperblock_nblock_nperblock,
1799 c_grid_buf);
1800
1801 if constexpr(access_id < num_access - 1)
1802 {
1803 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1804
1805 // move on C
1806 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1807 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1808 }
1809 });
1810 }
1811 }
1812
1813 template <bool HasMainKBlockLoop,
1814 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1815 TailNumber TailNum = TailNumber::Odd>
1816 __device__ static void Run(const ADataType* p_a_grid,
1817 const BDataType* p_b_grid,
1818 CDataType* p_c_grid,
1819 void* p_shared,
1820 const Problem& problem)
1821 {
1822 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1823 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1824 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1825 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1826 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
1827 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1828 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1830 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1831
1832 Run<decltype(a_grid_desc_ak0_m_ak1),
1833 decltype(b_grid_desc_bk0_n_bk1),
1834 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1835 HasMainKBlockLoop,
1836 CGlobalMemoryDataOperation,
1837 TailNum>(p_a_grid,
1838 p_b_grid,
1839 p_c_grid,
1840 p_shared,
1841 problem,
1842 a_grid_desc_ak0_m_ak1,
1843 b_grid_desc_bk0_n_bk1,
1844 c_grid_desc_mblock_mperblock_nblock_nperblock);
1845 }
1846
1847 template <typename AGridDesc_AK0_M_K1,
1848 typename BGridDesc_BK0_N_K1,
1849 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1850 bool HasMainKBlockLoop,
1851 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1852 TailNumber TailNum = TailNumber::Odd>
1853 __device__ static void Run_2Lds(const ADataType* p_a_grid,
1854 const BDataType* p_b_grid,
1855 CDataType* p_c_grid,
1856 void* p_shared_0,
1857 void* p_shared_1,
1858 const Problem& problem,
1859 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1860 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1861 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1862 c_grid_desc_mblock_mperblock_nblock_nperblock)
1863 {
1864 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1865 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1866 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1867 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1869 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1870
1871 // divide block work by [M, N]
1872 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1873
1874 const auto block_work_idx =
1875 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1876
1877 if(!block_2_ctile_map.ValidCTileIndex(
1878 block_work_idx,
1879 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1880 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1881 {
1882 return;
1883 }
1884
1885 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1886 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1887
1888 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1889 const index_t m_block_data_idx_on_grid =
1890 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1891
1892 const index_t n_block_data_idx_on_grid =
1893 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1894
1895 // lds max alignment
1896 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1897
1898 // A matrix in LDS memory, dst of blockwise copy
1899 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1900
1901 // B matrix in LDS memory, dst of blockwise copy
1902 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1903
1904 // A matrix blockwise copy
1905 auto a_blockwise_copy =
1907 AElementwiseOperation,
1911 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1912 ABlockTransferThreadClusterArrangeOrder,
1913 ADataType,
1914 ADataType,
1915 decltype(a_grid_desc_ak0_m_ak1),
1916 decltype(a_block_desc_ak0_m_ak1),
1917 ABlockTransferSrcAccessOrder,
1919 ABlockTransferSrcVectorDim,
1920 2,
1921 ABlockTransferSrcScalarPerVector,
1922 ABlockTransferDstScalarPerVector_AK1,
1923 1,
1924 1,
1925 AThreadTransferSrcResetCoordinateAfterRun,
1926 true,
1927 BlockwiseGemmPipe::GlobalBufferNum>(
1928 a_grid_desc_ak0_m_ak1,
1929 make_multi_index(0, m_block_data_idx_on_grid, 0),
1930 problem.a_element_op_,
1931 a_block_desc_ak0_m_ak1,
1932 make_multi_index(0, 0, 0),
1934
1935 // B matrix blockwise copy
1936 auto b_blockwise_copy =
1938 BElementwiseOperation,
1942 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1943 BBlockTransferThreadClusterArrangeOrder,
1944 BDataType,
1945 BDataType,
1946 decltype(b_grid_desc_bk0_n_bk1),
1947 decltype(b_block_desc_bk0_n_bk1),
1948 BBlockTransferSrcAccessOrder,
1950 BBlockTransferSrcVectorDim,
1951 2,
1952 BBlockTransferSrcScalarPerVector,
1953 BBlockTransferDstScalarPerVector_BK1,
1954 1,
1955 1,
1956 BThreadTransferSrcResetCoordinateAfterRun,
1957 true,
1958 BlockwiseGemmPipe::GlobalBufferNum>(
1959 b_grid_desc_bk0_n_bk1,
1960 make_multi_index(0, n_block_data_idx_on_grid, 0),
1961 problem.b_element_op_,
1962 b_block_desc_bk0_n_bk1,
1963 make_multi_index(0, 0, 0),
1965
1966 // LDS allocation for A and B: be careful of alignment
1967 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1968 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1969
1970 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1971 static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1972
1973 auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1974 bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
1975 a_block_space_size_aligned * sizeof(ADataType)),
1976 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1977
1978 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1979 static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1980
1981 auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1983 a_block_space_size_aligned * sizeof(ADataType)),
1984 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1985
1986 auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
1987 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
1988
1989 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1990 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1991
1992 // Blockwise GEMM pipeline
1993 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1994 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1995 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1996
1997 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1998 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1999 KPerBlock);
2000
2001 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
2002 a_block_desc_ak0_m_ak1,
2003 a_blockwise_copy,
2004 a_grid_buf,
2005 a_block_bufs,
2006 a_block_slice_copy_step,
2007 b_grid_desc_bk0_n_bk1,
2008 b_block_desc_bk0_n_bk1,
2009 b_blockwise_copy,
2010 b_grid_buf,
2011 b_block_bufs,
2012 b_block_slice_copy_step,
2013 c_thread_buf,
2014 num_k_block_main_loop);
2015
2016 // shuffle C and write out
2017 {
2018 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2019 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2020 "wrong!");
2021
2022 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2023 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2024
2025 // TODO: hacky, fix it!
2026 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2027 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2028
2029 // TODO: hacky, fix it!
2030 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2031 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2032 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2033
2034 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2035 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2036 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2037 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2038 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2039 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2040 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2041 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2042
2043 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2045
2046 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2047 static_cast<CShuffleDataType*>(p_shared_0),
2048 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2049
2050 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2051 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2052 make_tuple(
2055 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2056 M1, // M1 = MWave
2057 M2, // M2 * M3 * M4 = MPerXdl
2058 M3,
2059 M4)),
2062 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2063 N1, // N1 = NWave
2064 N2))), // N2 = NPerXdl
2066 make_tuple(
2068
2069 // calculate origin of thread output tensor on global memory
2070 // blockwise GEMM c matrix starting index
2071 const auto c_thread_mtx_on_block =
2072 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2073
2074 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2075 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2076
2077 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2079 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2082
2083 const auto m_thread_data_on_block_idx =
2084 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2085 make_multi_index(m_thread_data_on_block));
2086
2087 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2092
2093 const auto n_thread_data_on_block_idx =
2094 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2095 make_multi_index(n_thread_data_on_block));
2096
2097 // shuffle: threadwise copy C from VGPR to LDS
2098 auto c_thread_copy_vgpr_to_lds =
2100 CShuffleDataType,
2101 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2102 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2104 Sequence<CShuffleMXdlPerWavePerShuffle,
2105 CShuffleNXdlPerWavePerShuffle,
2106 I1,
2107 I1,
2108 M2,
2109 I1,
2110 M4,
2111 I1>,
2113 7,
2114 1,
2116 1,
2117 true>{
2118 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2120 0,
2121 m_thread_data_on_block_idx[I1],
2122 n_thread_data_on_block_idx[I1],
2123 m_thread_data_on_block_idx[I2],
2124 m_thread_data_on_block_idx[I3],
2125 m_thread_data_on_block_idx[I4],
2126 n_thread_data_on_block_idx[I2]),
2128
2129 // shuffle: blockwise copy C from LDS to global
2130 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
2131 ThisThreadBlock, // ThreadGroup
2132 CElementwiseOperation, // ElementwiseOperation,
2133 CGlobalMemoryDataOperation, // DstInMemOp,
2134 Sequence<1,
2135 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2136 1,
2137 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2138 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2139 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2140 CShuffleDataType, // typename SrcData,
2141 CDataType, // typename DstData,
2142 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2143 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2144 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
2145 3, // index_t VectorDim,
2146 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
2147 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
2148 false> // bool ThreadTransferDstResetCoordinateAfterRun>
2149 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2150 make_multi_index(0, 0, 0, 0),
2151 c_grid_desc_mblock_mperblock_nblock_nperblock,
2152 make_multi_index(block_m_id, 0, block_n_id, 0),
2153 problem.c_element_op_};
2154
2155 // space filling curve for threadwise C in VGPR
2156 constexpr auto sfc_c_vgpr =
2159 Sequence<CShuffleMXdlPerWavePerShuffle,
2160 CShuffleNXdlPerWavePerShuffle,
2161 1,
2162 1,
2163 M2,
2164 1,
2165 M4,
2166 1>>{};
2167
2168 // space filling curve for shuffled blockwise C in global mem
2169 constexpr auto sfc_c_global =
2172 Sequence<1,
2173 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2174 1,
2175 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2176
2177 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2178
2179 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
2180
2181 static_for<0, num_access, 1>{}([&](auto access_id) {
2182 // make sure it's safe to write to LDS
2184
2185 // each thread write its data from VGPR to LDS
2186 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2187 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2188 c_thread_buf,
2189 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2190 c_shuffle_block_buf);
2191
2192 // make sure it's safe to read from LDS
2194
2195 // each block copy its data from LDS to global
2196 c_shuffle_block_copy_lds_to_global.Run(
2197 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2198 c_shuffle_block_buf,
2199 c_grid_desc_mblock_mperblock_nblock_nperblock,
2200 c_grid_buf);
2201
2202 if constexpr(access_id < num_access - 1)
2203 {
2204 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2205
2206 // move on C
2207 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2208 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2209 }
2210 });
2211 }
2212 }
2213
2214 template <bool HasMainKBlockLoop,
2215 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2216 TailNumber TailNum = TailNumber::Odd>
2217 __device__ static void Run_2Lds(const ADataType* p_a_grid,
2218 const BDataType* p_b_grid,
2219 CDataType* p_c_grid,
2220 void* p_shared_0,
2221 void* p_shared_1,
2222 const Problem& problem)
2223 {
2224 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2225 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
2226 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2227 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2228 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
2229 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
2230
2231 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2233 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2234
2235 Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
2236 decltype(b_grid_desc_bk0_n_bk1),
2237 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2238 HasMainKBlockLoop,
2239 CGlobalMemoryDataOperation,
2240 TailNum>(p_a_grid,
2241 p_b_grid,
2242 p_c_grid,
2243 p_shared_0,
2244 p_shared_1,
2245 problem,
2246 a_grid_desc_ak0_m_ak1,
2247 b_grid_desc_bk0_n_bk1,
2248 c_grid_desc_mblock_mperblock_nblock_nperblock);
2249 }
2250};
2251
2252} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
ushort bhalf_t
Definition data_type.hpp:30
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
_Float16 half_t
Definition data_type.hpp:31
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
constexpr auto BlockGemmPipeline_Selector()
Definition blockwise_gemm_pipeline_wmma_selector.hpp:32
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
unsigned int uint32_t
Definition stdint.h:126
signed int int32_t
Definition stdint.h:123
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:716
const BElementwiseOperation b_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:643
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:717
const BDataType * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:759
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:760
__host__ __device__ bool IsReduceAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:748
const AElementwiseOperation a_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:642
__host__ __device__ bool IsAtomicAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:753
const ADataType * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:758
const CElementwiseOperation c_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:644
bool is_reduce
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:761
index_t N
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:695
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:702
index_t KBatch
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:700
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:697
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:642
CElementwiseOperation c_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:711
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:706
index_t M
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:708
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:701
index_t K
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:696
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:698
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:704
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:699
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:707
BElementwiseOperation b_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:710
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:705
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:703
AElementwiseOperation a_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:709
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:673
index_t a_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:814
index_t b_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:815
__device__ SplitKBatchOffset(Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:767
index_t c_reduce_offset
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:816
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:247
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, BlkGemmPipeSched, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1112
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1853
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:39
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition type.hpp:177
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129