block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp Source File

block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp Source File#

Composable Kernel: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp Source File
block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
11
12namespace ck_tile {
13
14template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy>
16{
32
35 static constexpr bool kQLoadOnce = true;
36 static_assert(kQLoadOnce == Policy::QLoadOnce);
37
38 static constexpr index_t kBlockSize = Problem::kBlockSize;
39
40 static constexpr index_t kM0 = BlockFmhaShape::kM0;
41 static constexpr index_t kN0 = BlockFmhaShape::kN0;
42 static constexpr index_t kK0 = BlockFmhaShape::kK0;
43 static constexpr index_t kN1 = BlockFmhaShape::kN1;
44 static constexpr index_t kK1 = BlockFmhaShape::kK1;
45 static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
46 static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
47
48 static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
49
50 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
51 static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
52 static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
53 static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
54 static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::kPadHeadDimV;
55 static constexpr auto BiasEnum = Problem::BiasEnum;
56 static constexpr bool kStoreLSE = Problem::kStoreLSE;
57 static constexpr bool kHasDropout = Problem::kHasDropout;
58 static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
59
60 // last dimension vector length used to create tensor view(and decide buffer_load vector length)
61 // ... together with tensor distribution. tensor dist should able to overwrite this
62 static constexpr index_t kAlignmentQ =
63 kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
64 static constexpr index_t kAlignmentK =
65 kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
66 static constexpr index_t kAlignmentV = []() {
67 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
68 return Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
69 else
70 return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
71 }();
72
73 static constexpr index_t kAlignmentO =
74 kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
75 static constexpr index_t kAlignmentBias =
76 kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
77
78 static constexpr index_t kBlockPerCu = []() {
79 if constexpr(Problem::kBlockPerCu != -1)
80 return Problem::kBlockPerCu;
81 else
82 {
83 if constexpr(kQKHeaddim == 32)
84 {
85 return 2;
86 }
87 else if constexpr(kQKHeaddim == 64)
88 {
89 return 2;
90 }
91 else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128)
92 {
94 return 1;
95 else
96 return 2;
97 }
98 else if constexpr(kQKHeaddim == 256)
99 {
100 return 1;
101 }
102 else
103 {
104 return 1;
105 };
106 }
107 }();
108
109 static constexpr const char* name = "qr_async";
110
111 using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
112
114 {
115 return Policy::template GetSmemSize<Problem>();
116 }
117
118 template <typename QDramBlockWindowTmp,
119 typename KDramBlockWindowTmp,
120 typename VDramBlockWindowTmp,
121 typename BiasDramBlockWindowTmp,
122 typename RandValDramBlockWindowTmp,
123 typename LSEDramBlockWindowTmp,
124 typename QElementFunction,
125 typename KElementFunction,
126 typename VElementFunction,
127 typename BiasElementFunction,
128 typename LSEElementFunction,
129 typename SAccElementFunction,
130 typename PComputeElementFunction,
131 typename OAccElementFunction,
132 typename PositionEncoding,
133 typename AttentionVariantParams,
134 typename BlockIndices>
136 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
137 const QElementFunction& q_element_func,
138 const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
139 const KElementFunction& k_element_func,
140 const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
141 const VElementFunction& v_element_func,
142 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
143 const BiasElementFunction& bias_element_func,
144 RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
145 LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
146 const LSEElementFunction& lse_element_func,
147 const SAccElementFunction& s_acc_element_func,
148 const PComputeElementFunction& p_compute_element_func,
149 const OAccElementFunction& o_acc_element_func,
150 FmhaMask mask,
151 PositionEncoding position_encoding,
152 float scale_s,
153 const AttentionVariant& /* unused */,
154 const AttentionVariantParams& /* unused */,
155 const BlockIndices& /* unused */,
156 void* smem_ptr,
157 DropoutType& dropout) const
158 {
159 ignore = q_element_func;
160 ignore = k_element_func;
161
162 static_assert(
163 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
164 std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
165 std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
166 "wrong!");
167
168 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
169 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
170 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
171 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
172 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
173 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
174 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
175 "wrong!");
176
177 constexpr auto I0 = number<0>{};
178 constexpr auto I1 = number<1>{};
179
180 constexpr index_t k0_loops = kQKHeaddim / kK0;
181 constexpr index_t k1_loops = kN0 / kK1;
182 static_assert(2 <= k0_loops);
183 static_assert(2 <= k1_loops);
184
185 constexpr bool kPreloadWholeNextIterationK =
186 Policy::template IsPreloadWholeNextIterationK<Problem>();
187
188 constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers<Problem>();
189 constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
190 constexpr auto NumPrefetchV = Policy::template GetNumPrefetchV<Problem>();
191
192 static_assert(NumKLdsBuffers >= 2);
193
194 auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
195 q_dram_block_window_tmp.get_window_lengths(),
196 q_dram_block_window_tmp.get_window_origin(),
197 Policy::template MakeQRegTileDistribution<Problem>());
198
199 const auto q_origin = q_dram_window.get_window_origin();
200 const auto [seqlen_k_start, seqlen_k_end] =
201 mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
202
203 auto k_dram_block_window =
204 make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
205 k_dram_block_window_tmp.get_window_lengths(),
206 {seqlen_k_start, 0});
207
208 auto k_dram_window =
209 make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
210 k_dram_block_window.get_window_lengths(),
211 k_dram_block_window.get_window_origin(),
212 Policy::template MakeKDramTileDistribution<Problem>());
213
214 using k_tile_type = decltype(load_tile(k_dram_window));
215
216 auto k_tiles = [&]() {
217 if constexpr(kPreloadWholeNextIterationK)
219 else
221 }();
222
223 k_tiles[I0] = load_tile(k_dram_window);
224 move_tile_window(k_dram_window, {0, kK0});
225
226 auto q_tile = load_tile(q_dram_window);
227
228 __builtin_amdgcn_sched_barrier(0);
229
230 // K tile in LDS
231 KDataType* k_lds_ptr = static_cast<KDataType*>(smem_ptr);
233 k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
234 auto k_lds_window = make_tile_window(
235 k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
236
237 using k_lds_window_type =
238 decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{}));
239
241
242 static_for<0, NumKLdsBuffers, 1>{}([&](auto i_buf) {
243 k_lds_windows[i_buf] = get_slice_tile(
244 k_lds_window, sequence<i_buf * kN0, 0>{}, sequence<(i_buf + 1) * kN0, kK0>{});
245 });
246
247 auto v_dram_window =
248 make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
249 v_dram_block_window_tmp.get_window_lengths(),
250 {0, seqlen_k_start}, // TODO: hdim split?
251 Policy::template MakeVDramTileDistribution<Problem>());
252 // V tile in LDS
254 reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
255 Policy::template GetExclusiveKLdsBytes<Problem>()),
256 Policy::template MakeVLdsBlockDescriptor<Problem>());
257 auto v_lds_window = make_tile_window(
258 v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
259
260 using v_tile_type = decltype(load_tile(v_dram_window));
261
263
264 using v_lds_window_type =
265 decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{}));
266
268
269 static_for<0, NumVLdsBuffers, 1>{}([&](auto i_buf) {
270 v_lds_windows[i_buf] = get_slice_tile(
271 v_lds_window, sequence<i_buf * kN1, 0>{}, sequence<(i_buf + 1) * kN1, kK1>{});
272 });
273
274 // Block GEMM
275 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
276 constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
277
278 using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
279 auto s_acc = SaccBlockTileType{};
280
281 // reduction function for softmax
282 const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
283 const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
284
285 // infer Sacc, S, P, M, L, Oacc type
286 using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
287
288 using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
289 SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
290
291 using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
292
293 // init Oacc, M, L
294 auto o_acc = OaccBlockTileType{};
295 auto m = MLBlockTileType{};
296 auto l = MLBlockTileType{};
297
298 clear_tile(o_acc);
300 clear_tile(l);
301
302 const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
303
304 // check early exit if no work to do
305 if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
306 {
307 if(num_total_loop <= 0)
308 {
309 if constexpr(kStoreLSE)
310 {
311 auto lse =
312 make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
313
315
316 store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
317 }
318
319 // Note: here occ are all cleard, return it
320 // Note: q loaded but no fence, ignore it.
321 return o_acc;
322 }
323 }
324
325 const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
326 auto bias_dram_window =
327 make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
328 bias_dram_block_window_tmp.get_window_lengths(),
329 {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
330 Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
331
332 auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
333 randval_dram_block_window_tmp, seqlen_k_start);
334
335 q_tile = tile_elementwise_in(q_element_func, q_tile);
336
337 index_t i_total_loops = 0;
338
339 do
340 {
341 if constexpr(kPreloadWholeNextIterationK)
342 {
343 if(i_total_loops == 0) // executed by fist iteration
344 {
345 if(num_total_loop > 1) // there are multiple iterations
346 {
347 static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
349 k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
350 tile_elementwise_in(k_element_func, k_tiles[number<i_k0>{}]));
351
352 k_tiles[number<i_k0 + 1>{}] = load_tile(k_dram_window);
353 if constexpr(i_k0 < k0_loops - 2)
354 move_tile_window(k_dram_window, {0, kK0});
355
356 if constexpr(i_k0 == 0)
357 clear_tile(s_acc);
358
360 // execute current unroll of gemm_0
361 gemm_0(s_acc,
362 get_slice_tile(q_tile,
363 sequence<0, i_k0 * kK0>{},
364 sequence<kM0, (i_k0 + 1) * kK0>{}),
365 k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
366 });
367
369 k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}],
370 tile_elementwise_in(k_element_func, k_tiles[number<k0_loops - 1>{}]));
371
372 // prefetch first v_tile
373 v_tiles[I0] = load_tile(v_dram_window);
374 move_tile_window(v_dram_window, {0, kK1});
375
376 move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0});
377
378 // prefetch all k_tiles for next iteration
379 static_for<0, k0_loops, 1>{}([&](auto i_k0) {
380 k_tiles[number<i_k0>{}] = load_tile(k_dram_window);
381
382 if constexpr(i_k0 < k0_loops - 1)
383 move_tile_window(k_dram_window, {0, kK0});
384 });
385
386 move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0});
387
389 // execute last unroll of gemm_0
390 gemm_0(s_acc,
391 get_slice_tile(q_tile,
392 sequence<0, (k0_loops - 1) * kK0>{},
393 sequence<kM0, k0_loops * kK0>{}),
394 k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
395 }
396 else // there is only single iteration
397 {
398 static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
400 k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
401 tile_elementwise_in(k_element_func, k_tiles[number<i_k0>{}]));
402
403 k_tiles[number<i_k0 + 1>{}] = load_tile(k_dram_window);
404 if constexpr(i_k0 < k0_loops - 2)
405 move_tile_window(k_dram_window, {0, kK0});
406
407 if constexpr(i_k0 == 0)
408 clear_tile(s_acc);
409
411 // execute current unroll of gemm_0
412 gemm_0(s_acc,
413 get_slice_tile(q_tile,
414 sequence<0, i_k0 * kK0>{},
415 sequence<kM0, (i_k0 + 1) * kK0>{}),
416 k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
417 });
418
420 k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}],
421 tile_elementwise_in(k_element_func, k_tiles[number<k0_loops - 1>{}]));
422
423 // prefetch first v_tile
424 v_tiles[I0] = load_tile(v_dram_window);
425 move_tile_window(v_dram_window, {0, kK1});
426
428 gemm_0(s_acc,
429 get_slice_tile(q_tile,
430 sequence<0, (k0_loops - 1) * kK0>{},
431 sequence<kM0, k0_loops * kK0>{}),
432 k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
433
434 // move_tile_window(k_dram_window, {0, -k0_loops * kK0});
435 }
436 }
437 else // executed by intermediate and last iteration
438 {
439 if(i_total_loops < num_total_loop - 1) // intermediate iteration
440 {
441 store_tile(k_lds_windows[I0],
442 tile_elementwise_in(k_element_func, k_tiles[I0]));
443
444 // prefetch first v_tile
445 v_tiles[I0] = load_tile(v_dram_window);
446 move_tile_window(v_dram_window, {0, kK1});
447
448 clear_tile(s_acc);
450 gemm_0(s_acc,
451 get_slice_tile(q_tile, sequence<0, 0>{}, sequence<kM0, kK0>{}),
452 k_lds_windows[I0]);
453
454 store_tile(k_lds_windows[I1],
455 tile_elementwise_in(k_element_func, k_tiles[I1]));
456
457 move_tile_window(k_dram_window, {kN0, 0});
458
459 // prefetch first k_tile for next iteration
460 k_tiles[I0] = load_tile(k_dram_window);
461 move_tile_window(k_dram_window, {0, kK0});
462
463 k_tiles[I1] = load_tile(k_dram_window);
464 if constexpr(1 < k0_loops - 1)
465 move_tile_window(k_dram_window, {0, kK0});
466
468 gemm_0(s_acc,
469 get_slice_tile(q_tile, sequence<0, kK0>{}, sequence<kM0, 2 * kK0>{}),
470 k_lds_windows[I1]);
471
472 // during the gemm-loop, also prefetch other k_tiles for next iteration
473 static_for<2, k0_loops, 1>{}([&](auto i_k0) {
475 k_tiles[number<i_k0>{}]);
476
477 k_tiles[number<i_k0>{}] = load_tile(k_dram_window);
478 if constexpr(i_k0 < k0_loops - 1)
479 move_tile_window(k_dram_window, {0, kK0});
480
482 gemm_0(s_acc,
483 get_slice_tile(q_tile,
484 sequence<0, i_k0 * kK0>{},
485 sequence<kM0, (i_k0 + 1) * kK0>{}),
486 k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
487 });
488
489 move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0});
490 }
491 else // last iteration
492 {
493 store_tile(k_lds_windows[I0],
494 tile_elementwise_in(k_element_func, k_tiles[I0]));
495
496 // prefetch first v_tile
497 v_tiles[I0] = load_tile(v_dram_window);
498 move_tile_window(v_dram_window, {0, kK1});
499
500 clear_tile(s_acc);
502 gemm_0(s_acc,
503 get_slice_tile(q_tile, sequence<0, 0>{}, sequence<kM0, kK0>{}),
504 k_lds_windows[I0]);
505
506 static_for<1, k0_loops, 1>{}([&](auto i_k0) {
508 k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
509 tile_elementwise_in(k_element_func, k_tiles[number<i_k0>{}]));
510
512 gemm_0(s_acc,
513 get_slice_tile(q_tile,
514 sequence<0, i_k0 * kK0>{},
515 sequence<kM0, (i_k0 + 1) * kK0>{}),
516 k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
517 });
518 };
519 };
520 }
521 else // only preload one unroll of K for next iteration
522 {
523 static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
525 tile_elementwise_in(k_element_func, k_tiles[I0]));
526 if constexpr(i_k0 == 0)
527 clear_tile(s_acc);
528
529 if constexpr(i_k0 < k0_loops - 1)
530 k_tiles[I0] = load_tile(k_dram_window);
531 if constexpr(i_k0 < k0_loops - 2)
532 move_tile_window(k_dram_window, {0, kK0});
533
535 // execute current unroll of gemm_0
536 gemm_0(s_acc,
537 get_slice_tile(q_tile,
538 sequence<0, i_k0 * kK0>{},
539 sequence<kM0, (i_k0 + 1) * kK0>{}),
540 k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
541 });
542
543 store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}],
544 tile_elementwise_in(k_element_func, k_tiles[I0]));
545
546 // prefetch first v_tile
547 v_tiles[I0] = load_tile(v_dram_window);
548 move_tile_window(v_dram_window, {0, kK1});
549
551 gemm_0(s_acc,
552 get_slice_tile(q_tile,
553 sequence<0, (k0_loops - 1) * kK0>{},
554 sequence<kM0, k0_loops * kK0>{}),
555 k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
556 };
557
558 __builtin_amdgcn_sched_barrier(0);
559
560 const auto bias_tile = load_tile(bias_dram_window); // load bias tile
561
562 static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) {
563 v_tiles[i_buf] = load_tile(v_dram_window);
564 move_tile_window(v_dram_window, {0, kK1});
565 });
566
567 // STAGE 2, scale_s, add bias, mask, softmax
569 {
570 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
571 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
573 [&](auto& x, const auto& y) {
574#if !CK_TILE_FMHA_FWD_FAST_EXP2
575 x += type_convert<SaccDataType>(bias_element_func(y));
576#else
578 type_convert<SaccDataType>(bias_element_func(y));
579#endif
580 },
581 s_acc,
582 bias_tile);
583 }
584 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
585 {
586 const auto k_origin = k_dram_block_window.get_window_origin();
587 constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
588 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
589 sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
590 sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
591 const auto tile_idx = get_x_indices_from_distributed_indices(
592 s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
593
594 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
595 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
596 constexpr auto i_j_idx = make_tuple(idx0, idx1);
597
598 s_acc(i_j_idx) *= scale_s;
599 position_encoding.update(s_acc(i_j_idx), row, col);
600 });
601 });
602 }
603 else
604 {
605 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
606#if !CK_TILE_FMHA_FWD_FAST_EXP2
607 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
608#endif
609 }
610 move_tile_window(bias_dram_window, {0, kN0});
611 if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
612 {
613 const auto k_origin = k_dram_block_window.get_window_origin();
614 bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
615 k_origin.at(number<0>{}),
616 number<kM0>{},
617 number<kN0>{});
618 if(need_perpixel_check)
619 {
621 s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
622 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
623 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
624 return mask.IsOutOfBound(row, col);
625 });
626 }
627 }
628
629 const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
631 s,
632 sequence<1>{},
633 f_max,
634 -numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
636
637 const auto m_old = m; // m{j-1}
639 [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
640
642 s.get_tile_distribution()); // Pcompute{j}
643
644 static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
648 FmhaMask::IsMasking)
649 {
652 : raw_m;
653 }
654 else
655 {
656 return raw_m;
657 }
658 };
659
660 constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
661 sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
662 constexpr auto i_idx = make_tuple(idx0);
663#if CK_TILE_FMHA_FWD_FAST_EXP2
664 auto row_max = scale_s * get_validated_m(m[i_idx]);
665#endif
666 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
667 constexpr auto i_j_idx = make_tuple(idx0, idx1);
668#if CK_TILE_FMHA_FWD_FAST_EXP2
671 {
672 p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
673 }
674 else
675 {
676 p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
677 }
678#else
679 p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
680#endif
681 });
682 });
683
685 p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
686
688 // l{j}, Oacc{j}
689 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
690 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
691 constexpr auto i_idx = make_tuple(idx0);
692#if CK_TILE_FMHA_FWD_FAST_EXP2
693 const auto tmp = [&]() {
696 {
697 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
698 }
699 else
700 {
701 auto row_max = scale_s * get_validated_m(m[i_idx]);
702 return exp2(scale_s * m_old[i_idx] - row_max);
703 }
704 }();
705#else
706 const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
707#endif
708 l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
709 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
710 constexpr auto i_j_idx = make_tuple(idx0, idx1);
711 // FIXME: this use different equation from FA v2 paper,
712 // but produce correc result.
713 // Is the equation wrong?
714 o_acc(i_j_idx) *= tmp;
715 });
716 });
717
718 if constexpr(kHasDropout)
719 {
720 auto randval_ptr =
721 reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>();
722 dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
723 smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
724 }
725
726 __builtin_amdgcn_sched_barrier(0x7f);
727
728 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
729 {
731 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
732 shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
733
735 v_lds_windows[I0],
736 tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
737 }
738 else
739 {
740 store_tile(v_lds_windows[I0],
741 tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch
742 }
743
744 __builtin_amdgcn_sched_barrier(0);
745
746 const auto p =
747 cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
748
749 if constexpr(!kPreloadWholeNextIterationK)
750 {
751 if(i_total_loops < num_total_loop - 1)
752 {
753 move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0});
754 k_tiles[I0] = load_tile(k_dram_window);
755 move_tile_window(k_dram_window, {0, kK0});
756 };
757
758 __builtin_amdgcn_sched_barrier(0);
759 }
760
761 // STAGE 3, KV gemm
762 if constexpr(k1_loops > 1)
763 {
764 if constexpr(NumPrefetchV == 1) // NumVLdsBuffers == 2
765 {
766 static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
767 v_tiles[I0] = load_tile(v_dram_window);
768
770 gemm_1(o_acc,
772 p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
773 v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
774
775 if constexpr(std::is_same_v<VLayout,
776 ck_tile::tensor_layout::gemm::RowMajor>)
777 {
779 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
780 shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
781 store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
782 tile_elementwise_in(v_element_func, v_shuffle_tmp));
783 }
784 else
785 {
786 store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
787 tile_elementwise_in(v_element_func, v_tiles[I0]));
788 }
789
790 move_tile_window(v_dram_window, {0, kK1});
791 });
792 }
793 else // NumVLdsBuffers == 3 or 2
794 {
795 static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
796 if constexpr(i_k1 < k1_loops - NumPrefetchV)
797 v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
798
800 gemm_1(o_acc,
802 p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
803 v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
804
805 if constexpr(std::is_same_v<VLayout,
806 ck_tile::tensor_layout::gemm::RowMajor>)
807 {
809 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
810 shuffle_tile(v_shuffle_tmp,
811 v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]);
812 store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
813 tile_elementwise_in(v_element_func, v_shuffle_tmp));
814 }
815 else
816 {
818 v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
819 tile_elementwise_in(v_element_func,
820 v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]));
821 }
822
823 if constexpr(i_k1 < k1_loops - NumPrefetchV)
824 move_tile_window(v_dram_window, {0, kK1});
825 });
826 }
827 }
828 // move K tile windows
829 move_tile_window(k_dram_block_window, {kN0, 0});
830
832 gemm_1(o_acc,
833 get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
834 v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]);
835
836 if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer<Problem>())
837 {
838 __builtin_amdgcn_sched_barrier(0);
839 __builtin_amdgcn_s_barrier();
840 };
841
842 } while(++i_total_loops < num_total_loop);
843
844 // store lse
845 if constexpr(kStoreLSE)
846 {
847 auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
848
849 constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
850 sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
851 constexpr auto i_idx = make_tuple(idx0);
852#if CK_TILE_FMHA_FWD_FAST_EXP2
855 {
856 lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
857 }
858 else
859 {
860 lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
861 }
862#else
863 lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
864#endif
865 });
866
867 store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
868 }
869
870 // finally, O
871 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
872
873 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
874 constexpr auto i_idx = make_tuple(idx0);
875 const auto tmp = [&]() {
876 if constexpr(FmhaMask::IsMasking)
877 {
878 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
879 }
880 else
881 return 1 / l[i_idx];
882 }();
883 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
884 constexpr auto i_j_idx = make_tuple(idx0, idx1);
885 o_acc(i_j_idx) *= tmp;
886 });
887 });
888
889 o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
890
891 return o_acc;
892 }
893
894 template <typename QDramBlockWindowTmp,
895 typename KDramBlockWindowTmp,
896 typename VDramBlockWindowTmp,
897 typename BiasDramBlockWindowTmp,
898 typename RandValDramBlockWindowTmp,
899 typename LSEDramBlockWindowTmp,
900 typename PositionEncoding,
901 typename AttentionVariantParams,
902 typename BlockIndices>
904 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
905 const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
906 const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
907 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
908 RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
909 LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
910 FmhaMask mask,
911 PositionEncoding position_encoding,
912 float scale_s,
913 const AttentionVariant& variant,
914 const AttentionVariantParams& variant_params,
915 const BlockIndices& block_indices,
916 void* smem_ptr,
917 DropoutType& dropout) const
918 {
919 return operator()(q_dram_block_window_tmp,
920 identity{},
921 k_dram_block_window_tmp,
922 identity{},
923 v_dram_block_window_tmp,
924 identity{},
925 bias_dram_block_window_tmp,
926 identity{},
927 randval_dram_block_window_tmp,
928 lse_dram_block_window_tmp,
929 identity{},
930 identity{},
931 identity{},
932 identity{},
933 mask,
934 position_encoding,
935 scale_s,
936 variant,
937 variant_params,
938 block_indices,
939 smem_ptr,
940 dropout);
941 }
942};
943
944} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
constexpr detail::ignore_t ignore
Definition tile/core/utility/ignore.hpp:20
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:16
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:26
static constexpr index_t kM0
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:40
static constexpr index_t kBlockSize
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:38
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &, const AttentionVariantParams &, const BlockIndices &, void *smem_ptr, DropoutType &dropout) const
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:136
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:34
static constexpr bool kIsGroupMode
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:50
static constexpr index_t kN1
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:43
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:19
static constexpr bool kPadHeadDimV
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:54
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:113
static constexpr index_t kAlignmentO
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:73
remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:31
static constexpr bool kPadHeadDimQ
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:53
static constexpr index_t kAlignmentQ
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:62
static constexpr bool kPadSeqLenK
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:52
static constexpr index_t kN0
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:41
static constexpr index_t kAlignmentBias
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:75
static constexpr bool kQLoadOnce
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:35
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:29
static constexpr index_t kK0
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:42
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:33
static constexpr index_t kK1
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:44
remove_cvref_t< typename Problem::PDataType > PDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:27
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:21
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:24
static constexpr index_t kAlignmentV
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:66
static constexpr index_t kQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:45
static constexpr bool kHasDropout
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:57
static constexpr bool kStoreLSE
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:56
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:22
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:30
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, DropoutType &dropout) const
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:904
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:25
remove_cvref_t< Policy_ > Policy
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:18
static constexpr bool kHasLogitsSoftCap
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:58
static constexpr auto BiasEnum
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:55
static constexpr const char * name
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:109
remove_cvref_t< Problem_ > Problem
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:17
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:20
static constexpr index_t kAlignmentK
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:64
static constexpr bool kPadSeqLenQ
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:51
static constexpr index_t kSubQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:46
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:28
static constexpr index_t kBlockPerCu
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:78
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:23
std::conditional_t< kHasDropout, BlockDropout, NullBlockDropout > DropoutType
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp:111
Definition tile/core/utility/functional.hpp:86
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
#define C_LOG2E
Definition tile/core/numeric/math.hpp:469