tile_fmha_traits.hpp Source File

tile_fmha_traits.hpp Source File#

Composable Kernel: tile_fmha_traits.hpp Source File
tile_fmha_traits.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
13 bool kPadSeqLenK_ /* padding for seqlen_k */,
14 bool kPadHeadDimQ_ /* paddding for hdim_q */,
15 bool kPadHeadDimV_ /* paddding for hdim_v */,
16 bool kHasLogitsSoftCap_,
17 BlockAttentionBiasEnum BiasEnum_,
18 bool kHasBiasGrad_,
19 bool kStoreLSE_,
20 bool kHasDropout_,
21 bool kDoFp8StaticQuant_,
22 index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
23 bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
25{
26 static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
27 static constexpr bool kPadSeqLenK = kPadSeqLenK_;
28 static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
29 static constexpr bool kPadHeadDimV = kPadHeadDimV_;
30 static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
31 static constexpr auto BiasEnum = BiasEnum_;
32 static constexpr bool kHasBiasGrad = kHasBiasGrad_;
33 static constexpr bool kStoreLSE = kStoreLSE_;
34 static constexpr bool kHasDropout = kHasDropout_;
35 static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
36 static constexpr index_t kBlockPerCu = kBlockPerCu_;
37 static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
38};
39
40template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
41 index_t kPadHeadDimV_ /* paddding for hdim_v */,
42 BlockAttentionBiasEnum BiasEnum_,
43 bool kHasBiasGrad_,
44 index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
46{
47 static constexpr index_t kPadHeadDimQ = kPadHeadDimQ_;
48 static constexpr index_t kPadHeadDimV = kPadHeadDimV_;
49 static constexpr auto BiasEnum = BiasEnum_;
50 static constexpr bool kHasBiasGrad = kHasBiasGrad_;
51 static constexpr index_t kBlockPerCu = kBlockPerCu_;
52
53 static_assert(kPadHeadDimQ == 0 || kPadHeadDimQ == 8 || kPadHeadDimQ == 1);
54 static_assert(kPadHeadDimV == 0 || kPadHeadDimV == 8 || kPadHeadDimV == 1);
55};
56
57template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
58 bool kPadSeqLenK_ /* padding for seqlen_k */,
59 bool kPadHeadDimQ_ /* paddding for hdim_q */,
60 bool kPadHeadDimV_ /* paddding for hdim_v */,
61 bool kHasLogitsSoftCap_,
62 BlockAttentionBiasEnum BiasEnum_,
63 bool kHasBiasGrad_,
64 bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
65 bool kIsPagedKV_,
66 bool kDoFp8StaticQuant_,
67 index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
68 bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
70{
71 static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
72 static constexpr bool kPadSeqLenK = kPadSeqLenK_;
73 static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
74 static constexpr bool kPadHeadDimV = kPadHeadDimV_;
75 static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
76 static constexpr auto BiasEnum = BiasEnum_;
77 static constexpr bool kHasBiasGrad = kHasBiasGrad_;
78 static constexpr bool kStoreLSE = kStoreLSE_;
79 static constexpr bool kIsPagedKV = kIsPagedKV_;
80 static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
81 static constexpr index_t kBlockPerCu = kBlockPerCu_;
82 static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
83};
84
85template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
86 bool kPadSeqLenK_ /* padding for seqlen_k */,
87 bool kPadHeadDimQ_ /* paddding for hdim_q */,
88 bool kPadHeadDimV_ /* paddding for hdim_v */,
89 bool kHasLogitsSoftCap_,
90 BlockAttentionBiasEnum BiasEnum_,
91 bool kHasBiasGrad_,
92 bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
93 bool kDoFp8StaticQuant_,
94 bool kIsPagedKV_,
95 bool kHasUnevenSplits_,
96 bool kMergeNumHeadGroupsSeqLenQ_ = false,
97 index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
99{
100 static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
101 static constexpr bool kPadSeqLenK = kPadSeqLenK_;
102 static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
103 static constexpr bool kPadHeadDimV = kPadHeadDimV_;
104 static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
105 static constexpr auto BiasEnum = BiasEnum_;
106 static constexpr bool kHasBiasGrad = kHasBiasGrad_;
107 static constexpr bool kStoreLSE = kStoreLSE_;
108 static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
109 static constexpr bool kIsPagedKV = kIsPagedKV_;
110 // determine if some split (length) is not divisible by tile size
111 static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
112 static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
113 static constexpr index_t kBlockPerCu = kBlockPerCu_;
114};
115
116template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
117 bool kPadHeadDimV_ /* paddding for hdim_v */,
118 bool kStoreLSE_,
119 bool kDoFp8StaticQuant_,
120 index_t kLogMaxSplits_,
121 index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
123{
124 static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
125 static constexpr bool kPadHeadDimV = kPadHeadDimV_;
126 static constexpr bool kStoreLSE = kStoreLSE_;
127 static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
128
129 static constexpr index_t kMaxSplits = (1 << kLogMaxSplits_);
130 static_assert(kMaxSplits <= get_warp_size() || kMaxSplits % get_warp_size() == 0);
131 static constexpr index_t kBlockPerCu = kBlockPerCu_;
132};
133
134template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
135 bool kPadSeqLenK_ /* padding for seqlen_k */,
136 bool kPadHeadDimQ_ /* paddding for hdim_q */,
137 bool kPadHeadDimV_ /* paddding for hdim_v */,
138 index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
140{
141 static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
142 static constexpr bool kPadSeqLenK = kPadSeqLenK_;
143 static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
144 static constexpr bool kPadHeadDimV = kPadHeadDimV_;
145 static constexpr index_t kBlockPerCu = kBlockPerCu_;
146};
147
148template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
149 bool kPadHeadDimV_ /* paddding for hdim_v */,
150 index_t kBlockPerCu_ = 2 /* hint to occupancy */>
152{
153 static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
154 static constexpr bool kPadHeadDimV = kPadHeadDimV_;
155 static constexpr index_t kBlockPerCu = kBlockPerCu_;
156};
157
158template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
159 bool kPadHeadDimQ_ /* paddding for hdim_q */,
160 index_t kBlockPerCu_ = 2 /* hint to occupancy */>
162{
163 static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
164 static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
165 static constexpr index_t kBlockPerCu = kBlockPerCu_;
166};
167
168template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
169 bool kPadSeqLenK_ /* padding for seqlen_k */,
170 bool kPadHeadDimQ_ /* paddding for hdim_q */,
171 bool kPadHeadDimV_ /* paddding for hdim_v */,
172 bool kStoreLSE_,
173 index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
175{
176 static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
177 static constexpr bool kPadSeqLenK = kPadSeqLenK_;
178 static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
179 static constexpr bool kPadHeadDimV = kPadHeadDimV_;
180 static constexpr bool kStoreLSE = kStoreLSE_;
181 static constexpr index_t kBlockPerCu = kBlockPerCu_;
182};
183
184} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
BlockAttentionBiasEnum
Definition block_attention_bias_enum.hpp:12
int32_t index_t
Definition integer.hpp:9
Definition tile_fmha_traits.hpp:162
static constexpr index_t kBlockPerCu
Definition tile_fmha_traits.hpp:165
static constexpr bool kPadHeadDimQ
Definition tile_fmha_traits.hpp:164
static constexpr bool kPadSeqLenQ
Definition tile_fmha_traits.hpp:163
Definition tile_fmha_traits.hpp:152
static constexpr index_t kBlockPerCu
Definition tile_fmha_traits.hpp:155
static constexpr bool kPadSeqLenQ
Definition tile_fmha_traits.hpp:153
static constexpr bool kPadHeadDimV
Definition tile_fmha_traits.hpp:154
Definition tile_fmha_traits.hpp:46
static constexpr index_t kPadHeadDimQ
Definition tile_fmha_traits.hpp:47
static constexpr index_t kPadHeadDimV
Definition tile_fmha_traits.hpp:48
static constexpr bool kHasBiasGrad
Definition tile_fmha_traits.hpp:50
static constexpr auto BiasEnum
Definition tile_fmha_traits.hpp:49
static constexpr index_t kBlockPerCu
Definition tile_fmha_traits.hpp:51
Definition tile_fmha_traits.hpp:140
static constexpr bool kPadHeadDimQ
Definition tile_fmha_traits.hpp:143
static constexpr bool kPadSeqLenK
Definition tile_fmha_traits.hpp:142
static constexpr index_t kBlockPerCu
Definition tile_fmha_traits.hpp:145
static constexpr bool kPadSeqLenQ
Definition tile_fmha_traits.hpp:141
static constexpr bool kPadHeadDimV
Definition tile_fmha_traits.hpp:144
Definition tile_fmha_traits.hpp:70
static constexpr auto BiasEnum
Definition tile_fmha_traits.hpp:76
static constexpr bool kDoFp8StaticQuant
Definition tile_fmha_traits.hpp:80
static constexpr bool kPadHeadDimQ
Definition tile_fmha_traits.hpp:73
static constexpr index_t kBlockPerCu
Definition tile_fmha_traits.hpp:81
static constexpr bool kHasLogitsSoftCap
Definition tile_fmha_traits.hpp:75
static constexpr bool kSkipMinSeqlenQ
Definition tile_fmha_traits.hpp:82
static constexpr bool kStoreLSE
Definition tile_fmha_traits.hpp:78
static constexpr bool kPadSeqLenQ
Definition tile_fmha_traits.hpp:71
static constexpr bool kIsPagedKV
Definition tile_fmha_traits.hpp:79
static constexpr bool kPadHeadDimV
Definition tile_fmha_traits.hpp:74
static constexpr bool kPadSeqLenK
Definition tile_fmha_traits.hpp:72
static constexpr bool kHasBiasGrad
Definition tile_fmha_traits.hpp:77
Definition tile_fmha_traits.hpp:123
static constexpr bool kPadSeqLenQ
Definition tile_fmha_traits.hpp:124
static constexpr bool kPadHeadDimV
Definition tile_fmha_traits.hpp:125
static constexpr bool kDoFp8StaticQuant
Definition tile_fmha_traits.hpp:127
static constexpr index_t kMaxSplits
Definition tile_fmha_traits.hpp:129
static constexpr bool kStoreLSE
Definition tile_fmha_traits.hpp:126
static constexpr index_t kBlockPerCu
Definition tile_fmha_traits.hpp:131
Definition tile_fmha_traits.hpp:99
static constexpr index_t kBlockPerCu
Definition tile_fmha_traits.hpp:113
static constexpr bool kStoreLSE
Definition tile_fmha_traits.hpp:107
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition tile_fmha_traits.hpp:112
static constexpr bool kPadHeadDimQ
Definition tile_fmha_traits.hpp:102
static constexpr bool kHasLogitsSoftCap
Definition tile_fmha_traits.hpp:104
static constexpr bool kHasUnevenSplits
Definition tile_fmha_traits.hpp:111
static constexpr bool kPadSeqLenK
Definition tile_fmha_traits.hpp:101
static constexpr bool kHasBiasGrad
Definition tile_fmha_traits.hpp:106
static constexpr bool kDoFp8StaticQuant
Definition tile_fmha_traits.hpp:108
static constexpr auto BiasEnum
Definition tile_fmha_traits.hpp:105
static constexpr bool kPadHeadDimV
Definition tile_fmha_traits.hpp:103
static constexpr bool kPadSeqLenQ
Definition tile_fmha_traits.hpp:100
static constexpr bool kIsPagedKV
Definition tile_fmha_traits.hpp:109
Definition tile_fmha_traits.hpp:175
static constexpr bool kStoreLSE
Definition tile_fmha_traits.hpp:180
static constexpr bool kPadHeadDimQ
Definition tile_fmha_traits.hpp:178
static constexpr bool kPadSeqLenK
Definition tile_fmha_traits.hpp:177
static constexpr bool kPadHeadDimV
Definition tile_fmha_traits.hpp:179
static constexpr index_t kBlockPerCu
Definition tile_fmha_traits.hpp:181
static constexpr bool kPadSeqLenQ
Definition tile_fmha_traits.hpp:176
Definition tile_fmha_traits.hpp:25
static constexpr bool kDoFp8StaticQuant
Definition tile_fmha_traits.hpp:35
static constexpr bool kHasBiasGrad
Definition tile_fmha_traits.hpp:32
static constexpr bool kPadHeadDimQ
Definition tile_fmha_traits.hpp:28
static constexpr bool kPadSeqLenQ
Definition tile_fmha_traits.hpp:26
static constexpr bool kPadHeadDimV
Definition tile_fmha_traits.hpp:29
static constexpr bool kHasDropout
Definition tile_fmha_traits.hpp:34
static constexpr bool kHasLogitsSoftCap
Definition tile_fmha_traits.hpp:30
static constexpr index_t kBlockPerCu
Definition tile_fmha_traits.hpp:36
static constexpr bool kSkipMinSeqlenQ
Definition tile_fmha_traits.hpp:37
static constexpr bool kPadSeqLenK
Definition tile_fmha_traits.hpp:27
static constexpr bool kStoreLSE
Definition tile_fmha_traits.hpp:33
static constexpr auto BiasEnum
Definition tile_fmha_traits.hpp:31