block_fmha_pipeline_qr_ks_vs_fp8.hpp Source File

block_fmha_pipeline_qr_ks_vs_fp8.hpp Source File#

Composable Kernel: block_fmha_pipeline_qr_ks_vs_fp8.hpp Source File
block_fmha_pipeline_qr_ks_vs_fp8.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11namespace ck_tile {
12
13// This pipeline is qkv all located in LDS
14template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
15struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
16{
31
34 static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
35 static_assert(kQLoadOnce == Policy::QLoadOnce);
36
37 static constexpr index_t kBlockSize = Problem::kBlockSize;
38
39 static constexpr index_t kM0 = BlockFmhaShape::kM0;
40 static constexpr index_t kN0 = BlockFmhaShape::kN0;
41 static constexpr index_t kK0 = BlockFmhaShape::kK0;
42 static constexpr index_t kN1 = BlockFmhaShape::kN1;
43 static constexpr index_t kK1 = BlockFmhaShape::kK1;
44 static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
45
46 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
47 static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
48 static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
49 static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
50 static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
51 static constexpr auto BiasEnum = Problem::BiasEnum;
52 static constexpr bool kStoreLSE = Problem::kStoreLSE;
53 static constexpr bool kHasDropout = Problem::kHasDropout;
54
55 // last dimension vector length used to create tensor view(and decide buffer_load vector length)
56 // ... together with tensor distribution. tensor dist should able to overwrite this
57 static constexpr index_t kAlignmentQ =
58 kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
59 static constexpr index_t kAlignmentK =
60 kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
61 static constexpr index_t kAlignmentV = []() {
62 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
63 return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
64 else
65 return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
66 }();
67
68 static constexpr index_t kAlignmentO =
69 kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
70 static constexpr index_t kAlignmentBias =
71 kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
72
73 static constexpr index_t kBlockPerCu = []() {
74 if constexpr(Problem::kBlockPerCu != -1)
75 return Problem::kBlockPerCu;
76 else
77 {
78 if constexpr(kQKHeaddim <= 32)
79 {
80 return 2;
81 }
82 else if constexpr(kQKHeaddim <= 64)
83 {
84 return 3;
85 }
86 else if constexpr(kQKHeaddim <= 128)
87 {
89 return 1;
90 else
91 return 2;
92 }
93 else if constexpr(kQKHeaddim <= 256)
94 {
95 return 1;
96 }
97 }
98 }();
99
100 static constexpr const char* name = "qr_fp8";
101
103 {
104 return Policy::template GetSmemSize<Problem>();
105 }
106
107 template <typename QDramBlockWindowTmp,
108 typename KDramBlockWindowTmp,
109 typename VDramBlockWindowTmp,
110 typename BiasDramBlockWindowTmp,
111 typename RandValDramBlockWindowTmp,
112 typename LSEDramBlockWindowTmp,
113 typename PositionEncoding>
115 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
116 const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
117 const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
118 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
119 RandValDramBlockWindowTmp& /*randval_dram_block_window_tmp*/, // not supported
120 LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported
121 FmhaMask mask,
122 PositionEncoding /*position_encoding*/,
123 float scale_s,
124 float descale_qk,
125 float descale_sv,
126 void* smem_ptr,
127 BlockDropout& /*dropout*/) const // not supported
128 {
129 static_assert(
130 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
131 std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
132 std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
133 "wrong!");
134
135 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
136 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
137 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
138 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
139 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
140 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
141 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
142 "wrong!");
143
144 // K tile in LDS
145 KDataType* k_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
146 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQ<Problem>()));
148 k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
149 auto k_lds_window =
151
152 // V tile in LDS
154 reinterpret_cast<VDataType*>(smem_ptr),
155 Policy::template MakeVLdsBlockDescriptor<Problem>());
156 auto v_lds_window = make_tile_window(
157 v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
158
159 // Block GEMM
160 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
161 constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
162
163 auto q_dram_window = make_tile_window(
164 q_dram_block_window_tmp.get_bottom_tensor_view(),
165 q_dram_block_window_tmp.get_window_lengths(),
166 q_dram_block_window_tmp.get_window_origin(),
167 Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>());
168
169 auto q = load_tile(q_dram_window);
170
171 using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
172 auto s_acc = SaccBlockTileType{};
173
174 // reduction function for softmax
175 const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
176 const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
177
178 // infer Sacc, S, P, M, L, Oacc type
179 using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
180
181 using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
182 SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
183
184 using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
185
186 // init Oacc, M, L
187 auto o_acc = OaccBlockTileType{};
188 auto m = MLBlockTileType{};
189 auto l = MLBlockTileType{};
190
191 clear_tile(o_acc);
193 clear_tile(l);
194
195 const auto q_origin = q_dram_window.get_window_origin();
196 const auto [seqlen_k_start, seqlen_k_end] =
197 mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
198
199 const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
200
201 // check early exit if masked and no work to do.
202 if constexpr(FmhaMask::IsMasking)
203 {
204 if(num_total_loop <= 0)
205 {
206 // Note: here occ are all cleard, return it
207 // Note: q loaded but no fence, ignore it.
208 return o_acc;
209 }
210 }
211
212 auto k_dram_block_window =
213 make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
214 k_dram_block_window_tmp.get_window_lengths(),
215 {seqlen_k_start, 0});
216
217 const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
218 auto bias_dram_window = make_tile_window(
219 bias_dram_block_window_tmp.get_bottom_tensor_view(),
220 bias_dram_block_window_tmp.get_window_lengths(),
221 {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
222 Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
223
224 auto v_dram_window =
225 make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
226 v_dram_block_window_tmp.get_window_lengths(),
227 {0, seqlen_k_start}, // TODO: hdim split?
228 Policy::template MakeVDramTileDistribution<Problem>());
229
230 // auto q_tile = tile_elementwise_in(q_element_func, q);
231 auto q_tile = q;
232
233 // prefetch K tile
234 index_t i_total_loops = 0;
235 constexpr index_t k0_loops = kQKHeaddim / kK0;
236 constexpr index_t k1_loops = kN0 / kK1;
237
238 static_assert(2 <= k0_loops);
239 static_assert(1 <= k1_loops);
240
241 scale_s = scale_s * descale_qk;
242 do
243 {
244 // STAGE 1, QK gemm
245 auto k_dram_window = make_tile_window(
246 k_dram_block_window.get_bottom_tensor_view(),
247 k_dram_block_window.get_window_lengths(),
248 k_dram_block_window.get_window_origin(),
249 Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
250 // load
251
252 auto k_block_tile = load_tile(k_dram_window);
253 {
254 move_tile_window(k_dram_window, {0, kK0});
255 clear_tile(s_acc); // initialize C
256 store_tile(k_lds_window, k_block_tile);
257 k_block_tile = load_tile(k_dram_window);
258 }
259
260 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
261 {
262 __builtin_amdgcn_sched_barrier(
263 0); // prevent from messing up the order of global loads
264 }
265 const auto bias_tile = load_tile(bias_dram_window); // load bias tile
266 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
267 {
268 __builtin_amdgcn_sched_barrier(
269 0); // prevent from messing up the order of global loads
270 }
271
272 if constexpr(k0_loops > 2)
273 {
274 static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
276 gemm_0(s_acc,
277 get_slice_tile(q_tile,
278 sequence<0, i_k0 * kK0>{},
279 sequence<kM0, (i_k0 + 1) * kK0>{}),
280 k_lds_window);
282 move_tile_window(k_dram_window, {0, kK0});
283
284 store_tile(k_lds_window,
285 k_block_tile); // LDS write i + 1
286 k_block_tile = load_tile(k_dram_window); // global read i + 2
287 });
288 }
289
290 const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
291 { // tail
293 gemm_0(s_acc,
294 get_slice_tile(q_tile,
295 sequence<0, (k0_loops - 2) * kK0>{},
296 sequence<kM0, (k0_loops - 1) * kK0>{}),
297 k_lds_window);
299
300 store_tile(k_lds_window, k_block_tile);
302
303 gemm_0(s_acc,
304 get_slice_tile(q_tile,
305 sequence<0, (k0_loops - 1) * kK0>{},
306 sequence<kM0, k0_loops * kK0>{}),
307 k_lds_window);
308 }
309
310 // STAGE 2, scale_s, add bias, mask, softmax
311 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
312 {
314 [&](auto& x, const auto& y) {
315#if !CK_TILE_FMHA_FWD_FAST_EXP2
316 x = scale_s * x + type_convert<SaccDataType>((y));
317#else
318 x = scale_s * x + log2e_v<SaccDataType> * type_convert<SaccDataType>((y));
319#endif
320 },
321 s_acc,
322 bias_tile);
323 }
324 else
325 {
326#if !CK_TILE_FMHA_FWD_FAST_EXP2
327 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
328#endif
329 }
330 move_tile_window(bias_dram_window, {0, kN0});
331 if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
332 {
333 const auto k_origin = k_dram_block_window.get_window_origin();
334 bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
335 k_origin.at(number<0>{}),
336 number<kM0>{},
337 number<kN0>{});
338 if(need_perpixel_check)
339 {
341 s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
342 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
343 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
344 return mask.IsOutOfBound(row, col);
345 });
346 }
347 }
348
349 const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
350 auto m_local = block_tile_reduce<SMPLComputeDataType>(
351 s,
352 sequence<1>{},
353 f_max,
354 -numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
355 block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
356
357 const auto m_old = m; // m{j-1}
359 [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
360
361 auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
362 s.get_tile_distribution()); // Pcompute{j}
363
364 static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
367 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
368 FmhaMask::IsMasking)
369 {
370 return raw_m == -numeric<SMPLComputeDataType>::infinity()
371 ? type_convert<SMPLComputeDataType>(0.f)
372 : raw_m;
373 }
374 else
375 {
376 return raw_m;
377 }
378 };
379
380 constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
381 sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
382 constexpr auto i_idx = make_tuple(idx0);
383#if CK_TILE_FMHA_FWD_FAST_EXP2
384 auto row_max = scale_s * get_validated_m(m[i_idx]);
385#endif
386 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
387 constexpr auto i_j_idx = make_tuple(idx0, idx1);
388#if CK_TILE_FMHA_FWD_FAST_EXP2
389 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
390 {
391 p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
392 }
393 else
394 {
395 p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
396 }
397#else
398 p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
399#endif
400 });
401 });
402
403 auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
404 p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
405
406 block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
407 // l{j}, Oacc{j}
408 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
409 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
410 constexpr auto i_idx = make_tuple(idx0);
411#if CK_TILE_FMHA_FWD_FAST_EXP2
412 const auto tmp = [&]() {
413 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
414 {
415 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
416 }
417 else
418 {
419 auto row_max = scale_s * get_validated_m(m[i_idx]);
420 return exp2(scale_s * m_old[i_idx] - row_max);
421 }
422 }();
423#else
424 const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
425#endif
426 l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
427 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
428 constexpr auto i_j_idx = make_tuple(idx0, idx1);
429 // FIXME: this use different equation from FA v2 paper,
430 // but produce correc result.
431 // Is the equation wrong?
432 o_acc(i_j_idx) *= tmp;
433 });
434 });
435
437 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
438 {
439 auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
440 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
441 shuffle_tile(v_shuffle_tmp, v_prefetch);
442 store_tile(v_lds_window,
443 v_shuffle_tmp); // store the prefetch
444 }
445 else
446 {
447 store_tile(v_lds_window,
448 v_prefetch); // store the prefetch
449 }
450 move_tile_window(v_dram_window, {0, kK1});
451
452 const auto p = cast_tile<PDataType>(p_compute);
453
454 // STAGE 3, KV gemm
455 if constexpr(k1_loops > 1)
456 {
457 static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
458 const auto v = load_tile(v_dram_window); // load next v
460 gemm_1(o_acc,
462 p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
463 v_lds_window);
465 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
466 {
467 auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
468 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
469 shuffle_tile(v_shuffle_tmp, v);
470 store_tile(v_lds_window, v_shuffle_tmp);
471 }
472 else
473 {
474 store_tile(v_lds_window, v);
475 }
476 move_tile_window(v_dram_window, {0, kK1});
477 });
478 }
479 // move K tile windows
480 move_tile_window(k_dram_block_window, {kN0, 0});
481 // tail
482 {
484 gemm_1(o_acc,
485 get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
486 v_lds_window);
488 }
489 } while(++i_total_loops < num_total_loop);
490
491 // finally, O
492 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
493
494 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
495 constexpr auto i_idx = make_tuple(idx0);
496 auto tmp = [&]() {
497 if constexpr(FmhaMask::IsMasking)
498 {
499 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
500 }
501 else
502 return 1 / l[i_idx];
503 }();
504 tmp = tmp * descale_sv;
505 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
506 constexpr auto i_j_idx = make_tuple(idx0, idx1);
507 o_acc(i_j_idx) *= tmp;
508 });
509 });
510
511 return o_acc;
512 }
513};
514
515} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
__host__ T exp(T x)
Definition math_v2.hpp:391
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
Definition block_dropout.hpp:53
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:16
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:102
static constexpr bool kHasDropout
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:53
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:29
static constexpr index_t kBlockPerCu
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:73
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:24
static constexpr index_t kM0
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:39
static constexpr index_t kBlockSize
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:37
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:23
static constexpr index_t kAlignmentO
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:68
static constexpr auto BiasEnum
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:51
static constexpr index_t kN0
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:40
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:21
static constexpr index_t kAlignmentK
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:59
static constexpr const char * name
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:100
static constexpr index_t kK0
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:41
remove_cvref_t< Problem_ > Problem
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:17
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:26
static constexpr index_t kQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:44
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:32
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:28
static constexpr index_t kAlignmentQ
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:57
static constexpr index_t kK1
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:43
remove_cvref_t< Policy_ > Policy
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:18
static constexpr bool kPadHeadDimQ
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:49
static constexpr index_t kAlignmentV
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:61
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:19
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &, LSEDramBlockWindowTmp &, FmhaMask mask, PositionEncoding, float scale_s, float descale_qk, float descale_sv, void *smem_ptr, BlockDropout &) const
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:115
static constexpr bool kIsGroupMode
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:46
remove_cvref_t< typename Problem::PDataType > PDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:27
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:22
static constexpr index_t kAlignmentBias
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:70
static constexpr bool kQLoadOnce
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:34
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:33
static constexpr bool kPadSeqLenK
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:48
static constexpr bool kStoreLSE
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:52
static constexpr bool kPadHeadDimV
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:50
static constexpr bool kPadSeqLenQ
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:47
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:20
static constexpr index_t kN1
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:42
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:25
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_pipeline_qr_ks_vs_fp8.hpp:30
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49