block_gemm_areg_breg_creg_v2.hpp Source File

block_gemm_areg_breg_creg_v2.hpp Source File#

Composable Kernel: block_gemm_areg_breg_creg_v2.hpp Source File
block_gemm_areg_breg_creg_v2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// This BlockGemm enhanced the control over inst issue order
12// A is block distributed tensor
13// B is block distributed tensor
14// C is block distributed tensor
15template <typename Problem_, typename Policy_>
17{
18 private:
19 template <typename PipelineProblem_, typename GemmPolicy_>
20 struct GemmTraits_
21 {
28
29 static constexpr index_t kBlockSize = Problem::kBlockSize;
30
31 static constexpr index_t MPerBlock = BlockGemmShape::kM;
32 static constexpr index_t NPerBlock = BlockGemmShape::kN;
33 static constexpr index_t KPerBlock = BlockGemmShape::kK;
34
35 static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
36 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
37
38 static constexpr index_t MWarp = config.template at<1>();
39 static constexpr index_t NWarp = config.template at<2>();
40 static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
41 static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
42 static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
43
44 static constexpr auto BlockGemmLoopOrder = Policy::BlockGemmLoopOrder;
45
46 static constexpr index_t KPack = WarpGemm::kKPerThread;
47 };
48
49 public:
52
53 using Traits = GemmTraits_<Problem, Policy>;
54
55 using WarpGemm = typename Traits::WarpGemm;
56 using BlockGemmShape = typename Traits::BlockGemmShape;
57 static constexpr auto BlockGemmLoopOrder = Traits::BlockGemmLoopOrder;
58
62
63 static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
64 static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
65 static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
66
67 static constexpr index_t MWarp = Traits::MWarp;
68 static constexpr index_t NWarp = Traits::NWarp;
69 static constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
70
72 {
73 if constexpr(UseDefaultScheduler)
74 {
75 constexpr auto a_block_outer_dstr_encoding =
78 tuple<>,
79 tuple<>,
82
83 constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
84 a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
85
86 return a_block_dstr_encode;
87 }
88 else
89 {
91 {
92 constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
99
100 constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
101 a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
102
103 return a_block_dstr_encode;
104 }
105 else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
106 {
107 constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
114
115 constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
116 a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
117
118 return a_block_dstr_encode;
119 }
120 }
121 }
122
124 {
125 if constexpr(UseDefaultScheduler)
126 {
127 constexpr auto b_block_outer_dstr_encoding =
130 tuple<>,
131 tuple<>,
134 constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
135 b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
136
137 return b_block_dstr_encode;
138 }
139 else
140 {
142 {
143 constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
150 constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
151 b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
152
153 return b_block_dstr_encode;
154 }
155 else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
156 {
157 constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
164 constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
165 b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
166 return b_block_dstr_encode;
167 }
168 }
169 }
170
172 {
173 if constexpr(UseDefaultScheduler)
174 {
175 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
178 tuple<>,
179 tuple<>,
182 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
183 c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
184
185 return c_block_dstr_encode;
186 }
187 else
188 {
189 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
196 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
197 c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
198
199 return c_block_dstr_encode;
200 }
201 }
202
203 // C += A * B
204 template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
205 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
206 const ABlockTensor& a_block_tensor,
207 const BBlockTensor& b_block_tensor) const
208 {
209 static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
210 std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
211 std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
212 "wrong!");
213
214 // check ABC-block-distribution
215 static_assert(
216 std::is_same_v<remove_cvref_t<decltype(MakeABlockDistributionEncode())>,
217 remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
218 .get_static_tile_distribution_encoding())>>,
219 "A distribution is wrong!");
220 static_assert(
221 std::is_same_v<remove_cvref_t<decltype(MakeBBlockDistributionEncode())>,
222 remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
223 .get_static_tile_distribution_encoding())>>,
224 "B distribution is wrong!");
225 static_assert(
226 std::is_same_v<remove_cvref_t<decltype(MakeCBlockDistributionEncode())>,
227 remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
228 .get_static_tile_distribution_encoding())>>,
229 "C distribution is wrong!");
230
231 using AWarpDstr = typename WarpGemm::AWarpDstr;
232 using BWarpDstr = typename WarpGemm::BWarpDstr;
233 using CWarpDstr = typename WarpGemm::CWarpDstr;
234
235 using AWarpTensor = typename WarpGemm::AWarpTensor;
236 using BWarpTensor = typename WarpGemm::BWarpTensor;
237 using CWarpTensor = typename WarpGemm::CWarpTensor;
238
239 constexpr auto a_warp_y_lengths =
240 to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
241 constexpr auto b_warp_y_lengths =
242 to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
243 constexpr auto c_warp_y_lengths =
244 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
245
246 constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
247 constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
248 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
249
250 // hot loop:
252 {
253 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
254 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
255 // read A warp tensor from A Block window
256 AWarpTensor a_warp_tensor;
257 a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
258 merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
259 merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
260
261 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
262 // read B warp tensor from B block tensor
263 BWarpTensor b_warp_tensor;
264 b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
265 merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
266 merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
267
268 CWarpTensor c_warp_tensor;
269 c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
270 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
271 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
272
273 // warp GEMM
274 WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
275
276 // write C warp tensor into C block tensor
277 c_block_tensor.set_y_sliced_thread_data(
278 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
279 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
280 c_warp_tensor.get_thread_buffer());
281 });
282 });
283 });
284 }
285 else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
286 {
287 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
288 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
289 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
290 // read A warp tensor from A Block window
291 AWarpTensor a_warp_tensor;
292
293 a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
294 merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
295 merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
296
297 // read B warp tensor from B block tensor
298 BWarpTensor b_warp_tensor;
299
300 b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
301 merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
302 merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
303
304 // read C warp tensor from C block tensor
305 CWarpTensor c_warp_tensor;
306
307 c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
308 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
309 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
310
311 // warp GEMM
312 WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
313
314 // write C warp tensor into C block tensor
315 c_block_tensor.set_y_sliced_thread_data(
316 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
317 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
318 c_warp_tensor.get_thread_buffer());
319 });
320 });
321 });
322 }
323 }
324
325 CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
326 {
327 if constexpr(UseDefaultScheduler)
328 {
329 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
332 tuple<>,
333 tuple<>,
336
337 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
338 c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
339 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
340 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
341 return c_block_tensor;
342 }
343 else
344 {
345 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
352
353 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
354 c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
355 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
356 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
357 return c_block_tensor;
358 }
359 }
360
361 // C = A * B
362 template <typename ABlockTensor, typename BBlockTensor>
363 CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
364 const BBlockTensor& b_block_tensor) const
365 {
366 auto c_block_tensor = MakeCBlockTile();
367 operator()(c_block_tensor, a_block_tensor, b_block_tensor);
368 return c_block_tensor;
369 }
370};
371
372} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
@ KMN
Definition block_gemm_areg_breg_creg_v2_custom_policy.hpp:12
@ MNK
Definition block_gemm_areg_breg_creg_v2_custom_policy.hpp:13
Definition block_gemm_areg_breg_creg_v2.hpp:17
remove_cvref_t< typename Traits::BDataType > BDataType
Definition block_gemm_areg_breg_creg_v2.hpp:60
static CK_TILE_DEVICE constexpr auto MakeBBlockDistributionEncode()
Definition block_gemm_areg_breg_creg_v2.hpp:123
static constexpr index_t NIterPerWarp
Definition block_gemm_areg_breg_creg_v2.hpp:65
remove_cvref_t< Problem_ > Problem
Definition block_gemm_areg_breg_creg_v2.hpp:50
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition block_gemm_areg_breg_creg_v2.hpp:325
remove_cvref_t< typename Traits::ADataType > ADataType
Definition block_gemm_areg_breg_creg_v2.hpp:59
static constexpr bool UseDefaultScheduler
Definition block_gemm_areg_breg_creg_v2.hpp:69
typename Traits::BlockGemmShape BlockGemmShape
Definition block_gemm_areg_breg_creg_v2.hpp:56
remove_cvref_t< Policy_ > Policy
Definition block_gemm_areg_breg_creg_v2.hpp:51
remove_cvref_t< typename Traits::CDataType > CDataType
Definition block_gemm_areg_breg_creg_v2.hpp:61
static CK_TILE_DEVICE constexpr auto MakeCBlockDistributionEncode()
Definition block_gemm_areg_breg_creg_v2.hpp:171
static constexpr index_t MIterPerWarp
Definition block_gemm_areg_breg_creg_v2.hpp:64
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, const ABlockTensor &a_block_tensor, const BBlockTensor &b_block_tensor) const
Definition block_gemm_areg_breg_creg_v2.hpp:205
static constexpr index_t NWarp
Definition block_gemm_areg_breg_creg_v2.hpp:68
static CK_TILE_DEVICE constexpr auto MakeABlockDistributionEncode()
Definition block_gemm_areg_breg_creg_v2.hpp:71
typename Traits::WarpGemm WarpGemm
Definition block_gemm_areg_breg_creg_v2.hpp:55
CK_TILE_DEVICE auto operator()(const ABlockTensor &a_block_tensor, const BBlockTensor &b_block_tensor) const
Definition block_gemm_areg_breg_creg_v2.hpp:363
GemmTraits_< Problem, Policy > Traits
Definition block_gemm_areg_breg_creg_v2.hpp:53
static constexpr index_t MWarp
Definition block_gemm_areg_breg_creg_v2.hpp:67
static constexpr index_t KIterPerWarp
Definition block_gemm_areg_breg_creg_v2.hpp:63
static constexpr auto BlockGemmLoopOrder
Definition block_gemm_areg_breg_creg_v2.hpp:57
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192