device_gemm_wmma_cshuffle_v3_b_scale.hpp Source File

device_gemm_wmma_cshuffle_v3_b_scale.hpp Source File#

Composable Kernel: device_gemm_wmma_cshuffle_v3_b_scale.hpp Source File
device_gemm_wmma_cshuffle_v3_b_scale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
25template <typename ALayout,
26 typename BLayout,
27 typename CLayout,
28 typename ADataType,
29 typename BDataType,
30 typename BScaleDataType,
31 typename CDataType,
32 typename AccDataType,
33 typename CShuffleDataType,
34 typename AElementwiseOperation,
35 typename BElementwiseOperation,
36 typename CElementwiseOperation,
37 GemmSpecialization GemmSpec,
38 index_t BlockSize,
39 index_t ScaleBlockN, // scale block for N
40 index_t ScaleBlockK, // scale block for K
41 index_t MPerBlock,
42 index_t NPerBlock,
43 index_t KPerBlock,
44 index_t AK1,
45 index_t BK1,
46 index_t MPerWmma,
47 index_t NPerWmma,
48 index_t MRepeat,
49 index_t NRepeat,
50 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
51 typename ABlockTransferThreadClusterArrangeOrder,
52 typename ABlockTransferSrcAccessOrder,
53 index_t ABlockTransferSrcVectorDim,
54 index_t ABlockTransferSrcScalarPerVector,
55 index_t ABlockTransferDstScalarPerVector_AK1,
56 bool ABlockLdsExtraM,
57 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
58 typename BBlockTransferThreadClusterArrangeOrder,
59 typename BBlockTransferSrcAccessOrder,
60 index_t BBlockTransferSrcVectorDim,
61 index_t BBlockTransferSrcScalarPerVector,
62 index_t BBlockTransferDstScalarPerVector_BK1,
63 bool BBlockLdsExtraN,
64 index_t CShuffleMRepeatPerShuffle,
65 index_t CShuffleNRepeatPerShuffle,
66 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
67 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
70 typename ComputeTypeA = CDataType,
71 typename ComputeTypeB = ComputeTypeA,
72 bool PermuteA = false,
73 bool PermuteB = false>
75 BLayout,
76 CLayout,
77 ADataType,
78 BDataType,
79 BScaleDataType,
80 CDataType,
81 ScaleBlockN,
82 ScaleBlockK,
83 AElementwiseOperation,
84 BElementwiseOperation,
85 CElementwiseOperation>
86{
87
88 // GridwiseGemm
90 ALayout,
91 BLayout,
92 Tuple<>, // DsLayout
93 CLayout,
96 BScaleDataType,
97 AccDataType,
98 CShuffleDataType,
99 Tuple<>, // DsDataType
100 CDataType,
101 AElementwiseOperation,
102 BElementwiseOperation,
103 CElementwiseOperation,
104 GemmSpec,
105 BlockSize,
106 ScaleBlockN,
107 ScaleBlockK,
108 MPerBlock,
109 NPerBlock,
110 KPerBlock,
111 AK1,
112 BK1,
113 MPerWmma,
114 NPerWmma,
115 MRepeat,
116 NRepeat,
117 ABlockTransferThreadClusterLengths_AK0_M_AK1,
118 ABlockTransferThreadClusterArrangeOrder,
119 ABlockTransferSrcAccessOrder,
120 ABlockTransferSrcVectorDim,
121 ABlockTransferSrcScalarPerVector,
122 ABlockTransferDstScalarPerVector_AK1,
123 false,
124 ABlockLdsExtraM,
125 BBlockTransferThreadClusterLengths_BK0_N_BK1,
126 BBlockTransferThreadClusterArrangeOrder,
127 BBlockTransferSrcAccessOrder,
128 BBlockTransferSrcVectorDim,
129 BBlockTransferSrcScalarPerVector,
130 BBlockTransferDstScalarPerVector_BK1,
131 false,
132 BBlockLdsExtraN,
133 CShuffleMRepeatPerShuffle,
134 CShuffleNRepeatPerShuffle,
135 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
137 BlkGemmPipeSched,
138 BlkGemmPipelineVer,
139 ComputeTypeA,
140 ComputeTypeB,
141 PermuteA,
142 PermuteB>;
143
144 using Argument = typename GridwiseGemm::Argument;
145
150 Tuple<>,
151 CDataType,
152 MPerBlock,
153 NPerBlock,
154 KPerBlock,
155 BlockSize,
156 AK1,
157 BK1,
158 GemmSpec,
160 BlkGemmPipeSched,
161 BlkGemmPipelineVer,
162 ComputeTypeA,
163 ComputeTypeB>;
164
165 // Invoker
166 using Invoker = typename DeviceGemmCommon::Invoker;
167
168 static bool IsSupportedArgument(const Argument& arg)
169 {
171 }
172
173 // polymorphic
174 bool IsSupportedArgument(const BaseArgument* p_arg) override
175 {
176 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
177 }
178
179 index_t GetKPerBlock() override { return KPerBlock; }
180
181 bool GetPermuteB() override { return PermuteB; }
182
183 static auto MakeArgument(const ADataType* p_a,
184 const BDataType* p_b,
185 CDataType* p_c,
186 index_t M,
187 index_t N,
188 index_t K,
189 index_t StrideA,
190 index_t StrideB,
191 index_t StrideC,
192 index_t StrideScaleB,
193 const BScaleDataType* p_b_scale,
194 index_t KBatch,
195 AElementwiseOperation a_element_op,
196 BElementwiseOperation b_element_op,
197 CElementwiseOperation cde_element_op)
198 {
199 return Argument{std::array<const void*, 1>{p_a},
200 std::array<const void*, 1>{p_b},
201 std::array<const void*, 0>{}, // p_ds_grid_
202 p_c,
203 M,
204 N,
205 K,
206 std::array<index_t, 1>{StrideA},
207 std::array<index_t, 1>{StrideB},
208 std::array<index_t, 0>{}, // StrideDs_
209 StrideC,
210 StrideScaleB,
211 p_b_scale,
212 KBatch,
213 a_element_op,
214 b_element_op,
215 cde_element_op};
216 }
217
218 static auto MakeInvoker() { return Invoker{}; }
219
220 // polymorphic
221 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
222 const void* p_b,
223 void* p_c,
224 index_t M,
225 index_t N,
226 index_t K,
227 index_t StrideA,
228 index_t StrideB,
229 index_t StrideC,
230 index_t StrideScaleB,
231 const void* p_b_scale,
232 index_t KBatch,
233 AElementwiseOperation a_element_op,
234 BElementwiseOperation b_element_op,
235 CElementwiseOperation c_element_op) override
236 {
237 return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
238 std::array<const void*, 1>{p_b},
239 std::array<const void*, 0>{}, // p_ds_grid_
240 static_cast<CDataType*>(p_c),
241 M,
242 N,
243 K,
244 std::array<index_t, 1>{StrideA},
245 std::array<index_t, 1>{StrideB},
246 std::array<index_t, 0>{}, // StrideDs_
247 StrideC,
248 StrideScaleB,
249 static_cast<const BScaleDataType*>(p_b_scale),
250 KBatch,
251 a_element_op,
252 b_element_op,
253 c_element_op);
254 }
255
256 // polymorphic
257 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
258 {
259 return std::make_unique<Invoker>(Invoker{});
260 }
261
262 // polymorphic
263 std::string GetTypeString() const override
264 {
265 auto str = std::stringstream();
266
267 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
270
271 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
277
278 // clang-format off
279 str << "DeviceGemm_Wmma_CShuffleV3_BScale"
280 << "<"
281 << getGemmSpecializationString(GemmSpec) << ", "
282 << std::string(ALayout::name)[0]
283 << std::string(BLayout::name)[0]
284 << std::string(CLayout::name)[0]
285 << ">"
286 << " BlkSize: "
287 << BlockSize << ", "
288 << "BlkTile: "
289 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
290 << "WaveTile: "
291 << MPerWmma<<"x"<<NPerWmma << ", "
292 << "WaveMap: "
293 << MRepeat<<"x" << NRepeat<<", "
294 << "VmemReadVec: "
295 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
296 << "BlkGemmPipelineScheduler: "
297 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
298 << "BlkGemmPipelineVersion: "
299 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
300 << "BlkGemmPipelinePrefetchStages: "
301 << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
302 << "KPack: "
304 // clang-format on
305
306 return str.str();
307 }
308};
309
310} // namespace device
311} // namespace tensor_operation
312} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
Definition gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp:127
static constexpr index_t KPack
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:154
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:186
Definition utility/tuple.hpp:117
Definition device_base.hpp:197
Definition device_gemm_wmma_cshuffle_v3_b_scale.hpp:86
GridwiseGemm_wmma_cshuffle_v3_b_scale< ALayout, BLayout, Tuple<>, CLayout, Tuple< ADataType >, Tuple< BDataType >, BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemm
Definition device_gemm_wmma_cshuffle_v3_b_scale.hpp:89
typename GridwiseGemm::Argument Argument
Definition device_gemm_wmma_cshuffle_v3_b_scale.hpp:144
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_wmma_cshuffle_v3_b_scale.hpp:174
bool GetPermuteB() override
Definition device_gemm_wmma_cshuffle_v3_b_scale.hpp:181
static auto MakeInvoker()
Definition device_gemm_wmma_cshuffle_v3_b_scale.hpp:218
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t StrideScaleB, const void *p_b_scale, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_wmma_cshuffle_v3_b_scale.hpp:221
std::string GetTypeString() const override
Definition device_gemm_wmma_cshuffle_v3_b_scale.hpp:263
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_wmma_cshuffle_v3_b_scale.hpp:257
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_wmma_cshuffle_v3_b_scale.hpp:168
index_t GetKPerBlock() override
Definition device_gemm_wmma_cshuffle_v3_b_scale.hpp:179
DeviceGemm_Wmma_CShuffleV3_Common< GridwiseGemm, Tuple< ADataType >, Tuple< BDataType >, Tuple<>, CDataType, MPerBlock, NPerBlock, KPerBlock, BlockSize, AK1, BK1, GemmSpec, Sequence< CShuffleBlockTransferScalarPerVector_NPerBlock >, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > DeviceGemmCommon
Definition device_gemm_wmma_cshuffle_v3_b_scale.hpp:146
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t StrideScaleB, const BScaleDataType *p_b_scale, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation cde_element_op)
Definition device_gemm_wmma_cshuffle_v3_b_scale.hpp:183
typename DeviceGemmCommon::Invoker Invoker
Definition device_gemm_wmma_cshuffle_v3_b_scale.hpp:166
Definition device_gemm_wmma_cshuffle_v3_common.hpp:43
Definition device_gemm_v2.hpp:93