blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp Source File

blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp Source File
blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Maximum Global Memory throughput pipeline with >=32KB data in fly
11// GlobalPrefetchStages: >=2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 0
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
62 BlockSize,
63 ADataType,
64 BDataType,
65 ComputeDataType,
66 AccDataType,
67 ATileDesc,
68 BTileDesc,
69 AMmaTileDesc,
70 BMmaTileDesc,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
73 MPerBlock,
74 NPerBlock,
75 KPerBlock,
76 MPerXDL,
77 NPerXDL,
78 MRepeat,
79 NRepeat,
80 KPack>
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack>
100
101{
103 ADataType,
104 BDataType,
105 ComputeDataType,
106 AccDataType,
107 ATileDesc,
108 BTileDesc,
109 AMmaTileDesc,
110 BMmaTileDesc,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
113 MPerBlock,
114 NPerBlock,
115 KPerBlock,
116 MPerXDL,
117 NPerXDL,
118 MRepeat,
119 NRepeat,
120 KPack>;
121 using Base::I0;
122 using Base::KRepeat;
123 using Base::xdlops_gemm;
124
136
139
140 using Base::AMmaKStride;
141 using Base::BMmaKStride;
142 using Base::WaveSize;
143
145
146 static constexpr index_t WgpPerCU =
147 (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
149 32768 / WgpPerCU,
150 (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
151 static constexpr index_t PrefetchStages =
154 : 2;
155
156 static constexpr index_t PrefillStages = 1;
158
159 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
160 {
161 return num_loop > PrefetchStages;
162 }
163
164 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
165 {
166 if(num_loop % PrefetchStages == 1)
167 {
168 return TailNumber::One;
169 }
170 else if(num_loop % PrefetchStages == 2)
171 {
172 return TailNumber::Two;
173 }
174 else if(num_loop % PrefetchStages == 3)
175 {
176 return TailNumber::Three;
177 }
178 else if(num_loop % PrefetchStages == 4)
179 {
180 return TailNumber::Four;
181 }
182 else if(num_loop % PrefetchStages == 5)
183 {
184 return TailNumber::Five;
185 }
186 else if(num_loop % PrefetchStages == 6)
187 {
188 return TailNumber::Six;
189 }
190 else if(num_loop % PrefetchStages == 7)
191 {
192 return TailNumber::Seven;
193 }
194 else
195 {
196 return TailNumber::Full;
197 }
198 }
199
200 template <bool HasMainLoop,
201 TailNumber TailNum,
202 typename AGridDesc,
203 typename ABlockDesc,
204 typename ABlockTransfer,
205 typename AGridBuffer,
206 typename ABlockBuffer,
207 typename ABlockTransferStep,
208 typename BGridDesc,
209 typename BBlockDesc,
210 typename BBlockTransfer,
211 typename BGridBuffer,
212 typename BBlockBuffer,
213 typename BBlockTransferStep,
214 typename CThreadBuffer>
215 __device__ void Run(const AGridDesc& a_grid_desc,
216 const ABlockDesc& a_block_desc,
217 ABlockTransfer& a_blockwise_copy,
218 const AGridBuffer& a_grid_buf,
219 ABlockBuffer& a_block_buf,
220 const ABlockTransferStep& a_block_copy_step,
221 const BGridDesc& b_grid_desc,
222 const BBlockDesc& b_block_desc,
223 BBlockTransfer& b_blockwise_copy,
224 const BGridBuffer& b_grid_buf,
225 BBlockBuffer& b_block_buf,
226 const BBlockTransferStep& b_block_copy_step,
227 CThreadBuffer& c_thread_buf,
228 index_t num_loop) const
229 {
231 a_thread_desc_.GetElementSpaceSize());
233 b_thread_desc_.GetElementSpaceSize());
234
235 // Global prefetch 1
236 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
237 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
238
239 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
240 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
241
242 // Initialize C
243 c_thread_buf.Clear();
244
245 // Local prefill 1
246 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
247 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
248
249 // Global prefetch [2, PrefetchStages]
250 static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) {
251 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
252 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
253
254 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
255 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
256 });
257
258 // main body
259 if constexpr(HasMainLoop)
260 {
261 index_t i = 0;
262 do
263 {
264 static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) {
265 // -------------------------------------------------------------------------------------------
267 static_for<0, KRepeat, 1>{}([&](auto k) {
268 static_for<0, MRepeat, 1>{}([&](auto m0) {
271 a_block_buf,
273 make_tuple(m0, I0, k, I0),
274 a_thread_buf);
275 static_for<0, NRepeat, 1>{}([&](auto n0) {
276 b_thread_copy_.Run(
279 b_block_buf,
281 make_tuple(n0, I0, k, I0),
282 b_thread_buf);
283 });
284 });
285 });
286
287 static_for<0, KRepeat, 1>{}([&](auto k0) {
288 static_for<0, MRepeat, 1>{}([&](auto m0) {
289 static_for<0, NRepeat, 1>{}([&](auto n0) {
292
293 static_for<0, KPack, 1>{}([&](auto ik) {
294 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
295 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
296 make_tuple(m0, I0, k0, ik))>{}];
297 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
298 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
299 make_tuple(n0, I0, k0, ik))>{}];
300 });
301
302 using mfma_input_type =
304 xdlops_gemm.K1PerXdlops>::type;
305
306 constexpr index_t c_offset =
307 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
308
309 xdlops_gemm.Run(
310 a_thread_vec.template AsType<mfma_input_type>(),
311 b_thread_vec.template AsType<mfma_input_type>(),
312 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
313 });
314 });
315 });
316
318 a_blockwise_copy.RunWrite(
319 a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
320 b_blockwise_copy.RunWrite(
321 b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
322
323 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
324 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
325
326 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
327 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
328 });
329
330 i += PrefetchStages;
331 } while(i < (num_loop - PrefetchStages));
332 }
333
334 // tail
335
336 auto LoopTailFunc = [&](auto tail_num) {
337 static_for<1, tail_num, 1>{}([&](auto iprefetch) {
339 static_for<0, KRepeat, 1>{}([&](auto k) {
340 static_for<0, MRepeat, 1>{}([&](auto m0) {
343 a_block_buf,
345 make_tuple(m0, I0, k, I0),
346 a_thread_buf);
347 static_for<0, NRepeat, 1>{}([&](auto n0) {
350 b_block_buf,
352 make_tuple(n0, I0, k, I0),
353 b_thread_buf);
354 });
355 });
356 });
357
358 static_for<0, KRepeat, 1>{}([&](auto k0) {
359 static_for<0, MRepeat, 1>{}([&](auto m0) {
360 static_for<0, NRepeat, 1>{}([&](auto n0) {
363
364 static_for<0, KPack, 1>{}([&](auto ik) {
365 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
366 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
367 make_tuple(m0, I0, k0, ik))>{}];
368 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
369 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
370 make_tuple(n0, I0, k0, ik))>{}];
371 });
372
373 using mfma_input_type =
375 xdlops_gemm.K1PerXdlops>::type;
376
377 constexpr index_t c_offset =
378 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
379
380 xdlops_gemm.Run(
381 a_thread_vec.template AsType<mfma_input_type>(),
382 b_thread_vec.template AsType<mfma_input_type>(),
383 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
384 });
385 });
386 });
387
389 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
390 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
391 });
392
394 static_for<0, KRepeat, 1>{}([&](auto k) {
395 static_for<0, MRepeat, 1>{}([&](auto m0) {
398 a_block_buf,
400 make_tuple(m0, I0, k, I0),
401 a_thread_buf);
402 static_for<0, NRepeat, 1>{}([&](auto n0) {
405 b_block_buf,
407 make_tuple(n0, I0, k, I0),
408 b_thread_buf);
409 });
410 });
411 });
412
413 static_for<0, KRepeat, 1>{}([&](auto k0) {
414 static_for<0, MRepeat, 1>{}([&](auto m0) {
415 static_for<0, NRepeat, 1>{}([&](auto n0) {
418
419 static_for<0, KPack, 1>{}([&](auto ik) {
420 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
421 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
422 make_tuple(m0, I0, k0, ik))>{}];
423 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
424 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
425 make_tuple(n0, I0, k0, ik))>{}];
426 });
427
428 using mfma_input_type =
429 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
430
431 constexpr index_t c_offset =
432 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
433
434 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
435 b_thread_vec.template AsType<mfma_input_type>(),
436 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
437 });
438 });
439 });
440 };
441
442 if constexpr(TailNum == TailNumber::One)
443 {
445 static_for<0, KRepeat, 1>{}([&](auto k) {
446 static_for<0, MRepeat, 1>{}([&](auto m0) {
449 a_block_buf,
451 make_tuple(m0, I0, k, I0),
452 a_thread_buf);
453 static_for<0, NRepeat, 1>{}([&](auto n0) {
456 b_block_buf,
458 make_tuple(n0, I0, k, I0),
459 b_thread_buf);
460 });
461 });
462 });
463
464 static_for<0, KRepeat, 1>{}([&](auto k0) {
465 static_for<0, MRepeat, 1>{}([&](auto m0) {
466 static_for<0, NRepeat, 1>{}([&](auto n0) {
469
470 static_for<0, KPack, 1>{}([&](auto ik) {
471 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
472 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
473 make_tuple(m0, I0, k0, ik))>{}];
474 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
475 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
476 make_tuple(n0, I0, k0, ik))>{}];
477 });
478
479 using mfma_input_type =
480 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
481
482 constexpr index_t c_offset =
483 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
484
485 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
486 b_thread_vec.template AsType<mfma_input_type>(),
487 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
488 });
489 });
490 });
491 }
492 else if constexpr(TailNum == TailNumber::Two)
493 {
494 LoopTailFunc(Number<2>{});
495 }
496 else if constexpr(TailNum == TailNumber::Three)
497 {
498 LoopTailFunc(Number<3>{});
499 }
500 else if constexpr(TailNum == TailNumber::Four)
501 {
502 LoopTailFunc(Number<4>{});
503 }
504 else if constexpr(TailNum == TailNumber::Five)
505 {
506 LoopTailFunc(Number<5>{});
507 }
508 else if constexpr(TailNum == TailNumber::Six)
509 {
510 LoopTailFunc(Number<6>{});
511 }
512 else if constexpr(TailNum == TailNumber::Seven)
513 {
514 LoopTailFunc(Number<7>{});
515 }
516 else if constexpr(TailNum == TailNumber::Full)
517 {
518 LoopTailFunc(Number<PrefetchStages>{});
519 }
520 }
521
522 protected:
523 using Base::a_thread_copy_;
524 using Base::a_thread_desc_;
525 using Base::b_thread_copy_;
526 using Base::b_thread_desc_;
527 using Base::c_thread_desc_;
528};
529
530template <index_t BlockSize,
531 typename ADataType,
532 typename BDataType,
533 typename ComputeDataType,
534 typename AccDataType,
535 typename ATileDesc,
536 typename BTileDesc,
537 typename AMmaTileDesc,
538 typename BMmaTileDesc,
539 index_t ABlockTransferSrcScalarPerVector,
540 index_t BBlockTransferSrcScalarPerVector,
541 index_t MPerBlock,
542 index_t NPerBlock,
543 index_t KPerBlock,
544 index_t MPerXDL,
545 index_t NPerXDL,
546 index_t MRepeat,
547 index_t NRepeat,
548 index_t KPack
549 // ,bool TransposeC //disable transposec right now...
550 >
552 BlockSize,
553 ADataType,
554 BDataType,
555 ComputeDataType,
556 AccDataType,
557 ATileDesc,
558 BTileDesc,
559 AMmaTileDesc,
560 BMmaTileDesc,
561 ABlockTransferSrcScalarPerVector,
562 BBlockTransferSrcScalarPerVector,
563 MPerBlock,
564 NPerBlock,
565 KPerBlock,
566 MPerXDL,
567 NPerXDL,
568 MRepeat,
569 NRepeat,
570 KPack>
572 ADataType,
573 BDataType,
574 ComputeDataType,
575 AccDataType,
576 ATileDesc,
577 BTileDesc,
578 AMmaTileDesc,
579 BMmaTileDesc,
580 ABlockTransferSrcScalarPerVector,
581 BBlockTransferSrcScalarPerVector,
582 MPerBlock,
583 NPerBlock,
584 KPerBlock,
585 MPerXDL,
586 NPerXDL,
587 MRepeat,
588 NRepeat,
589 KPack>
590
591{
593 ADataType,
594 BDataType,
595 ComputeDataType,
596 AccDataType,
597 ATileDesc,
598 BTileDesc,
599 AMmaTileDesc,
600 BMmaTileDesc,
601 ABlockTransferSrcScalarPerVector,
602 BBlockTransferSrcScalarPerVector,
603 MPerBlock,
604 NPerBlock,
605 KPerBlock,
606 MPerXDL,
607 NPerXDL,
608 MRepeat,
609 NRepeat,
610 KPack>;
611 using Base::A_K1;
612 using Base::B_K1;
613 using Base::I0;
614 using Base::I1;
615 using Base::KPerThread;
616 using Base::xdlops_gemm;
617
629
632 using Base::WaveSize;
633
635
639
640 static constexpr index_t WgpPerCU =
641 (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
643 32768 / WgpPerCU,
644 (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
645 static constexpr index_t PrefetchStages =
648 : 2;
649
650 static constexpr index_t PrefillStages = 1;
652
653 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
654 {
655 return num_loop > PrefetchStages;
656 }
657
658 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
659 {
660 if(num_loop % PrefetchStages == 1)
661 {
662 return TailNumber::One;
663 }
664 else if(num_loop % PrefetchStages == 2)
665 {
666 return TailNumber::Two;
667 }
668 else if(num_loop % PrefetchStages == 3)
669 {
670 return TailNumber::Three;
671 }
672 else if(num_loop % PrefetchStages == 4)
673 {
674 return TailNumber::Four;
675 }
676 else if(num_loop % PrefetchStages == 5)
677 {
678 return TailNumber::Five;
679 }
680 else if(num_loop % PrefetchStages == 6)
681 {
682 return TailNumber::Six;
683 }
684 else if(num_loop % PrefetchStages == 7)
685 {
686 return TailNumber::Seven;
687 }
688 else
689 {
690 return TailNumber::Full;
691 }
692 }
693
694 template <bool HasMainLoop,
695 TailNumber TailNum,
696 typename AGridDesc,
697 typename ABlockDesc,
698 typename ABlockTransfer,
699 typename AGridBuffer,
700 typename ABlockBuffer,
701 typename ABlockTransferStep,
702 typename BGridDesc,
703 typename BBlockDesc,
704 typename BBlockTransfer,
705 typename BGridBuffer,
706 typename BBlockBuffer,
707 typename BBlockTransferStep,
708 typename CThreadBuffer,
709 typename BScaleGridBuffer,
710 typename BScaleGridDesc,
711 typename BScaleThreadDesc,
712 typename BScaleThreadTransfer,
713 typename BScaleThreadTransferStep>
714 __device__ void Run(const AGridDesc& a_grid_desc,
715 const ABlockDesc& a_block_desc,
716 ABlockTransfer& a_blockwise_copy,
717 const AGridBuffer& a_grid_buf,
718 ABlockBuffer& a_block_buf,
719 const ABlockTransferStep& a_block_copy_step,
720 const BGridDesc& b_grid_desc,
721 const BBlockDesc& b_block_desc,
722 BBlockTransfer& b_blockwise_copy,
723 const BGridBuffer& b_grid_buf,
724 BBlockBuffer& b_block_buf,
725 const BBlockTransferStep& b_block_copy_step,
726 CThreadBuffer& c_thread_buf,
727 const BScaleGridDesc& b_scale_grid_desc,
728 // BScaleThreadCopy
729 const BScaleThreadDesc& b_scale_thread_desc,
730 BScaleThreadTransfer& b_scale_thread_copy,
731 const BScaleGridBuffer& b_scale_grid_buf,
732 const BScaleThreadTransferStep& b_scale_thread_copy_step,
733 // num loop
734 index_t num_loop,
735 index_t num_loop_per_scale) const
736 {
737 ignore = num_loop_per_scale;
738
740 a_thread_desc_.GetElementSpaceSize());
742 b_thread_desc_.GetElementSpaceSize());
743
745 b_scale_thread_desc.GetElementSpaceSize());
746
747 // Global prefetch 1
748 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
749 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
750
751 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
752 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
753
754 static_for<0, NRepeat, 1>{}([&](auto n0) {
755 b_scale_thread_copy.Run(b_scale_grid_desc,
756 b_scale_grid_buf,
757 b_scale_thread_desc,
758 make_tuple(n0, I0),
759 b_scale_thread_buf);
760
761 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
762 b_scale_thread_copy_step.At(Number<0>{}));
763 });
764 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
765 b_scale_thread_copy_step.At(Number<1>{}));
766
767 // Initialize C
768 c_thread_buf.Clear();
769
770 // Local prefill 1
771 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
772 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
773
774 // Global prefetch [2, PrefetchStages]
775 static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) {
776 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
777 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
778
779 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
780 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
781 });
782
783 auto c_thread_buf_per_scale = remove_cvref_t<decltype(c_thread_buf)>(); // need?
784
785 // main body
786 if constexpr(HasMainLoop)
787 {
788 index_t i = 0;
789 do
790 {
791 static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) {
792 // -------------------------------------------------------------------------------------------
794 static_for<0, KRepeat, 1>{}([&](auto k0) {
795 static_for<0, MRepeat, 1>{}([&](auto m0) {
798 a_block_buf,
800 make_tuple(m0, I0, k0, I0),
801 a_thread_buf);
802 static_for<0, NRepeat, 1>{}([&](auto n0) {
803 b_thread_copy_.Run(
806 b_block_buf,
808 make_tuple(n0, I0, k0, I0),
809 b_thread_buf);
810 });
811 });
812 __builtin_amdgcn_sched_barrier(0);
813 // NOTE: Synchronize threads in a workgroup at the start of each MAC
814 // cluster, but except the first, as we can shorten non-MAC cluster a bit
815 // and there's no observable negative impact. The desired effect is waves in
816 // a workgroup executing MAC in sync. This avoids some out-of-sync waves
817 // hijacking MAC resource from other workgroups and reducing the chance of
818 // latency hiding by waiting for the rest of the workgroup at the eventual
819 // sync point.
820 if constexpr(k0.value != 0 || KRepeat == 1)
821 {
822 __builtin_amdgcn_s_barrier();
823 __builtin_amdgcn_sched_barrier(0);
824 }
826 static_for<0, MRepeat, 1>{}([&](auto m0) {
827 static_for<0, NRepeat, 1>{}([&](auto n0) {
830
831 static_for<0, KPack, 1>{}([&](auto ik) {
832 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
833 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
834 make_tuple(m0, I0, k0, k_ + ik))>{}];
835 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
836 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
837 make_tuple(n0, I0, k0, k_ + ik))>{}];
838 });
839
840 using mfma_input_type =
842 xdlops_gemm.K1PerXdlops>::type;
843
844 constexpr index_t c_offset =
845 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
846
847 // The block_sync_lds() here performs double duty:
848 // A) safeguard against data hazard because barrier from
849 // blockwise_gemm is moved here B) reduce VMEM FIFO congestion
850 // by applying small delays to different wavefronts It is
851 // performed near the end of MAC cluster to minimize lgkmcnt
852 // penalty
853 if constexpr(k0.value == KRepeat - 1 &&
854 k_.value == KPerInnerLoop - KPack &&
855 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
856 {
857 __builtin_amdgcn_sched_barrier(0);
859 __builtin_amdgcn_sched_barrier(0);
860 }
861 xdlops_gemm.Run(
862 a_thread_vec.template AsType<mfma_input_type>(),
863 b_thread_vec.template AsType<mfma_input_type>(),
864 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
865 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
866 {
867 __builtin_amdgcn_sched_barrier(0);
868 __builtin_amdgcn_s_setprio(1);
869 __builtin_amdgcn_sched_barrier(0);
870 }
871 });
872
873 // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t)
874 // {
875 // constexpr index_t c_offset =
876 // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
877 // c_thread_buf(Number<c_offset>{}) +=
878 // c_thread_buf_per_scale[Number<t>{}] *
879 // type_convert<AccDataType>(b_scale_thread_buf[n0]);
880 // });
881 });
882 });
883 __builtin_amdgcn_sched_barrier(0);
884 __builtin_amdgcn_s_setprio(0);
885 __builtin_amdgcn_sched_barrier(0);
886 });
887
888 // static_for<0, NRepeat, 1>{}([&](auto n0) {
889 // b_scale_thread_copy.Run(b_scale_grid_desc,
890 // b_scale_grid_buf,
891 // b_scale_thread_desc,
892 // make_tuple(n0, I0),
893 // b_scale_thread_buf);
894
895 // b_scale_thread_copy.MoveSrcSliceWindow(
896 // b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
897 // });
898 // b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
899 // b_scale_thread_copy_step.At(Number<1>{}));
900
901 // block_sync_lds();
902 a_blockwise_copy.RunWrite(
903 a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
904 b_blockwise_copy.RunWrite(
905 b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
906
907 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
908 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
909
910 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
911 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
912 });
913 i += PrefetchStages;
914 } while(i < (num_loop - PrefetchStages));
915 }
916
917 // tail
918
919 auto LoopTailFunc = [&](auto tail_num) {
920 static_for<1, tail_num, 1>{}([&](auto iprefetch) {
922 static_for<0, KRepeat, 1>{}([&](auto k0) {
923 static_for<0, MRepeat, 1>{}([&](auto m0) {
926 a_block_buf,
928 make_tuple(m0, I0, k0, I0),
929 a_thread_buf);
930 static_for<0, NRepeat, 1>{}([&](auto n0) {
933 b_block_buf,
935 make_tuple(n0, I0, k0, I0),
936 b_thread_buf);
937 });
938 });
939
940 __builtin_amdgcn_sched_barrier(0);
941 if constexpr(k0.value != 0 || KRepeat == 1)
942 {
943 __builtin_amdgcn_s_barrier();
944 __builtin_amdgcn_sched_barrier(0);
945 }
947 static_for<0, MRepeat, 1>{}([&](auto m0) {
948 static_for<0, NRepeat, 1>{}([&](auto n0) {
951
952 static_for<0, KPack, 1>{}([&](auto ik) {
953 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
954 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
955 make_tuple(m0, I0, k0, k_ + ik))>{}];
956 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
957 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
958 make_tuple(n0, I0, k0, k_ + ik))>{}];
959 });
960
961 using mfma_input_type =
963 xdlops_gemm.K1PerXdlops>::type;
964
965 constexpr index_t c_offset =
966 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
967
968 if constexpr(k0.value == KRepeat - 1 &&
969 k_.value == KPerInnerLoop - KPack &&
970 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
971 {
972 __builtin_amdgcn_sched_barrier(0);
974 __builtin_amdgcn_sched_barrier(0);
975 }
976 xdlops_gemm.Run(
977 a_thread_vec.template AsType<mfma_input_type>(),
978 b_thread_vec.template AsType<mfma_input_type>(),
979 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
980 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
981 {
982 __builtin_amdgcn_sched_barrier(0);
983 __builtin_amdgcn_s_setprio(1);
984 __builtin_amdgcn_sched_barrier(0);
985 }
986 });
987
988 // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
989 // constexpr index_t c_offset =
990 // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
991 // c_thread_buf(Number<c_offset>{}) +=
992 // c_thread_buf_per_scale[Number<t>{}] *
993 // type_convert<AccDataType>(b_scale_thread_buf[n0]);
994 // });
995 });
996 });
997 __builtin_amdgcn_sched_barrier(0);
998 __builtin_amdgcn_s_setprio(0);
999 __builtin_amdgcn_sched_barrier(0);
1000 });
1001
1002 // static_for<0, NRepeat, 1>{}([&](auto n0) {
1003 // b_scale_thread_copy.Run(b_scale_grid_desc,
1004 // b_scale_grid_buf,
1005 // b_scale_thread_desc,
1006 // make_tuple(n0, I0),
1007 // b_scale_thread_buf);
1008
1009 // b_scale_thread_copy.MoveSrcSliceWindow(
1010 // b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{}));
1011 // });
1012 // b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
1013 // b_scale_thread_copy_step.At(Number<1>{}));
1014
1015 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
1016 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
1017 });
1019 static_for<0, KRepeat, 1>{}([&](auto k0) {
1020 static_for<0, MRepeat, 1>{}([&](auto m0) {
1023 a_block_buf,
1025 make_tuple(m0, I0, k0, I0),
1026 a_thread_buf);
1027 static_for<0, NRepeat, 1>{}([&](auto n0) {
1030 b_block_buf,
1032 make_tuple(n0, I0, k0, I0),
1033 b_thread_buf);
1034 });
1035 });
1036
1037 __builtin_amdgcn_sched_barrier(0);
1038 if constexpr(k0.value != 0 || KRepeat == 1)
1039 {
1040 __builtin_amdgcn_s_barrier();
1041 __builtin_amdgcn_sched_barrier(0);
1042 }
1044 static_for<0, MRepeat, 1>{}([&](auto m0) {
1045 static_for<0, NRepeat, 1>{}([&](auto n0) {
1048
1049 static_for<0, KPack, 1>{}([&](auto ik) {
1050 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1051 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1052 make_tuple(m0, I0, k0, k_ + ik))>{}];
1053 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1054 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
1055 make_tuple(n0, I0, k0, k_ + ik))>{}];
1056 });
1057
1058 using mfma_input_type =
1060 xdlops_gemm.K1PerXdlops>::type;
1061
1062 constexpr index_t c_offset =
1063 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1064
1065 if constexpr(k0.value == KRepeat - 1 &&
1066 k_.value == KPerInnerLoop - KPack &&
1067 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
1068 {
1069 __builtin_amdgcn_sched_barrier(0);
1071 __builtin_amdgcn_sched_barrier(0);
1072 }
1073 xdlops_gemm.Run(
1074 a_thread_vec.template AsType<mfma_input_type>(),
1075 b_thread_vec.template AsType<mfma_input_type>(),
1076 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1077 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
1078 {
1079 __builtin_amdgcn_sched_barrier(0);
1080 __builtin_amdgcn_s_setprio(1);
1081 __builtin_amdgcn_sched_barrier(0);
1082 }
1083 });
1084
1085 // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
1086 // constexpr index_t c_offset =
1087 // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
1088 // c_thread_buf(Number<c_offset>{}) +=
1089 // c_thread_buf_per_scale[Number<t>{}] *
1090 // type_convert<AccDataType>(b_scale_thread_buf[n0]);
1091 // });
1092 });
1093 });
1094 __builtin_amdgcn_sched_barrier(0);
1095 __builtin_amdgcn_s_setprio(0);
1096 __builtin_amdgcn_sched_barrier(0);
1097 });
1098 };
1099
1100 if constexpr(TailNum == TailNumber::One)
1101 {
1103 static_for<0, KRepeat, 1>{}([&](auto k0) {
1104 static_for<0, MRepeat, 1>{}([&](auto m0) {
1107 a_block_buf,
1109 make_tuple(m0, I0, k0, I0),
1110 a_thread_buf);
1111 static_for<0, NRepeat, 1>{}([&](auto n0) {
1114 b_block_buf,
1116 make_tuple(n0, I0, k0, I0),
1117 b_thread_buf);
1118 });
1119 });
1120
1121 __builtin_amdgcn_sched_barrier(0);
1122 if constexpr(k0.value != 0 || KRepeat == 1)
1123 {
1124 __builtin_amdgcn_s_barrier();
1125 __builtin_amdgcn_sched_barrier(0);
1126 }
1128 static_for<0, MRepeat, 1>{}([&](auto m0) {
1129 static_for<0, NRepeat, 1>{}([&](auto n0) {
1132
1133 static_for<0, KPack, 1>{}([&](auto ik) {
1134 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1135 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1136 make_tuple(m0, I0, k0, k_ + ik))>{}];
1137 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1138 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
1139 make_tuple(n0, I0, k0, k_ + ik))>{}];
1140 });
1141
1142 using mfma_input_type =
1144 xdlops_gemm.K1PerXdlops>::type;
1145
1146 constexpr index_t c_offset =
1147 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1148
1149 if constexpr(k0.value == KRepeat - 1 &&
1150 k_.value == KPerInnerLoop - KPack &&
1151 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
1152 {
1153 __builtin_amdgcn_sched_barrier(0);
1155 __builtin_amdgcn_sched_barrier(0);
1156 }
1157 xdlops_gemm.Run(
1158 a_thread_vec.template AsType<mfma_input_type>(),
1159 b_thread_vec.template AsType<mfma_input_type>(),
1160 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1161 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
1162 {
1163 __builtin_amdgcn_sched_barrier(0);
1164 __builtin_amdgcn_s_setprio(1);
1165 __builtin_amdgcn_sched_barrier(0);
1166 }
1167 });
1168
1169 // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
1170 // constexpr index_t c_offset =
1171 // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
1172 // c_thread_buf(Number<c_offset>{}) +=
1173 // c_thread_buf_per_scale[Number<t>{}] *
1174 // type_convert<AccDataType>(b_scale_thread_buf[n0]);
1175 // });
1176 });
1177 });
1178 __builtin_amdgcn_sched_barrier(0);
1179 __builtin_amdgcn_s_setprio(0);
1180 __builtin_amdgcn_sched_barrier(0);
1181 });
1182 }
1183 else if constexpr(TailNum == TailNumber::Two)
1184 {
1185 LoopTailFunc(Number<2>{});
1186 }
1187 else if constexpr(TailNum == TailNumber::Three)
1188 {
1189 LoopTailFunc(Number<3>{});
1190 }
1191 else if constexpr(TailNum == TailNumber::Four)
1192 {
1193 LoopTailFunc(Number<4>{});
1194 }
1195 else if constexpr(TailNum == TailNumber::Five)
1196 {
1197 LoopTailFunc(Number<5>{});
1198 }
1199 else if constexpr(TailNum == TailNumber::Six)
1200 {
1201 LoopTailFunc(Number<6>{});
1202 }
1203 else if constexpr(TailNum == TailNumber::Seven)
1204 {
1205 LoopTailFunc(Number<7>{});
1206 }
1207 else if constexpr(TailNum == TailNumber::Full)
1208 {
1209 LoopTailFunc(Number<PrefetchStages>{});
1210 }
1211 }
1212
1213 protected:
1214 // K->M loopover
1220 I1));
1221
1227 I1));
1228
1231 decltype(a_block_desc_m0_m1_m2_k),
1232 decltype(a_thread_desc_),
1235 3,
1236 A_K1,
1237 A_K1>;
1238
1241 decltype(b_block_desc_n0_n1_n2_k),
1242 decltype(b_thread_desc_),
1245 3,
1246 B_K1,
1247 B_K1>;
1248
1251 using Base::c_thread_desc_;
1252};
1253
1254} // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition ck.hpp:209
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops_base.hpp:54
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition blockwise_gemm_pipeline_xdlops_base.hpp:57
static __device__ auto CalculateBThreadOriginDataIndex()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:147
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:360
static constexpr auto I1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static constexpr index_t AMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:60
static __device__ auto CalculateAThreadOriginDataIndex()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:125
static constexpr index_t WaveSize
Definition blockwise_gemm_pipeline_xdlops_base.hpp:46
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
static constexpr index_t B_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:51
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
static constexpr index_t A_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:50
static constexpr index_t BMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:61
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
static constexpr index_t KPerThread
Definition blockwise_gemm_pipeline_xdlops_base.hpp:63
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp:102
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp:215
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp:592
ThreadwiseTensorSliceTransfer_v4< BDataType, ComputeDataTypeBuf, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPerInnerLoop >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp:1239
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataTypeBuf, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPerInnerLoop >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp:1229
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, const BScaleGridDesc &b_scale_grid_desc, const BScaleThreadDesc &b_scale_thread_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, const BScaleThreadTransferStep &b_scale_thread_copy_step, index_t num_loop, index_t num_loop_per_scale) const
Definition blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp:714
Definition blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp:37
Definition utility/sequence.hpp:43
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10