layernorm2d_fwd_pipeline_one_pass.hpp Source File

layernorm2d_fwd_pipeline_one_pass.hpp Source File#

Composable Kernel: layernorm2d_fwd_pipeline_one_pass.hpp Source File
layernorm2d_fwd_pipeline_one_pass.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9#include <string>
10#include <type_traits>
11
12namespace ck_tile {
13
14template <typename Problem_, typename Policy_ = Layernorm2dFwdPipelineDefaultPolicy>
16{
19
28
31
32 static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
33 static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
34 static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
35 static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
36
37 static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
38 static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
39 static constexpr bool kPadN = Problem::Traits::kPadN;
40 static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
41 static constexpr bool kWelford = Problem::Traits::kWelford;
42 static constexpr auto kXbias = Problem::Traits::kXbias;
43 static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
44 static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
45
46 static constexpr const char* name = []() {
47 if constexpr(kNeedCrossWarpSync)
48 return "bpr"; // block per row
49 else
50 return "wpr"; // warp per row
51 }();
52
54 {
55 return Policy::template GetSmemSize<Problem>();
56 }
57
58 template <typename XWindow,
59 typename XResidualWindow,
60 typename XBiasWindow,
61 typename GammaWindow,
62 typename BetaWindow,
63 typename YWindow,
64 typename YResidualWindow,
65 typename MeanWindow,
66 typename InvStdWindow,
67 typename SmoothScaleWindow,
68 typename YScaleWindow,
69 typename Epilogue>
70 CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
71 const XResidualWindow& x_residual_window_,
72 const XBiasWindow& x_bias_window_,
73 const GammaWindow& gamma_window_,
74 const BetaWindow& beta_window_,
75 YWindow& y_window_,
76 const YResidualWindow& y_residual_window_,
77 MeanWindow& mean_window,
78 InvStdWindow& inv_std_window,
79 const SmoothScaleWindow& sm_scale_window_,
80 YScaleWindow& y_scale_window,
81 ComputeDataType epsilon,
82 ck_tile::index_t row_size,
83 void* smem,
84 Epilogue) const
85 {
86 const auto x_window =
87 make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
88 const auto x_bias_window = make_tile_window(
89 x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
90 const auto gamma_window = make_tile_window(
91 gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
92 const auto beta_window = make_tile_window(
93 beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
94 const auto x_residual_window = make_tile_window(
95 x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
96 auto y_residual_window = make_tile_window(
97 y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
98
99 auto x = load_tile(x_window);
100 auto x_resi = load_tile(x_residual_window);
101 const auto x_bias = load_tile(x_bias_window);
102
103 int cur_count = 0;
104 int max_count =
106 auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
107 auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
108 auto block_norm_reduce_cross_warp_sync =
109 Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
110
111 using XTensorType = decltype(cast_tile<ComputeDataType>(x));
112 auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
113 auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
114 clear_tile(mean);
115 clear_tile(var);
116 // load gamma/beta (TODO: support no gamma/beta?)
117 const auto gamma = load_tile(gamma_window);
118 const auto beta = load_tile(beta_window);
119
120 auto acc = cast_tile<ComputeDataType>(x);
121
123 {
124 sweep_tile(x, [&](auto idx) {
125 // compute x = bias + x
126 constexpr auto j_idx = make_tuple(idx[number<1>{}]);
127 acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
128 });
129 }
130
133 {
134 sweep_tile(x_resi, [&](auto idx) {
135 // compute x = x_resi + x
136 acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
137 });
139 store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
140 }
141
142 // compute reduce each-thread->cross-lane->cross-warp
143 block_norm_reduce(acc, mean, var, cur_count, max_count);
144 block_norm_reduce_sync(mean, var, cur_count);
145 block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
146 if(kWelford)
147 {
149 }
150 else
151 {
152 sweep_tile(mean, [&](auto idx) {
153 mean(idx) = mean(idx) / type_convert<MeanDataType>(row_size);
154 var(idx) = var(idx) / type_convert<MeanDataType>(row_size) - mean(idx) * mean(idx);
155 });
156 }
157 // compute inv-std
158 auto inv_std = tile_elementwise_in(
159 [&](const auto& v_) {
160 if(kFastFDiv && std::is_same_v<ComputeDataType, float>)
161 {
162 return type_convert<ComputeDataType>(1.0f) *
163 __builtin_amdgcn_rcpf(sqrt(v_ + epsilon));
164 }
165 else
166 {
167 return type_convert<ComputeDataType>(1.0f) / sqrt(v_ + epsilon);
168 }
169 },
170 var);
171
172 if constexpr(kSaveMean)
173 store_tile(mean_window, cast_tile<MeanDataType>(mean));
174 if constexpr(kSaveInvStd)
175 store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
176
177 // layernorm computation
178 auto ln = make_static_distributed_tensor<ComputeDataType>(acc.get_tile_distribution());
179 sweep_tile(ln, [&, mean_ = mean](auto idx) {
180 constexpr auto i_idx = make_tuple(idx[number<0>{}]);
181 constexpr auto j_idx = make_tuple(idx[number<1>{}]);
182
183 const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
184 const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
185
186 auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
187 ln(idx) = ln_;
188 });
189
192 {
193 Epilogue{}(y_window_, sm_scale_window_, y_scale_window, ln, smem);
194 }
195 else
196 Epilogue{}(y_window_, ln, nullptr);
197 }
198};
199} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_size)
Definition block_norm_reduce.hpp:361
CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTensor_ &var_tensor, int count, bool_constant< FastFdiv_ >={})
Definition block_norm_reduce.hpp:393
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
@ SMOOTH_DYNAMIC_QUANT
Definition layernorm2d_fwd_traits.hpp:42
@ DYNAMIC_QUANT
Definition layernorm2d_fwd_traits.hpp:43
@ ADD_BIAS
Definition layernorm2d_fwd_traits.hpp:14
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F &f, UnpacksPerXDim={})
Definition sweep_tile.hpp:231
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition bfloat16.hpp:413
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
@ PRE_ADD_STORE
Definition layernorm2d_fwd_traits.hpp:27
@ PRE_ADD
Definition layernorm2d_fwd_traits.hpp:29
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition layernorm2d_fwd_pipeline_one_pass.hpp:16
static constexpr auto kFusedQuant
Definition layernorm2d_fwd_pipeline_one_pass.hpp:44
XDataType YResidualDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:30
ck_tile::remove_cvref_t< typename Problem::BetaDataType > BetaDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:23
ck_tile::remove_cvref_t< typename Problem::YDataType > YDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:25
ck_tile::remove_cvref_t< typename Problem::MeanDataType > MeanDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:26
static constexpr bool kHasBeta
Definition layernorm2d_fwd_pipeline_one_pass.hpp:33
ck_tile::remove_cvref_t< typename Problem::XBiasDataType > XBiasDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:21
ck_tile::remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:22
ck_tile::remove_cvref_t< Policy_ > Policy
Definition layernorm2d_fwd_pipeline_one_pass.hpp:18
static constexpr auto kFusedAdd
Definition layernorm2d_fwd_pipeline_one_pass.hpp:43
static constexpr const char * name
Definition layernorm2d_fwd_pipeline_one_pass.hpp:46
static constexpr auto kXbias
Definition layernorm2d_fwd_pipeline_one_pass.hpp:42
ck_tile::remove_cvref_t< typename Problem::XDataType > XDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:20
ck_tile::remove_cvref_t< typename Problem::InvStdDataType > InvStdDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:27
static constexpr bool kFastFDiv
Definition layernorm2d_fwd_pipeline_one_pass.hpp:40
XDataType XResidualDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:29
static constexpr bool kPadN
Definition layernorm2d_fwd_pipeline_one_pass.hpp:39
static constexpr bool kSaveMean
Definition layernorm2d_fwd_pipeline_one_pass.hpp:34
static constexpr bool kNeedCrossWarpSync
Definition layernorm2d_fwd_pipeline_one_pass.hpp:37
static constexpr bool kPadM
Definition layernorm2d_fwd_pipeline_one_pass.hpp:38
ck_tile::remove_cvref_t< Problem_ > Problem
Definition layernorm2d_fwd_pipeline_one_pass.hpp:17
static constexpr bool kHasGamma
Definition layernorm2d_fwd_pipeline_one_pass.hpp:32
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition layernorm2d_fwd_pipeline_one_pass.hpp:53
static constexpr bool kSaveInvStd
Definition layernorm2d_fwd_pipeline_one_pass.hpp:35
CK_TILE_DEVICE auto operator()(const XWindow &x_window_, const XResidualWindow &x_residual_window_, const XBiasWindow &x_bias_window_, const GammaWindow &gamma_window_, const BetaWindow &beta_window_, YWindow &y_window_, const YResidualWindow &y_residual_window_, MeanWindow &mean_window, InvStdWindow &inv_std_window, const SmoothScaleWindow &sm_scale_window_, YScaleWindow &y_scale_window, ComputeDataType epsilon, ck_tile::index_t row_size, void *smem, Epilogue) const
Definition layernorm2d_fwd_pipeline_one_pass.hpp:70
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition layernorm2d_fwd_pipeline_one_pass.hpp:24
static constexpr bool kWelford
Definition layernorm2d_fwd_pipeline_one_pass.hpp:41
Definition tile/core/numeric/integral_constant.hpp:13