multi_index_transform.hpp Source File

multi_index_transform.hpp Source File#

Composable Kernel: multi_index_transform.hpp Source File
multi_index_transform.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8
9namespace ck {
10
11template <typename LowLength>
13{
16
17 using UpLengths = decltype(make_tuple(LowLength{}));
18
20
21 __host__ __device__ constexpr PassThrough() = default;
22
23 __host__ __device__ constexpr PassThrough(const LowLength& low_length)
24 : up_lengths_{make_tuple(low_length)}
25 {
26 }
27
28 __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
29
30 __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
31
32 __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
33
34 template <typename LowIdx, typename UpIdx>
35 __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low,
36 const UpIdx& idx_up)
37 {
38 static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
39 "wrong! inconsistent # of dimension");
40
41 idx_low(Number<0>{}) = idx_up[Number<0>{}];
42 }
43
44 template <typename LowIdxDiff,
45 typename UpIdxDiff,
46 typename LowIdx,
47 typename UpIdx,
48 index_t Hack>
49 __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
50 const UpIdxDiff& idx_diff_up,
51 LowIdx& idx_low,
52 const UpIdx&,
54 {
55 static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
56 UpIdx::Size() == 1,
57 "wrong! inconsistent # of dimension");
58
59 constexpr auto I0 = Number<0>{};
60
61 idx_diff_low(I0) = idx_diff_up[I0];
62
63 idx_low += idx_diff_low;
64 }
65
66 __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
67
68 __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
69 {
70 return true;
71 }
72
73 template <typename UpIdx>
74 __host__ __device__ static constexpr bool
76 {
77 return true;
78 }
79
80 __host__ __device__ static constexpr bool IsKnownAtCompileTime()
81 {
83 }
84
85 __host__ __device__ void Print() const
86 {
87 printf("{");
88 printf("PassThrough, ");
89 printf("up_lengths_");
91 printf("}");
92 }
93};
94
95template <typename LowLength,
96 typename LeftPadLength,
97 typename RightPadLength,
98 bool SkipIsValidCheck = false>
99struct Pad
100{
103
104 using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{} + RightPadLength{}));
105
107 LeftPadLength left_pad_length_;
108 RightPadLength right_pad_length_;
109
110 __host__ __device__ constexpr Pad() = default;
111
112 __host__ __device__ constexpr Pad(const LowLength& low_length,
113 const LeftPadLength& left_pad_length,
114 const RightPadLength& right_pad_length)
115 : up_lengths_{make_tuple(low_length + left_pad_length + right_pad_length)},
116 left_pad_length_{left_pad_length},
117 right_pad_length_{right_pad_length}
118 {
119 }
120
121 __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
122
123 __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
124
125 __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
126
127 template <typename LowIdx, typename UpIdx>
128 __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
129 const UpIdx& idx_up) const
130 {
131 static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
132 "wrong! inconsistent # of dimension");
133
134 idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_;
135 }
136
137 template <typename LowIdxDiff,
138 typename UpIdxDiff,
139 typename LowIdx,
140 typename UpIdx,
141 index_t Hack>
142 __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
143 const UpIdxDiff& idx_diff_up,
144 LowIdx& idx_low,
145 const UpIdx&,
147 {
148 static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
149 UpIdx::Size() == 1,
150 "wrong! inconsistent # of dimension");
151
152 constexpr auto I0 = Number<0>{};
153
154 idx_diff_low(I0) = idx_diff_up[I0];
155
156 idx_low += idx_diff_low;
157 }
158
159 __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
160
161 __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
162 {
163 return SkipIsValidCheck;
164 }
165
166 template <typename UpIdx>
167 __host__ __device__ constexpr bool
169 {
170 return SkipIsValidCheck ||
171 ((idx_up[Number<0>{}] >= left_pad_length_) &&
173 }
174
181
182 __host__ __device__ void Print() const
183 {
184 printf("{");
185 printf("Pad, ");
186 printf("up_lengths_");
188 printf("left_pad_length %d", index_t{left_pad_length_});
189 printf("right_pad_length %d", index_t{right_pad_length_});
190 printf("}");
191 }
192};
193
194template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
196{
199
200 using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{}));
201
203 LeftPadLength left_pad_length_;
204
205 __host__ __device__ constexpr LeftPad() = default;
206
207 __host__ __device__ constexpr LeftPad(const LowLength& low_length,
208 const LeftPadLength& left_pad_length)
209 : up_lengths_{make_tuple(low_length + left_pad_length)}, left_pad_length_{left_pad_length}
210 {
211 }
212
213 __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
214
215 __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
216
217 __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
218
219 template <typename LowIdx, typename UpIdx>
220 __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
221 const UpIdx& idx_up) const
222 {
223 static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
224 "wrong! inconsistent # of dimension");
225
226 idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_;
227 }
228
229 template <typename LowIdxDiff,
230 typename UpIdxDiff,
231 typename LowIdx,
232 typename UpIdx,
233 index_t Hack>
234 __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
235 const UpIdxDiff& idx_diff_up,
236 LowIdx& idx_low,
237 const UpIdx&,
239 {
240 static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
241 UpIdx::Size() == 1,
242 "wrong! inconsistent # of dimension");
243
244 constexpr auto I0 = Number<0>{};
245
246 idx_diff_low(I0) = idx_diff_up[I0];
247
248 idx_low += idx_diff_low;
249 }
250
251 __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
252
253 __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
254 {
255 return SkipIsValidCheck;
256 }
257
258 template <typename UpIdx>
259 __host__ __device__ constexpr bool
261 {
262 return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_length_);
263 }
264
265 __host__ __device__ static constexpr bool IsKnownAtCompileTime()
266 {
269 }
270
271 __host__ __device__ void Print() const
272 {
273 printf("{");
274 printf("LeftPad, ");
275 printf("up_lengths_");
277 printf("left_pad_length_ %d", index_t{left_pad_length_});
278 printf("}");
279 }
280};
281
282template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
284{
287
288 using UpLengths = decltype(make_tuple(LowLength{} + RightPadLength{}));
289
291 LowLength low_length_;
292 RightPadLength right_pad_length_;
293
294 __host__ __device__ constexpr RightPad() = default;
295
296 __host__ __device__ constexpr RightPad(const LowLength& low_length,
297 const RightPadLength& right_pad_length)
298 : up_lengths_{make_tuple(low_length + right_pad_length)},
299 low_length_{low_length},
300 right_pad_length_{right_pad_length}
301 {
302 }
303
304 __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
305
306 __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
307
308 __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
309
310 template <typename LowIdx, typename UpIdx>
311 __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low,
312 const UpIdx& idx_up)
313 {
314 static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
315 "wrong! inconsistent # of dimension");
316
317 idx_low(Number<0>{}) = idx_up[Number<0>{}];
318 }
319
320 template <typename LowIdxDiff,
321 typename UpIdxDiff,
322 typename LowIdx,
323 typename UpIdx,
324 index_t Hack>
325 __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
326 const UpIdxDiff& idx_diff_up,
327 LowIdx& idx_low,
328 const UpIdx&,
330 {
331 static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
332 UpIdx::Size() == 1,
333 "wrong! inconsistent # of dimension");
334
335 constexpr auto I0 = Number<0>{};
336
337 idx_diff_low(I0) = idx_diff_up[I0];
338
339 idx_low += idx_diff_low;
340 }
341
342 __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
343
344 __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
345 {
346 return SkipIsValidCheck;
347 }
348
349 template <typename UpIdx>
350 __host__ __device__ constexpr bool
352 {
353 return SkipIsValidCheck || (idx_up[Number<0>{}] < low_length_);
354 }
355
362
363 __host__ __device__ void Print() const
364 {
365 printf("{");
366 printf("RightPad, ");
367 printf("up_lengths_");
369 printf("low_length_ %d", index_t{low_length_});
370 printf("left_pad_length_ %d", index_t{right_pad_length_});
371 printf("}");
372 }
373};
374
375// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
376// UpLengths and Coefficients can be either of the followings:
377// 1) Tuple of index_t, which is known at run-time, or
378// 2) Tuple of Number, which is known at compile-time, or
379// 3) Tuple of mixture of index_t and Number, which is known partially at run-time and partially
380// at compile-time
381template <typename UpLengths,
382 typename Coefficients,
383 typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
384struct Embed
385{
386 static constexpr index_t NDimUp = UpLengths::Size();
387
390
391 UpLengths up_lengths_;
392 Coefficients coefficients_;
393
394 __host__ __device__ constexpr Embed() = default;
395
396 __host__ __device__ constexpr Embed(const UpLengths& up_lengths,
397 const Coefficients& coefficients)
398 : up_lengths_{up_lengths}, coefficients_{coefficients}
399 {
400 }
401
402 __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
403
404 __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; }
405
406 __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
407
408 template <typename LowIdx, typename UpIdx>
409 __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
410 const UpIdx& idx_up) const
411 {
412 static_assert(LowIdx::Size() == 1 && UpIdx::Size() == NDimUp,
413 "wrong! inconsistent # of dimension");
414
415 idx_low(Number<0>{}) = 0;
416
417 static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) {
418 idx_low(Number<0>{}) += idx_up[i] * this->coefficients_[i];
419 });
420 }
421
422 template <typename LowIdxDiff,
423 typename UpIdxDiff,
424 typename LowIdx,
425 typename UpIdx,
426 index_t Hack>
427 __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
428 const UpIdxDiff& idx_diff_up,
429 LowIdx& idx_low,
430 const UpIdx&,
431 Number<Hack>) const
432 {
433 static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == NDimUp &&
434 LowIdx::Size() == 1 && UpIdx::Size() == NDimUp,
435 "wrong! inconsistent # of dimension");
436
437 idx_diff_low(Number<0>{}) = 0;
438
440 [&](auto i) { idx_diff_low(Number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
441
442 idx_low += idx_diff_low;
443 }
444
445 __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
446
447 __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
448 {
449 return true;
450 }
451
452 template <typename UpIdx>
453 __host__ __device__ static constexpr bool
455 {
456 return true;
457 }
458
459 __host__ __device__ static constexpr bool IsKnownAtCompileTime()
460 {
463 }
464
465 __host__ __device__ void Print() const
466 {
467 printf("{");
468 printf("Embed, ");
469 printf("up_lengths_ ");
471 printf("coefficients_ ");
473 printf("}");
474 }
475};
476
477// Implementation of "Merge" transformation primitive that uses regular to do lowering of
478// multi-index and use carry-and-borrow check to do lowering of multi-index delta
479template <typename LowLengths>
481{
482 static constexpr index_t NDimLow = LowLengths::Size();
483
486
488 decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
489
490 using UpLengths =
491 decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
492
493 LowLengths low_lengths_;
496
497 __host__ __device__ constexpr Merge_v1_carry_check() = default;
498
499 __host__ __device__ constexpr Merge_v1_carry_check(const LowLengths& low_lengths)
500 : low_lengths_{low_lengths},
502 container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
503 up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
504 {
505 static_assert(LowerIndex::Size() == NDimLow, "wrong!");
506 }
507
508 __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
509
510 __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
511
512 __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
513
514 template <typename LowIdx, typename UpIdx>
515 __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
516 const UpIdx& idx_up) const
517 {
518 static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
519 "wrong! inconsistent # of dimension");
520
521 index_t tmp = idx_up[Number<0>{}];
522
523 // normal division
524 static_for<0, NDimLow - 1, 1>{}([&](auto i) {
525 idx_low(i) = tmp / this->low_lengths_scan_[i];
526 tmp -= idx_low[i] * this->low_lengths_scan_[i];
527 });
528
529 idx_low(Number<NDimLow - 1>{}) = tmp;
530 }
531
532 template <typename LowIdxDiff,
533 typename UpIdxDiff,
534 typename LowIdx,
535 typename UpIdx,
536 index_t Hack>
537 __host__ __device__ void UpdateLowerIndex_1a(LowIdxDiff& idx_diff_low,
538 const UpIdxDiff& idx_diff_up,
539 LowIdx& idx_low,
540 const UpIdx& /* idx_up_new */,
541 Number<Hack>) const
542 {
543 static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
544 LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
545 "wrong! inconsistent # of dimension");
546
547 // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
548 // However,
549 // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
550 // can be calculated at compile-time.
551 // 2) If idx_diff_up is not known at compile-time, but its value
552 // doesn't change during the whole kernel execution, then
553 // idx_diff_low_const also
554 // doesn't change during the whole kernel execution. Compiler generated
555 // ISA should
556 // only caclculate idx_diff_low_const once and save it durinng the whole
557 // kernel execution
558 // If neither 1) nor 2) is satisfied, then the calculation will also be
559 // computed at
560 // run-time each time this function is called, and can be very expensive.
561 LowerIndex idx_diff_low_const;
562 LowerIndex idx_low_length_minus_idx_diff_low_const;
563 LowerIndex idx_low_length_plus_idx_diff_low_const;
564
565#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
566 index_t tmp = idx_diff_up[Number<0>{}];
567
568 static_for<0, NDimLow - 1, 1>{}([&](auto i) {
569 idx_diff_low_const(i) = tmp / low_lengths_scan_[i];
570 tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
571 });
572
573 idx_diff_low_const(Number<NDimLow - 1>{}) = tmp;
574
575 static_for<0, NDimLow, 1>{}([&](auto i) {
576 idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i];
577
578 idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i];
579 });
580#else
581 // Hack: this force result into SGPR. Need to make sure the result is thread invariant
582 index_t tmp = idx_diff_up[Number<0>{}];
583
584 static_for<0, NDimLow - 1, 1>{}([&](auto i) {
585 idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]);
586 tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
587 });
588
589 idx_diff_low_const(Number<NDimLow - 1>{}) = __builtin_amdgcn_readfirstlane(tmp);
590
591 static_for<0, NDimLow, 1>{}([&](auto i) {
592 idx_low_length_minus_idx_diff_low_const(i) =
593 __builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]);
594
595 idx_low_length_plus_idx_diff_low_const(i) =
596 __builtin_amdgcn_readfirstlane(low_lengths_[i] + idx_diff_low_const[i]);
597 });
598#endif
599
600 if constexpr(Hack == 1)
601 {
602 // do carry check on each low dimension in reversed order
603 // do not need to check the first dimension
604 index_t carry = 0;
605
606 static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
607 index_t idx_low_tmp = idx_low[i] + carry;
608
609 bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
610
611 idx_diff_low(i) =
612 do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
613
614 idx_diff_low(i) += carry;
615
616 carry = do_carry ? 1 : 0;
617 });
618
619 idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
620
621 idx_low += idx_diff_low;
622 }
623 else if constexpr(Hack == 2)
624 {
625 // do carry check on each low dimension in reversed order
626 // do not need to check the first dimension
627 index_t borrow = 0;
628
629 static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
630 index_t idx_low_tmp = idx_low[i] - borrow;
631
632 bool do_borrow = idx_low_tmp < -idx_diff_low_const[i];
633
634 idx_diff_low(i) =
635 do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i];
636
637 idx_diff_low(i) -= borrow;
638
639 borrow = do_borrow ? 1 : 0;
640 });
641
642 idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow;
643
644 idx_low += idx_diff_low;
645 }
646 else
647 {
648 // do carry check on each low dimension in reversed order
649 // do not need to check the first dimension
650 index_t carry = 0;
651
652 static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
653 index_t idx_low_tmp = idx_low[i] + carry;
654
655 bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
656 bool do_borrow = idx_low_tmp < -idx_diff_low_const[i];
657
658 idx_diff_low(i) =
659 do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
660 idx_diff_low(i) =
661 do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i];
662
663 idx_diff_low(i) += carry;
664
665 carry = do_carry ? 1 : 0;
666 carry = do_borrow ? -1 : carry;
667 });
668
669 idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
670
671 idx_low += idx_diff_low;
672 }
673 }
674
675 template <typename LowIdxDiff,
676 typename UpIdxDiff,
677 typename LowIdx,
678 typename UpIdx,
679 index_t Hack>
680 __host__ __device__ void UpdateLowerIndex_1b(LowIdxDiff& idx_diff_low,
681 const UpIdxDiff& idx_diff_up,
682 LowIdx& idx_low,
683 const UpIdx& /* idx_up_new */,
684 Number<Hack>) const
685 {
686 static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
687 LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
688 "wrong! inconsistent # of dimension");
689
690 // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
691 // However,
692 // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
693 // can be calculated at compile-time.
694 // 2) If idx_diff_up is not known at compile-time, but its value
695 // doesn't change during the whole kernel execution, then
696 // idx_diff_low_const also
697 // doesn't change during the whole kernel execution. Compiler generated
698 // ISA should
699 // only caclculate idx_diff_low_const once and save it durinng the whole
700 // kernel execution
701 // If neither 1) nor 2) is satisfied, then the calculation will also be
702 // computed at
703 // run-time each time this function is called, and can be very expensive.
704 LowerIndex idx_diff_low_const;
705 LowerIndex idx_low_length_minus_idx_diff_low_const;
706 LowerIndex idx_low_length_plus_idx_diff_low_const;
707
708#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
709 index_t tmp = idx_diff_up[Number<0>{}];
710
711 static_for<0, NDimLow - 1, 1>{}([&](auto i) {
712 idx_diff_low_const(i) = tmp / low_lengths_scan_[i];
713 tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
714 });
715
716 idx_diff_low_const(Number<NDimLow - 1>{}) = tmp;
717
718 static_for<0, NDimLow, 1>{}([&](auto i) {
719 idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i];
720
721 idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i];
722 });
723#else
724 // Hack: this force result into SGPR. Need to make sure the result is thread invariant
725 index_t tmp = idx_diff_up[Number<0>{}];
726
727 static_for<0, NDimLow - 1, 1>{}([&](auto i) {
728 idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]);
729 tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
730 });
731
732 idx_diff_low_const(Number<NDimLow - 1>{}) = __builtin_amdgcn_readfirstlane(tmp);
733
734 static_for<0, NDimLow, 1>{}([&](auto i) {
735 idx_low_length_minus_idx_diff_low_const(i) =
736 __builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]);
737
738 idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i];
739 });
740#endif
741
742 if constexpr(Hack == 1)
743 {
744 // do carry check on each low dimension in reversed order
745 // do not need to check the first dimension
746 index_t carry = 0;
747
748 static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
749 index_t idx_low_tmp = idx_low[i] + carry;
750
751 bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
752
753 idx_diff_low(i) =
754 do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
755
756 idx_diff_low(i) += carry;
757
758 carry = do_carry ? 1 : 0;
759 });
760
761 idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
762
763 idx_low += idx_diff_low;
764 }
765 else if constexpr(Hack == 2)
766 {
767 // do carry check on each low dimension in reversed order
768 // do not need to check the first dimension
769 index_t borrow = 0;
770
771 static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
772 index_t negative_idx_low_tmp = borrow - idx_low[i];
773
774 bool do_borrow = negative_idx_low_tmp > idx_diff_low_const[i];
775
776 idx_diff_low(i) =
777 do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i];
778
779 idx_diff_low(i) -= borrow;
780
781 borrow = do_borrow ? 1 : 0;
782 });
783
784 idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow;
785
786 idx_low += idx_diff_low;
787 }
788 else
789 {
790 // do carry check on each low dimension in reversed order
791 // do not need to check the first dimension
792 index_t carry = 0;
793
794 static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
795 index_t idx_low_tmp = idx_low[i] + carry;
796
797 bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i];
798 bool do_borrow = idx_low_tmp < -idx_diff_low_const[i];
799
800 idx_diff_low(i) =
801 do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i];
802 idx_diff_low(i) =
803 do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i];
804
805 idx_diff_low(i) += carry;
806
807 carry = do_carry ? 1 : 0;
808 carry = do_borrow ? -1 : carry;
809 });
810
811 idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry;
812
813 idx_low += idx_diff_low;
814 }
815 }
816
817 template <typename LowIdxDiff,
818 typename UpIdxDiff,
819 typename LowIdx,
820 typename UpIdx,
821 index_t Hack>
822 __host__ __device__ void UpdateLowerIndex_2(LowIdxDiff& idx_diff_low,
823 const UpIdxDiff& idx_diff_up,
824 LowIdx& idx_low,
825 const UpIdx& /* idx_up_new */,
826 Number<Hack>) const
827 {
828 static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
829 LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
830 "wrong! inconsistent # of dimension");
831
832 // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
833 // However,
834 // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
835 // can be calculated at compile-time.
836 // 2) If idx_diff_up is not known at compile-time, but its value
837 // doesn't change during the whole kernel execution, then
838 // idx_diff_low_const also
839 // doesn't change during the whole kernel execution. Compiler generated
840 // ISA should
841 // only caclculate idx_diff_low_const once and save it durinng the whole
842 // kernel execution
843 // If neither 1) nor 2) is satisfied, then the calculation will also be
844 // computed at run-time each time this function is called, and can be
845 // very expensive.
846 LowerIndex idx_diff_low_const;
847
848#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
849 index_t tmp = idx_diff_up[Number<0>{}];
850
851 static_for<0, NDimLow - 1, 1>{}([&](auto i) {
852 idx_diff_low_const(i) = tmp / low_lengths_scan_[i];
853 tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
854 });
855
856 idx_diff_low_const(Number<NDimLow - 1>{}) = tmp;
857#else
858 // Hack: this force result into SGPR. Need to make sure the result is thread invariant
859 index_t tmp = idx_diff_up[Number<0>{}];
860
861 static_for<0, NDimLow - 1, 1>{}([&](auto i) {
862 idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]);
863 tmp -= idx_diff_low_const[i] * low_lengths_scan_[i];
864 });
865
866 idx_diff_low_const(Number<NDimLow - 1>{}) = __builtin_amdgcn_readfirstlane(tmp);
867#endif
868
869 if constexpr(Hack == 1)
870 {
871 // do carry check on each low dimension in reversed order
872 // do not need to check the first dimension
873 bool do_carry = 0;
874
875 static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
876 idx_diff_low(i) = idx_diff_low_const[i] + do_carry;
877
878 index_t idx_low_tmp = idx_low[i] + idx_diff_low[i];
879
880 do_carry = idx_low_tmp >= low_lengths_[i];
881
882#if 0
883 // TODO: use exec-mask inline asm, which use 1 VALU
884 if(do_carry)
885 {
886 idx_diff_low(i) -= low_lengths_[i];
887 }
888#elif 1
889 // this use 2 VALU
890 idx_diff_low(i) = do_carry ? idx_diff_low[i] - low_lengths_[i] : idx_diff_low[i];
891#elif 1
892 // this use 2 VALU
893 index_t idx_diff_low_tmp = idx_diff_low[i] - low_lengths_[i];
894 idx_diff_low(i) = do_carry ? idx_diff_low_tmp : idx_diff_low[i];
895#endif
896
897 idx_low(i) += idx_diff_low[i];
898 });
899
900 constexpr auto I0 = Number<0>{};
901
902 idx_diff_low(I0) = idx_diff_low_const[I0] + do_carry;
903
904 idx_low(I0) += idx_diff_low[I0];
905 }
906 else if constexpr(Hack == 2)
907 {
908 // do borrow check on each low dimension in reversed order
909 // do not need to check the first dimension
910 bool do_borrow = 0;
911
912 static_for<NDimLow - 1, 0, -1>{}([&](auto i) {
913 idx_diff_low(i) = idx_diff_low_const[i] - do_borrow;
914
915 index_t idx_low_tmp = idx_low[i] + idx_diff_low[i];
916
917 do_borrow = idx_low_tmp < 0;
918
919#if 0
920 // TODO: use exec-mask inline asm
921 if(do_borrow)
922 {
923 idx_diff_low(i) += low_lengths_[i];
924 }
925#elif 1
926 idx_diff_low(i) = do_borrow ? idx_diff_low[i] + low_lengths_[i] : idx_diff_low[i];
927#elif 1
928 index_t idx_diff_low_tmp = idx_diff_low[i] + low_lengths_[i];
929 idx_diff_low(i) = do_borrow ? idx_diff_low_tmp : idx_diff_low[i];
930#endif
931
932 idx_low(i) += idx_diff_low[i];
933 });
934
935 constexpr auto I0 = Number<0>{};
936
937 idx_diff_low(I0) = idx_diff_low_const[I0] - do_borrow;
938
939 idx_low(I0) += idx_diff_low[I0];
940 }
941 else
942 {
943 // not implemented
944 }
945 }
946
947 template <typename LowIdxDiff,
948 typename UpIdxDiff,
949 typename LowIdx,
950 typename UpIdx,
951 index_t Hack>
952 __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
953 const UpIdxDiff& idx_diff_up,
954 LowIdx& idx_low,
955 const UpIdx& idx_up_new,
956 Number<Hack>) const
957 {
958#if 1
959 UpdateLowerIndex_1a(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
960#elif 0
961 UpdateLowerIndex_1b(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
962#else
963 UpdateLowerIndex_2(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
964#endif
965 }
966
967 __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
968
969 __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
970 {
971 return true;
972 }
973
980
981 template <typename UpIdx>
982 __host__ __device__ static constexpr bool
984 {
985 return true;
986 }
987
988 __host__ __device__ void Print() const
989 {
990 printf("{");
991 printf("Merge_v1_carry_check, ");
992 printf("low_lengths_ ");
994 printf("low_lengths_scan_ ");
996 printf("up_lengths_ ");
998 printf("}");
999 }
1000};
1001
1002template <typename LowLengths>
1004{
1005 template <index_t I>
1006 __host__ __device__ constexpr auto operator()(Number<I> i) const
1007 {
1008 return MagicDivision::CalculateMagicMultiplier(LowLengths{}[i]);
1009 }
1010};
1011
1012template <typename LowLengths>
1014{
1015 template <index_t I>
1016 __host__ __device__ constexpr auto operator()(Number<I> i) const
1017 {
1018 return MagicDivision::CalculateMagicShift(LowLengths{}[i]);
1019 }
1020};
1021
1022// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
1023// of both multi-index and delta of multi-index
1024// Caution:
1025// 1. The magic number division implementation being used would produce correct result if the
1026// dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
1027// 2. The magic number division for int32_t dividened has not been implemented, the int32_t
1028// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
1029// uint32_t is then used.
1030// 3. For Merge primitive, upper-index is the dividend.
1031// 4. When upper-index is uint32_t, its value need to be within 31-bit range.
1032// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
1033// non-negative.
1034template <typename LowLengths>
1036{
1037 static constexpr index_t NDimLow = LowLengths::Size();
1038
1041
1043 decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
1044
1047 Number<NDimLow>{}));
1048
1051 Number<NDimLow>{}));
1052
1053 LowLengths low_lengths_;
1057
1058 __host__ __device__ constexpr Merge_v2_magic_division() = default;
1059
1060 __host__ __device__ constexpr Merge_v2_magic_division(const LowLengths& low_lengths)
1061 : low_lengths_{low_lengths},
1063 [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); },
1064 Number<NDimLow>{})},
1065 low_lengths_magic_divisor_shift_{generate_tuple(
1066 [&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths[i]); },
1067 Number<NDimLow>{})},
1068 up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
1069 {
1070 static_assert(LowerIndex::Size() == NDimLow, "wrong!");
1071 }
1072
1073 __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
1074
1075 __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1076
1077 __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1078
1079 template <typename LowIdx, typename UpIdx>
1080 __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1081 const UpIdx& idx_up) const
1082 {
1083 static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1084 "wrong! inconsistent # of dimension");
1085
1086 index_t tmp = idx_up[Number<0>{}];
1087
1088 static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
1089 index_t tmp2 =
1093 idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
1094 tmp = tmp2;
1095 });
1096
1097 idx_low(Number<0>{}) = tmp;
1098 }
1099
1100 template <typename LowIdxDiff,
1101 typename UpIdxDiff,
1102 typename LowIdx,
1103 typename UpIdx,
1104 index_t Hack>
1105 __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1106 const UpIdxDiff&,
1107 LowIdx& idx_low,
1108 const UpIdx& idx_up_new,
1109 Number<Hack>) const
1110 {
1111 static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
1112 LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1113 "wrong! inconsistent # of dimension");
1114
1115 index_t tmp = idx_up_new[Number<0>{}];
1116
1117 static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
1118 index_t tmp2 =
1122
1123 index_t idx_low_old = idx_low[i];
1124
1125 idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
1126 tmp = tmp2;
1127
1128 idx_diff_low(i) = idx_low[i] - idx_low_old;
1129 });
1130
1131 idx_diff_low(Number<0>{}) = tmp - idx_low(Number<0>{});
1132
1133 idx_low(Number<0>{}) = tmp;
1134 }
1135
1136 __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
1137
1138 __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1139 {
1140 return true;
1141 }
1142
1150
1151 template <typename UpIdx>
1152 __host__ __device__ static constexpr bool
1154 {
1155 return true;
1156 }
1157
1158 __host__ __device__ void Print() const
1159 {
1160 printf("{");
1161 printf("Merge_v2_magic_division, ");
1162 printf("low_lengths_ ");
1164 printf("low_lengths_magic_divisor_multiplier_ ");
1166 printf("low_lengths_magic_divisor_shift_ ");
1168 printf("up_lengths_ ");
1170 printf("}");
1171 }
1172};
1173
1174// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
1175// of both multi-index and delta of multi-index
1176// Caution:
1177// 1. The magic number division implementation being used would produce correct result if the
1178// dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
1179// 2. The magic number division for int32_t dividened has not been implemented, the int32_t
1180// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
1181// uint32_t is then used.
1182// 3. For Merge primitive, upper-index is the dividend.
1183// 4. When upper-index is uint32_t, its value need to be within 31-bit range.
1184// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
1185// non-negative.
1186template <typename LowLengths>
1188{
1189 static constexpr index_t NDimLow = LowLengths::Size();
1190
1193
1195 decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
1196
1198 decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
1199
1202 Number<NDimLow>{}));
1203
1206 Number<NDimLow>{}));
1207
1208 LowLengths low_lengths_;
1213
1214 __host__ __device__ constexpr Merge_v2r2_magic_division() = default;
1215
1216 __host__ __device__ constexpr Merge_v2r2_magic_division(const LowLengths& low_lengths)
1217 : low_lengths_{low_lengths},
1219 container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
1220 low_lengths_scan_magic_divisor_multiplier_{generate_tuple(
1221 [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths_scan_[i]); },
1222 Number<NDimLow>{})},
1223 low_lengths_scan_magic_divisor_shift_{generate_tuple(
1224 [&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths_scan_[i]); },
1225 Number<NDimLow>{})},
1226 up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
1227 {
1228 static_assert(LowerIndex::Size() == NDimLow, "wrong!");
1229 }
1230
1231 __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
1232
1233 __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1234
1235 __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1236
1237 template <typename LowIdx, typename UpIdx>
1238 __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1239 const UpIdx& idx_up) const
1240 {
1241 static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1242 "wrong! inconsistent # of dimension");
1243
1244 index_t tmp = idx_up[Number<0>{}];
1245
1246 static_for<0, NDimLow - 1, 1>{}([&, this](auto i) {
1247 idx_low(i) =
1251
1252 tmp -= idx_low[i] * this->low_lengths_scan_[i];
1253 });
1254
1255 idx_low(Number<NDimLow - 1>{}) = tmp;
1256 }
1257
1258 template <typename LowIdxDiff,
1259 typename UpIdxDiff,
1260 typename LowIdx,
1261 typename UpIdx,
1262 index_t Hack>
1263 __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1264 const UpIdxDiff&,
1265 LowIdx& idx_low,
1266 const UpIdx& idx_up_new,
1267 Number<Hack>) const
1268 {
1269 static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
1270 LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1271 "wrong! inconsistent # of dimension");
1272
1273 index_t tmp = idx_up_new[Number<0>{}];
1274
1275 static_for<0, NDimLow - 1, 1>{}([&, this](auto i) {
1276 index_t idx_low_old = idx_low[i];
1277
1278 idx_low(i) =
1282
1283 idx_diff_low(i) = idx_low[i] - idx_low_old;
1284
1285 tmp -= idx_low[i] * this->low_lengths_scan_[i];
1286 });
1287
1288 idx_diff_low(Number<NDimLow - 1>{}) = tmp - idx_low[Number<NDimLow - 1>{}];
1289
1290 idx_low(Number<NDimLow - 1>{}) = tmp;
1291 }
1292
1293 __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
1294
1295 __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1296 {
1297 return true;
1298 }
1299
1307
1308 template <typename UpIdx>
1309 __host__ __device__ static constexpr bool
1311 {
1312 return true;
1313 }
1314
1315 __host__ __device__ void Print() const
1316 {
1317 printf("{");
1318 printf("Merge_v2r2_magic_division, ");
1319 printf("low_lengths_ ");
1321 printf("low_lengths_scan ");
1323 printf("low_lengths_scan_magic_divisor_multiplier_ ");
1325 printf("low_lengths_scan_magic_divisor_shift_ ");
1327 printf("up_lengths_ ");
1329 printf("}");
1330 }
1331};
1332
1333// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to
1334// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
1335// will be very bad
1336template <typename LowLengths>
1338{
1339 static constexpr index_t NDimLow = LowLengths::Size();
1340
1343
1345 decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
1346
1348 decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
1349
1350 LowLengths low_lengths_;
1353
1354 __host__ __device__ constexpr Merge_v3_division_mod() = default;
1355
1356 __host__ __device__ constexpr Merge_v3_division_mod(const LowLengths& low_lengths)
1357 : low_lengths_{low_lengths},
1359 container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
1360 up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
1361 {
1362 static_assert(LowerIndex::Size() == NDimLow, "wrong!");
1363 }
1364
1365 __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }
1366
1367 __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1368
1369 __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1370
1371 template <typename LowIdx, typename UpIdx>
1372 __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1373 const UpIdx& idx_up) const
1374 {
1375 static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1376 "wrong! inconsistent # of dimension");
1377
1378 index_t tmp = idx_up[Number<0>{}];
1379
1380 // division and mod
1381 static_for<0, NDimLow - 1, 1>{}([&](auto i) {
1382 idx_low(i) = tmp / this->low_lengths_scan_[i];
1383 tmp %= this->low_lengths_scan_[i];
1384 });
1385
1386 idx_low(Number<NDimLow - 1>{}) = tmp;
1387 }
1388
1389 template <typename LowIdxDiff,
1390 typename UpIdxDiff,
1391 typename LowIdx,
1392 typename UpIdx,
1393 index_t Hack>
1394 __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1395 const UpIdxDiff&,
1396 LowIdx& idx_low,
1397 const UpIdx& idx_up_new,
1398 Number<Hack>) const
1399 {
1400 static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
1401 LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
1402 "wrong! inconsistent # of dimension");
1403
1404 constexpr auto I0 = Number<0>{};
1405 constexpr auto INm1 = Number<NDimLow - 1>{};
1406
1407 index_t tmp = idx_up_new[I0];
1408
1409 static_for<0, NDimLow - 1, 1>{}([&](auto i) {
1410 const index_t tmp2 = idx_low[i];
1411 idx_low(i) = tmp / this->low_lengths_scan_[i];
1412 idx_diff_low(i) = idx_low[i] - tmp2;
1413 tmp %= this->low_lengths_scan_[i];
1414 });
1415
1416 const index_t tmp2 = idx_low[INm1];
1417 idx_low(INm1) = tmp;
1418 idx_diff_low(INm1) = idx_low[INm1] - tmp2;
1419 }
1420
1421 __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
1422
1423 __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1424 {
1425 return true;
1426 }
1427
1434
1435 template <typename UpIdx>
1436 __host__ __device__ static constexpr bool
1438 {
1439 return true;
1440 }
1441
1442 __host__ __device__ void Print() const
1443 {
1444 printf("{");
1445 printf("Merge_v3_direct_division_mod, ");
1446 printf("low_lengths_ ");
1448 printf("low_lengths_scan_ ");
1450 printf("up_lengths_ ");
1452 printf("}");
1453 }
1454};
1455
1456template <typename UpLengths, bool Use24BitIntegerCalculation>
1458{
1459 static constexpr index_t NDimUp = UpLengths::Size();
1460
1463
1465 decltype(container_reverse_exclusive_scan(UpLengths{}, math::multiplies{}, Number<1>{}));
1466
1467 UpLengths up_lengths_;
1469
1470 __host__ __device__ constexpr UnMerge() = default;
1471
1472 __host__ __device__ constexpr UnMerge(const UpLengths& up_lengths)
1473 : up_lengths_{up_lengths},
1475 container_reverse_exclusive_scan(up_lengths, math::multiplies{}, Number<1>{})}
1476 {
1477 }
1478
1479 __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
1480
1481 __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; }
1482
1483 __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1484
1485 template <typename LowIdx, typename UpIdx>
1486 __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1487 const UpIdx& idx_up) const
1488 {
1489 if constexpr(!Use24BitIntegerCalculation)
1490 {
1491 idx_low(Number<0>{}) = idx_up[Number<NDimUp - 1>{}];
1492
1493 static_for<0, NDimUp - 1, 1>{}(
1494 [&](auto i) { idx_low(Number<0>{}) += idx_up[i] * up_lengths_scan_[i]; });
1495 }
1496 else
1497 {
1498 idx_low(Number<0>{}) = idx_up[Number<NDimUp - 1>{}];
1499
1500 static_for<0, NDimUp - 1, 1>{}([&](auto i) {
1501 idx_low(Number<0>{}) =
1502 (0x00ffffff & idx_low[Number<0>{}]) +
1503 (0x00ffffff & idx_up[i]) * (0x00ffffff & up_lengths_scan_[i]);
1504 });
1505 }
1506 }
1507
1508 template <typename LowIdxDiff,
1509 typename UpIdxDiff,
1510 typename LowIdx,
1511 typename UpIdx,
1512 index_t Hack>
1513 __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1514 const UpIdxDiff& idx_diff_up,
1515 LowIdx& idx_low,
1516 const UpIdx&,
1517 Number<Hack>) const
1518 {
1519 CalculateLowerIndex(idx_diff_low, idx_diff_up);
1520
1521 idx_low += idx_diff_low;
1522 }
1523
1524 __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
1525
1526 __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1527 {
1528 return true;
1529 }
1530
1531 template <typename UpIdx>
1532 __host__ __device__ static constexpr bool
1534 {
1535 return true;
1536 }
1537
1538 __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1539 {
1542 }
1543
1544 __host__ __device__ void Print() const
1545 {
1546 printf("{");
1547 printf("UnMerge, ");
1548 printf("up_lengths_");
1550 printf("up_lengths_scan_");
1552 printf("}");
1553 }
1554};
1555
1565{
1566 static constexpr auto I0 = Number<0>{};
1567 static constexpr auto I1 = Number<1>{};
1568 static constexpr auto I2 = Number<2>{};
1569 static constexpr auto I3 = Number<3>{};
1570
1571 using LowerIndex = MultiIndex<4>; // N, Ho, Wo, K
1572 using UpperIndex = MultiIndex<3>; // K0, M, K1
1573
1583
1585 low_lengths_magic_divisor_multiplier_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_
1587 low_lengths_magic_divisor_shift_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_
1588
1589 __host__ __device__ ConvBwdDataImplicitGemmOutTransform() = default;
1590
1591 __host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform(index_t N,
1592 index_t Ho,
1593 index_t Wo,
1594 index_t K,
1595 index_t XDot,
1596 index_t HTilde,
1597 index_t WTilde,
1598 index_t WTildeSlice,
1599 index_t HWTildeSlice,
1600 index_t IHTildeSliceBegin,
1601 index_t IWTildeSliceBegin,
1602 index_t HRatio,
1603 index_t WRatio,
1604 index_t XDotSlice_K,
1605 index_t K0,
1606 index_t MPadded,
1607 index_t K1,
1608 index_t MPad,
1609 index_t KPad)
1610 : N_{N},
1611 Ho_{Ho},
1612 Wo_{Wo},
1613 K_{K},
1614 XDot_{XDot},
1615 HTilde_{HTilde},
1616 WTilde_{WTilde},
1617 WTildeSlice_{WTildeSlice},
1618 TildeSlice_{HWTildeSlice},
1619 IHTildeSliceBegin_{IHTildeSliceBegin},
1620 IWTildeSliceBegin_{IWTildeSliceBegin},
1621 HRatio_{HRatio},
1622 WRatio_{WRatio},
1623 XDotSlice_K_{XDotSlice_K},
1624 MPad_{MPad},
1625 KPad_{KPad},
1626 up_lengths_{make_tuple(K0, MPadded, K1)},
1628 MagicDivision::CalculateMagicMultiplier(XDotSlice_K_),
1629 MagicDivision::CalculateMagicMultiplier(K_),
1630 MagicDivision::CalculateMagicMultiplier(TildeSlice_),
1631 MagicDivision::CalculateMagicMultiplier(WTildeSlice_)},
1633 MagicDivision::CalculateMagicShift(K_),
1634 MagicDivision::CalculateMagicShift(TildeSlice_),
1635 MagicDivision::CalculateMagicShift(WTildeSlice_)}
1636 {
1637 }
1638
1639 __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 4; }
1640
1641 __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 3; }
1642
1643 __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1644
1645 template <typename UpIdx>
1646 __host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx& idx_up) const
1647 {
1648 index_t NStep{0}, HStep{0}, WStep{0};
1649 // Merge
1650 // NStep = M_id / TildeSlice_
1651 NStep = MagicDivision::DoMagicDivision(idx_up[I1],
1654 HStep = idx_up[I1] - NStep * TildeSlice_;
1655 // HStep = HStep / WTildeSlice_
1656 HStep = MagicDivision::DoMagicDivision(HStep,
1659 WStep = idx_up[I1] - NStep * TildeSlice_ - HStep * WTildeSlice_;
1660 // Slice
1661 HStep += IHTildeSliceBegin_;
1662 WStep += IWTildeSliceBegin_;
1663
1664 return make_tuple(NStep, HStep, WStep, 0);
1665 }
1666
1667 template <typename UpIdx>
1668 __host__ __device__ constexpr auto CalculateLowerIndexK(const UpIdx& idx_up) const
1669 {
1670 // UnMerge
1671 // K_idx <- K0_idx * K1 + K1_idx
1672 index_t K_idx = idx_up[I0] * up_lengths_[I2] + idx_up[I2];
1673 // Merge
1674 // YStep = K_idx / XDotSlice_K_
1675 index_t YStep =
1679 index_t KStep = K_idx - YStep * XDotSlice_K_;
1680 // Xstep = KStep / K_
1681 index_t XStep =
1685 KStep -= XStep * K_;
1686 // Embed
1687 YStep *= HRatio_;
1688 XStep *= WRatio_;
1689
1690 return make_tuple(0, YStep, XStep, KStep);
1691 }
1692
1693 template <typename LowIdx, typename UpIdx>
1694 __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1695 const UpIdx& idx_up) const
1696 {
1697 idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
1698 }
1699
1700 template <typename LowIdxDiff,
1701 typename UpIdxDiff,
1702 typename LowIdx,
1703 typename UpIdx,
1704 index_t Hack>
1705 __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1706 const UpIdxDiff& /* idx_diff_up */,
1707 LowIdx& idx_low,
1708 const UpIdx& idx_up,
1709 Number<Hack>) const
1710 {
1711 LowIdx low_old = idx_low;
1712 idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
1713 idx_diff_low = idx_low - low_old;
1714 }
1715
1716 __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
1717
1718 __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1719 {
1720 return true;
1721 }
1722
1723 template <typename UpIdx>
1724 __host__ __device__ constexpr bool
1726 {
1727 // Padding
1728 index_t K_idx = idx_up[Number<0>{}] * up_lengths_[Number<2>{}] + idx_up[Number<2>{}];
1729 index_t& M_idx = idx_up[Number<1>{}];
1730
1731 bool pad_valid = M_idx < up_lengths_[Number<1>{}] - MPad_ &&
1732 K_idx < up_lengths_[Number<0>{}] * up_lengths_[Number<2>{}] - KPad_;
1733 return pad_valid;
1734 }
1735
1736 __host__ __device__ static constexpr bool IsKnownAtCompileTime() { return false; }
1737
1738 __host__ __device__ void Print() const
1739 {
1740 printf("{");
1741 printf("ConvBwdDataImplicitGemmOutTransform, ");
1742 printf("up_lengths_");
1744 printf("}");
1745 }
1746};
1747
1748template <typename LowerIndex>
1750{
1751 LowerIndex low_idx_;
1752
1753 __host__ __device__ constexpr Freeze() = default;
1754
1755 __host__ __device__ constexpr Freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {}
1756
1757 __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
1758
1759 __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 0; }
1760
1761 __host__ __device__ static constexpr auto GetUpperLengths() { return Tuple<>{}; }
1762
1763 template <typename LowIdx, typename UpIdx>
1764 __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1765 const UpIdx& /* idx_up */) const
1766 {
1767 static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 0,
1768 "wrong! inconsistent # of dimension");
1769
1770 idx_low(Number<0>{}) = low_idx_;
1771 }
1772
1773 template <typename LowIdxDiff,
1774 typename UpIdxDiff,
1775 typename LowIdx,
1776 typename UpIdx,
1777 index_t Hack>
1778 __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1779 const UpIdxDiff& /* idx_diff_up */,
1780 LowIdx& /* idx_low */,
1781 const UpIdx& /* idx_up_new */,
1783 {
1784 idx_diff_low(Number<0>{}) = 0;
1785 }
1786
1787 __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
1788
1789 __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1790 {
1791 return true;
1792 }
1793
1794 template <typename UpIdx>
1795 __host__ __device__ static constexpr bool
1797 {
1798 return true;
1799 }
1800
1801 __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1802 {
1804 }
1805
1806 __host__ __device__ void Print() const
1807 {
1808 printf("Freeze");
1809 printf("low_idx_ %d", index_t{low_idx_});
1810 }
1811};
1812
1813// Insert a dangling upper dimension without lower dimension
1814template <typename UpperLength>
1816{
1817 using UpLengths = decltype(make_tuple(UpperLength{}));
1818
1820
1821 __host__ __device__ constexpr Insert() = default;
1822
1823 __host__ __device__ constexpr Insert(const UpperLength& up_length)
1824 : up_lengths_{make_tuple(up_length)}
1825 {
1826 }
1827
1828 __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 0; }
1829
1830 __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1831
1832 __host__ __device__ constexpr auto GetUpperLengths() const { return up_lengths_; }
1833
1834 template <typename LowIdx, typename UpIdx>
1835 __host__ __device__ constexpr void CalculateLowerIndex(LowIdx&, const UpIdx&) const
1836 {
1837 static_assert(LowIdx::Size() == 0 && UpIdx::Size() == 1,
1838 "wrong! inconsistent # of dimension");
1839 }
1840
1841 template <typename LowIdxDiff,
1842 typename UpIdxDiff,
1843 typename LowIdx,
1844 typename UpIdx,
1845 index_t Hack>
1846 __host__ __device__ static void
1847 UpdateLowerIndex(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&, Number<Hack>)
1848 {
1849 static_assert(LowIdxDiff::Size() == 0 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 0 &&
1850 UpIdx::Size() == 1,
1851 "wrong! inconsistent # of dimension");
1852 }
1853
1854 __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
1855
1856 __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1857 {
1858 return true;
1859 }
1860
1861 template <typename UpIdx>
1862 __host__ __device__ static constexpr bool
1864 {
1865 return true;
1866 }
1867
1868 __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1869 {
1871 }
1872
1873 __host__ __device__ void Print() const
1874 {
1875 printf("Insert");
1877 }
1878};
1879
1880template <typename VectorSize, typename UpLength>
1882{
1885
1886 using UpLengths = decltype(make_tuple(UpLength{}));
1887
1889 VectorSize vector_size_;
1890
1891 __host__ __device__ constexpr Vectorize() = default;
1892
1893 __host__ __device__ constexpr Vectorize(const VectorSize& vector_size,
1894 const UpLength& up_length)
1895 : vector_size_{vector_size}, up_lengths_{make_tuple(up_length)}
1896 {
1897 }
1898
1899 __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
1900
1901 __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1902
1903 __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1904
1905 template <typename LowIdx, typename UpIdx>
1906 __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1907 const UpIdx& idx_up) const
1908 {
1909 static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
1910 "wrong! inconsistent # of dimension");
1911
1912 idx_low(Number<0>{}) = vector_size_ * idx_up[Number<0>{}];
1913 }
1914
1915 template <typename LowIdxDiff,
1916 typename UpIdxDiff,
1917 typename LowIdx,
1918 typename UpIdx,
1919 index_t Hack>
1920 __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1921 const UpIdxDiff& idx_diff_up,
1922 LowIdx& idx_low,
1923 const UpIdx&,
1924 Number<Hack>) const
1925 {
1926 static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
1927 UpIdx::Size() == 1,
1928 "wrong! inconsistent # of dimension");
1929
1930 constexpr auto I0 = Number<0>{};
1931
1932 idx_diff_low(I0) = vector_size_ * idx_diff_up[I0];
1933
1934 idx_low += idx_diff_low;
1935 }
1936
1937 __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
1938
1939 __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1940 {
1941 return true;
1942 }
1943
1944 template <typename UpIdx>
1945 __host__ __device__ static constexpr bool
1947 {
1948 return true;
1949 }
1950
1951 __host__ __device__ static constexpr bool IsKnownAtCompileTime()
1952 {
1954 }
1955
1956 __host__ __device__ void Print() const
1957 {
1958 printf("{");
1959 printf("Vectorize, ");
1960 printf("up_lengths_");
1962 printf("}");
1963 }
1964};
1965
1966template <typename LowLength, typename SliceBegin, typename SliceEnd>
1967struct Slice
1968{
1971
1972 using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{}));
1973
1975 SliceBegin slice_begin_;
1976 SliceEnd slice_end_;
1977
1978 __host__ __device__ constexpr Slice() = default;
1979
1980 __host__ __device__ constexpr Slice(const LowLength&,
1981 const SliceBegin& slice_begin,
1982 const SliceEnd& slice_end)
1983 : up_lengths_{make_tuple(slice_end - slice_begin)},
1984 slice_begin_{slice_begin},
1985 slice_end_{slice_end}
1986 {
1987 }
1988
1989 __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
1990
1991 __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
1992
1993 __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1994
1995 template <typename LowIdx, typename UpIdx>
1996 __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1997 const UpIdx& idx_up) const
1998 {
1999 static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
2000 "wrong! inconsistent # of dimension");
2001
2002 idx_low(Number<0>{}) = idx_up[Number<0>{}] + slice_begin_;
2003 }
2004
2005 template <typename LowIdxDiff,
2006 typename UpIdxDiff,
2007 typename LowIdx,
2008 typename UpIdx,
2009 index_t Hack>
2010 __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
2011 const UpIdxDiff& idx_diff_up,
2012 LowIdx& idx_low,
2013 const UpIdx&,
2015 {
2016 static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
2017 UpIdx::Size() == 1,
2018 "wrong! inconsistent # of dimension");
2019
2020 constexpr auto I0 = Number<0>{};
2021
2022 idx_diff_low(I0) = idx_diff_up[I0];
2023
2024 idx_low += idx_diff_low;
2025 }
2026
2027 __host__ __device__ static constexpr bool IsLinearTransform() { return true; }
2028
2029 __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
2030 {
2031 return true;
2032 }
2033
2034 template <typename UpIdx>
2035 __host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx&) const
2036 {
2037 return true;
2038 }
2039
2046
2047 __host__ __device__ void Print() const
2048 {
2049 printf("{");
2050 printf("Slice, ");
2051 printf("up_lengths_");
2053 printf("slice_begin_ %d", index_t{slice_begin_});
2054 printf("slice_end %d", index_t{slice_end_});
2055 printf("}");
2056 }
2057};
2058
2059/*
2060 * \brief lower_idx = upper_idx % modulus.
2061 * TODO: Need an improved implementation since the modulo operation is expensive.
2062 */
2063template <typename Modulus, typename UpLength>
2065{
2068 using UpLengths = decltype(make_tuple(UpLength{}));
2069
2070 Modulus modulus_;
2072
2073 __host__ __device__ constexpr Modulo() = default;
2074
2075 __host__ __device__ constexpr Modulo(const Modulus& modulus, const UpLength& up_length)
2076 : modulus_{modulus}, up_lengths_{make_tuple(up_length)}
2077 {
2078 }
2079
2080 __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
2081
2082 __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }
2083
2084 __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
2085
2086 template <typename LowIdx, typename UpIdx>
2087 __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
2088 const UpIdx& idx_up) const
2089 {
2090 static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
2091 "wrong! inconsistent # of dimension");
2092
2093 idx_low(Number<0>{}) = idx_up[Number<0>{}] % modulus_;
2094 }
2095
2096 template <typename LowIdxDiff,
2097 typename UpIdxDiff,
2098 typename LowIdx,
2099 typename UpIdx,
2100 index_t Hack>
2101 __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
2102 const UpIdxDiff& idx_diff_up,
2103 LowIdx& idx_low,
2104 const UpIdx& up_idx,
2105 Number<Hack>) const
2106 {
2107 static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
2108 UpIdx::Size() == 1,
2109 "wrong! inconsistent # of dimension");
2110
2111 constexpr auto I0 = Number<0>{};
2112
2113 const auto idx_low_old = idx_low;
2114 idx_low(I0) = (up_idx(I0) + idx_diff_up(I0)) % modulus_;
2115 idx_diff_low(I0) = idx_low - idx_low_old;
2116 }
2117
2118 __host__ __device__ static constexpr bool IsLinearTransform() { return false; }
2119
2120 __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
2121 {
2122 return true;
2123 }
2124
2125 template <typename UpIdx>
2126 __host__ __device__ static constexpr bool
2128 {
2129 return true;
2130 }
2131
2132 __host__ __device__ static constexpr bool IsKnownAtCompileTime()
2133 {
2135 }
2136
2137 __host__ __device__ void Print() const
2138 {
2139 printf("{");
2140 printf("Modulus, ");
2141 printf("up_lengths_");
2143 printf("}");
2144 }
2145};
2146
2147template <typename LowLengths, bool ApplyModulo>
2148struct Xor
2149{
2152
2153 using UpLengths = LowLengths;
2154
2156
2157 __host__ __device__ constexpr Xor() : up_lengths_{} {}
2158
2159 __host__ __device__ constexpr Xor(const LowLengths& low_lengths) : up_lengths_{low_lengths} {}
2160
2161 __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 2; }
2162
2163 __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 2; }
2164
2165 __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
2166
2167 template <typename LowIdx, typename UpIdx>
2168 __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
2169 const UpIdx& idx_up) const
2170 {
2171 static_assert(LowIdx::Size() == 2 && UpIdx::Size() == 2,
2172 "wrong! inconsistent # of dimension");
2173
2174 idx_low(Number<0>{}) = idx_up[Number<0>{}];
2175
2176 if constexpr(ApplyModulo)
2177 {
2178 idx_low(Number<1>{}) =
2179 idx_up[Number<1>{}] ^ (idx_up[Number<0>{}] % up_lengths_[Number<1>{}]);
2180 }
2181 else
2182 {
2183 idx_low(Number<1>{}) = idx_up[Number<1>{}] ^ idx_up[Number<0>{}];
2184 }
2185 }
2186
2187 template <typename LowIdxDiff,
2188 typename UpIdxDiff,
2189 typename LowIdx,
2190 typename UpIdx,
2191 index_t Hack>
2192 __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
2193 const UpIdxDiff&,
2194 LowIdx& idx_low,
2195 const UpIdx& idx_up,
2196 Number<Hack>) const
2197 {
2198 static_assert(LowIdxDiff::Size() == 2 && UpIdxDiff::Size() == 2 && LowIdx::Size() == 2 &&
2199 UpIdx::Size() == 2,
2200 "wrong! inconsistent # of dimension");
2201
2202 const auto idx_low_old = idx_low;
2203
2204 CalculateLowerIndex(idx_low, idx_up);
2205
2206 idx_diff_low = idx_low - idx_low_old;
2207 }
2208
2209 __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
2210 {
2211 return true;
2212 }
2213
2214 template <typename UpIdx>
2215 __host__ __device__ static constexpr bool
2217 {
2218 return true;
2219 }
2220
2221 __host__ __device__ static constexpr bool IsKnownAtCompileTime()
2222 {
2224 }
2225
2226 __host__ __device__ void Print() const
2227 {
2228 printf("Xor{");
2229
2230 //
2231 printf("up_lengths_: ");
2232 print(up_lengths_);
2233 printf(", ");
2234
2235 printf("}");
2236 }
2237};
2238} // namespace ck
Definition utility/math.hpp:13
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
__host__ __device__ constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition utility/container_helper.hpp:111
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto container_reverse_exclusive_scan(const Array< TData, NSize > &x, Reduce f, TData init)
Definition utility/container_helper.hpp:213
__host__ __device__ void print_multi_index(const Tuple< Xs... > &x)
Definition statically_indexed_array_multi_index.hpp:147
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
index_t KPad_
Definition multi_index_transform.hpp:1581
__host__ static __device__ constexpr index_t GetNumOfLowerDimension()
Definition multi_index_transform.hpp:1639
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition multi_index_transform.hpp:1694
__host__ __device__ void Print() const
Definition multi_index_transform.hpp:1738
index_t IWTildeSliceBegin_
Definition multi_index_transform.hpp:1578
index_t IHTildeSliceBegin_
Definition multi_index_transform.hpp:1578
index_t WRatio_
Definition multi_index_transform.hpp:1579
index_t WTildeSlice_
Definition multi_index_transform.hpp:1577
index_t WTilde_
Definition multi_index_transform.hpp:1576
index_t N_
Definition multi_index_transform.hpp:1574
__host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx &idx_up) const
Definition multi_index_transform.hpp:1646
index_t K_
Definition multi_index_transform.hpp:1574
__host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform(index_t N, index_t Ho, index_t Wo, index_t K, index_t XDot, index_t HTilde, index_t WTilde, index_t WTildeSlice, index_t HWTildeSlice, index_t IHTildeSliceBegin, index_t IWTildeSliceBegin, index_t HRatio, index_t WRatio, index_t XDotSlice_K, index_t K0, index_t MPadded, index_t K1, index_t MPad, index_t KPad)
Definition multi_index_transform.hpp:1591
Tuple< index_t, index_t, index_t > up_lengths_
Definition multi_index_transform.hpp:1582
index_t Ho_
Definition multi_index_transform.hpp:1574
__host__ static __device__ constexpr bool IsLinearTransform()
Definition multi_index_transform.hpp:1716
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition multi_index_transform.hpp:1718
Tuple< index_t, index_t, index_t, index_t > low_lengths_magic_divisor_multiplier_
Definition multi_index_transform.hpp:1585
Tuple< index_t, index_t, index_t, index_t > low_lengths_magic_divisor_shift_
Definition multi_index_transform.hpp:1587
index_t XDot_
Definition multi_index_transform.hpp:1575
index_t XDotSlice_K_
Definition multi_index_transform.hpp:1580
MultiIndex< 4 > LowerIndex
Definition multi_index_transform.hpp:1571
index_t HRatio_
Definition multi_index_transform.hpp:1579
__host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &idx_up) const
Definition multi_index_transform.hpp:1725
__host__ __device__ ConvBwdDataImplicitGemmOutTransform()=default
static constexpr auto I2
Definition multi_index_transform.hpp:1568
static constexpr auto I0
Definition multi_index_transform.hpp:1566
index_t MPad_
Definition multi_index_transform.hpp:1581
__host__ __device__ constexpr auto CalculateLowerIndexK(const UpIdx &idx_up) const
Definition multi_index_transform.hpp:1668
index_t TildeSlice_
Definition multi_index_transform.hpp:1577
index_t Wo_
Definition multi_index_transform.hpp:1574
MultiIndex< 3 > UpperIndex
Definition multi_index_transform.hpp:1572
__host__ static __device__ constexpr index_t GetNumOfUpperDimension()
Definition multi_index_transform.hpp:1641
__host__ __device__ constexpr const auto & GetUpperLengths() const
Definition multi_index_transform.hpp:1643
index_t HTilde_
Definition multi_index_transform.hpp:1576
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition multi_index_transform.hpp:1736
static constexpr auto I3
Definition multi_index_transform.hpp:1569
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up, Number< Hack >) const
Definition multi_index_transform.hpp:1705
static constexpr auto I1
Definition multi_index_transform.hpp:1567
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition multi_index_transform.hpp:447
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition multi_index_transform.hpp:459
__host__ __device__ constexpr Embed()=default
static constexpr index_t NDimUp
Definition multi_index_transform.hpp:386
__host__ __device__ constexpr Embed(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform.hpp:396
Coefficients coefficients_
Definition multi_index_transform.hpp:392
MultiIndex< NDimUp > UpperIndex
Definition multi_index_transform.hpp:389
__host__ static __device__ constexpr index_t GetNumOfLowerDimension()
Definition multi_index_transform.hpp:402
__host__ static __device__ constexpr index_t GetNumOfUpperDimension()
Definition multi_index_transform.hpp:404
__host__ static __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition multi_index_transform.hpp:454
UpLengths up_lengths_
Definition multi_index_transform.hpp:391
__host__ __device__ constexpr const auto & GetUpperLengths() const
Definition multi_index_transform.hpp:406
MultiIndex< 1 > LowerIndex
Definition multi_index_transform.hpp:388
__host__ __device__ void Print() const
Definition multi_index_transform.hpp:465
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition multi_index_transform.hpp:427
__host__ static __device__ constexpr bool IsLinearTransform()
Definition multi_index_transform.hpp:445
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition multi_index_transform.hpp:409
LowerIndex low_idx_
Definition multi_index_transform.hpp:1751
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition multi_index_transform.hpp:1789
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &) const
Definition multi_index_transform.hpp:1764
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &, const UpIdx &, Number< Hack >)
Definition multi_index_transform.hpp:1778
__host__ __device__ void Print() const
Definition multi_index_transform.hpp:1806
__host__ __device__ constexpr Freeze()=default
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition multi_index_transform.hpp:1801
__host__ __device__ constexpr Freeze(const LowerIndex &low_idx)
Definition multi_index_transform.hpp:1755
__host__ static __device__ constexpr index_t GetNumOfLowerDimension()
Definition multi_index_transform.hpp:1757
__host__ static __device__ constexpr index_t GetNumOfUpperDimension()
Definition multi_index_transform.hpp:1759
__host__ static __device__ constexpr bool IsLinearTransform()
Definition multi_index_transform.hpp:1787
__host__ static __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition multi_index_transform.hpp:1796
__host__ static __device__ constexpr auto GetUpperLengths()
Definition multi_index_transform.hpp:1761
__host__ __device__ constexpr Insert()=default
__host__ static __device__ constexpr index_t GetNumOfLowerDimension()
Definition multi_index_transform.hpp:1828
__host__ static __device__ constexpr bool IsLinearTransform()
Definition multi_index_transform.hpp:1854
__host__ static __device__ constexpr index_t GetNumOfUpperDimension()
Definition multi_index_transform.hpp:1830
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition multi_index_transform.hpp:1856
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx &, const UpIdx &) const
Definition multi_index_transform.hpp:1835
__host__ __device__ void Print() const
Definition multi_index_transform.hpp:1873
UpLengths up_lengths_
Definition multi_index_transform.hpp:1819
decltype(make_tuple(UpperLength{})) UpLengths
Definition multi_index_transform.hpp:1817
__host__ static __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition multi_index_transform.hpp:1863
__host__ __device__ constexpr Insert(const UpperLength &up_length)
Definition multi_index_transform.hpp:1823
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition multi_index_transform.hpp:1868
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &, const UpIdxDiff &, LowIdx &, const UpIdx &, Number< Hack >)
Definition multi_index_transform.hpp:1847
__host__ __device__ constexpr auto GetUpperLengths() const
Definition multi_index_transform.hpp:1832
__host__ __device__ constexpr const auto & GetUpperLengths() const
Definition multi_index_transform.hpp:217
__host__ static __device__ constexpr bool IsLinearTransform()
Definition multi_index_transform.hpp:251
MultiIndex< 1 > UpperIndex
Definition multi_index_transform.hpp:198
__host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &idx_up) const
Definition multi_index_transform.hpp:260
__host__ __device__ constexpr LeftPad(const LowLength &low_length, const LeftPadLength &left_pad_length)
Definition multi_index_transform.hpp:207
decltype(make_tuple(LowLength{}+LeftPadLength{})) UpLengths
Definition multi_index_transform.hpp:200
MultiIndex< 1 > LowerIndex
Definition multi_index_transform.hpp:197
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition multi_index_transform.hpp:220
LeftPadLength left_pad_length_
Definition multi_index_transform.hpp:203
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >)
Definition multi_index_transform.hpp:234
__host__ static __device__ constexpr index_t GetNumOfLowerDimension()
Definition multi_index_transform.hpp:213
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition multi_index_transform.hpp:265
__host__ __device__ constexpr LeftPad()=default
__host__ __device__ void Print() const
Definition multi_index_transform.hpp:271
__host__ static __device__ constexpr index_t GetNumOfUpperDimension()
Definition multi_index_transform.hpp:215
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition multi_index_transform.hpp:253
UpLengths up_lengths_
Definition multi_index_transform.hpp:202
Definition magic_division.hpp:27
__host__ static __device__ constexpr uint32_t CalculateMagicShift(uint32_t divisor)
Definition magic_division.hpp:64
static __device__ constexpr uint32_t DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
Definition magic_division.hpp:127
__host__ static __device__ constexpr uint32_t CalculateMagicMultiplier(uint32_t divisor)
Definition magic_division.hpp:57
__host__ __device__ void UpdateLowerIndex_2(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition multi_index_transform.hpp:822
LowLengths low_lengths_
Definition multi_index_transform.hpp:493
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number< 1 >{})) LowLengthsScan
Definition multi_index_transform.hpp:487
__host__ __device__ void UpdateLowerIndex_1b(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition multi_index_transform.hpp:680
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition multi_index_transform.hpp:969
UpLengths up_lengths_
Definition multi_index_transform.hpp:495
static constexpr index_t NDimLow
Definition multi_index_transform.hpp:482
MultiIndex< NDimLow > LowerIndex
Definition multi_index_transform.hpp:484
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition multi_index_transform.hpp:974
__host__ static __device__ constexpr bool IsLinearTransform()
Definition multi_index_transform.hpp:967
MultiIndex< 1 > UpperIndex
Definition multi_index_transform.hpp:485
__host__ __device__ constexpr Merge_v1_carry_check()=default
__host__ static __device__ constexpr index_t GetNumOfLowerDimension()
Definition multi_index_transform.hpp:508
__host__ __device__ constexpr Merge_v1_carry_check(const LowLengths &low_lengths)
Definition multi_index_transform.hpp:499
__host__ static __device__ constexpr index_t GetNumOfUpperDimension()
Definition multi_index_transform.hpp:510
__host__ __device__ constexpr const auto & GetUpperLengths() const
Definition multi_index_transform.hpp:512
__host__ static __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition multi_index_transform.hpp:983
__host__ __device__ void Print() const
Definition multi_index_transform.hpp:988
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &idx_up_new, Number< Hack >) const
Definition multi_index_transform.hpp:952
LowLengthsScan low_lengths_scan_
Definition multi_index_transform.hpp:494
__host__ __device__ void UpdateLowerIndex_1a(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition multi_index_transform.hpp:537
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition multi_index_transform.hpp:515
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number< 1 >{}))) UpLengths
Definition multi_index_transform.hpp:490
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition multi_index_transform.hpp:1138
__host__ static __device__ constexpr bool IsLinearTransform()
Definition multi_index_transform.hpp:1136
MultiIndex< 1 > UpperIndex
Definition multi_index_transform.hpp:1040
__host__ __device__ constexpr const auto & GetUpperLengths() const
Definition multi_index_transform.hpp:1077
LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_
Definition multi_index_transform.hpp:1055
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition multi_index_transform.hpp:1143
UpLengths up_lengths_
Definition multi_index_transform.hpp:1056
decltype(generate_tuple( lambda_merge_generate_MagicDivision_calculate_magic_multiplier< LowLengths >{}, Number< NDimLow >{})) LowLengthsMagicDivisorMultipiler
Definition multi_index_transform.hpp:1045
static constexpr index_t NDimLow
Definition multi_index_transform.hpp:1037
__host__ static __device__ constexpr index_t GetNumOfUpperDimension()
Definition multi_index_transform.hpp:1075
__host__ static __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition multi_index_transform.hpp:1153
LowLengthsMagicDivisorMultipiler low_lengths_magic_divisor_multiplier_
Definition multi_index_transform.hpp:1054
__host__ __device__ constexpr Merge_v2_magic_division()=default
__host__ static __device__ constexpr index_t GetNumOfLowerDimension()
Definition multi_index_transform.hpp:1073
__host__ __device__ constexpr Merge_v2_magic_division(const LowLengths &low_lengths)
Definition multi_index_transform.hpp:1060
LowLengths low_lengths_
Definition multi_index_transform.hpp:1053
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number< 1 >{}))) UpLengths
Definition multi_index_transform.hpp:1042
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up_new, Number< Hack >) const
Definition multi_index_transform.hpp:1105
MultiIndex< NDimLow > LowerIndex
Definition multi_index_transform.hpp:1039
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition multi_index_transform.hpp:1080
decltype(generate_tuple( lambda_merge_generate_MagicDivision_calculate_magic_shift< LowLengths >{}, Number< NDimLow >{})) LowLengthsMagicDivisorShift
Definition multi_index_transform.hpp:1049
__host__ __device__ void Print() const
Definition multi_index_transform.hpp:1158
LowLengths low_lengths_
Definition multi_index_transform.hpp:1208
__host__ __device__ void Print() const
Definition multi_index_transform.hpp:1315
__host__ static __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition multi_index_transform.hpp:1310
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number< 1 >{})) LowLengthsScan
Definition multi_index_transform.hpp:1194
__host__ static __device__ constexpr bool IsLinearTransform()
Definition multi_index_transform.hpp:1293
MultiIndex< 1 > UpperIndex
Definition multi_index_transform.hpp:1192
decltype(generate_tuple( lambda_merge_generate_MagicDivision_calculate_magic_shift< LowLengthsScan >{}, Number< NDimLow >{})) LowLengthsScanMagicDivisorShift
Definition multi_index_transform.hpp:1204
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition multi_index_transform.hpp:1295
LowLengthsScanMagicDivisorShift low_lengths_scan_magic_divisor_shift_
Definition multi_index_transform.hpp:1211
UpLengths up_lengths_
Definition multi_index_transform.hpp:1212
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition multi_index_transform.hpp:1238
MultiIndex< NDimLow > LowerIndex
Definition multi_index_transform.hpp:1191
LowLengthsScanMagicDivisorMultipiler low_lengths_scan_magic_divisor_multiplier_
Definition multi_index_transform.hpp:1210
decltype(generate_tuple( lambda_merge_generate_MagicDivision_calculate_magic_multiplier< LowLengthsScan >{}, Number< NDimLow >{})) LowLengthsScanMagicDivisorMultipiler
Definition multi_index_transform.hpp:1200
__host__ static __device__ constexpr index_t GetNumOfUpperDimension()
Definition multi_index_transform.hpp:1233
__host__ __device__ constexpr Merge_v2r2_magic_division(const LowLengths &low_lengths)
Definition multi_index_transform.hpp:1216
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up_new, Number< Hack >) const
Definition multi_index_transform.hpp:1263
LowLengthsScan low_lengths_scan_
Definition multi_index_transform.hpp:1209
__host__ __device__ constexpr const auto & GetUpperLengths() const
Definition multi_index_transform.hpp:1235
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition multi_index_transform.hpp:1300
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number< 1 >{}))) UpLengths
Definition multi_index_transform.hpp:1197
static constexpr index_t NDimLow
Definition multi_index_transform.hpp:1189
__host__ static __device__ constexpr index_t GetNumOfLowerDimension()
Definition multi_index_transform.hpp:1231
__host__ __device__ constexpr Merge_v2r2_magic_division()=default
UpLengths up_lengths_
Definition multi_index_transform.hpp:1352
MultiIndex< 1 > UpperIndex
Definition multi_index_transform.hpp:1342
__host__ __device__ constexpr Merge_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform.hpp:1356
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up_new, Number< Hack >) const
Definition multi_index_transform.hpp:1394
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number< 1 >{}))) UpLengths
Definition multi_index_transform.hpp:1347
__host__ static __device__ constexpr bool IsLinearTransform()
Definition multi_index_transform.hpp:1421
__host__ __device__ void Print() const
Definition multi_index_transform.hpp:1442
__host__ static __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition multi_index_transform.hpp:1437
__host__ __device__ constexpr Merge_v3_division_mod()=default
__host__ static __device__ constexpr index_t GetNumOfLowerDimension()
Definition multi_index_transform.hpp:1365
LowLengthsScan low_lengths_scan_
Definition multi_index_transform.hpp:1351
__host__ static __device__ constexpr index_t GetNumOfUpperDimension()
Definition multi_index_transform.hpp:1367
LowLengths low_lengths_
Definition multi_index_transform.hpp:1350
MultiIndex< NDimLow > LowerIndex
Definition multi_index_transform.hpp:1341
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number< 1 >{})) LowLengthsScan
Definition multi_index_transform.hpp:1344
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition multi_index_transform.hpp:1372
__host__ __device__ constexpr const auto & GetUpperLengths() const
Definition multi_index_transform.hpp:1369
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition multi_index_transform.hpp:1423
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition multi_index_transform.hpp:1428
static constexpr index_t NDimLow
Definition multi_index_transform.hpp:1339
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition multi_index_transform.hpp:2087
__host__ __device__ constexpr Modulo(const Modulus &modulus, const UpLength &up_length)
Definition multi_index_transform.hpp:2075
__host__ static __device__ constexpr bool IsLinearTransform()
Definition multi_index_transform.hpp:2118
Modulus modulus_
Definition multi_index_transform.hpp:2070
__host__ __device__ void Print() const
Definition multi_index_transform.hpp:2137
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &up_idx, Number< Hack >) const
Definition multi_index_transform.hpp:2101
decltype(make_tuple(UpLength{})) UpLengths
Definition multi_index_transform.hpp:2068
MultiIndex< 1 > UpperIndex
Definition multi_index_transform.hpp:2067
__host__ static __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition multi_index_transform.hpp:2127
MultiIndex< 1 > LowerIndex
Definition multi_index_transform.hpp:2066
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition multi_index_transform.hpp:2132
UpLengths up_lengths_
Definition multi_index_transform.hpp:2071
__host__ __device__ constexpr Modulo()=default
__host__ static __device__ constexpr index_t GetNumOfLowerDimension()
Definition multi_index_transform.hpp:2080
__host__ __device__ constexpr const auto & GetUpperLengths() const
Definition multi_index_transform.hpp:2084
__host__ static __device__ constexpr index_t GetNumOfUpperDimension()
Definition multi_index_transform.hpp:2082
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition multi_index_transform.hpp:2120
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >)
Definition multi_index_transform.hpp:142
__host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &idx_up) const
Definition multi_index_transform.hpp:168
LeftPadLength left_pad_length_
Definition multi_index_transform.hpp:107
MultiIndex< 1 > UpperIndex
Definition multi_index_transform.hpp:102
__host__ static __device__ constexpr index_t GetNumOfLowerDimension()
Definition multi_index_transform.hpp:121
decltype(make_tuple(LowLength{}+LeftPadLength{}+RightPadLength{})) UpLengths
Definition multi_index_transform.hpp:104
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition multi_index_transform.hpp:161
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition multi_index_transform.hpp:128
__host__ static __device__ constexpr index_t GetNumOfUpperDimension()
Definition multi_index_transform.hpp:123
__host__ __device__ constexpr const auto & GetUpperLengths() const
Definition multi_index_transform.hpp:125
__host__ static __device__ constexpr bool IsLinearTransform()
Definition multi_index_transform.hpp:159
__host__ __device__ constexpr Pad()=default
UpLengths up_lengths_
Definition multi_index_transform.hpp:106
RightPadLength right_pad_length_
Definition multi_index_transform.hpp:108
MultiIndex< 1 > LowerIndex
Definition multi_index_transform.hpp:101
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition multi_index_transform.hpp:175
__host__ __device__ void Print() const
Definition multi_index_transform.hpp:182
__host__ __device__ constexpr Pad(const LowLength &low_length, const LeftPadLength &left_pad_length, const RightPadLength &right_pad_length)
Definition multi_index_transform.hpp:112
__host__ static __device__ constexpr bool IsLinearTransform()
Definition multi_index_transform.hpp:66
__host__ __device__ constexpr PassThrough()=default
UpLengths up_lengths_
Definition multi_index_transform.hpp:19
__host__ static __device__ constexpr index_t GetNumOfLowerDimension()
Definition multi_index_transform.hpp:28
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >)
Definition multi_index_transform.hpp:49
__host__ static __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition multi_index_transform.hpp:75
MultiIndex< 1 > UpperIndex
Definition multi_index_transform.hpp:15
__host__ __device__ void Print() const
Definition multi_index_transform.hpp:85
decltype(make_tuple(LowLength{})) UpLengths
Definition multi_index_transform.hpp:17
__host__ static __device__ constexpr void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up)
Definition multi_index_transform.hpp:35
__host__ __device__ constexpr const auto & GetUpperLengths() const
Definition multi_index_transform.hpp:32
__host__ __device__ constexpr PassThrough(const LowLength &low_length)
Definition multi_index_transform.hpp:23
__host__ static __device__ constexpr index_t GetNumOfUpperDimension()
Definition multi_index_transform.hpp:30
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition multi_index_transform.hpp:68
MultiIndex< 1 > LowerIndex
Definition multi_index_transform.hpp:14
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition multi_index_transform.hpp:80
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition multi_index_transform.hpp:344
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition multi_index_transform.hpp:356
__host__ static __device__ constexpr void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up)
Definition multi_index_transform.hpp:311
decltype(make_tuple(LowLength{}+RightPadLength{})) UpLengths
Definition multi_index_transform.hpp:288
__host__ __device__ constexpr RightPad(const LowLength &low_length, const RightPadLength &right_pad_length)
Definition multi_index_transform.hpp:296
MultiIndex< 1 > UpperIndex
Definition multi_index_transform.hpp:286
__host__ static __device__ constexpr index_t GetNumOfUpperDimension()
Definition multi_index_transform.hpp:306
__host__ __device__ constexpr RightPad()=default
UpLengths up_lengths_
Definition multi_index_transform.hpp:290
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >)
Definition multi_index_transform.hpp:325
__host__ __device__ constexpr const auto & GetUpperLengths() const
Definition multi_index_transform.hpp:308
__host__ __device__ void Print() const
Definition multi_index_transform.hpp:363
__host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &idx_up) const
Definition multi_index_transform.hpp:351
__host__ static __device__ constexpr bool IsLinearTransform()
Definition multi_index_transform.hpp:342
__host__ static __device__ constexpr index_t GetNumOfLowerDimension()
Definition multi_index_transform.hpp:304
LowLength low_length_
Definition multi_index_transform.hpp:291
MultiIndex< 1 > LowerIndex
Definition multi_index_transform.hpp:285
RightPadLength right_pad_length_
Definition multi_index_transform.hpp:292
__host__ static __device__ constexpr index_t GetNumOfUpperDimension()
Definition multi_index_transform.hpp:1991
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition multi_index_transform.hpp:2029
__host__ __device__ constexpr Slice(const LowLength &, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition multi_index_transform.hpp:1980
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition multi_index_transform.hpp:1996
__host__ static __device__ constexpr bool IsLinearTransform()
Definition multi_index_transform.hpp:2027
SliceEnd slice_end_
Definition multi_index_transform.hpp:1976
decltype(make_tuple(SliceEnd{} - SliceBegin{})) UpLengths
Definition multi_index_transform.hpp:1972
__host__ __device__ void Print() const
Definition multi_index_transform.hpp:2047
__host__ __device__ constexpr const auto & GetUpperLengths() const
Definition multi_index_transform.hpp:1993
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition multi_index_transform.hpp:2040
UpLengths up_lengths_
Definition multi_index_transform.hpp:1974
SliceBegin slice_begin_
Definition multi_index_transform.hpp:1975
__host__ __device__ constexpr Slice()=default
MultiIndex< 1 > LowerIndex
Definition multi_index_transform.hpp:1969
__host__ static __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >)
Definition multi_index_transform.hpp:2010
__host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &) const
Definition multi_index_transform.hpp:2035
MultiIndex< 1 > UpperIndex
Definition multi_index_transform.hpp:1970
__host__ static __device__ constexpr index_t GetNumOfLowerDimension()
Definition multi_index_transform.hpp:1989
Definition utility/tuple.hpp:186
Definition utility/tuple.hpp:117
MultiIndex< NDimUp > UpperIndex
Definition multi_index_transform.hpp:1462
MultiIndex< 1 > LowerIndex
Definition multi_index_transform.hpp:1461
UpLengths up_lengths_
Definition multi_index_transform.hpp:1467
decltype(container_reverse_exclusive_scan(UpLengths{}, math::multiplies{}, Number< 1 >{})) UpLengthsScan
Definition multi_index_transform.hpp:1464
__host__ static __device__ constexpr index_t GetNumOfLowerDimension()
Definition multi_index_transform.hpp:1479
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition multi_index_transform.hpp:1538
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition multi_index_transform.hpp:1526
__host__ __device__ constexpr UnMerge()=default
__host__ __device__ void Print() const
Definition multi_index_transform.hpp:1544
__host__ static __device__ constexpr index_t GetNumOfUpperDimension()
Definition multi_index_transform.hpp:1481
UpLengthsScan up_lengths_scan_
Definition multi_index_transform.hpp:1468
__host__ __device__ constexpr const auto & GetUpperLengths() const
Definition multi_index_transform.hpp:1483
static constexpr index_t NDimUp
Definition multi_index_transform.hpp:1459
__host__ static __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition multi_index_transform.hpp:1533
__host__ static __device__ constexpr bool IsLinearTransform()
Definition multi_index_transform.hpp:1524
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition multi_index_transform.hpp:1513
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition multi_index_transform.hpp:1486
__host__ __device__ constexpr UnMerge(const UpLengths &up_lengths)
Definition multi_index_transform.hpp:1472
MultiIndex< 1 > UpperIndex
Definition multi_index_transform.hpp:1884
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition multi_index_transform.hpp:1939
__host__ static __device__ constexpr index_t GetNumOfLowerDimension()
Definition multi_index_transform.hpp:1899
__host__ static __device__ constexpr bool IsLinearTransform()
Definition multi_index_transform.hpp:1937
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &, Number< Hack >) const
Definition multi_index_transform.hpp:1920
__host__ __device__ constexpr Vectorize()=default
__host__ __device__ constexpr Vectorize(const VectorSize &vector_size, const UpLength &up_length)
Definition multi_index_transform.hpp:1893
MultiIndex< 1 > LowerIndex
Definition multi_index_transform.hpp:1883
VectorSize vector_size_
Definition multi_index_transform.hpp:1889
__host__ __device__ constexpr const auto & GetUpperLengths() const
Definition multi_index_transform.hpp:1903
__host__ static __device__ constexpr index_t GetNumOfUpperDimension()
Definition multi_index_transform.hpp:1901
decltype(make_tuple(UpLength{})) UpLengths
Definition multi_index_transform.hpp:1886
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition multi_index_transform.hpp:1906
__host__ static __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition multi_index_transform.hpp:1946
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition multi_index_transform.hpp:1951
UpLengths up_lengths_
Definition multi_index_transform.hpp:1888
__host__ __device__ void Print() const
Definition multi_index_transform.hpp:1956
__host__ static __device__ constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
Definition multi_index_transform.hpp:2209
__host__ static __device__ constexpr index_t GetNumOfLowerDimension()
Definition multi_index_transform.hpp:2161
__host__ static __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx &)
Definition multi_index_transform.hpp:2216
MultiIndex< 2 > UpperIndex
Definition multi_index_transform.hpp:2151
MultiIndex< 2 > LowerIndex
Definition multi_index_transform.hpp:2150
__host__ static __device__ constexpr index_t GetNumOfUpperDimension()
Definition multi_index_transform.hpp:2163
UpLengths up_lengths_
Definition multi_index_transform.hpp:2155
__host__ __device__ void UpdateLowerIndex(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up, Number< Hack >) const
Definition multi_index_transform.hpp:2192
__host__ __device__ constexpr Xor()
Definition multi_index_transform.hpp:2157
__host__ __device__ constexpr const auto & GetUpperLengths() const
Definition multi_index_transform.hpp:2165
__host__ __device__ constexpr Xor(const LowLengths &low_lengths)
Definition multi_index_transform.hpp:2159
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx &idx_low, const UpIdx &idx_up) const
Definition multi_index_transform.hpp:2168
__host__ __device__ void Print() const
Definition multi_index_transform.hpp:2226
__host__ static __device__ constexpr bool IsKnownAtCompileTime()
Definition multi_index_transform.hpp:2221
LowLengths UpLengths
Definition multi_index_transform.hpp:2153
Definition is_known_at_compile_time.hpp:14
__host__ __device__ constexpr auto operator()(Number< I > i) const
Definition multi_index_transform.hpp:1006
Definition multi_index_transform.hpp:1014
__host__ __device__ constexpr auto operator()(Number< I > i) const
Definition multi_index_transform.hpp:1016
Definition utility/math.hpp:34
Definition functional2.hpp:33