gemm_pipeline_agmem_bgmem_creg_v2.hpp Source File

gemm_pipeline_agmem_bgmem_creg_v2.hpp Source File#

Composable Kernel: gemm_pipeline_agmem_bgmem_creg_v2.hpp Source File
gemm_pipeline_agmem_bgmem_creg_v2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12// A Tile Window: global memory
13// B Tile Window: global memory
14// C Distributed tensor: register
15template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV2DefaultPolicy>
17{
21
25
29
32
35
36 static constexpr index_t APackedSize =
38 static constexpr index_t BPackedSize =
40
41 static constexpr index_t BlockSize = Problem::kBlockSize;
42
43 static constexpr index_t kMPerBlock = BlockGemmShape::kM;
44 static constexpr index_t kNPerBlock = BlockGemmShape::kN;
45 static constexpr index_t kKPerBlock = BlockGemmShape::kK;
46
47 template <bool IsWave32Host = false>
48 static constexpr index_t GetVectorSizeA()
49 {
50 return Problem::VectorSizeA;
51 }
52 template <bool IsWave32Host = false>
53 static constexpr index_t GetVectorSizeB()
54 {
55 return Problem::VectorSizeB;
56 }
57 static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
58
59 static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
60 static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
61
62 static constexpr bool kPadM = Problem::kPadM;
63 static constexpr bool kPadN = Problem::kPadN;
64 static constexpr bool kPadK = Problem::kPadK;
65
66 static constexpr bool Preshuffle = Problem::Preshuffle;
67
68 static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
69
70 // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
71 static constexpr bool DoubleSmemBuffer = false;
72
73 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
74 {
75 // clang-format off
76 return concat('_', "pipeline_AGmemBGmemCRegV2",
78 // clang-format on
79 }
80 CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
81
83 {
84 return integer_divide_ceil(sizeof(ADataType) *
85 Policy::template MakeALdsBlockDescriptor<Problem>()
86 .get_element_space_size() /
88 16) *
89 16 +
90 sizeof(BDataType) *
91 Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size() /
93 }
94
96 {
97 return Policy::template GetSmemSize<Problem>();
98 }
99
100 template <typename AsDramBlockWindowTmp,
101 typename BsDramBlockWindowTmp,
102 typename AElementFunction,
103 typename BElementFunction,
104 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
106 bool>* = nullptr>
107 CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
108 const AElementFunction& a_element_func,
109 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
110 const BElementFunction& b_element_func,
111 index_t num_loop,
112 void* p_smem) const
113 {
114
115 using ADramBlockWindowTmp =
116 remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
117 using BDramBlockWindowTmp =
118 remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
119
120 static_assert(
121 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
122 std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
123 "wrong!");
124
125 static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
126 kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
127 kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
128 "wrong!");
129
130 // A tile in LDS
131 ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
132
133 constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
134
135 auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
136
137 constexpr index_t a_lds_block_space_size_aligned =
139 sizeof(ADataType) * a_lds_block_desc.get_element_space_size() / APackedSize, 16) *
140 16;
141
142 // B tile in LDS
143 BDataType* p_b_lds = static_cast<BDataType*>(
144 static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
145
146 constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
147
148 auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
149
150 // A DRAM tile window for load
151 auto as_copy_dram_window = generate_tuple(
152 [&](auto idx) {
153 return make_tile_window(
154 a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
156 a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
157 Policy::template MakeADramTileDistribution<Problem>());
158 },
159 number<AsLayout::size()>{});
160
161 // A LDS tile window for store
162 auto a_copy_lds_window =
163 make_tile_window(a_lds_block,
165 {0, 0},
166 as_copy_dram_window[number<0>{}].get_tile_distribution());
167
168 // B DRAM tile window for load
169 auto bs_copy_dram_window = generate_tuple(
170 [&](auto idx) {
171 return make_tile_window(
172 b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
174 b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
175 Policy::template MakeBDramTileDistribution<Problem>());
176 },
177 number<BsLayout::size()>{});
178
179 // B LDS tile window for store
180 auto b_copy_lds_window =
181 make_tile_window(b_lds_block,
183 {0, 0},
184 bs_copy_dram_window[number<0>{}].get_tile_distribution());
185
186 // Block GEMM
187 constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
188
189 // Tile distribution for load from lds
190 constexpr auto a_lds_load_tile_distr =
191 make_static_tile_distribution(decltype(block_gemm)::MakeABlockDistributionEncode());
192 constexpr auto b_lds_load_tile_distr =
193 make_static_tile_distribution(decltype(block_gemm)::MakeBBlockDistributionEncode());
194
195 // A LDS tile for block GEMM
196 auto a_lds_gemm_window =
197 make_tile_window(a_lds_block,
199 {0, 0},
200 a_lds_load_tile_distr);
201
202 // B LDS tile for block GEMM
203 auto b_lds_gemm_window =
204 make_tile_window(b_lds_block,
206 {0, 0},
207 b_lds_load_tile_distr);
208
209 // Acc register tile
210 auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){};
211
212 // prefetch
213 // global read 0
214 // Load tile — during value loading, an elementwise function is executed for each A0,
215 // A1, … AN. The values A0, A1, … AN are read by the same thread.
216 auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
217 // Load tile — during value loading, an elementwise function is executed for each B0,
218 // B1, … BN. The values B0, B1, … BN are read by the same thread.
219 auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
220
221 {
222 // move to 1
223 move_tile_window(as_copy_dram_window, {0, kKPerBlock});
224 move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
225
226 // initialize C
227 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
228
229 // LDS write 0
230 store_tile(a_copy_lds_window, elementwise_As_res);
231 // global read 1
232 elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
233
234 // LDS write 0
235 store_tile(b_copy_lds_window, elementwise_Bs_res);
236 // global read 1
237 elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
238 }
239
240 index_t iCounter = num_loop - 2;
241
242 do
243 {
245
246 // GEMM i
247 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
248
250
251 // move to i + 2
252 move_tile_window(as_copy_dram_window, {0, kKPerBlock});
253 move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
254
255 // LDS write i + 1
256 store_tile(a_copy_lds_window, elementwise_As_res);
257 // global read i + 2
258 elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func);
259
260 // LDS write i + 1
261 store_tile(b_copy_lds_window, elementwise_Bs_res);
262 // global read i + 2
263 elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
264
265 iCounter--;
266
267 } while(iCounter > 0);
268
269 // tail
270 {
272
273 // GEMM num_loop - 2
274 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
275
277
278 // LDS write num_loop - 1
279 store_tile(a_copy_lds_window, elementwise_As_res);
280
281 store_tile(b_copy_lds_window, elementwise_Bs_res);
282
284
285 // GEMM num_loop - 1
286 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
287 }
288
289 return c_block_tile;
290 }
291
292 template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
293 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
294 const BDramBlockWindowTmp& b_dram_block_window_tmp,
295 index_t num_loop,
296 void* p_smem) const
297 {
298 return operator()(
299 a_dram_block_window_tmp,
300 [](auto& e, const ADataType & a) { e = a; },
301 b_dram_block_window_tmp,
302 [](auto& e, const BDataType & b) { e = b; },
303 num_loop,
304 p_smem);
305 }
306
307 template <typename ADramBlockWindowTmp,
308 typename BDramBlockWindowTmp,
309 typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
311 bool>* = nullptr>
312 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
313 const BDramBlockWindowTmp& b_dram_block_window_tmp,
314 index_t num_loop,
315 void* p_smem) const
316 {
317 return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
318 ck_tile::make_tuple(b_dram_block_window_tmp),
319 num_loop,
320 p_smem);
321 }
322};
323
324} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_ &tile_window, ElementWise_ elementwise, number< i_access >={}, bool_constant< oob_conditional_check >={})
Load tile with elementwise function.
Definition load_tile.hpp:41
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:17
static constexpr bool kPadM
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:62
remove_cvref_t< typename Problem::BsDataTypeTuple > BsDataType
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:19
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:24
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BDramBlockWindowTmp &b_dram_block_window_tmp, index_t num_loop, void *p_smem) const
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:312
static constexpr bool Preshuffle
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:66
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BDramBlockWindowTmp &b_dram_block_window_tmp, index_t num_loop, void *p_smem) const
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:293
static constexpr index_t GetVectorSizeB()
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:53
static constexpr bool DoubleSmemBuffer
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:71
remove_cvref_t< std::tuple_element_t< 0, BsDataType > > BDataType
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:34
static constexpr bool kPadK
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:64
remove_cvref_t< typename Problem::AElementWise > AElementWise
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:22
static constexpr index_t kNPerBlock
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:44
static constexpr bool kPadN
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:63
remove_cvref_t< typename Problem::AsLayoutTuple > AsLayout
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:26
static constexpr index_t BPackedSize
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:38
static constexpr index_t kMPerBlock
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:43
remove_cvref_t< typename Problem::CDataType > CDataType
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:20
static constexpr index_t APackedSize
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:36
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BsDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *p_smem) const
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:107
static CK_TILE_HOST_DEVICE constexpr index_t GetStaticLdsSize()
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:82
remove_cvref_t< typename Problem::BsLayoutTuple > BsLayout
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:27
static constexpr index_t kKPerBlock
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:45
static constexpr index_t GetVectorSizeC()
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:57
static constexpr index_t NumWaveGroups
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:68
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:95
remove_cvref_t< typename Problem::AsDataTypeTuple > AsDataType
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:18
remove_cvref_t< std::tuple_element_t< 0, AsLayout > > ALayout
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:30
remove_cvref_t< typename Problem::BElementWise > BElementWise
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:23
static constexpr index_t GetSmemPackB()
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:60
remove_cvref_t< std::tuple_element_t< 0, BsLayout > > BLayout
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:31
remove_cvref_t< std::tuple_element_t< 0, AsDataType > > ADataType
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:33
remove_cvref_t< typename Problem::CLayout > CLayout
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:28
static CK_TILE_HOST const std::string GetName()
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:73
static constexpr index_t GetVectorSizeA()
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:48
static constexpr index_t BlockSize
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:41
static constexpr index_t GetSmemPackA()
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:59
static CK_TILE_HOST_DEVICE constexpr auto TransposeC()
Definition gemm_pipeline_agmem_bgmem_creg_v2.hpp:80
Definition tile/core/numeric/numeric.hpp:81