block_reduce2d.hpp Source File

block_reduce2d.hpp Source File#

Composable Kernel: block_reduce2d.hpp Source File
block_reduce2d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// BlockReduce2d implements a hierarchical 2D reduction operator that reduces data along the second
12// dimension using a user-specified reduction function.
13//
14// The reduction is performed in a three-stage hierarchical approach:
15//
16// STAGE 1: Thread-level reduction (BlockReduce2d)
17// ===============================================
18// - Each thread processes multiple elements from the input tensor within its assigned data
19// partition
20// - Reduction is performed locally within each thread by iterating over assigned elements
21// - ReducePacksPerXDim controls how many elements sweep_tile processes in one iteration per
22// dimension
23// (e.g., {1,1} = 1 element at a time from each dimension, {2,4} = 2 from dim0, 4 from dim1)
24// - Results are accumulated into a thread-local output tensor stored in registers
25// - The output tensor distribution is derived from the input tensor's distribution using
26// make_reduce_tile_distribution_encoding() to handle dimension reduction
27//
28// STAGE 2: Warp-level reduction (BlockReduce2dSync)
29// ================================================
30// - Performs inter-thread reduction within each warp
31// - Uses warp shuffle operations to exchange data between threads in the same warp
32// - Implements a tree-reduction pattern with power-of-2 stages
33// - Only reduces along dimensions that map to lane IDs within the warp
34//
35// STAGE 3: Cross-warp reduction (BlockReduce2dCrossWarpSync)
36// ========================================================
37// - Performs reduction across multiple warps within the same thread block
38// - Uses shared memory (LDS) to facilitate data exchange between warps
39// - Each warp's lane-0 thread stores its partial results to shared memory
40// - All threads participate in loading and reducing data from shared memory
41// - Implements block-level synchronization to ensure memory consistency
42
43// BlockReduce2d: Thread-level reduction (Stage 1)
44template <typename Problem_, typename Policy_ = void>
46{
47 // Thread-level reduction implementation
49 using XDataType = typename Problem::XDataType;
50 using ComputeDataType = typename Problem::ComputeDataType;
51
53
54 private:
55 template <bool kProcessIndex,
56 typename XDistributedTensor_,
57 typename YDistributedTensor_,
58 typename YIndexDistributedTensor_,
59 typename ReduceFunc,
60 typename IndexCalculatorFunc,
61 typename ReducePacksPerXDim>
62 CK_TILE_DEVICE void reduce_impl(const XDistributedTensor_& x_tensor,
63 YDistributedTensor_& y_tensor,
64 YIndexDistributedTensor_& y_index_tensor,
65 const ReduceFunc& reduce_func,
66 const IndexCalculatorFunc& index_calculator,
67 ReducePacksPerXDim)
68 {
70 [&](auto... idx_) {
71 constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]);
72
73 (..., [&](auto idx) {
74 auto val = ck_tile::type_convert<ComputeDataType>(x_tensor[idx]);
75
76 if constexpr(kProcessIndex)
77 {
78
79 const auto x_indices = get_x_indices_from_distributed_indices(
80 XDistributedTensor_::get_tile_distribution(), idx);
81 const auto new_idx = index_calculator(x_indices);
82 auto current_idx = y_index_tensor(idx_0);
83
84 AccumulateWithIndex{}(
85 reduce_func, y_tensor(idx_0), current_idx, val, new_idx);
86
87 y_index_tensor(idx_0) =
89 }
90 else
91 {
92 Accumulate{}(reduce_func, y_tensor(idx_0), val);
93 }
94 }(idx_));
95 },
96 ReducePacksPerXDim{});
97 }
98
99 public:
100 // Overload for non-index tracking
101 template <
102 typename XDistributedTensor_,
103 typename YDistributedTensor_,
104 typename ReduceFunc,
105 typename ReducePacksPerXDim =
106 uniform_sequence_gen_t<2, 1>> // {1,1} = process 1 element at a time from each dimension
107 CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
108 YDistributedTensor_& y_tensor,
109 const ReduceFunc& reduce_func,
110 ReducePacksPerXDim = {})
111 {
112 reduce_impl<false>(
113 x_tensor,
114 y_tensor,
115 y_tensor, // dummy
116 reduce_func,
117 [](auto) { return 0; }, // dummy
118 ReducePacksPerXDim{});
119 }
120
121 // Overload for index tracking
122 template <typename XDistributedTensor_,
123 typename YDistributedTensor_,
124 typename YIndexDistributedTensor_,
125 typename ReduceFunc,
126 typename IndexCalculatorFunc,
127 typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
128 CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
129 YDistributedTensor_& y_tensor,
130 YIndexDistributedTensor_& y_index_tensor,
131 const ReduceFunc& reduce_func,
132 const IndexCalculatorFunc& index_calculator,
133 ReducePacksPerXDim = {})
134 {
135 reduce_impl<Problem::kOutputIndex>(x_tensor,
136 y_tensor,
137 y_index_tensor,
138 reduce_func,
139 index_calculator,
140 ReducePacksPerXDim{});
141 }
142
143#if 0
144 constexpr auto I0 = number<0>{};
145 constexpr auto I1 = number<1>{};
146 constexpr auto spans = XDistributedTensor_::get_distributed_spans();
147
148 // FIXME: hard coded to reduce 2nd axis
149 sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
150 constexpr auto y_dstr_idx = make_tuple(dstr_idx_i0);
151
152 auto y = y_tensor[y_dstr_idx];
153
154 sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
155 constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
156 const auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
157
158 y = reduce_func(y, x);
159 });
160
161 y_tensor(y_dstr_idx) = y;
162 });
163#endif
164
165 template <typename XDistributedTensor_>
167 {
168 static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
169
170 // FIXME: hard coded to reduce 2nd axis
171 constexpr auto reduce_dims = sequence<1>{};
172
173 constexpr auto dstr =
175 XDistributedTensor_::get_tile_distribution()
176 .get_static_tile_distribution_encoding(),
177 reduce_dims));
178
180
181 return tensor;
182 }
183
184 template <typename XDistributedTensor_, typename IndexDataType = index_t>
186 {
187 static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
188
189 // FIXME: hard coded to reduce 2nd axis
190 constexpr auto reduce_dims = sequence<1>{};
191
192 constexpr auto dstr =
194 XDistributedTensor_::get_tile_distribution()
195 .get_static_tile_distribution_encoding(),
196 reduce_dims));
197
199
200 return tensor;
201 }
202
203 // uniform_sequence_gen_t<NSize, Value> generates sequence of NSize elements filled with Value
204 // e.g., uniform_sequence_gen_t<2, 1> → {1, 1} and uniform_sequence_gen_t<3, 4> → {4, 4, 4}
205 template <typename XDistributedTensor_,
206 typename ReduceFunc,
207 typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
208 CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor,
209 const ComputeDataType& reduce_init,
210 const ReduceFunc& reduce_func,
211 ReducePacksPerXDim = {})
212 {
213 auto y_tensor = MakeYBlockTile<XDistributedTensor_>();
214 set_tile(y_tensor, reduce_init);
215 (*this)(x_tensor, y_tensor, reduce_func, ReducePacksPerXDim{});
216
217 return y_tensor;
218 }
219};
220
221// BlockReduce2dSync: Warp-level reduction (Stage 2)
222template <typename Problem_, typename Policy_ = void>
224{
226
227 private:
228 template <bool kProcessIndex,
229 typename YDistributedTensor_,
230 typename YIndexDistributedTensor_,
231 typename ReduceFunc>
232 CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
233 YIndexDistributedTensor_& y_index_tensor,
234 const ReduceFunc& reduce_func)
235 {
236 using Dstr = typename YDistributedTensor_::StaticTileDistribution;
237 using DstrEncode = typename Dstr::DstrEncode;
238 using DstrEncodeDetail = typename DstrEncode::detail;
239
240 constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
241 constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
242
243 constexpr index_t idim_p_lane = NDimP - 1;
244
245 // const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
246 // const auto rs_idx =
247 // y_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
248
249 constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
250
251 // loop over thread data
253 auto v_local = y_tensor.get_thread_buffer()[i];
254
255 using IndexDataType = typename YIndexDistributedTensor_::DataType;
256 IndexDataType idx_local{};
257
258 if constexpr(kProcessIndex)
259 {
260 idx_local = y_index_tensor.get_thread_buffer()[i];
261 }
262
263 // cross-lane reduce for replication
264 // only reduce on R dimension correspond to lane
265 // (lane id maps to this R dimension)
266 static_for<0, NDimR, 1>{}([&](auto idim_r) {
267 // FIXME: nasty to use does_p_own_r_
268 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
269 {
270 constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
271
272 constexpr index_t lid_over_rid_derivative =
273 DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
274
275 static_assert(is_power_of_two_integer(r_length),
276 "wrong! only support power of 2 reduction");
277
278 constexpr index_t nstage = integer_log2_floor(r_length);
279
280 // reduction sweep forward
281 static_for<0, nstage, 1>{}([&](auto istage) {
282 // xor
283 index_t src_lane =
284 (__lane_id()) ^
285 (number<lid_over_rid_derivative << istage.value>{}.value);
286
287 // pull data from remote lane
288 const auto v_remote = warp_shuffle(v_local, src_lane);
289
290 if constexpr(kProcessIndex)
291 {
292 const auto idx_remote = warp_shuffle(idx_local, src_lane);
293
295 reduce_func, v_local, idx_local, v_remote, idx_remote);
296 }
297 else
298 {
299 Accumulate{}(reduce_func, v_local, v_remote);
300 }
301 });
302 }
303 });
304
305 // TODO - Do we need to broadcast to other lane?
306 y_tensor.get_thread_buffer()(i) = v_local;
307
308 if constexpr(kProcessIndex)
309 {
310 y_index_tensor.get_thread_buffer()(i) = idx_local;
311 }
312 });
313 }
314
315 public:
316 template <typename YDistributedTensor_, typename ReduceFunc>
317 CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func)
318 {
319 reduce_impl<false>(y_tensor, y_tensor, reduce_func);
320 }
321
322 template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
323 CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
324 YIndexDistributedTensor_& y_index_tensor,
325 const ReduceFunc& reduce_func)
326 {
327 reduce_impl<Problem::kOutputIndex>(y_tensor, y_index_tensor, reduce_func);
328 }
329};
330
331// BlockReduce2dCrossWarpSync: Cross-warp reduction (Stage 3)
332template <typename Problem_, typename Policy_ = void>
334{
336 using BlockShape = typename Problem::BlockShape;
337
338 template <typename YDistributedTensor_>
340 {
341 constexpr index_t num_reduce_warps = [&]() {
342 using Dstr = typename YDistributedTensor_::StaticTileDistribution;
343 using DstrEncode = typename Dstr::DstrEncode;
344 using DstrEncodeDetail = typename DstrEncode::detail;
345
346 constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
347
348 constexpr index_t idim_p_warp = 0;
349
350 index_t len_ = 1;
351 static_for<0, NDimR, 1>{}([&](auto idim_r) {
352 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
353 {
354 constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
355 len_ *= r_length;
356 }
357 });
358 return len_;
359 }();
360 return num_reduce_warps;
361 }
362
363 // return in byte
364 template <typename YDistributedTensor_>
366 {
367 using DataType = typename YDistributedTensor_::DataType;
368 constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
369
370 // we need to store all data from every wave into smem
371 // e.g. 2x2 reduce along N
372 // -------------> reduce N
373 // | w0 | w1 | ___> | w01 |
374 // | w2 | w3 | | w23 |
375 //
376 // -> store data from every wave into LDS
377 //
378 //
379 // -------------> reduce N
380 // | w0 | w1 | w2 | w3 | -----> | w0123 |
381 //
382 // -> also store data from every wave into LDS
383 constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
384 return num_warps * thread_buf_size * sizeof(DataType);
385 }
386
387 // return in byte - separate shared memory size calculation for indices
388 template <typename YIndexDistributedTensor_>
390 {
391 using IndexDataType = typename YIndexDistributedTensor_::DataType;
392 constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
393 constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
394 return num_warps * thread_buf_size * sizeof(IndexDataType);
395 }
396
397 private:
398 template <bool kProcessIndex,
399 typename YDistributedTensor_,
400 typename YIndexDistributedTensor_,
401 typename ReduceFunc>
402 CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
403 YIndexDistributedTensor_& y_index_tensor,
404 void* smem,
405 void* smem_indices_ptr,
406 const ReduceFunc& reduce_func)
407 {
408 using DataType = typename YDistributedTensor_::DataType;
409 using IndexDataType = typename YIndexDistributedTensor_::DataType;
410
411 constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
412
413 DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
414 IndexDataType* smem_indices = nullptr;
415 if constexpr(kProcessIndex)
416 {
417 smem_indices = reinterpret_cast<IndexDataType*>(smem_indices_ptr);
418 }
419
420 const index_t lane_id = get_lane_id();
421 const index_t warp_id = get_warp_id();
422
423 constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
424 constexpr index_t num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
425
426 if constexpr(num_reduce_warps == 1)
427 return;
428
429 // Each warp's lane 0 writes its partial results to shared memory
430 const index_t smem_offset = warp_id;
431 if(lane_id == 0)
432 {
433 static_for<0, thread_buf_size, 1>{}([&](auto i) {
434 // Store the i-th element of this warp's thread_buffer into SMEM
435 smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
436 if constexpr(kProcessIndex)
437 {
438 smem_indices[smem_offset + i * num_warps] =
439 y_index_tensor.get_thread_buffer()[i];
440 }
441 });
442 }
444
445 // We let each warp holds a duplication to do reduction.
446 const index_t local_warp_id = warp_id / num_reduce_warps;
447 const index_t local_smem_os = local_warp_id * num_reduce_warps;
448
449 static_for<0, thread_buf_size, 1>{}([&](auto i) {
450 DataType v[num_reduce_warps];
451 [[maybe_unused]] std::
452 conditional_t<kProcessIndex, IndexDataType[num_reduce_warps], IndexDataType> idx_v;
453
454 static_for<0, num_reduce_warps, 1>{}([&](auto idx) {
455 v[idx] = smem_ptr[i * num_warps + local_smem_os + idx];
456 if constexpr(kProcessIndex)
457 {
458 idx_v[idx] = smem_indices[i * num_warps + local_smem_os + idx];
459 }
460 });
461
462 static_assert(is_power_of_two_integer(num_reduce_warps),
463 "wrong! only support power of 2 reduction");
464
465 constexpr index_t nstage = integer_log2_floor(num_reduce_warps);
466
467 static_for<0, nstage, 1>{}([&](auto istage) {
468 constexpr index_t stride = 1 << istage.value;
469 static_for<0, num_reduce_warps, stride * 2>{}([&](auto idx_) {
470 constexpr index_t i0 = idx_();
471 constexpr index_t i1 = idx_ + stride;
472 if constexpr(i1 < num_reduce_warps)
473 {
474 if constexpr(kProcessIndex)
475 {
476 AccumulateWithIndex{}(reduce_func, v[i0], idx_v[i0], v[i1], idx_v[i1]);
477 }
478 else
479 {
480 Accumulate{}(reduce_func, v[i0], v[i1]);
481 }
482 }
483 });
484 });
485
486 y_tensor.get_thread_buffer()(i) = v[0];
487 if constexpr(kProcessIndex)
488 {
489 y_index_tensor.get_thread_buffer()(i) = idx_v[0];
490 }
491 });
492 }
493
494 public:
495 template <typename YDistributedTensor_, typename ReduceFunc>
496 CK_TILE_DEVICE void
497 operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
498 {
499 reduce_impl<false>(y_tensor, y_tensor, smem, nullptr, reduce_func);
500 }
501
502 template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
503 CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
504 YIndexDistributedTensor_& y_index_tensor,
505 void* smem,
506 void* smem_indices,
507 const ReduceFunc& reduce_func)
508 {
509 reduce_impl<Problem::kOutputIndex>(
510 y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
511 }
512};
513
514template <typename Problem_, typename Policy_ = void>
516{
518 using BlockShape = typename Problem::BlockShape;
519
520 template <typename YDistributedTensor_>
522 {
523 constexpr index_t num_reduce_warps = [&]() {
524 using Dstr = typename YDistributedTensor_::StaticTileDistribution;
525 using DstrEncode = typename Dstr::DstrEncode;
526 using DstrEncodeDetail = typename DstrEncode::detail;
527
528 constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
529
530 constexpr index_t idim_p_warp = 0;
531
532 index_t len_ = 1;
533 static_for<0, NDimR, 1>{}([&](auto idim_r) {
534 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
535 {
536 constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
537 len_ *= r_length;
538 }
539 });
540 return len_;
541 }();
542 return num_reduce_warps;
543 }
544
545 // return in byte
546 template <typename YDistributedTensor_>
548 {
549 using DataType = typename YDistributedTensor_::DataType;
550 constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
551
552 // we need to store all data from every wave into smem
553 // e.g. 2x2 reduce along N
554 // -------------> reduce N
555 // | w0 | w1 | ___> | w01 |
556 // | w2 | w3 | | w23 |
557 //
558 // -> store data from every wave into LDS
559 //
560 //
561 // -------------> reduce N
562 // | w0 | w1 | w2 | w3 | -----> | w0123 |
563 //
564 // -> also store data from every wave into LDS
565 constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
566 return num_warps * thread_buf_size * sizeof(DataType);
567 }
568
569 // return in byte - separate shared memory size calculation for indices
570 template <typename YIndexDistributedTensor_>
572 {
573 using IndexDataType = typename YIndexDistributedTensor_::DataType;
574 constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
575 constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
576 return num_warps * thread_buf_size * sizeof(IndexDataType);
577 }
578
579 private:
580 template <bool kProcessIndex,
581 typename YDistributedTensor_,
582 typename YIndexDistributedTensor_,
583 typename ReduceFunc>
584 CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
585 YIndexDistributedTensor_& y_index_tensor,
586 void* smem,
587 void* smem_indices_ptr,
588 const ReduceFunc& reduce_func)
589 {
590 using DataType = typename YDistributedTensor_::DataType;
591 using IndexDataType = typename YIndexDistributedTensor_::DataType;
592
593 constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
594
595 DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
596 IndexDataType* smem_indices = nullptr;
597 if constexpr(kProcessIndex)
598 {
599 smem_indices = reinterpret_cast<IndexDataType*>(smem_indices_ptr);
600 }
601
602 const index_t lane_id = get_lane_id();
603 const index_t warp_id = get_warp_id();
604 constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
605 constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
606 const index_t smem_offset = warp_id;
607
608 // skip if nonthing to do
609 if constexpr(num_reduce_warps == 1)
610 return;
611
612 // store into smem only for lane-0 within one warp
613 if(lane_id == 0)
614 {
615 static_for<0, thread_buf_size, 1>{}([&](auto i) {
616 smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
617 if constexpr(kProcessIndex)
618 {
619 smem_indices[smem_offset + i * num_warps] =
620 y_index_tensor.get_thread_buffer()[i];
621 }
622 });
623 }
625
626 // load from smem. here we let everythread to do compute :)
627 index_t local_warp_id = warp_id / num_reduce_warps;
628 index_t local_smem_os = local_warp_id * num_reduce_warps;
629
630 DataType all_scratch[thread_buf_size * num_reduce_warps];
631 [[maybe_unused]] std::conditional_t<kProcessIndex,
632 IndexDataType[thread_buf_size * num_reduce_warps],
633 IndexDataType> all_indices;
634
635 // Load data from shared memory
636 static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
637 static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
638 all_scratch[i_0 * num_reduce_warps + i_1] =
639 smem_ptr[i_0 * num_warps + local_smem_os + i_1];
640
641 if constexpr(kProcessIndex)
642 {
643 all_indices[i_0 * num_reduce_warps + i_1] =
644 smem_indices[i_0 * num_warps + local_smem_os + i_1];
645 }
646 });
647 });
648 block_sync_lds(); // TODO: we don't need sync here
649
650 // Perform reduction
651 static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
652 // TODO: use descriptor for this
653 auto v_local = all_scratch[i_0 * num_reduce_warps];
654
655 IndexDataType idx_local{};
656 if constexpr(kProcessIndex)
657 {
658 idx_local = all_indices[i_0 * num_reduce_warps];
659 }
660
661 // further reduce mean/var
662 static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
663 constexpr auto i_1 = number<i_1_n1 + 1>{};
664 const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
665
666 if constexpr(kProcessIndex)
667 {
668 const IndexDataType idx_remote = all_indices[i_0 * num_reduce_warps + i_1];
669
670 bool changed = false;
671 v_local = reduce_func(v_local, v_remote, changed);
672 if(changed)
673 {
674 idx_local = idx_remote;
675 }
676 }
677 else
678 {
679 v_local = reduce_func(v_local, v_remote);
680 }
681 });
682
683 y_tensor.get_thread_buffer()(i_0) = v_local;
684 if constexpr(kProcessIndex)
685 {
686 y_index_tensor.get_thread_buffer()(i_0) = idx_local;
687 }
688 });
689 }
690
691 public:
692 template <typename YDistributedTensor_, typename ReduceFunc>
693 CK_TILE_DEVICE void
694 operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
695 {
696 reduce_impl<false>(y_tensor, y_tensor, smem, nullptr, reduce_func);
697 }
698
699 template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
700 CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
701 YIndexDistributedTensor_& y_index_tensor,
702 void* smem,
703 void* smem_indices,
704 const ReduceFunc& reduce_func)
705 {
706 reduce_impl<Problem::kOutputIndex>(
707 y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
708 }
709};
710
711} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition tile_distribution_encoding.hpp:762
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
Definition tile/core/numeric/math.hpp:462
CK_TILE_DEVICE index_t get_lane_id()
Definition arch.hpp:101
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition utility.hpp:78
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F &f, UnpacksPerXDim={})
Definition sweep_tile.hpp:231
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x)
Definition tile/core/numeric/math.hpp:455
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition reduce_operator_accumulate.hpp:41
Accumulate with index tracking reductions, provides deterministic first occurring index.
Definition reduce_operator_accumulate.hpp:12
Definition block_reduce2d.hpp:334
static CK_TILE_HOST_DEVICE constexpr index_t GetIndicesSmemSize()
Definition block_reduce2d.hpp:389
static CK_TILE_DEVICE constexpr index_t GetReduceWarps()
Definition block_reduce2d.hpp:339
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, void *smem, void *smem_indices, const ReduceFunc &reduce_func)
Definition block_reduce2d.hpp:503
remove_cvref_t< Problem_ > Problem
Definition block_reduce2d.hpp:335
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition block_reduce2d.hpp:365
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, void *smem, const ReduceFunc &reduce_func)
Definition block_reduce2d.hpp:497
typename Problem::BlockShape BlockShape
Definition block_reduce2d.hpp:336
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, const ReduceFunc &reduce_func, const IndexCalculatorFunc &index_calculator, ReducePacksPerXDim={})
Definition block_reduce2d.hpp:128
CK_TILE_DEVICE constexpr BlockReduce2d()
Definition block_reduce2d.hpp:52
typename Problem::ComputeDataType ComputeDataType
Definition block_reduce2d.hpp:50
static CK_TILE_DEVICE auto MakeYBlockTile()
Definition block_reduce2d.hpp:166
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition block_reduce2d.hpp:107
remove_cvref_t< Problem_ > Problem
Definition block_reduce2d.hpp:48
CK_TILE_DEVICE auto operator()(const XDistributedTensor_ &x_tensor, const ComputeDataType &reduce_init, const ReduceFunc &reduce_func, ReducePacksPerXDim={})
Definition block_reduce2d.hpp:208
typename Problem::XDataType XDataType
Definition block_reduce2d.hpp:49
static CK_TILE_DEVICE auto MakeYIndexBlockTile()
Definition block_reduce2d.hpp:185
Definition block_reduce2d.hpp:516
static CK_TILE_HOST_DEVICE constexpr index_t GetIndicesSmemSize()
Definition block_reduce2d.hpp:571
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, void *smem, void *smem_indices, const ReduceFunc &reduce_func)
Definition block_reduce2d.hpp:700
remove_cvref_t< Problem_ > Problem
Definition block_reduce2d.hpp:517
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, void *smem, const ReduceFunc &reduce_func)
Definition block_reduce2d.hpp:694
typename Problem::BlockShape BlockShape
Definition block_reduce2d.hpp:518
static CK_TILE_DEVICE constexpr index_t GetReduceWarps()
Definition block_reduce2d.hpp:521
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition block_reduce2d.hpp:547
Definition block_reduce2d.hpp:224
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, const ReduceFunc &reduce_func)
Definition block_reduce2d.hpp:317
CK_TILE_DEVICE void operator()(YDistributedTensor_ &y_tensor, YIndexDistributedTensor_ &y_index_tensor, const ReduceFunc &reduce_func)
Definition block_reduce2d.hpp:323
remove_cvref_t< Problem_ > Problem
Definition block_reduce2d.hpp:225
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43