wmma_gemm.hpp Source File

wmma_gemm.hpp Source File#

Composable Kernel: wmma_gemm.hpp Source File
wmma_gemm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck/utility/math.hpp"
9
10namespace ck {
11
30
31/*
32 * WMMA Wave Tile Always MxNxK = 16x16x16
33 * WAVE32
34 -----------------------------------
35 |RC0| | | | | | | | | | | | | | | | SubGroup 0
36 |RC1| | | | | | | | | | | | | | | |
37 |RC2| | | | | | | | | | | | | | | |
38 |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
39 |RC4|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
40 |RC5|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
41 |RC6| | | | | | | | | | | | | | | |
42 |RC7| | | | | | | | | | | | | | | |
43 -----------------------------------
44 | | | | | | | | | | | | | | | | | SubGroup 1
45 | | | | | | | | | | | | | | | | |
46 | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
47 | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
48 | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
49 | | | | | | | | | | | | | | | | |
50 | | | | | | | | | | | | | | | | |
51 | | | | | | | | | | | | | | | | |
52 -----------------------------------
53
54
55 * WAVE64
56 -----------------------------------
57 |RC0|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 0
58 |RC1|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
59 |RC2|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
60 |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
61 -----------------------------------
62 | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 1
63 | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
64 | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
65 | | | | | | | | | | | | | | | | |
66 -----------------------------------
67 | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 2
68 | 3 |3|3|3|3|3|3|3|4|4|4|4|4|4|4|4|
69 | 2 |3|4|5|6|7|8|9|0|1|2|3|4|5|6|7|
70 | | | | | | | | | | | | | | | | |
71 -----------------------------------
72 | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 3
73 | 4 |4|5|5|5|5|5|5|5|5|5|5|6|6|6|6|
74 | 8 |9|0|1|2|3|4|5|6|7|8|9|0|1|2|3|
75 | | | | | | | | | | | | | | | | |
76 -----------------------------------
77
78* RC = Register for storing accumalted result
79* T = Thread ID
80*/
81
82template <WmmaInstr Instr, index_t WaveSize, typename = void>
84{
85};
86
87// A-swizzled
88template <index_t WaveSize>
90 WaveSize,
91 typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
92{
93 // Absolute fixing property
94 // * Data Pixel
95 static constexpr index_t m_per_wmma = 16;
96 static constexpr index_t n_per_wmma = 16;
97 static constexpr index_t k_per_wmma = 16;
98 static constexpr index_t src_a_data_size = 2;
99 static constexpr index_t src_b_data_size = 2;
100 static constexpr index_t acc_data_size = 4;
101 static constexpr index_t acc_pack_number = 1;
102 // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
104
105 // Wave mode dependent propety
106 static constexpr index_t wave_size = Number<WaveSize>{};
107 // * Fixed on gfx11, Will be wave mode dependent for future architectures
110 // * num_acc_vgprs_per_wave alone M direction
111 // * num_subgroups alone M direction
115
116 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
117 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
118 {
119 if constexpr(wave_size == 32)
120 {
122 }
123 else if constexpr(wave_size == 64)
124 {
126 }
127 }
128};
129
130template <index_t WaveSize>
132 WaveSize,
133 typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
134{
135 // Absolute fixing property
136 static constexpr index_t m_per_wmma = 16;
137 static constexpr index_t n_per_wmma = 16;
138 static constexpr index_t k_per_wmma = 16;
139 static constexpr index_t src_a_data_size = 2;
140 static constexpr index_t src_b_data_size = 2;
141 static constexpr index_t acc_data_size = 4;
142 static constexpr index_t acc_pack_number = 1;
144
145 // Wave mode dependent propety
146 static constexpr index_t wave_size = Number<WaveSize>{};
152
153 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
154 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
155 {
156 if constexpr(wave_size == 32)
157 {
159 }
160 else if constexpr(wave_size == 64)
161 {
163 }
164 }
165};
166
167template <index_t WaveSize>
169 WaveSize,
170 typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
171{
172 // Absolute fixing property
173 static constexpr index_t m_per_wmma = 16;
174 static constexpr index_t n_per_wmma = 16;
175 static constexpr index_t k_per_wmma = 16;
176 static constexpr index_t src_a_data_size = 2;
177 static constexpr index_t src_b_data_size = 2;
178 static constexpr index_t acc_data_size = 2;
179 static constexpr index_t acc_pack_number = 2;
181
182 // Wave mode dependent propety
183 static constexpr index_t wave_size = Number<WaveSize>{};
189
190 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
191 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
192 {
193 if constexpr(wave_size == 32)
194 {
196 }
197 else if constexpr(wave_size == 64)
198 {
200 }
201 }
202};
203template <index_t WaveSize>
205 WaveSize,
206 typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
207{
208 // Absolute fixing property
209 static constexpr index_t m_per_wmma = 16;
210 static constexpr index_t n_per_wmma = 16;
211 static constexpr index_t k_per_wmma = 16;
212 static constexpr index_t src_a_data_size = 2;
213 static constexpr index_t src_b_data_size = 2;
214 static constexpr index_t acc_data_size = 2;
215 static constexpr index_t acc_pack_number = 2;
217
218 // Wave mode dependent propety
219 static constexpr index_t wave_size = Number<WaveSize>{};
225
226 template <index_t MPerWmma,
227 index_t NPerWmma,
228 index_t Opsel,
229 class FloatA,
230 class FloatB,
231 class FloatC>
232 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
233 {
234 if constexpr(wave_size == 32)
235 {
237 }
238 else if constexpr(wave_size == 64)
239 {
241 }
242 }
243};
244
245template <index_t WaveSize>
247 WaveSize,
248 typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
249{
250 // Absolute fixing property
251 static constexpr index_t m_per_wmma = 16;
252 static constexpr index_t n_per_wmma = 16;
253 static constexpr index_t k_per_wmma = 16;
254 static constexpr index_t src_a_data_size = 2;
255 static constexpr index_t src_b_data_size = 2;
256 static constexpr index_t acc_data_size = 4;
257 static constexpr index_t acc_pack_number = 1;
259
260 // Wave mode dependent propety
261 static constexpr index_t wave_size = Number<WaveSize>{};
267
268 template <index_t MPerWmma,
269 index_t NPerWmma,
270 class FloatA,
271 class FloatB,
272 class FloatC,
273 bool neg_a = true,
274 bool neg_b = true,
275 bool clamp = false>
276 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
277 {
278 if constexpr(wave_size == 32)
279 {
281 a, b, reg_c);
282 }
283 else if constexpr(wave_size == 64)
284 {
286 a, b, reg_c);
287 }
288 }
289};
290
291// gfx12
292
293// A-swizzled
294template <index_t WaveSize>
296 WaveSize,
297 typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
298{
299 // Absolute fixing property
300 // * Data Pixel
301 static constexpr index_t m_per_wmma = 16;
302 static constexpr index_t n_per_wmma = 16;
303 static constexpr index_t k_per_wmma = 16;
304 // static constexpr index_t src_a_data_size = 2;
305 // static constexpr index_t src_b_data_size = 2;
306 // static constexpr index_t acc_data_size = 4;
307 // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
308 static constexpr index_t acc_data_size = 4;
309 static constexpr index_t acc_pack_number = 1;
311
312 // Wave mode dependent propety
313 static constexpr index_t wave_size = Number<WaveSize>{};
314 // * Fixed for gfx11, Will be wave mode dependent on gfx12
315 // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
316 // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
317 // * num_acc_vgprs_per_wave alone M direction
318 // * num_subgroups alone M direction
321
322 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
323 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
324 {
325 static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
326 if constexpr(wave_size == 32)
327 {
329 }
330 }
331};
332
333template <index_t WaveSize>
335 WaveSize,
336 typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
337{
338 // Absolute fixing property
339 static constexpr index_t m_per_wmma = 16;
340 static constexpr index_t n_per_wmma = 16;
341 static constexpr index_t k_per_wmma = 16;
342 // static constexpr index_t src_a_data_size = 2;
343 // static constexpr index_t src_b_data_size = 2;
344 static constexpr index_t acc_data_size = 4;
345 static constexpr index_t acc_pack_number = 1;
347
348 // Wave mode dependent propety
349 static constexpr index_t wave_size = Number<WaveSize>{};
350 // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
351 // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
354
355 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
356 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
357 {
358 static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
359 if constexpr(wave_size == 32)
360 {
362 }
363 }
364};
365
366template <index_t WaveSize>
368 WaveSize,
369 typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
370{
371 // Absolute fixing property
372 static constexpr index_t m_per_wmma = 16;
373 static constexpr index_t n_per_wmma = 16;
374 static constexpr index_t k_per_wmma = 16;
375 // static constexpr index_t src_a_data_size = 2;
376 // static constexpr index_t src_b_data_size = 2;
377 static constexpr index_t acc_data_size = 4;
378 static constexpr index_t acc_pack_number = 1;
380
381 // Wave mode dependent propety
382 static constexpr index_t wave_size = Number<WaveSize>{};
383 // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
384 // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
387
388 template <index_t MPerWmma,
389 index_t NPerWmma,
390 class FloatA,
391 class FloatB,
392 class FloatC,
393 bool neg_a = true,
394 bool neg_b = true,
395 bool clamp = false>
396 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
397 {
398 static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
399 if constexpr(wave_size == 32)
400 {
402 a, b, reg_c);
403 }
404 }
405};
406
407template <index_t WaveSize>
409 WaveSize,
410 typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
411{
412 // Absolute fixing property
413 static constexpr index_t m_per_wmma = 16;
414 static constexpr index_t n_per_wmma = 16;
415 static constexpr index_t k_per_wmma = 16;
416 static constexpr index_t acc_data_size = 4;
417 static constexpr index_t acc_pack_number = 1;
419
420 // Wave mode dependent propety
421 static constexpr index_t wave_size = Number<WaveSize>{};
424
425 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
426 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
427 {
428 static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
429 if constexpr(wave_size == 32)
430 {
431#ifdef __gfx12__
433#else
434 ignore = a;
435 ignore = b;
436 ignore = reg_c;
437#endif
438 }
439 }
440};
441
442template <index_t WaveSize>
444 WaveSize,
445 typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
446{
447 // Absolute fixing property
448 static constexpr index_t m_per_wmma = 16;
449 static constexpr index_t n_per_wmma = 16;
450 static constexpr index_t k_per_wmma = 16;
451 static constexpr index_t acc_data_size = 4;
452 static constexpr index_t acc_pack_number = 1;
454
455 // Wave mode dependent propety
456 static constexpr index_t wave_size = Number<WaveSize>{};
459
460 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
461 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
462 {
463 static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
464 if constexpr(wave_size == 32)
465 {
466#ifdef __gfx12__
468#else
469 ignore = a;
470 ignore = b;
471 ignore = reg_c;
472#endif
473 }
474 }
475};
476
477template <index_t WaveSize>
479 WaveSize,
480 typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
481{
482 // Absolute fixing property
483 static constexpr index_t m_per_wmma = 16;
484 static constexpr index_t n_per_wmma = 16;
485 static constexpr index_t k_per_wmma = 16;
486 static constexpr index_t acc_data_size = 4;
487 static constexpr index_t acc_pack_number = 1;
489
490 // Wave mode dependent propety
491 static constexpr index_t wave_size = Number<WaveSize>{};
494
495 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
496 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
497 {
498 static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
499 if constexpr(wave_size == 32)
500 {
501#ifdef __gfx12__
503#else
504 ignore = a;
505 ignore = b;
506 ignore = reg_c;
507#endif
508 }
509 }
510};
511
512template <index_t WaveSize>
514 WaveSize,
515 typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
516{
517 // Absolute fixing property
518 static constexpr index_t m_per_wmma = 16;
519 static constexpr index_t n_per_wmma = 16;
520 static constexpr index_t k_per_wmma = 16;
521 static constexpr index_t acc_data_size = 4;
522 static constexpr index_t acc_pack_number = 1;
524
525 // Wave mode dependent propety
526 static constexpr index_t wave_size = Number<WaveSize>{};
529
530 template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
531 __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
532 {
533 static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
534 if constexpr(wave_size == 32)
535 {
536#ifdef __gfx12__
538#else
539 ignore = a;
540 ignore = b;
541 ignore = reg_c;
542#endif
543 }
544 }
545};
546
547template <typename src_type_a,
548 typename src_type_b,
549 typename dst_type,
550 index_t MPerWmma,
551 index_t NPerWmma>
553{
554 template <typename src_type_a_,
555 typename src_type_b_,
556 typename dst_type_,
557 index_t MPerWmma_,
558 index_t NPerWmma_>
559 static constexpr auto GetWmma();
560
561 template <>
563 {
564#ifdef __gfx12__
566#else
568#endif
569 }
570
571 template <>
573 {
574#ifdef __gfx12__
576#else
578#endif
579 }
580
581 template <>
586
587 template <>
592
593 template <>
595 {
596#ifdef __gfx12__
598#else
600#endif
601 }
602
603#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
604 template <>
606 {
608 }
609#endif
610
611 template <>
616
617 template <>
622
623 template <>
628
629 template <>
634
635 // get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
638
639 __host__ __device__ constexpr WmmaSelector()
640 {
641 static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
642
643 static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
644
645 static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
646
647 static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
648 selected_wmma.acc_data_size * selected_wmma.acc_pack_number ==
649 selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
650 "WRONG! Invalid Number of Accumulator Register");
651 }
652};
653
654template <typename src_type_a,
655 typename src_type_b,
656 typename dst_type,
657 index_t MPerWmma,
658 index_t NPerWmma,
659 index_t KPack,
660 bool TransposeC = false,
661 bool AssemblyBackend = false>
663{
664 static constexpr auto I0 = Number<0>{};
665 static constexpr auto I1 = Number<1>{};
666 static constexpr auto I2 = Number<2>{};
667 static constexpr auto I3 = Number<3>{};
668 static constexpr auto I4 = Number<4>{};
669 static constexpr auto I5 = Number<5>{};
670
673
674 __host__ __device__ constexpr WmmaGemm()
675 {
676 static_assert(NPerWmma == 16 && MPerWmma == 16,
677 "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
678
679 static_assert(KPack % wmma_instr.k_per_wmma == 0, "KPack should be multiple of k_per_wmma");
680 }
681
682 // WMMA output supporting C = A * B
683 // Vector Write
684 // MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
685 template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
686 __host__ __device__ static constexpr auto
688 const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
689 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
690 {
691 const auto MBlockxRepeat =
692 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
693 const auto NBlockxRepeat =
694 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
695 const auto MWave =
696 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
697 const auto NWave =
698 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
699
701 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
703 make_pass_through_transform(MBlockxRepeat),
706 Number<wmma_instr.num_acc_vgprs_per_wave>{})),
707 make_pass_through_transform(NBlockxRepeat),
711 Sequence<1>{},
712 Sequence<2>{},
713 Sequence<3>{},
714 Sequence<4>{},
715 Sequence<5>{}),
717 Sequence<1>{},
719 Sequence<3>{},
720 Sequence<4>{},
721 Sequence<5>{}));
722 }
723
724 // Transposed WMMA Output C' = B' * A'
725 template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
726 __host__ __device__ static constexpr auto
728 const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
729 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
730 {
731 const auto MBlockxRepeat =
732 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
733 const auto NBlockxRepeat =
734 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
735 const auto MWave =
736 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
737 const auto NWave =
738 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
739
741 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
743 make_pass_through_transform(MBlockxRepeat),
746 make_pass_through_transform(NBlockxRepeat),
749 Number<wmma_instr.num_acc_vgprs_per_wave>{}))),
751 Sequence<1>{},
752 Sequence<2>{},
753 Sequence<3>{},
754 Sequence<4>{},
755 Sequence<5>{}),
757 Sequence<1>{},
758 Sequence<2>{},
759 Sequence<3>{},
760 Sequence<4>{},
761 Sequence<5, 6>{}));
762 }
763
764 __device__ static constexpr index_t GetRegSizePerWmma()
765 {
766 return wmma_instr.num_acc_vgprs_per_wave * wmma_instr.acc_pack_number;
767 }
768
769 __device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
770
771 template <class FloatA, class FloatB, class FloatC>
772 __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
773 {
774 static_assert(
788#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
791#endif
792 false,
793 "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
794 "((f8 or bf8, f8 or bf8), float), (int8, int32) or (int4, int32)!");
795 static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) {
796 // Integer wmma operators need extra input flags to indicate if the input is signed or
797 // unsigned. At the moment CK supports only signed integer inputs, so these flags are
798 // hardcoded.
799 if constexpr(!TransposeC)
800 {
801 wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);
802 }
803 else
804 {
805 wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave[k], p_a_wave[k], p_c_thread);
806 }
807 });
808 }
809
810 __device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
811
812 __device__ static auto GetSubGroupId()
813 {
814 static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups ==
815 wmma_instr.wave_size,
816 "");
817 return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups;
818 }
819
820 __device__ static auto GetLaneIdUnderSubGroup()
821 {
822 return GetLaneId() % wmma_instr.num_thread_per_subgroups;
823 }
824 __device__ static auto GetSwizzledLaneIdLow()
825 {
826 return ((GetLaneIdUnderSubGroup() & 1) << 3) | (GetLaneIdUnderSubGroup() >> 1);
827 }
828
829 __host__ __device__ static auto CalculateAThreadOriginDataIndex()
830 {
831#ifdef __gfx12__
832 return GetLaneIdUnderSubGroup();
833#else
834 return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
835#endif
836 }
837
838 __host__ __device__ static auto CalculateBThreadOriginDataIndex()
839 {
840#ifdef __gfx12__
841 return GetLaneIdUnderSubGroup();
842#else
843 return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
844#endif
845 }
846
847 __device__ static CIndex GetBeginOfThreadBlk()
848 {
849 index_t n_offset = GetLaneIdUnderSubGroup();
850 index_t m_offset = GetSubGroupId() * wmma_instr.num_acc_vgprs_per_wave;
851
852 return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
853 }
854
855 __device__ static CIndex3D GetBeginOfThreadBlk3D()
856 {
857 index_t n_offset = GetLaneIdUnderSubGroup();
858 index_t m_offset = GetSubGroupId();
859
860 return TransposeC ? CIndex3D{n_offset, m_offset, I0} : CIndex3D{m_offset, n_offset, I0};
861 }
862
865 static constexpr auto wmma_instr = wmma.selected_wmma;
866
867 __host__ __device__ static constexpr auto
875};
876
877} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
WmmaInstr
Definition wmma_gemm.hpp:13
@ wmma_f32_16x16x16_bf16_gfx12
Definition wmma_gemm.hpp:23
@ wmma_i32_16x16x16_iu8_gfx12
Definition wmma_gemm.hpp:24
@ wmma_f32_16x16x16_bf8f8_gfx12
Definition wmma_gemm.hpp:27
@ wmma_f32_16x16x16_f16_gfx12
Definition wmma_gemm.hpp:22
@ wmma_i32_16x16x16_iu4
Definition wmma_gemm.hpp:20
@ wmma_i32_16x16x16_iu8
Definition wmma_gemm.hpp:19
@ wmma_f32_16x16x16_bf8bf8_gfx12
Definition wmma_gemm.hpp:28
@ wmma_f32_16x16x16_f8f8_gfx12
Definition wmma_gemm.hpp:25
@ wmma_f32_16x16x16_bf16
Definition wmma_gemm.hpp:16
@ wmma_f32_16x16x16_f16
Definition wmma_gemm.hpp:15
@ wmma_bf16_16x16x16_bf16
Definition wmma_gemm.hpp:18
@ wmma_f16_16x16x16_f16
Definition wmma_gemm.hpp:17
@ wmma_f32_16x16x16_f8bf8_gfx12
Definition wmma_gemm.hpp:26
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition utility/sequence.hpp:43
static constexpr auto I0
Definition wmma_gemm.hpp:664
__host__ static __device__ constexpr auto MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition wmma_gemm.hpp:727
static __device__ constexpr index_t GetRegSizePerWmma()
Definition wmma_gemm.hpp:764
static __device__ auto GetLaneId()
Definition wmma_gemm.hpp:810
static __device__ constexpr index_t GetWaveSize()
Definition wmma_gemm.hpp:769
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition wmma_gemm.hpp:772
__host__ __device__ constexpr WmmaGemm()
Definition wmma_gemm.hpp:674
static constexpr auto wmma
Definition wmma_gemm.hpp:863
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition wmma_gemm.hpp:829
static __device__ auto GetSubGroupId()
Definition wmma_gemm.hpp:812
__host__ static __device__ constexpr auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
Definition wmma_gemm.hpp:868
static __device__ auto GetSwizzledLaneIdLow()
Definition wmma_gemm.hpp:824
static constexpr auto I3
Definition wmma_gemm.hpp:667
static constexpr auto I5
Definition wmma_gemm.hpp:669
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition wmma_gemm.hpp:838
static __device__ CIndex GetBeginOfThreadBlk()
Definition wmma_gemm.hpp:847
static constexpr auto I4
Definition wmma_gemm.hpp:668
MultiIndex< 3 > CIndex3D
Definition wmma_gemm.hpp:672
__host__ static __device__ constexpr auto MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition wmma_gemm.hpp:687
static constexpr auto I2
Definition wmma_gemm.hpp:666
static __device__ CIndex3D GetBeginOfThreadBlk3D()
Definition wmma_gemm.hpp:855
static constexpr auto I1
Definition wmma_gemm.hpp:665
static __device__ auto GetLaneIdUnderSubGroup()
Definition wmma_gemm.hpp:820
MultiIndex< 2 > CIndex
Definition wmma_gemm.hpp:671
static constexpr auto wmma_instr
Definition wmma_gemm.hpp:865
Definition wmma_gemm.hpp:553
static constexpr auto selected_wmma
Definition wmma_gemm.hpp:636
static constexpr auto GetWmma()
__host__ __device__ constexpr WmmaSelector()
Definition wmma_gemm.hpp:639
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition amd_wmma.hpp:96
Definition amd_wmma.hpp:216
Definition amd_wmma.hpp:72
Definition amd_wmma.hpp:192
Definition amd_wmma.hpp:50
Definition amd_wmma.hpp:170
Definition amd_wmma.hpp:25
Definition amd_wmma.hpp:149
Definition amd_wmma.hpp:121
Definition amd_wmma.hpp:241
Definition functional2.hpp:33
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition wmma_gemm.hpp:232
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition wmma_gemm.hpp:191
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition wmma_gemm.hpp:154
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition wmma_gemm.hpp:356
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition wmma_gemm.hpp:531
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition wmma_gemm.hpp:496
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition wmma_gemm.hpp:117
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition wmma_gemm.hpp:323
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition wmma_gemm.hpp:461
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition wmma_gemm.hpp:426
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition wmma_gemm.hpp:276
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition wmma_gemm.hpp:396
Definition wmma_gemm.hpp:84