block_gemm_asmem_breg_creg_v1.hpp Source File

block_gemm_asmem_breg_creg_v1.hpp Source File#

Composable Kernel: block_gemm_asmem_breg_creg_v1.hpp Source File
block_gemm_asmem_breg_creg_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// A is block window on shared memory
12// B is block distributed tensor
13// C is block distributed tensor
14template <typename Problem_, typename Policy_ = BlockGemmASmemBRegCRegV1DefaultPolicy>
16{
23
24 static constexpr index_t kBlockSize = Problem::kBlockSize;
25
26 // C += A * B
27 template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockTensorTmp>
28 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
29 const ABlockWindowTmp& a_block_window_tmp,
30 const BBlockTensorTmp& b_block_tensor_tmp) const
31 {
32 static_assert(
33 std::is_same_v<ADataType, remove_cv_t<typename ABlockWindowTmp::DataType>> &&
34 std::is_same_v<BDataType, remove_cv_t<typename BBlockTensorTmp::DataType>> &&
35 std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
36 "wrong!");
37
38 constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
39 constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}];
40 constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
41
42 static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
43 KPerBlock == BlockGemmShape::kK,
44 "wrong!");
45
46 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
47
48 using WG = remove_cvref_t<decltype(config.template at<0>())>;
49
50 constexpr index_t MWarp = config.template at<1>();
51 constexpr index_t NWarp = config.template at<2>();
52
53 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
54 constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
55 constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
56
57 constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
58 constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
59
60 const index_t iMWarp = get_warp_id() / NWarp;
61
62 constexpr auto b_block_outer_dstr_encoding =
69
70 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
77
78 constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
79 b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
80
81 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
82 c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
83
84 constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode);
85
86 // constrcut from B-block-tensor from B-Block-tensor-tmp
87 // FIXME: need method to check b_block_tensor and b_block_tensor_tmp have equivalent
88 // distribution
89 auto b_block_tensor =
91
92 b_block_tensor.get_thread_buffer() = b_block_tensor_tmp.get_thread_buffer();
93
94 // construct A-warp-window
95 auto a_warp_window_tmp = make_tile_window(
96 a_block_window_tmp.get_bottom_tensor_view(),
98 a_block_window_tmp.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
99 make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
100
101#if 0 // FIXME: using array will cause register spill
102 array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
103 {b_warp_window_tmp}};
104
105 for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
106 {
107 for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
108 {
109 move_tile_window(b_warp_windows(nIter)(kIter),
110 {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
111 }
112 }
113#else
115 statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
116 MIterPerWarp>
117 a_warp_windows;
118
119 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
120 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
121 a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
122
123 move_tile_window(a_warp_windows(mIter)(kIter),
124 {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
125 });
126 });
127#endif
128
129 // check C-block-distribution
130 static_assert(
131 std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
132 remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
133 .get_static_tile_distribution_encoding())>>,
134 "wrong!");
135
136 using BWarpDstr = typename WG::BWarpDstr;
137 using CWarpDstr = typename WG::CWarpDstr;
138
139 using BWarpTensor = typename WG::BWarpTensor;
140 using CWarpTensor = typename WG::CWarpTensor;
141
142 constexpr auto b_warp_y_lengths =
143 to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
144 constexpr auto c_warp_y_lengths =
145 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
146
147 constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
148 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
149
150 // hot loop:
151 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
152 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
153 // read A warp tensor from A Block window
154 const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
155 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
156 // read B warp tensor from B block tensor
157 BWarpTensor b_warp_tensor;
158
159 b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
160 merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
161 merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
162
163 // read C warp tensor from C block tensor
164 CWarpTensor c_warp_tensor;
165
166 c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
167 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
168 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
169
170 // warp GEMM
171 WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
172
173 // write C warp tensor into C block tensor
174 c_block_tensor.set_y_sliced_thread_data(
175 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
176 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
177 c_warp_tensor.get_thread_buffer());
178 });
179 });
180 });
181 }
182
183 CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
184 {
185 constexpr index_t MPerBlock = BlockGemmShape::kM;
186 constexpr index_t NPerBlock = BlockGemmShape::kN;
187
188 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
189
190 using WG = remove_cvref_t<decltype(config.template at<0>())>;
191
192 constexpr index_t MWarp = config.template at<1>();
193 constexpr index_t NWarp = config.template at<2>();
194
195 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
196 constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
197 // constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
198
199 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
206
207 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
208 c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
209 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
210 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
211 return c_block_tensor;
212 }
213
214 // C = A * B
215 template <typename ABlockWindowTmp, typename BBlockTensorTmp>
216 CK_TILE_DEVICE auto operator()(const ABlockWindowTmp& a_block_window_tmp,
217 const BBlockTensorTmp& b_block_tensor_tmp) const
218 {
219 auto c_block_tensor = MakeCBlockTile();
220 operator()(c_block_tensor, a_block_window_tmp, b_block_tensor_tmp);
221 return c_block_tensor;
222 }
223};
224
225} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
array< index_t, N > multi_index
Definition tile/core/container/multi_index.hpp:17
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
Definition block_gemm_asmem_breg_creg_v1.hpp:16
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition block_gemm_asmem_breg_creg_v1.hpp:183
remove_cvref_t< Problem_ > Problem
Definition block_gemm_asmem_breg_creg_v1.hpp:17
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, const ABlockWindowTmp &a_block_window_tmp, const BBlockTensorTmp &b_block_tensor_tmp) const
Definition block_gemm_asmem_breg_creg_v1.hpp:28
static constexpr index_t kBlockSize
Definition block_gemm_asmem_breg_creg_v1.hpp:24
remove_cvref_t< Policy_ > Policy
Definition block_gemm_asmem_breg_creg_v1.hpp:18
remove_cvref_t< typename Problem::CDataType > CDataType
Definition block_gemm_asmem_breg_creg_v1.hpp:21
remove_cvref_t< typename Problem::ADataType > ADataType
Definition block_gemm_asmem_breg_creg_v1.hpp:19
remove_cvref_t< typename Problem::BDataType > BDataType
Definition block_gemm_asmem_breg_creg_v1.hpp:20
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition block_gemm_asmem_breg_creg_v1.hpp:22
CK_TILE_DEVICE auto operator()(const ABlockWindowTmp &a_block_window_tmp, const BBlockTensorTmp &b_block_tensor_tmp) const
Definition block_gemm_asmem_breg_creg_v1.hpp:216
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192