block_fmha_fwd_appendkv_pipeline.hpp Source File

block_fmha_fwd_appendkv_pipeline.hpp Source File#

Composable Kernel: block_fmha_fwd_appendkv_pipeline.hpp Source File
block_fmha_fwd_appendkv_pipeline.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12template <typename Problem_, typename Policy_ = BlockFmhaFwdAppendKVPipelineDefaultPolicy>
14{
17 using QDataType = typename Problem::QDataType;
18 using KDataType = typename Problem::KDataType;
19 using VDataType = typename Problem::VDataType;
20
21 using VLayout = typename Problem::VLayout;
22
23 static constexpr index_t kBlockSize = Problem::kBlockSize;
24
25 static constexpr index_t kM0 = Problem::kM0;
26 static constexpr index_t kN0 = Problem::kN0;
27 static constexpr index_t kK0 = Problem::kK0;
28 static constexpr index_t kN1 = Problem::kN1;
29
30 static constexpr auto RotaryEnum = Problem::RotaryEnum;
31 static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
32
33 static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
34 static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
35 static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
36 static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
37
38 // last dimension vector length used to create tensor view(and decide buffer_load vector length)
39 // ... together with tensor distribution. tensor dist should able to overwrite this
40 static constexpr index_t kAlignmentQ =
41 kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
42 static constexpr index_t kAlignmentK =
43 kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
44 static constexpr index_t kAlignmentV = []() {
45 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
46 return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
47 else
48 return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
49 }();
50
51 static constexpr index_t kBlockPerCu = []() {
52 if constexpr(Problem::kBlockPerCu != -1)
53 return Problem::kBlockPerCu;
54 else
55 {
56 if constexpr(kK0 <= 32)
57 {
58 return 2;
59 }
60 else if constexpr(kK0 <= 64)
61 {
62 return 3;
63 }
64 else if constexpr(kK0 <= 128)
65 {
66 return 2;
67 }
68 else if constexpr(kK0 <= 256)
69 {
70 return 1;
71 }
72 }
73 }();
74
75 template <typename QDramBlockWindow,
76 typename KDramBlockWindow,
77 typename KPageBlockNavigator,
78 typename KnewDramBlockWindow,
79 typename VDramBlockWindow,
80 typename VPageBlockNavigator,
81 typename VnewDramBlockWindow,
82 typename QElementFunction,
83 typename KnewElementFunction,
84 typename VnewElementFunction,
85 typename QRotaryCosDramBlockWindow,
86 typename QRotarySinDramBlockWindow,
87 typename KnewRotaryCosDramBlockWindow,
88 typename KnewRotarySinDramBlockWindow>
90 operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile
91 const QElementFunction& q_element_func,
92 KDramBlockWindow& k_dram_block_window, // N0*K0 tile
93 index_t i_page_block_k,
94 const KPageBlockNavigator& k_page_block_navigator,
95 const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile
96 const KnewElementFunction& knew_element_func,
97 VDramBlockWindow& v_dram_block_window, // N1*N0 tile
98 index_t i_page_block_v,
99 const VPageBlockNavigator& v_page_block_navigator,
100 const VnewDramBlockWindow& vnew_dram_block_window, // N1*N0 tile
101 const VnewElementFunction& vnew_element_func,
102 const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window,
103 const QRotarySinDramBlockWindow q_rotary_sin_dram_block_window,
104 const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window,
105 const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window,
106 index_t rotary_dim,
107 bool skip_rotate_q,
108 bool skip_rotate_append_kv) const
109 {
110 if(!skip_rotate_append_kv)
111 {
112 // append Knew to K
113 auto knew_window = make_tile_window(
114 knew_dram_block_window, Policy::template MakeKnewDramTileDistribution<Problem>());
115
116 auto knew_tile = [&]() {
117 auto knew = load_tile(knew_window);
118 return tile_elementwise_in(knew_element_func, knew);
119 }();
120
121 // optionally apply rotary embedding to Knew
123 {
124 auto rotary_cos_window =
125 make_tile_window(knew_rotary_cos_dram_block_window,
126 Policy::template MakeRotaryCosSinTileDistribution<
127 Problem,
128 /*IsRotaryCosSinForQ=*/false>());
129
130 auto rotary_sin_window =
131 make_tile_window(knew_rotary_sin_dram_block_window,
132 Policy::template MakeRotaryCosSinTileDistribution<
133 Problem,
134 /*IsRotaryCosSinForQ=*/false>());
135
136 // We assume that each thread owns contiguous elements on head dimention. And we
137 // will use the distribution to enable/disable threads in order to override partial
138 // knew_tile content
139 auto [thread_start, thread_end] =
140 Policy::template GetKnewThreadRangeAlongK<Problem>();
141 ignore = thread_start;
142
144 knew_window,
145 rotary_cos_window,
146 rotary_sin_window,
147 rotary_dim,
148 thread_end);
149 }
150
151 store_tile(k_dram_block_window, knew_tile);
152
153 // write tile to another block if nesscary
154 if constexpr(kIsPagedKV)
155 {
156 if(k_page_block_navigator.is_cross_block(i_page_block_k, k_dram_block_window))
157 {
158 k_page_block_navigator.move_to_block(
159 i_page_block_k, k_dram_block_window, i_page_block_k + 1);
160 store_tile(k_dram_block_window, knew_tile);
161 }
162 }
163
164 // append Vnew to V
165 auto vnew_window = make_tile_window(
166 vnew_dram_block_window, Policy::template MakeVnewDramTileDistribution<Problem>());
167
168 auto vnew_tile = [&]() {
169 auto vnew = load_tile(vnew_window);
170 return tile_elementwise_in(vnew_element_func, vnew);
171 }();
172
173 store_tile(v_dram_block_window, vnew_tile);
174
175 // write tile to another block if nesscary
176 if constexpr(kIsPagedKV)
177 {
178 if(v_page_block_navigator.is_cross_block(i_page_block_v, v_dram_block_window))
179 {
180 v_page_block_navigator.move_to_block(
181 i_page_block_v, v_dram_block_window, i_page_block_v + 1);
182 store_tile(v_dram_block_window, vnew_tile);
183 }
184 }
185 }
186
187 if(!skip_rotate_q)
188 {
189 // optionally apply rotary embedding to Q
191 {
192 auto q_window = make_tile_window(
193 q_dram_block_window, Policy::template MakeQDramTileDistribution<Problem>());
194
195 auto q_tile = [&]() {
196 auto q = load_tile(q_window);
197 return tile_elementwise_in(q_element_func, q);
198 }();
199
200 auto rotary_cos_window =
201 make_tile_window(q_rotary_cos_dram_block_window,
202 Policy::template MakeRotaryCosSinTileDistribution<
203 Problem,
204 /*IsRotaryCosSinForQ=*/true>());
205
206 auto rotary_sin_window =
207 make_tile_window(q_rotary_sin_dram_block_window,
208 Policy::template MakeRotaryCosSinTileDistribution<
209 Problem,
210 /*IsRotaryCosSinForQ=*/true>());
211
212 // We assume that each thread owns contiguous elements on head dimention. And we
213 // will use the distribution to enable/disable threads in order to override partial
214 // q_tile content
215 auto [thread_start, thread_end] = Policy::template GetQThreadRangeAlongK<Problem>();
216 ignore = thread_start;
217
219 q_tile, q_window, rotary_cos_window, rotary_sin_window, rotary_dim, thread_end);
220
221 store_tile(q_dram_block_window, q_tile);
222 }
223 }
224 }
225
226 template <typename QDramBlockWindow,
227 typename KDramBlockWindow,
228 typename KPageBlockNavigator,
229 typename KnewDramBlockWindow,
230 typename VDramBlockWindow,
231 typename VPageBlockNavigator,
232 typename VnewDramBlockWindow,
233 typename QRotaryCosDramBlockWindow,
234 typename QRotarySinDramBlockWindow,
235 typename KnewRotaryCosDramBlockWindow,
236 typename KnewRotarySinDramBlockWindow>
238 operator()(QDramBlockWindow& q_dram_block_window,
239 KDramBlockWindow& k_dram_block_window,
240 index_t i_page_block_k,
241 const KPageBlockNavigator& k_page_block_navigator,
242 const KnewDramBlockWindow& knew_dram_block_window,
243 VDramBlockWindow& v_dram_block_window,
244 index_t i_page_block_v,
245 const VPageBlockNavigator& v_page_block_navigator,
246 const VnewDramBlockWindow& vnew_dram_block_window,
247 const QRotaryCosDramBlockWindow& q_rotary_cos_dram_block_window,
248 const QRotarySinDramBlockWindow& q_rotary_sin_dram_block_window,
249 const KnewRotaryCosDramBlockWindow& knew_rotary_cos_dram_block_window,
250 const KnewRotarySinDramBlockWindow& knew_rotary_sin_dram_block_window,
251 index_t rotary_dim,
252 bool skip_rotate_q,
253 bool skip_rotate_append_kv) const
254 {
255 return operator()(q_dram_block_window,
256 identity{},
257 k_dram_block_window,
258 i_page_block_k,
259 k_page_block_navigator,
260 knew_dram_block_window,
261 identity{},
262 v_dram_block_window,
263 i_page_block_v,
264 v_page_block_navigator,
265 vnew_dram_block_window,
266 identity{},
267 q_rotary_cos_dram_block_window,
268 q_rotary_sin_dram_block_window,
269 knew_rotary_cos_dram_block_window,
270 knew_rotary_sin_dram_block_window,
271 rotary_dim,
272 skip_rotate_q,
273 skip_rotate_append_kv);
274 }
275};
276
277} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
@ NONE
Definition block_rotary_embedding.hpp:13
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
constexpr detail::ignore_t ignore
Definition tile/core/utility/ignore.hpp:20
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
Definition block_fmha_fwd_appendkv_pipeline.hpp:14
static constexpr bool kIsPagedKV
Definition block_fmha_fwd_appendkv_pipeline.hpp:31
typename Problem::QDataType QDataType
Definition block_fmha_fwd_appendkv_pipeline.hpp:17
typename Problem::KDataType KDataType
Definition block_fmha_fwd_appendkv_pipeline.hpp:18
static constexpr index_t kAlignmentQ
Definition block_fmha_fwd_appendkv_pipeline.hpp:40
static constexpr index_t kN1
Definition block_fmha_fwd_appendkv_pipeline.hpp:28
static constexpr index_t kM0
Definition block_fmha_fwd_appendkv_pipeline.hpp:25
typename Problem::VLayout VLayout
Definition block_fmha_fwd_appendkv_pipeline.hpp:21
static constexpr bool kPadSeqLenK
Definition block_fmha_fwd_appendkv_pipeline.hpp:34
static constexpr index_t kN0
Definition block_fmha_fwd_appendkv_pipeline.hpp:26
static constexpr index_t kK0
Definition block_fmha_fwd_appendkv_pipeline.hpp:27
static constexpr bool kPadHeadDimV
Definition block_fmha_fwd_appendkv_pipeline.hpp:36
static constexpr index_t kBlockSize
Definition block_fmha_fwd_appendkv_pipeline.hpp:23
remove_cvref_t< Policy_ > Policy
Definition block_fmha_fwd_appendkv_pipeline.hpp:16
CK_TILE_HOST_DEVICE auto operator()(QDramBlockWindow &q_dram_block_window, const QElementFunction &q_element_func, KDramBlockWindow &k_dram_block_window, index_t i_page_block_k, const KPageBlockNavigator &k_page_block_navigator, const KnewDramBlockWindow &knew_dram_block_window, const KnewElementFunction &knew_element_func, VDramBlockWindow &v_dram_block_window, index_t i_page_block_v, const VPageBlockNavigator &v_page_block_navigator, const VnewDramBlockWindow &vnew_dram_block_window, const VnewElementFunction &vnew_element_func, const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window, const QRotarySinDramBlockWindow q_rotary_sin_dram_block_window, const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window, const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window, index_t rotary_dim, bool skip_rotate_q, bool skip_rotate_append_kv) const
Definition block_fmha_fwd_appendkv_pipeline.hpp:90
static constexpr auto RotaryEnum
Definition block_fmha_fwd_appendkv_pipeline.hpp:30
static constexpr bool kPadHeadDimQ
Definition block_fmha_fwd_appendkv_pipeline.hpp:35
CK_TILE_HOST_DEVICE auto operator()(QDramBlockWindow &q_dram_block_window, KDramBlockWindow &k_dram_block_window, index_t i_page_block_k, const KPageBlockNavigator &k_page_block_navigator, const KnewDramBlockWindow &knew_dram_block_window, VDramBlockWindow &v_dram_block_window, index_t i_page_block_v, const VPageBlockNavigator &v_page_block_navigator, const VnewDramBlockWindow &vnew_dram_block_window, const QRotaryCosDramBlockWindow &q_rotary_cos_dram_block_window, const QRotarySinDramBlockWindow &q_rotary_sin_dram_block_window, const KnewRotaryCosDramBlockWindow &knew_rotary_cos_dram_block_window, const KnewRotarySinDramBlockWindow &knew_rotary_sin_dram_block_window, index_t rotary_dim, bool skip_rotate_q, bool skip_rotate_append_kv) const
Definition block_fmha_fwd_appendkv_pipeline.hpp:238
static constexpr bool kPadSeqLenQ
Definition block_fmha_fwd_appendkv_pipeline.hpp:33
static constexpr index_t kAlignmentV
Definition block_fmha_fwd_appendkv_pipeline.hpp:44
remove_cvref_t< Problem_ > Problem
Definition block_fmha_fwd_appendkv_pipeline.hpp:15
typename Problem::VDataType VDataType
Definition block_fmha_fwd_appendkv_pipeline.hpp:19
static constexpr index_t kBlockPerCu
Definition block_fmha_fwd_appendkv_pipeline.hpp:51
static constexpr index_t kAlignmentK
Definition block_fmha_fwd_appendkv_pipeline.hpp:42
static CK_TILE_HOST_DEVICE void apply(DistributedTensor &tile, OtherDramBlockWindow other_window, RotaryCosDramBlockWindow rotary_cos_window, RotarySinDramBlockWindow rotary_sin_window, index_t rotary_dim, index_t thread_end)
Definition block_rotary_embedding.hpp:44
Definition tile/core/utility/functional.hpp:86