blockwise_gemm_pipeline_wmmaops.hpp Source File

blockwise_gemm_pipeline_wmmaops.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_wmmaops.hpp Source File
blockwise_gemm_pipeline_wmmaops.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10template <index_t BlockSize,
11 index_t MPerBlock,
12 index_t NPerBlock,
13 index_t KPerBlock,
14 index_t ABufferLoadWidth,
15 index_t BBufferLoadWidth,
16 index_t ALDSWriteWidth,
17 index_t BLDSWriteWidth,
18 index_t ALDSReadWidth,
19 index_t BLDSReadWidth,
20 index_t MRepeat,
21 index_t NRepeat,
22 index_t MPerWmma,
23 index_t NPerWmma,
24 index_t KPerWmma>
26{
27 static constexpr index_t WaveSize = 32;
28 static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerWmma);
29 static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerWmma);
30
31 static constexpr index_t A_LDS_Read_Width = ALDSReadWidth;
32 static constexpr index_t B_LDS_Read_Width = BLDSReadWidth;
33
35 MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
37 NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
38
39 static constexpr index_t A_LDS_Write_Inst_Num =
40 MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
41 static constexpr index_t B_LDS_Write_Inst_Num =
42 NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
43
44 static constexpr index_t A_LDS_Read_Inst_Num =
45 WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
46 static constexpr index_t B_LDS_Read_Inst_Num =
47 WaveNumM * NPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
48
49 static constexpr index_t C_WMMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
50 (BlockSize / WaveSize) /
51 (MPerWmma * NPerWmma * KPerWmma);
52
53 static constexpr auto Print()
54 {
55 printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerWmma: %d, %d, %d\n",
56 BlockSize,
58 MPerBlock,
59 NPerBlock,
60 KPerBlock,
61 MPerWmma,
62 NPerWmma,
63 KPerWmma);
64
65 printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
66 "%d, %d\n C WMMA inst: %d\n"
67 "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
68 "%d, %d\n",
78 ALDSWriteWidth,
79 BLDSWriteWidth,
80 ABufferLoadWidth,
81 BBufferLoadWidth);
82 }
83};
84
85} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition blockwise_gemm_pipeline_wmmaops.hpp:26
static constexpr auto Print()
Definition blockwise_gemm_pipeline_wmmaops.hpp:53