gridwise_normalization_splitk_2nd.hpp Source File

gridwise_normalization_splitk_2nd.hpp Source File#

Composable Kernel: gridwise_normalization_splitk_2nd.hpp Source File
gridwise_normalization_splitk_2nd.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck/utility/math.hpp"
12
13namespace ck {
14
15template <typename MeanVarDataType,
16 typename XDataType,
17 typename GammaDataType,
18 typename BetaDataType,
19 typename YDataType,
20 typename SaveMeanInvStdDataType,
21 typename ComputeDataType,
22 typename YElementwiseOperation,
23 typename MeanVarGridDesc_M_KBlock,
24 typename CountGridDesc_M_KBlock,
25 typename XYGammaBetaGridDesc_M_K,
26 typename SaveMeanInvStdGridDesc_M,
27 index_t BlockSize,
28 index_t MThreadClusterSize,
29 index_t KThreadClusterSize,
30 index_t MThreadSliceSize,
31 index_t KThreadSliceSize,
32 index_t XSrcVectorDim,
33 index_t XSrcVectorSize,
34 index_t GammaSrcVectorDim,
35 index_t GammaSrcVectorSize,
36 index_t BetaSrcVectorDim,
37 index_t BetaSrcVectorSize,
38 index_t YDstVectorDim,
39 index_t YDstVectorSize,
40 index_t SaveMeanInvStdDstVectorSize>
42{
43 static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
44 (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
45 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
46
47 static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
48 (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
49 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
50
51 static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
52 "Invalid thread slice sizes and/or save mean and inverse std vector sizes "
53 "configuration, please check!");
54
55 static_assert(XSrcVectorSize == YDstVectorSize);
56 static_assert(XSrcVectorSize == GammaSrcVectorSize);
57 static_assert(XSrcVectorSize == BetaSrcVectorSize);
58
59 static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
60
61 static constexpr auto I0 = Number<0>{};
62 static constexpr auto I1 = Number<1>{};
63
65
68
71
72 static constexpr auto thread_cluster_desc =
74
78
80 static constexpr auto thread_buffer_desc_m =
82
84 static constexpr auto thread_buffer_desc_m_1 =
86
90
93
94 using BlockwiseWelford = BlockwiseWelford<ComputeDataType,
95 BlockSize,
98
100
101 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
102 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
103 static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;
104
105 static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
106
107 __device__ static void Run(const MeanVarGridDesc_M_KBlock& mean_var_grid_desc_m_kblock,
108 const CountGridDesc_M_KBlock& count_grid_desc_m_kblock,
109 const XYGammaBetaGridDesc_M_K& x_grid_desc_m_k,
110 const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k,
111 const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k,
112 const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k,
113 const SaveMeanInvStdGridDesc_M& save_mean_grid_desc_m,
114 const SaveMeanInvStdGridDesc_M& save_inv_std_grid_desc_m,
115 index_t num_k_mean_var_count_iteration,
116 index_t num_k_block_tile_iteration,
117 index_t k_grid_size,
118 ComputeDataType epsilon,
119 const MeanVarDataType* const p_mean_global,
120 const MeanVarDataType* const p_variance_global,
121 const int32_t* const p_welford_count_global,
122 const XDataType* const __restrict__ p_x_global,
123 const GammaDataType* const __restrict__ p_gamma_global,
124 const BetaDataType* const __restrict__ p_beta_global,
125 YDataType* const __restrict__ p_y_global,
126 SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
127 SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
128 const YElementwiseOperation y_elementwise_op)
129 {
130 // Thread/Block id
131 const index_t thread_local_id = get_thread_local_1d_id();
132 const index_t block_global_id = get_block_1d_id();
133 const index_t block_m_cluster_id = block_global_id / k_grid_size;
134 const index_t block_k_cluster_id = block_global_id % k_grid_size;
135 const auto thread_cluster_idx =
136 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
137
138 const auto thread_m_cluster_id = thread_cluster_idx[I0];
139 const auto thread_k_cluster_id = thread_cluster_idx[I1];
140
141 // Global Memory
142 const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
143 p_mean_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
144
145 const auto var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
146 p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize());
147
148 const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
149 p_welford_count_global, count_grid_desc_m_kblock.GetElementSpaceSize());
150
151 const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
152 p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
153
154 const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
155 p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
156
157 const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
158 p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
159
160 auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
161 p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
162
163 auto save_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
164 p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize());
165
166 auto save_inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
167 p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize());
168
169 // VGPR
171 in_mean_thread_buf;
173 in_var_thread_buf;
175 in_welford_count_thread_buf;
177 mean_thread_buf;
179 var_thread_buf;
181 welford_count_thread_buf;
182 auto& inv_std_thread_buf = var_thread_buf;
183
184 auto x_thread_buf = generate_tuple(
185 [&](auto) {
187 ComputeDataType,
188 MThreadSliceSize * XSrcVectorSize,
189 true>{};
190 },
192
193 auto gamma_thread_buf = generate_tuple(
194 [&](auto) {
196 ComputeDataType,
197 MThreadSliceSize * GammaSrcVectorSize,
198 true>{};
199 },
201
202 auto& beta_thread_buf = gamma_thread_buf;
203 auto& y_thread_buf = x_thread_buf;
204
205 // IO
206 auto threadwise_mean_var_load_m_kblock =
207 ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
208 ComputeDataType,
209 MeanVarGridDesc_M_KBlock,
210 decltype(thread_buffer_desc_m_1),
213 1,
214 1,
215 1,
216 true>(
217 mean_var_grid_desc_m_kblock,
218 make_multi_index(block_m_cluster_id * M_BlockTileSize +
219 thread_m_cluster_id * MThreadSliceSize,
220 thread_k_cluster_id));
221
222 auto threadwise_count_load_m_kblock =
224 int32_t,
225 CountGridDesc_M_KBlock,
226 decltype(thread_buffer_desc_m_1),
229 1,
230 1,
231 1,
232 true>(
233 count_grid_desc_m_kblock,
234 make_multi_index(block_m_cluster_id * M_BlockTileSize +
235 thread_m_cluster_id * MThreadSliceSize,
236 thread_k_cluster_id));
237
238 auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
239 ComputeDataType,
240 XYGammaBetaGridDesc_M_K,
241 decltype(thread_buffer_desc_m_k),
244 XSrcVectorDim,
245 XSrcVectorSize,
246 1,
247 true>(
248 x_grid_desc_m_k,
249 make_multi_index(block_m_cluster_id * M_BlockTileSize +
250 thread_m_cluster_id * MThreadSliceSize,
251 block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration +
252 thread_k_cluster_id * XSrcVectorSize));
253
254 auto threadwise_gamma_load =
256 ComputeDataType,
257 XYGammaBetaGridDesc_M_K,
258 decltype(thread_buffer_desc_m_k),
261 GammaSrcVectorDim,
262 GammaSrcVectorSize,
263 1,
264 true>(
265 gamma_grid_desc_m_k,
266 make_multi_index(block_m_cluster_id * M_BlockTileSize +
267 thread_m_cluster_id * MThreadSliceSize,
268 block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration +
269 thread_k_cluster_id * GammaSrcVectorSize));
270
271 auto threadwise_beta_load =
273 ComputeDataType,
274 XYGammaBetaGridDesc_M_K,
275 decltype(thread_buffer_desc_m_k),
278 BetaSrcVectorDim,
279 BetaSrcVectorSize,
280 1,
281 true>(
282 beta_grid_desc_m_k,
283 make_multi_index(block_m_cluster_id * M_BlockTileSize +
284 thread_m_cluster_id * MThreadSliceSize,
285 block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration +
286 thread_k_cluster_id * BetaSrcVectorSize));
287
288 auto threadwise_y_store =
290 YDataType,
291 decltype(thread_buffer_desc_m_k),
292 XYGammaBetaGridDesc_M_K,
293 YElementwiseOperation,
296 YDstVectorDim,
297 YDstVectorSize,
299 1,
300 true>(
301 y_grid_desc_m_k,
302 make_multi_index(block_m_cluster_id * M_BlockTileSize +
303 thread_m_cluster_id * MThreadSliceSize,
304 block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration +
305 thread_k_cluster_id * YDstVectorSize),
306 y_elementwise_op);
307
308 auto threadwise_mean_store =
310 SaveMeanInvStdDataType,
311 decltype(thread_buffer_desc_m),
312 SaveMeanInvStdGridDesc_M,
315 Sequence<0>, // DimAccessOrder
316 0, // SrcVectorDim
317 SaveMeanInvStdDstVectorSize, // ScalarPerVector
319 1,
320 true>(
321 save_mean_grid_desc_m,
322 make_multi_index(block_m_cluster_id * M_BlockTileSize +
323 thread_m_cluster_id * MThreadSliceSize),
324 PassThroughOp{});
325
326 auto threadwise_inv_std_store =
328 SaveMeanInvStdDataType,
329 decltype(thread_buffer_desc_m),
330 SaveMeanInvStdGridDesc_M,
333 Sequence<0>, // DimAccessOrder
334 0, // SrcVectorDim
335 SaveMeanInvStdDstVectorSize, // ScalarPerVector
337 1,
338 true>(
339 save_inv_std_grid_desc_m,
340 make_multi_index(block_m_cluster_id * M_BlockTileSize +
341 thread_m_cluster_id * MThreadSliceSize),
342 PassThroughOp{});
343
344 // step1: Merge mean and variance
345 constexpr auto mean_var_count_thread_copy_step_I0_k =
346 make_multi_index(I0, KThreadClusterSize);
347
349 mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
350 var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
351 welford_count_thread_buf(I) = 0;
352 });
353
354 for(index_t k = 0; k < num_k_mean_var_count_iteration; ++k)
355 {
356 threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock,
357 mean_global_val_buf,
359 make_tuple(I0, I0),
360 in_mean_thread_buf);
361
362 threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock,
363 var_global_val_buf,
365 make_tuple(I0, I0),
366 in_var_thread_buf);
367
368 threadwise_count_load_m_kblock.Run(count_grid_desc_m_kblock,
369 welford_count_global_val_buf,
371 make_tuple(I0, I0),
372 in_welford_count_thread_buf);
373
374 ThreadwiseWelford::Run(in_mean_thread_buf,
375 in_var_thread_buf,
376 in_welford_count_thread_buf,
377 mean_thread_buf,
378 var_thread_buf,
379 welford_count_thread_buf);
380
381 threadwise_mean_var_load_m_kblock.MoveSrcSliceWindow(
382 mean_var_grid_desc_m_kblock, mean_var_count_thread_copy_step_I0_k);
383 threadwise_count_load_m_kblock.MoveSrcSliceWindow(count_grid_desc_m_kblock,
384 mean_var_count_thread_copy_step_I0_k);
385 }
386
388 if constexpr(I > 0)
390
392 mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I));
393
394 inv_std_thread_buf(I) =
395 type_convert<ComputeDataType>(1.0f) / ck::math::sqrt(var_thread_buf(I) + epsilon);
396 });
397
398 // step2: save mean and inverse std for backward (optional)
399 if(block_k_cluster_id == 0 && thread_k_cluster_id == 0)
400 {
401 if(p_save_mean_global != nullptr)
402 {
403 threadwise_mean_store.Run(thread_buffer_desc_m,
404 make_tuple(I0),
405 mean_thread_buf,
406 save_mean_grid_desc_m,
407 save_mean_global_val_buf);
408 }
409 if(p_save_inv_std_global != nullptr)
410 {
411 threadwise_inv_std_store.Run(thread_buffer_desc_m,
412 make_tuple(I0),
413 inv_std_thread_buf,
414 save_inv_std_grid_desc_m,
415 save_inv_std_global_val_buf);
416 }
417 }
418
419 // step3: normalization
420 constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
421
422 for(index_t k = 0; k < num_k_block_tile_iteration; ++k)
423 {
425 threadwise_x_load.Run(x_grid_desc_m_k,
426 x_global_val_buf,
428 make_tuple(I0, I0),
429 x_thread_buf(i));
430 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
431 });
432
434 threadwise_gamma_load.Run(gamma_grid_desc_m_k,
435 gamma_global_val_buf,
437 make_tuple(I0, I0),
438 gamma_thread_buf(i));
439
440 threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
441 thread_copy_fwd_step_m_k);
442 });
443
446 static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
447 constexpr auto offset_m_k =
448 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
449
450 // normalize
451 y_thread_buf(iK0)(Number<offset_m_k>{}) =
452 (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
453 inv_std_thread_buf(iM);
454
455 // gamma
456 y_thread_buf(iK0)(Number<offset_m_k>{}) =
457 y_thread_buf(iK0)(Number<offset_m_k>{}) *
458 gamma_thread_buf(iK0)(Number<offset_m_k>{});
459 });
460 });
461 });
462
464 threadwise_beta_load.Run(beta_grid_desc_m_k,
465 beta_global_val_buf,
467 make_tuple(I0, I0),
468 beta_thread_buf(i));
469 threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
470 thread_copy_fwd_step_m_k);
471 });
472
475 static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
476 constexpr auto offset_m_k =
477 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
478
479 // beta
480 y_thread_buf(iK0)(Number<offset_m_k>{}) =
481 y_thread_buf(iK0)(Number<offset_m_k>{}) +
482 beta_thread_buf(iK0)(Number<offset_m_k>{});
483 });
484 });
485 });
486
488 threadwise_y_store.Run(thread_buffer_desc_m_k,
489 make_tuple(I0, I0),
490 y_thread_buf(i),
491 y_grid_desc_m_k,
492 y_global_val_buf);
493 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k);
494 });
495 } // end for (normalization)
496 }
497};
498
499} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
signed int int32_t
Definition stdint.h:123
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition gridwise_normalization_splitk_2nd.hpp:42
static __device__ void Run(const MeanVarGridDesc_M_KBlock &mean_var_grid_desc_m_kblock, const CountGridDesc_M_KBlock &count_grid_desc_m_kblock, const XYGammaBetaGridDesc_M_K &x_grid_desc_m_k, const XYGammaBetaGridDesc_M_K &gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K &beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K &y_grid_desc_m_k, const SaveMeanInvStdGridDesc_M &save_mean_grid_desc_m, const SaveMeanInvStdGridDesc_M &save_inv_std_grid_desc_m, index_t num_k_mean_var_count_iteration, index_t num_k_block_tile_iteration, index_t k_grid_size, ComputeDataType epsilon, const MeanVarDataType *const p_mean_global, const MeanVarDataType *const p_variance_global, const int32_t *const p_welford_count_global, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, SaveMeanInvStdDataType *const __restrict__ p_save_mean_global, SaveMeanInvStdDataType *const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op)
Definition gridwise_normalization_splitk_2nd.hpp:107
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:66
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition threadwise_welford.hpp:83
static __device__ void Run(const SrcMeanBufferType &src_mean_buf, const SrcVarBufferType &src_var_buf, const SrcCountBufferType &src_count_buf, DstMeanBufferType &dst_mean_buf, DstVarBufferType &dst_var_buf, DstCountBufferType &dst_count_buf)
Definition threadwise_welford.hpp:110
Definition utility/functional.hpp:100
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340