gridwise_batchnorm_forward_blockwise_welford.hpp Source File

gridwise_batchnorm_forward_blockwise_welford.hpp Source File#

Composable Kernel: gridwise_batchnorm_forward_blockwise_welford.hpp Source File
gridwise_batchnorm_forward_blockwise_welford.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
13namespace ck {
14
15template <typename GridwiseBatchrNormForwardWithBlockwiseWelford_,
16 typename XDataType,
17 typename YDataType,
18 typename AccDataType,
19 typename ScaleDataType,
20 typename BiasDataType,
21 typename MeanVarDataType,
22 typename YElementwiseOp,
23 typename XYGridDesc_M_K,
24 typename ScaleBiasGridDesc_M,
25 typename MeanVarGridDesc_M,
26 typename GetReduceCountPerThreadFunctor>
28 const XYGridDesc_M_K x_grid_desc_m_k,
29 const XYGridDesc_M_K y_grid_desc_m_k,
30 const ScaleBiasGridDesc_M scale_grid_desc_m,
31 const ScaleBiasGridDesc_M bias_grid_desc_m,
32 const MeanVarGridDesc_M mean_var_grid_desc_m,
33 const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
34 index_t num_k_block_tile_iteration,
35 AccDataType epsilon,
36 const XDataType* const __restrict__ p_x,
37 const ScaleDataType* const __restrict__ p_scale,
38 const BiasDataType* const __restrict__ p_bias,
39 const YElementwiseOp y_elementwise_op,
40 YDataType* const __restrict__ p_y,
41 bool updateMovingAverage,
42 AccDataType averageFactor,
43 MeanVarDataType* const __restrict__ resultRunningMean,
44 MeanVarDataType* const __restrict__ resultRunningVariance,
45 bool saveMeanInvVariance,
46 MeanVarDataType* const __restrict__ resultSaveMean,
47 MeanVarDataType* const __restrict__ resultSaveInvVariance)
48{
49 GridwiseBatchrNormForwardWithBlockwiseWelford_::Run(x_grid_desc_m_k,
50 y_grid_desc_m_k,
51 scale_grid_desc_m,
52 bias_grid_desc_m,
53 mean_var_grid_desc_m,
54 get_reduce_count_per_thread,
55 num_k_block_tile_iteration,
56 epsilon,
57 p_x,
58 p_scale,
59 p_bias,
60 y_elementwise_op,
61 p_y,
62 updateMovingAverage,
63 averageFactor,
64 resultRunningMean,
65 resultRunningVariance,
66 saveMeanInvVariance,
67 resultSaveMean,
68 resultSaveInvVariance);
69};
70
71template <typename XDataType,
72 typename YDataType,
73 typename AccDataType,
74 typename ScaleDataType,
75 typename BiasDataType,
76 typename MeanVarDataType,
77 typename YElementwiseOp,
78 typename XYGridDesc_M_K,
79 typename ScaleBiasGridDesc_M,
80 typename MeanVarGridDesc_M,
81 typename GetReduceCountPerThreadFunctor,
82 index_t BlockSize,
83 index_t MThreadClusterSize,
84 index_t KThreadClusterSize,
85 index_t MThreadSliceSize,
86 index_t KThreadSliceSize,
87 index_t XSrcYDstVectorDim,
88 index_t XSrcVectorSize,
89 index_t YDstVectorSize,
90 index_t ScaleSrcVectorSize,
91 index_t BiasSrcVectorSize,
92 index_t MeanVarSrcDstVectorSize>
94{
95 static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
96 (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
97 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
98
99 static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
100 (XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
101 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
102
103 static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0);
104
106
109
112
113 static constexpr auto thread_cluster_desc =
115
120
123
125 BlockSize,
128
130
131 static constexpr auto I0 = Number<0>{};
132 static constexpr auto I1 = Number<1>{};
133
134 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
135 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
136
137 __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
138 const XYGridDesc_M_K& y_grid_desc_m_k,
139 const ScaleBiasGridDesc_M& scale_grid_desc_m,
140 const ScaleBiasGridDesc_M& bias_grid_desc_m,
141 const MeanVarGridDesc_M& mean_var_grid_desc_m,
142 const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
143 index_t num_k_block_tile_iteration,
144 AccDataType epsilon,
145 const XDataType* const __restrict__ p_x,
146 const ScaleDataType* const __restrict__ p_scale,
147 const BiasDataType* const __restrict__ p_bias,
148 const YElementwiseOp y_elementwise_op,
149 YDataType* const __restrict__ p_y,
150 bool updateMovingAverage,
151 AccDataType averageFactor,
152 MeanVarDataType* const __restrict__ resultRunningMean,
153 MeanVarDataType* const __restrict__ resultRunningVariance,
154 bool saveMeanInvVariance,
155 MeanVarDataType* const __restrict__ resultSaveMean,
156 MeanVarDataType* const __restrict__ resultSaveInvVariance)
157 {
158 using ck::math::sqrt;
159
161 x_thread_buf;
162
164
166
168 y_thread_buf;
169
172
173 const index_t thread_local_id = get_thread_local_1d_id();
174 const index_t block_global_id = get_block_1d_id();
175
176 const auto thread_cluster_idx =
177 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
178
179 const auto thread_m_cluster_id = thread_cluster_idx[I0];
180 const auto thread_k_cluster_id = thread_cluster_idx[I1];
181
182 using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
183 using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
184 constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
186 constexpr auto thread_buffer_desc_m =
188
189 auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
190 AccDataType,
191 XYGridDesc_M_K,
192 decltype(thread_buffer_desc_m_k),
193 ThreadBufferLengths_M_K,
195 XSrcYDstVectorDim,
196 XSrcVectorSize,
197 1,
198 true>(
199 x_grid_desc_m_k,
200 make_multi_index(block_global_id * M_BlockTileSize +
201 thread_m_cluster_id * MThreadSliceSize,
202 thread_k_cluster_id * KThreadSliceSize));
203
204 auto threadwise_y_store =
206 YDataType,
207 decltype(thread_buffer_desc_m_k),
208 XYGridDesc_M_K,
209 YElementwiseOp,
210 ThreadBufferLengths_M_K,
212 XSrcYDstVectorDim,
213 YDstVectorSize,
215 1,
216 true>(
217 y_grid_desc_m_k,
218 make_multi_index(block_global_id * M_BlockTileSize +
219 thread_m_cluster_id * MThreadSliceSize,
220 thread_k_cluster_id * KThreadSliceSize),
221 y_elementwise_op);
222
223 auto threadwise_scale_load =
225 AccDataType,
226 ScaleBiasGridDesc_M,
227 decltype(thread_buffer_desc_m),
228 ThreadBufferLengths_M,
230 0,
231 ScaleSrcVectorSize,
232 1,
233 true>(
234 scale_grid_desc_m,
235 make_multi_index(block_global_id * M_BlockTileSize +
236 thread_m_cluster_id * MThreadSliceSize));
237
238 auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2<BiasDataType,
239 AccDataType,
240 ScaleBiasGridDesc_M,
241 decltype(thread_buffer_desc_m),
242 ThreadBufferLengths_M,
244 0,
245 BiasSrcVectorSize,
246 1,
247 true>(
248 bias_grid_desc_m,
249 make_multi_index(block_global_id * M_BlockTileSize +
250 thread_m_cluster_id * MThreadSliceSize));
251
252 constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
253 constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize);
254
255 const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
256 p_x, x_grid_desc_m_k.GetElementSpaceSize());
257
258 const auto scale_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
259 p_scale, scale_grid_desc_m.GetElementSpaceSize());
260
261 const auto bias_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
262 p_bias, bias_grid_desc_m.GetElementSpaceSize());
263
264 auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
265 p_y, y_grid_desc_m_k.GetElementSpaceSize());
266
267 // Step 1: do welford reduction to get mean and variance
268
269 auto threadwise_welford = ThreadwiseWelford();
270 threadwise_welford.max_count_ = get_reduce_count_per_thread(thread_k_cluster_id);
271
273 mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
274 var_thread_buf(I) = type_convert<AccDataType>(0.0f);
275 });
276
277 for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
278 {
279
280 threadwise_x_load.Run(x_grid_desc_m_k,
281 x_global_val_buf,
282 thread_buffer_desc_m_k,
283 make_tuple(I0, I0),
284 x_thread_buf);
285
286 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
287 threadwise_welford.Run(x_thread_buf, mean_thread_buf, var_thread_buf);
288 }
289
291 if constexpr(I > 0)
293
294 int count = threadwise_welford.cur_count_;
295 BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
296 });
297
298 // Step 2: do normalization and output y
299
300 threadwise_scale_load.Run(scale_grid_desc_m,
301 scale_global_val_buf,
302 thread_buffer_desc_m,
303 make_tuple(I0),
304 scale_thread_buf);
305
306 threadwise_bias_load.Run(bias_grid_desc_m,
307 bias_global_val_buf,
308 thread_buffer_desc_m,
309 make_tuple(I0),
310 bias_thread_buf);
311
312 auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
313
314 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
315 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
316
317 for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
318 {
319 threadwise_x_load.Run(x_grid_desc_m_k,
320 x_global_val_buf,
321 thread_buffer_desc_m_k,
322 make_tuple(I0, I0),
323 x_thread_buf);
324
326 AccDataType multiplier =
327 scale_thread_buf[Number<iM>{}] / sqrt(var_thread_buf[iM] + epsilon);
328
329 AccDataType fused_mean_bias =
330 bias_thread_buf[Number<iM>{}] - mean_thread_buf[iM] * multiplier;
331
333 constexpr auto offset =
334 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
335
336 // normalize
337 y_thread_buf(Number<offset>{}) =
338 x_thread_buf[Number<offset>{}] * multiplier + fused_mean_bias;
339 });
340 });
341
342 threadwise_y_store.Run(thread_buffer_desc_m_k,
343 make_tuple(I0, I0),
344 y_thread_buf,
345 y_grid_desc_m_k,
346 y_global_val_buf);
347
348 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k);
349 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k);
350 }
351
352 // Step 3: update the moving average of mean and variance (optional)
353
354 if(updateMovingAverage && thread_k_cluster_id == 0)
355 {
357 running_mean_thread_buf;
359 running_var_thread_buf;
360
361 auto running_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
362 resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize());
363
364 auto running_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
365 resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize());
366
367 auto threadwise_mean_var_load =
368 ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
369 AccDataType,
370 MeanVarGridDesc_M,
371 decltype(thread_buffer_desc_m),
372 ThreadBufferLengths_M,
374 0,
375 MeanVarSrcDstVectorSize,
376 1,
377 true>(
378 mean_var_grid_desc_m,
379 make_multi_index(block_global_id * M_BlockTileSize +
380 thread_m_cluster_id * MThreadSliceSize));
381
382 threadwise_mean_var_load.Run(mean_var_grid_desc_m,
383 running_mean_global_buf,
384 thread_buffer_desc_m,
385 make_tuple(I0),
386 running_mean_thread_buf);
387
388 threadwise_mean_var_load.Run(mean_var_grid_desc_m,
389 running_var_global_buf,
390 thread_buffer_desc_m,
391 make_tuple(I0),
392 running_var_thread_buf);
393
394 AccDataType oneMinusAverageFactor = type_convert<AccDataType>(1.0) - averageFactor;
395
397 running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor +
398 mean_thread_buf[I] * averageFactor;
399 running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor +
400 var_thread_buf[I] * averageFactor;
401 });
402
403 auto threadwise_mean_var_store =
405 MeanVarDataType,
406 decltype(thread_buffer_desc_m),
407 MeanVarGridDesc_M,
409 ThreadBufferLengths_M,
411 0,
412 MeanVarSrcDstVectorSize,
414 1,
415 true>(
416 mean_var_grid_desc_m,
417 make_multi_index(block_global_id * M_BlockTileSize +
418 thread_m_cluster_id * MThreadSliceSize),
419 PassThroughOp{});
420
421 threadwise_mean_var_store.Run(thread_buffer_desc_m,
422 make_tuple(I0),
423 running_mean_thread_buf,
424 mean_var_grid_desc_m,
425 running_mean_global_buf);
426
427 threadwise_mean_var_store.Run(thread_buffer_desc_m,
428 make_tuple(I0),
429 running_var_thread_buf,
430 mean_var_grid_desc_m,
431 running_var_global_buf);
432 };
433
434 // Step 4: save mean and inv-variance (optional)
435
436 if(saveMeanInvVariance && thread_k_cluster_id == 0)
437 {
438 auto result_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
439 resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize());
440
441 auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
442 resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
443
444 // calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
446 var_thread_buf(I) =
447 type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]);
448 });
449
450 auto threadwise_mean_inv_var_store =
452 MeanVarDataType,
453 decltype(thread_buffer_desc_m),
454 MeanVarGridDesc_M,
456 ThreadBufferLengths_M,
458 0,
459 MeanVarSrcDstVectorSize,
461 1,
462 true>(
463 mean_var_grid_desc_m,
464 make_multi_index(block_global_id * M_BlockTileSize +
465 thread_m_cluster_id * MThreadSliceSize),
466 PassThroughOp{});
467
468 threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
469 make_tuple(I0),
470 mean_thread_buf,
471 mean_var_grid_desc_m,
472 result_mean_global_buf);
473
474 threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
475 make_tuple(I0),
476 var_thread_buf,
477 mean_var_grid_desc_m,
478 result_inv_var_global_buf);
479 };
480 }
481};
482
483} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__global__ void kernel_batchnorm_forward_with_blockwise_welford(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K y_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:27
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:94
static constexpr bool reorder_thread_cluster
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:103
static __device__ void Run(const XYGridDesc_M_K &x_grid_desc_m_k, const XYGridDesc_M_K &y_grid_desc_m_k, const ScaleBiasGridDesc_M &scale_grid_desc_m, const ScaleBiasGridDesc_M &bias_grid_desc_m, const MeanVarGridDesc_M &mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor &get_reduce_count_per_thread, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:137
static constexpr auto thread_cluster_desc
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:113
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:107
ThreadwiseWelford< AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M > ThreadwiseWelford
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:121
static constexpr auto I0
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:131
BlockwiseWelford< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder > BlockwiseWelford
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:124
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:110
static constexpr index_t K_BlockTileSize
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:135
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:118
decltype(make_naive_tensor_descriptor_packed( make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:116
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:105
static constexpr index_t M_BlockTileSize
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:134
tensor_operation::element_wise::PassThrough PassThroughOp
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:129
static constexpr auto I1
Definition gridwise_batchnorm_forward_blockwise_welford.hpp:132
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340