naive_attention.hpp Source File

naive_attention.hpp Source File#

Composable Kernel: naive_attention.hpp Source File
naive_attention.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9#include <thread>
10#include <string>
11
12namespace ck_tile {
13
15{
16 DEFAULT, // maybe this tensor is not used, set some irrelevant value
17 BSHD, // [batch, seqlen, nhead, hdim]
18 BHSD, // [batch, nhead, seqlen, hdim]
19 BS3HD, // [batch, nhead, 3, seqlen, hdim], used when qkv are packed
20 PHSD, // [pages, nhead, page_size, hdim]
21 // PHSDX, // [pages, nhead, page_size/x, hdim, x], where <# used pages>*page_size = seqlen
22 PHDSX, // [pages, nhead, hdim/x, page_size, x], where <# used pages>*page_size = seqlen
23 PHDS, // [pages, nhead, hdim, page_size], where <# used pages>*page_size = seqlen
24
25 // scale layout used for dynamic dequant
26 SCALE_HS, // [nhead, tokens] or [nhead, tokens-per-group], nhe KVCache quant
27 SCALE_SH, // [tokens, nhead]
28};
29
30// will used to specialize kernel variation
32{
33 FLASH_BATCHED = 0, // standard flash attention, or xformer/sdpa, used for training
35 DECODE_PAGED, // decode attn, where kv token from another buffer called kvcache
36};
37
39{
40 NO = 0,
42 // FP8/INT8 quant for KVCache, per-token quant
43 // [num_tokens, nhead, hdim] -> [nhead, num_tokens]
45};
46
47// TODO: for simplicity, this will be used as host/device arg
49{
50 void* q_ptr;
51 void* k_ptr;
52 void* v_ptr;
53 void* o_ptr;
54 void* context_len_ptr; // [batch] used when seqlen kv come from a pointer(each element is a
55 // number, not cumsum)
56 void* page_table_ptr; // [batch, max_pages_per_seq] seqlen_kv is in different block(paged attn)
57 void* kscale_ptr; // [nhead, max_kv_tokens] used for kvcache dequant
58 void* vscale_ptr; // [nhead, max_kv_tokens] used for kvcache dequant
59 float scale_s;
60 int hdim;
61 int hdim_v; // could be cross-attn, where V and Q/K hdim are different
64 int batch_ratio_kv; // batch_q / batch_kv
65 int seqlen_q; // in decode case, this should be 1
66 int seqlen_kv; // if context_len_ptr is not nullptr, ignore this field
69 int nhead_ratio_kv; // nhead_q / nhead_kv
70 int page_size; // if paged, the seqlen-kv per each block
72 int max_kv_tokens; // used as stride to access kv scale ptr
73};
74
75// this is trait for host API
77{
78 std::string q_type;
79 std::string k_type;
80 std::string v_type;
81 std::string o_type;
82 std::string q_layout;
83 std::string k_layout;
84 std::string v_layout;
85 std::string o_layout;
86 int variation; // sync with naive_attention_variation_enum
87 int quant_algo; // sync with naive_attention_quant_algo
88};
89
90// this is trait for kernel template
91template <naive_attention_variation_enum variation_, naive_attention_quant_algo quant_algo_>
93{
94 static constexpr naive_attention_variation_enum variation = variation_;
95 static constexpr naive_attention_quant_algo quant_algo = quant_algo_;
96};
97
98// for simplicity, please do not use const-reference type for the template type
99template <typename QType,
100 typename KType,
101 typename VType,
102 typename OType,
103 typename AccType,
104 typename KVScaleType,
109 naive_attention_layout_enum KScaleLayout,
110 naive_attention_layout_enum VScaleLayout,
111 typename Traits>
113{
114 static constexpr bool is_kvcache_i8 =
115 std::is_same_v<KType, int8_t> && std::is_same_v<VType, int8_t>;
116 static constexpr bool is_kvcache_fp8 =
117 std::is_same_v<KType, fp8_t> && std::is_same_v<VType, fp8_t>;
118
119 static constexpr int v_per_token_quant_group_size = 64;
120 static constexpr int kBlockSize = 256;
121 // TODO: hardcode
122 using SoftmaxType = float; // always using float to do softmax compute
123 using QuantComputeType = float; // used for quant/dequant scale compute
124 using QCompute = KType; // src A of gemm1, same type as K
125 using PType = VType; // src A of gemm2, same type as V
126 using OAccType = float; // always float, in case int8 FA
127
128 using p_vec_type = ext_vector_t<PType, 16 / sizeof(PType)>;
130
131 // clang-format off
132 template <typename T_> struct scale_max { static constexpr float value = 1; /* dummy code */ };
133 template <> struct scale_max<int8_t> { static constexpr float value = 127.0; };
134 template <> struct scale_max<fp8_t> { static constexpr float value = 240.0; };
135 // clang-format on
136
137 __host__ __device__ naive_attention_fwd_kernel() {}
138
139 template <typename T, naive_attention_layout_enum Layout>
141 {
142 int b, s, h, d; // batch, seqlen, nhead, hdim
144 __device__ addresser(int b_, int s_, int h_, int d_, void* base_ptr_)
145 : b(b_), s(s_), h(h_), d(d_), base_ptr(reinterpret_cast<T*>(base_ptr_))
146 {
147 }
148
149 // TODO: all the batch/nhead offset will accumulate to the base pointer
150 __device__ T* get_base(int i_b, int i_h)
151 {
153 return base_ptr + i_b * s * h * d + i_h * d;
154 else if constexpr(Layout == naive_attention_layout_enum::BHSD)
155 return base_ptr + i_b * s * h * d + i_h * s * d;
156 }
157
158 __device__ int get_offset(int i_s, int i_d)
159 {
161 return i_s * h * d + i_d;
162 else if constexpr(Layout == naive_attention_layout_enum::BHSD)
163 return i_s * d + i_d;
164 }
165
166 // below set of API will directly use pointer inside this struct
167 __device__ void init(int i_b, int i_h) { base_ptr = get_base(i_b, i_h); }
168 __device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; }
169 __device__ void store(T value, int i_s, int i_d) { base_ptr[get_offset(i_s, i_d)] = value; }
170 };
171
172 template <typename T, naive_attention_layout_enum Layout>
174 {
175 int s, h, d; // page_size, nhead, hdim
176 static constexpr int x = 16 / sizeof(T); // pack 4 dword
178 int* page_table_ptr; // TODO: page table always int
179 int i_h; // store current head
180
181 __device__ page_addresser(int s_, int h_, int d_, void* base_ptr_, void* pptr_)
182 : s(s_),
183 h(h_),
184 d(d_),
185 base_ptr(reinterpret_cast<T*>(base_ptr_)),
186 page_table_ptr(reinterpret_cast<int*>(pptr_))
187 {
188 }
189
190 __device__ int64_t get_phy_page_idx(int i_s)
191 {
192 // dynamic compute page idx is simple but slow
193 int page_idx = i_s / s;
194 int phy = page_table_ptr[page_idx];
195 return static_cast<int64_t>(phy);
196 }
197
198 __device__ int get_phy_page_offset(int i_s)
199 {
200 // dynamic compute page idx is simple but slow
201 return i_s % s;
202 }
203
204 __device__ int64_t get_offset(int i_s, int i_d)
205 {
206 int page_offset = get_phy_page_offset(i_s);
207 int64_t page_idx = get_phy_page_idx(i_s);
208 int64_t base_ = page_idx * h * s * d;
210 return static_cast<int64_t>(i_h * s * d + page_offset * d + i_d) + base_;
211 else if constexpr(Layout == naive_attention_layout_enum::PHDSX)
212 {
213 int d_r = i_d / x;
214 int d_x = i_d % x;
215 return static_cast<int64_t>(i_h * d * s + d_r * s * x + page_offset * x + d_x) +
216 base_;
217 }
218 else if constexpr(Layout == naive_attention_layout_enum::PHDS)
219 {
220 return static_cast<int64_t>(i_h * d * s + i_d * s + page_offset) + base_;
221 }
222 }
223
224 // below set of API will directly use pointer inside this struct
225 __device__ void init(int /*i_b*/, int i_h_) { i_h = i_h_; }
226 __device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; }
227 __device__ void store(T /*value*/, int /*i_s*/, int /*i_d*/) {}
228 };
229
230 template <typename T, naive_attention_layout_enum Layout>
232 {
233 int s, h, d; // seqlen(tokens), nhead, hdim
235 __device__ kvscale_addresser(int s_, int h_, int d_, void* p_)
236 : s(s_), h(h_), d(d_), base_ptr(reinterpret_cast<T*>(p_))
237 {
238 }
239 __device__ int get_offset(int i_s, int i_h, int i_d)
240 {
242 {
243 // [nhead, tokens]
244 (void)i_d;
245 return i_h * s + i_s;
246 }
247 else if constexpr(Layout == naive_attention_layout_enum::DEFAULT)
248 {
249 return 0;
250 }
251 // [h, 2, d]
252 // return i_h * 2 * d + i_kv * d + i_d;
253 }
254 __device__ T load(int i_s, int i_h, int i_d) { return base_ptr[get_offset(i_s, i_h, i_d)]; }
255 };
256
257 __device__ __host__ static constexpr int get_block_size() { return kBlockSize; }
258
259 // for simpliciy, 1 WG always compute 1 token along q, compute all token along kv
260 // compute all hdim from q, compute WG_SIZE hdim from v
261 // 1) in prefill case, seqlen_q >= 1, seqlen_kv >= 1, batch_q=batch_kv
262 // 2) in decode case, seqlen_q = 1, batch_q is input num-tokens, batch_kv is 1
263 // 3) in paged-attn case, we still use 1 WG compute all the seqlen-kv for simplicity
264 // TODO: could support split-kv to validate intermediate logsum
265 __host__ static dim3 get_grid_size(naive_attention_fwd_args args)
266 {
267 constexpr int wg_size = get_block_size();
268 auto g =
269 dim3((args.hdim_v + wg_size - 1) / wg_size, args.seqlen_q, args.batch_q * args.nhead_q);
270 return g;
271 }
272
273 // reduce single pixel within a wave
274 template <typename T, typename F>
275 __device__ constexpr T wave_reduce(T local, F reduce_f)
276 {
277 // constexpr int wave_size = 64;
278 constexpr int reduce_stage = 6; // 1<<6=64
279 T v_local = local;
280#pragma unroll
281 for(int i_stage = 0; i_stage < reduce_stage; i_stage++)
282 {
283 int src_lane = __lane_id() ^ (1 << i_stage);
284 int32_t v_remote_tmp =
285 __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
286 T v_remote = bit_cast<T>(v_remote_tmp);
287 v_local = reduce_f(v_local, v_remote);
288 }
289 return v_local;
290 }
291
292 // Note: this function must be called after wave_reduce
293 // Note: better not use this under if...else... with thread divergence (syncthreads)
294 template <typename T, typename F>
295 __device__ constexpr T cross_wave_reduce(T local, F reduce_f, T* smem)
296 {
297 constexpr int waves = 4;
298 constexpr int wave_size = 64;
299 int lane_id = threadIdx.x % wave_size;
300
301 __syncthreads();
302 smem[threadIdx.x] = local;
303 __syncthreads();
304
305 // the data within single wave is the same
306 // but for simplicity, we still use data from each lane.
307 T v_local = smem[lane_id];
308#pragma unroll
309 for(int i_stage = 1; i_stage < waves; i_stage++)
310 {
311 T v_remote = smem[i_stage * wave_size + lane_id];
312 v_local = reduce_f(v_local, v_remote);
313 }
314 return v_local;
315 }
316
317 // kernel entry point
319 {
320 constexpr int wg_size = get_block_size();
321 __shared__ char smem[wg_size * 4 * sizeof(float)]; // should enough
322 char* smem_quant_q = smem + wg_size * 2 * sizeof(float); // second half, should enough
323 int i_dv = blockIdx.x * wg_size + threadIdx.x; // index of hdim_v
324 int i_sq = blockIdx.y; // index of seqlen_q
325 int i_batch = blockIdx.z; // index of batch_q * nhead_q
326 int i_bq = i_batch / args.nhead_q; // index of batch_q
327 int i_hq = i_batch % args.nhead_q; // index of nhead_q
328
329 int i_bk = i_bq / args.batch_ratio_kv;
330 int i_hk = i_hq / args.nhead_ratio_kv;
331
332 void* page_table_ptr = [&]() {
333 if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
334 {
335 return reinterpret_cast<int*>(args.page_table_ptr) + i_bq * args.max_pages_per_seq;
336 }
337 else
338 {
339 return nullptr;
340 }
341 }();
342
343 auto q_addr = [&]() {
344 if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
345 {
347 args.batch_q, args.seqlen_q, args.nhead_q, args.hdim, args.q_ptr};
348 }
349 else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
350 {
352 args.batch_q, args.seqlen_q, args.nhead_q, args.hdim, args.q_ptr};
353 }
354 }();
355 auto k_addr = [&]() {
356 if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
357 {
359 args.batch_kv, args.seqlen_kv, args.nhead_kv, args.hdim, args.k_ptr};
360 }
361 else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
362 {
364 args.page_size, args.nhead_kv, args.hdim, args.k_ptr, page_table_ptr};
365 }
366 }();
367 auto v_addr = [&]() {
368 if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
369 {
371 args.batch_kv, args.seqlen_kv, args.nhead_kv, args.hdim_v, args.v_ptr};
372 }
373 else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
374 {
376 args.page_size, args.nhead_kv, args.hdim_v, args.v_ptr, page_table_ptr};
377 }
378 }();
379 auto o_addr = [&]() {
380 if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
381 {
383 args.batch_q, args.seqlen_q, args.nhead_q, args.hdim_v, args.o_ptr};
384 }
385 else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
386 {
388 args.batch_q, args.seqlen_q, args.nhead_q, args.hdim_v, args.o_ptr};
389 }
390 }();
391
392 q_addr.init(i_bq, i_hq);
393 k_addr.init(i_bk, i_hk);
394 v_addr.init(i_bk, i_hk);
395 o_addr.init(i_bq, i_hq);
396
397 auto f_max = [](auto x_, auto y_) { return max(x_, y_); };
398 auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
399 auto f_absmax_f32 = [](float v_0_, float v_1_) {
400 // float rtn;
401 // asm volatile("v_max_f32 %0, abs(%1), abs(%2)" : "=v"(rtn) : "v"(v_0_), "v"(v_1_));
402 // return rtn;
403 return max(abs(v_0_), abs(v_1_));
404 };
405
406 int seqlen_kv = [&]() {
407 if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
408 {
409 return args.seqlen_kv;
410 }
411 else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
412 {
413 return reinterpret_cast<int*>(args.context_len_ptr)[i_bq];
414 }
415 }();
416
418 SoftmaxType l{0};
419 // AccType o_acc = {0};
420 OAccType o_acc = {0};
421
422 int sk_loops = (seqlen_kv + wg_size - 1) / wg_size;
423 QuantComputeType q_dequant_scale = .0f;
425 args.max_kv_tokens, args.nhead_kv, args.hdim, args.kscale_ptr};
427 args.max_kv_tokens, args.nhead_kv, args.hdim_v, args.vscale_ptr};
428
429 if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
430 {
431 // AccType is i32 now, seqlen_q = 1, hdim up to 256
432 AccType q = 0;
433 AccType k_s = 0;
434 if(static_cast<int>(threadIdx.x) < args.hdim)
435 {
436 q = type_convert<AccType>(q_addr.load(0, threadIdx.x));
437 k_s = type_convert<AccType>(kscale_addr.load(i_hk, threadIdx.x, 0));
438 }
439 // 1) we apply the k scale to q
440 AccType q_forwarded = q * k_s;
441
442 // 2) apply smooth-quant
443 // find absmax
444 AccType qf_max = wave_reduce(q_forwarded, f_absmax_f32);
445 qf_max = cross_wave_reduce(qf_max, f_absmax_f32, reinterpret_cast<AccType*>(smem));
446
447 // per-token scale
449
450 // devide by scale
451 q = q / q_dequant_scale;
452
453 // fp32->i8
454 QCompute quantized_q = static_cast<QCompute>(q);
455 __syncthreads();
456 reinterpret_cast<QCompute*>(smem)[threadIdx.x] = quantized_q;
457 __syncthreads();
458
459 // after above process, we have 2 data
460 // 1) int8 q data stored in smem(no need to reload)
461 // 2) per-token scale q_dequant_scale, to be mul after 1st gemm
462 }
463 else if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERTOKEN)
464 {
465 if(std::is_same_v<QType, fp16_t> || std::is_same_v<QType, bf16_t>)
466 {
467 // dyanmic quant q here
468 float q = 0;
469 if(static_cast<int>(threadIdx.x) < args.hdim)
470 {
471 q = type_convert<float>(q_addr.load(i_sq, threadIdx.x));
472 }
473
474 // apply smooth-quant
475 // find absmax
476 float q_max = wave_reduce(q, f_absmax_f32);
477 q_max = cross_wave_reduce(q_max, f_absmax_f32, reinterpret_cast<float*>(smem));
478
479 // per-token scale
480 q_dequant_scale =
482
483 // devide by scale
484 q = q / q_dequant_scale;
485
486 QCompute quantized_q = type_convert<QCompute>(q);
487 __syncthreads();
488 reinterpret_cast<QCompute*>(smem_quant_q)[threadIdx.x] = quantized_q;
489 __syncthreads();
490
491 // after above process, we have 2 data
492 // 1) fp8 q data stored in smem(no need to reload from global)
493 // 2) per-token scale q_dequant_scale, to be mul after 1st gemm
494 }
495 }
496
497 for(int i_loop1 = 0; i_loop1 < sk_loops; i_loop1++)
498 {
499 int i_sk = i_loop1 * wg_size + threadIdx.x;
500 // gemm-1
502 if(i_sk < seqlen_kv)
503 {
504 AccType s_acc{0}; // clear for every loop
505 for(auto i_dq = 0; i_dq < args.hdim; i_dq++)
506 {
507 auto q = [&]() {
508 if constexpr(Traits::quant_algo ==
510 Traits::quant_algo ==
512 {
513 return reinterpret_cast<QCompute*>(smem_quant_q)[i_dq];
514 }
515 else
516 return q_addr.load(i_sq, i_dq); // q will have duplicate load
517 }();
518 auto k = [&]() { return k_addr.load(i_sk, i_dq); }();
519
521 }
522 // scale
523 s_softmax = type_convert<SoftmaxType>(s_acc);
524 s_softmax *=
526 if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
527 {
528 s_softmax *= q_dequant_scale; // post scale the per-token factor
529 }
530 else if constexpr(Traits::quant_algo ==
532 {
533 SoftmaxType k_per_token_scale =
534 type_convert<SoftmaxType>(kscale_addr.load(i_sk, i_hk, 0));
535 s_softmax *= q_dequant_scale;
536 s_softmax *= k_per_token_scale;
537 }
538 }
539
540 // s->p
541 QuantComputeType p_dequant_scale = 1.;
542 {
543 // softmax, find max
544 SoftmaxType old_max = row_max;
545 SoftmaxType cur_max = wave_reduce(s_softmax, f_max);
546
547 cur_max = cross_wave_reduce(cur_max, f_max, reinterpret_cast<SoftmaxType*>(smem));
548 row_max = max(old_max, cur_max); // update row_max
549 // softmax, exp(i_elem - max)
550 SoftmaxType p_compute = __builtin_amdgcn_exp2f(s_softmax - row_max);
551
552 // compute exp_sum
553 SoftmaxType row_sum = wave_reduce(p_compute, f_sum);
554 row_sum = cross_wave_reduce(row_sum, f_sum, reinterpret_cast<SoftmaxType*>(smem));
555
556 // l, pre-scall o_acc
557 SoftmaxType tmp = __builtin_amdgcn_exp2f(old_max - row_max);
558 l = tmp * l + row_sum;
560
561 // prepare the p_compute into smem, to let every thread read same p_compute and do
562 // 2nd gemm
563 if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
564 {
565 QuantComputeType v_s = 0;
566 if(static_cast<int>(threadIdx.x) < args.hdim_v)
567 {
568 v_s =
569 type_convert<QuantComputeType>(vscale_addr.load(i_hk, threadIdx.x, 1));
570 }
571
572 // 1) we apply the v scale to p
573 QuantComputeType p_forwarded = p_compute * v_s;
574
575 // 2) apply smooth-quant
576 // find absmax
577 QuantComputeType pf_max = wave_reduce(p_forwarded, f_absmax_f32);
578 pf_max = cross_wave_reduce(
579 pf_max, f_absmax_f32, reinterpret_cast<QuantComputeType*>(smem));
580
581 // per-token scale
582 p_dequant_scale = pf_max / scale_max<PType>::value; // 127.0;
583
584 // devide by scale
585 p_compute = p_compute / p_dequant_scale;
586
587 // fp32->i8
588 PType quantized_p = static_cast<PType>(p_compute);
589 __syncthreads();
590 reinterpret_cast<PType*>(smem)[threadIdx.x] = quantized_p;
591 __syncthreads();
592 // after above process, we have 2 data
593 // 1) int8 p data stored in smem(no need to reload)
594 // 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
595 }
596 else if constexpr(Traits::quant_algo ==
598 {
599 // forward apply the v scale to p_compute, this is compute friendly
600 auto v_scale = type_convert<QuantComputeType>(vscale_addr.load(i_sk, i_hk, 0));
601 p_compute *= v_scale;
602 // smooth-quant
603 // find absmax
604 QuantComputeType p_max = wave_reduce(p_compute, f_absmax_f32);
605 p_max = cross_wave_reduce(
606 p_max, f_absmax_f32, reinterpret_cast<QuantComputeType*>(smem));
607
608 // per-token scale
609 p_dequant_scale = p_max / scale_max<PType>::value; // 240.0;
610
611 // devide by scale
612 p_compute = p_compute / p_dequant_scale;
613
614 // fp32->i8
615 PType quantized_p = type_convert<PType>(p_compute);
616 __syncthreads();
617 reinterpret_cast<PType*>(smem)[threadIdx.x] = quantized_p;
618 __syncthreads();
619 // after above process, we have 2 data
620 // 1) fp8_t p data stored in smem(no need to reload)
621 // 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
622 }
623 else
624 {
625 __syncthreads();
626 reinterpret_cast<PType*>(smem)[threadIdx.x] = type_convert<PType>(p_compute);
627 __syncthreads();
628 }
629 }
630
631 // gemm-2, simple loop over vector by vector
632 constexpr int gemm_2_loop = wg_size / p_vec_elem;
633 {
634 AccType o_acc_local = {0};
635 int sk_start = i_loop1 * wg_size; // we start from the first seqlen_kv element
636 for(int i_loop2 = 0; i_loop2 < gemm_2_loop; i_loop2++)
637 {
638 p_vec_type p_vec = reinterpret_cast<p_vec_type*>(smem)[i_loop2];
639#pragma unroll
640 for(int i_j = 0; i_j < p_vec_elem; i_j++)
641 {
642 int sv_offset = i_loop2 * p_vec_elem + i_j;
643 int i_sv = sk_start + sv_offset;
644
645 VType v = 0;
646 if(i_dv < args.hdim_v && i_sv < seqlen_kv)
647 {
648 v = v_addr.load(i_sv, i_dv);
649 }
650
651 AccType v_compute = [&]() { return type_convert<AccType>(v); }();
652
653 o_acc_local += type_convert<AccType>(p_vec[i_j]) * v_compute;
654 }
655 }
656
657 OAccType post_scale_o_acc_local = [&]() {
658 if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
659 {
660 // apply pr scale to local acc
662 p_dequant_scale);
663 }
664 else if constexpr(Traits::quant_algo ==
666 {
667 // apply pr scale to local acc
669 p_dequant_scale);
670 }
671 else
672 {
673 return type_convert<OAccType>(o_acc_local);
674 }
675 }();
676 o_acc += post_scale_o_acc_local;
677 }
678 }
679
680 // post scale o_acc
681 {
682 SoftmaxType tmp = l == 0.f ? 0.f : 1.f / l; // in case masking
684 }
685
686 // store O
687 if(i_dv < args.hdim_v)
688 o_addr.store(type_convert<OType>(o_acc), i_sq, i_dv);
689 }
690};
691
692#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_() \
693 { \
694 using ktraits_ = naive_attention_fwd_kernel_traits< \
695 static_cast<naive_attention_variation_enum>(variation_), \
696 static_cast<naive_attention_quant_algo>(quant_algo_)>; \
697 using k_ = naive_attention_fwd_kernel<q_type_, \
698 k_type_, \
699 v_type_, \
700 o_type_, \
701 acc_type_, \
702 kvscale_type_, \
703 q_layout_, \
704 k_layout_, \
705 v_layout_, \
706 o_layout_, \
707 k_scale_layout_, \
708 v_scale_layout_, \
709 ktraits_>; \
710 dim3 grids = k_::get_grid_size(a); \
711 r = ck_tile::launch_kernel(s, \
712 ck_tile::make_kernel(k_{}, grids, k_::get_block_size(), 0, a)); \
713 }
714
715#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_() \
716 if(t.variation == 0 && t.q_layout == "bshd" && t.k_layout == "bshd" && t.v_layout == "bshd" && \
717 t.o_layout == "bshd") \
718 { \
719 constexpr auto q_layout_ = naive_attention_layout_enum::BSHD; \
720 constexpr auto k_layout_ = naive_attention_layout_enum::BSHD; \
721 constexpr auto v_layout_ = naive_attention_layout_enum::BSHD; \
722 constexpr auto o_layout_ = naive_attention_layout_enum::BSHD; \
723 constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
724 constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
725 constexpr int variation_ = 0; \
726 CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
727 } \
728 else if(t.variation == 0 && t.q_layout == "bhsd" && t.k_layout == "bhsd" && \
729 t.v_layout == "bhsd" && t.o_layout == "bhsd") \
730 { \
731 constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
732 constexpr auto k_layout_ = naive_attention_layout_enum::BHSD; \
733 constexpr auto v_layout_ = naive_attention_layout_enum::BHSD; \
734 constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
735 constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
736 constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
737 constexpr int variation_ = 0; \
738 CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
739 } \
740 else if(t.variation == 2 && t.q_layout == "bhsd" && t.k_layout == "phdsx" && \
741 t.v_layout == "phds" && t.o_layout == "bhsd") \
742 { \
743 constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
744 constexpr auto k_layout_ = naive_attention_layout_enum::PHDSX; \
745 constexpr auto v_layout_ = naive_attention_layout_enum::PHDS; \
746 constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
747 constexpr auto k_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \
748 constexpr auto v_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \
749 constexpr int variation_ = 2; \
750 CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
751 }
752
753//
757{
758 float r = -1;
759 // TODO: do not explicitly create too much instance!
760 if(t.q_type == "fp16" && t.k_type == "fp16" && t.v_type == "fp16" && t.o_type == "fp16" &&
761 t.quant_algo == 0)
762 {
763 using q_type_ = fp16_t;
764 using k_type_ = fp16_t;
765 using v_type_ = fp16_t;
766 using o_type_ = fp16_t;
767 using acc_type_ = float;
768 using kvscale_type_ = float;
769 constexpr int quant_algo_ = 0;
771 }
772 else if(t.q_type == "bf16" && t.k_type == "bf16" && t.v_type == "bf16" && t.o_type == "bf16" &&
773 t.quant_algo == 0)
774 {
775 using q_type_ = bf16_t;
776 using k_type_ = bf16_t;
777 using v_type_ = bf16_t;
778 using o_type_ = bf16_t;
779 using acc_type_ = float;
780 using kvscale_type_ = float;
781 constexpr int quant_algo_ = 0;
783 }
784 else if(t.q_type == "bf16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "bf16" &&
785 t.quant_algo == 2)
786 {
787 using q_type_ = bf16_t;
788 using k_type_ = fp8_t;
789 using v_type_ = fp8_t;
790 using o_type_ = bf16_t;
791 using acc_type_ = float; // NOTE!
792 using kvscale_type_ = float;
793 constexpr int quant_algo_ = 2;
795 }
796 else if(t.q_type == "fp16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "fp16" &&
797 t.quant_algo == 2)
798 {
799 using q_type_ = fp16_t;
800 using k_type_ = fp8_t;
801 using v_type_ = fp8_t;
802 using o_type_ = fp16_t;
803 using acc_type_ = float; // NOTE!
804 using kvscale_type_ = float;
805 constexpr int quant_algo_ = 2;
807 }
808 else if(t.q_type == "bf16" && t.k_type == "int8" && t.v_type == "int8" && t.o_type == "bf16" &&
809 t.quant_algo == 2)
810 {
811 using q_type_ = bf16_t;
812 using k_type_ = int8_t;
813 using v_type_ = int8_t;
814 using o_type_ = bf16_t;
815 using acc_type_ = int32_t; // NOTE!
816 using kvscale_type_ = float;
817 constexpr int quant_algo_ = 2;
819 }
820 return r;
821}
822
823#undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
824#undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_
825
826} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_()
Definition naive_attention.hpp:715
Definition tile/core/algorithm/cluster_descriptor.hpp:13
naive_attention_variation_enum
Definition naive_attention.hpp:32
@ FLASH_BATCHED
Definition naive_attention.hpp:33
@ FLASH_GROUPED
Definition naive_attention.hpp:34
@ DECODE_PAGED
Definition naive_attention.hpp:35
CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t, naive_attention_fwd_args a, ck_tile::stream_config s)
Definition naive_attention.hpp:754
int8_t int8_t
Definition int8.hpp:20
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_Float16 fp16_t
Definition half.hpp:110
_BitInt(8) fp8_t
Definition float8.hpp:204
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
int32_t int32_t
Definition integer.hpp:10
naive_attention_layout_enum
Definition naive_attention.hpp:15
@ BS3HD
Definition naive_attention.hpp:19
@ SCALE_HS
Definition naive_attention.hpp:26
@ SCALE_SH
Definition naive_attention.hpp:27
@ DEFAULT
Definition naive_attention.hpp:16
@ PHSD
Definition naive_attention.hpp:20
@ BSHD
Definition naive_attention.hpp:17
@ PHDS
Definition naive_attention.hpp:23
@ BHSD
Definition naive_attention.hpp:18
@ PHDSX
Definition naive_attention.hpp:22
@ NO
Definition block_position_encoding.hpp:15
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition bfloat16.hpp:400
typename impl::ext_vector< T, N >::type ext_vector_t
Definition vector_type.hpp:84
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
naive_attention_quant_algo
Definition naive_attention.hpp:39
@ KV_8BIT_PERHEAD
Definition naive_attention.hpp:41
@ KV_8BIT_PERTOKEN
Definition naive_attention.hpp:44
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
signed __int64 int64_t
Definition stdint.h:135
Layout wrapper that performs the tensor descriptor logic.
Definition layout.hpp:24
Definition naive_attention.hpp:49
int page_size
Definition naive_attention.hpp:70
int max_kv_tokens
Definition naive_attention.hpp:72
void * page_table_ptr
Definition naive_attention.hpp:56
void * o_ptr
Definition naive_attention.hpp:53
int seqlen_kv
Definition naive_attention.hpp:66
int hdim_v
Definition naive_attention.hpp:61
int hdim
Definition naive_attention.hpp:60
void * k_ptr
Definition naive_attention.hpp:51
int batch_kv
Definition naive_attention.hpp:63
int nhead_kv
Definition naive_attention.hpp:68
int nhead_ratio_kv
Definition naive_attention.hpp:69
void * kscale_ptr
Definition naive_attention.hpp:57
int max_pages_per_seq
Definition naive_attention.hpp:71
void * v_ptr
Definition naive_attention.hpp:52
int batch_q
Definition naive_attention.hpp:62
int nhead_q
Definition naive_attention.hpp:67
void * q_ptr
Definition naive_attention.hpp:50
int seqlen_q
Definition naive_attention.hpp:65
void * context_len_ptr
Definition naive_attention.hpp:54
float scale_s
Definition naive_attention.hpp:59
void * vscale_ptr
Definition naive_attention.hpp:58
int batch_ratio_kv
Definition naive_attention.hpp:64
Definition naive_attention.hpp:141
__device__ void init(int i_b, int i_h)
Definition naive_attention.hpp:167
__device__ T * get_base(int i_b, int i_h)
Definition naive_attention.hpp:150
int b
Definition naive_attention.hpp:142
int h
Definition naive_attention.hpp:142
int d
Definition naive_attention.hpp:142
int s
Definition naive_attention.hpp:142
T * base_ptr
Definition naive_attention.hpp:143
__device__ T load(int i_s, int i_d)
Definition naive_attention.hpp:168
__device__ addresser(int b_, int s_, int h_, int d_, void *base_ptr_)
Definition naive_attention.hpp:144
__device__ void store(T value, int i_s, int i_d)
Definition naive_attention.hpp:169
__device__ int get_offset(int i_s, int i_d)
Definition naive_attention.hpp:158
T * base_ptr
Definition naive_attention.hpp:234
__device__ T load(int i_s, int i_h, int i_d)
Definition naive_attention.hpp:254
__device__ int get_offset(int i_s, int i_h, int i_d)
Definition naive_attention.hpp:239
int d
Definition naive_attention.hpp:233
int h
Definition naive_attention.hpp:233
int s
Definition naive_attention.hpp:233
__device__ kvscale_addresser(int s_, int h_, int d_, void *p_)
Definition naive_attention.hpp:235
Definition naive_attention.hpp:174
int h
Definition naive_attention.hpp:175
__device__ int get_phy_page_offset(int i_s)
Definition naive_attention.hpp:198
T * base_ptr
Definition naive_attention.hpp:177
int s
Definition naive_attention.hpp:175
int i_h
Definition naive_attention.hpp:179
int d
Definition naive_attention.hpp:175
static constexpr int x
Definition naive_attention.hpp:176
__device__ T load(int i_s, int i_d)
Definition naive_attention.hpp:226
__device__ page_addresser(int s_, int h_, int d_, void *base_ptr_, void *pptr_)
Definition naive_attention.hpp:181
__device__ int64_t get_phy_page_idx(int i_s)
Definition naive_attention.hpp:190
__device__ void init(int, int i_h_)
Definition naive_attention.hpp:225
__device__ int64_t get_offset(int i_s, int i_d)
Definition naive_attention.hpp:204
__device__ void store(T, int, int)
Definition naive_attention.hpp:227
int * page_table_ptr
Definition naive_attention.hpp:178
static constexpr float value
Definition naive_attention.hpp:134
static constexpr float value
Definition naive_attention.hpp:133
Definition naive_attention.hpp:132
static constexpr float value
Definition naive_attention.hpp:132
Definition naive_attention.hpp:93
static constexpr naive_attention_variation_enum variation
Definition naive_attention.hpp:94
static constexpr naive_attention_quant_algo quant_algo
Definition naive_attention.hpp:95
float QuantComputeType
Definition naive_attention.hpp:123
VType PType
Definition naive_attention.hpp:125
ext_vector_t< PType, 16/sizeof(PType)> p_vec_type
Definition naive_attention.hpp:128
__device__ constexpr T wave_reduce(T local, F reduce_f)
Definition naive_attention.hpp:275
static constexpr bool is_kvcache_i8
Definition naive_attention.hpp:114
__device__ static __host__ constexpr int get_block_size()
Definition naive_attention.hpp:257
static constexpr int v_per_token_quant_group_size
Definition naive_attention.hpp:119
static constexpr int p_vec_elem
Definition naive_attention.hpp:129
static __host__ dim3 get_grid_size(naive_attention_fwd_args args)
Definition naive_attention.hpp:265
static constexpr bool is_kvcache_fp8
Definition naive_attention.hpp:116
float OAccType
Definition naive_attention.hpp:126
static constexpr int kBlockSize
Definition naive_attention.hpp:120
KType QCompute
Definition naive_attention.hpp:124
__device__ void operator()(naive_attention_fwd_args args)
Definition naive_attention.hpp:318
__host__ __device__ naive_attention_fwd_kernel()
Definition naive_attention.hpp:137
__device__ constexpr T cross_wave_reduce(T local, F reduce_f, T *smem)
Definition naive_attention.hpp:295
float SoftmaxType
Definition naive_attention.hpp:122
Definition naive_attention.hpp:77
std::string q_layout
Definition naive_attention.hpp:82
std::string v_layout
Definition naive_attention.hpp:84
std::string o_layout
Definition naive_attention.hpp:85
std::string k_type
Definition naive_attention.hpp:79
std::string k_layout
Definition naive_attention.hpp:83
int variation
Definition naive_attention.hpp:86
std::string v_type
Definition naive_attention.hpp:80
std::string q_type
Definition naive_attention.hpp:78
int quant_algo
Definition naive_attention.hpp:87
std::string o_type
Definition naive_attention.hpp:81
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition ck_tile/host/stream_config.hpp:30
static constexpr index_t vector_size
Definition vector_type.hpp:98