device_multiple_reduce_multiblock.hpp Source File

device_multiple_reduce_multiblock.hpp Source File#

Composable Kernel: device_multiple_reduce_multiblock.hpp Source File
device_multiple_reduce_multiblock.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
11
17
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <index_t NumReduction,
25 typename InDataType,
26 typename AccDataType,
27 typename OutDataTypeTuple,
28 index_t Rank,
29 index_t NumReduceDim,
30 typename ReduceOperation,
31 typename InElementwiseOperationTuple,
32 typename AccElementwiseOperationTuple,
33 InMemoryDataOperationEnum OutMemoryDataOperation,
34 bool PropagateNan,
35 index_t BlockSize,
36 index_t MThreadClusterSize,
37 index_t KThreadClusterSize,
38 index_t MThreadSliceSize,
39 index_t KThreadSliceSize,
40 index_t InSrcVectorDim,
41 index_t InSrcVectorSize,
42 typename OutDstVectorSizeSeq>
44 NumReduceDim,
45 NumReduction,
46 InElementwiseOperationTuple,
47 AccElementwiseOperationTuple>
48{
49 static_assert(Rank <= 6, "Bigger Rank size is not supported!");
50 static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
51 "Invalid thread cluster size assignments!");
52
53 static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
54 (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
55 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
56
57 static_assert(NumReduction == OutDataTypeTuple::Size() &&
58 NumReduction == InElementwiseOperationTuple::Size() &&
59 NumReduction == AccElementwiseOperationTuple::Size() &&
60 NumReduction == OutDstVectorSizeSeq::Size(),
61 "All tuple should have the same size as the number of Reductions!");
62
63 static_assert(sequence_all_of(OutDstVectorSizeSeq{},
64 [](auto vectorSize) {
65 return (MThreadSliceSize % vectorSize == 0);
66 }),
67 "The OutDstVectorSize should completely divide the MThreadSliceSize!");
68
69 static constexpr bool CheckDataTypeTuple()
70 {
71 bool flag = true;
72
73 static_for<0, NumReduction, 1>{}([&](auto I) {
74 using OutDataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
75 flag =
76 flag && ck::reduce::InMemoryDataOperationSupportedOnDataType<OutMemoryDataOperation,
77 OutDataType>::value;
78 });
79
80 return flag;
81 };
82
83 static_assert(CheckDataTypeTuple(),
84 "The OutDataType must support the specified OutMemoryDataOperation!");
85
86 static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
87
88 static constexpr index_t NumInputDim = Rank;
89 static constexpr index_t NumOutputDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
90 static constexpr bool reduceAllDim = (NumInvariantDim == 0);
91
92 // So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
93 // later
94 static constexpr bool use_multiblock =
95 (OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd);
96
97 static_assert(
98 ReduceOperation::IsCompatibleInMemoryDataOperation(OutMemoryDataOperation),
99 "The reduction accumulation operation must be compatible with the OutMemoryDataOperation!");
100
101 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
102 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
103
105 {
106 return generate_tuple(
107 [&](auto I) {
108 using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
109
110 return static_cast<DataType*>(nullptr);
111 },
113 };
114
116
117 static auto MakeSrc2dDescriptor(const std::array<index_t, NumInputDim>& inLengths,
118 const std::array<index_t, NumInputDim>& inStrides,
119 int blkGroupSize,
120 int numBlockTileIteration)
121 {
122 const auto tupleSrcLengths =
123 generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInputDim>{});
124 const auto tupleSrcStrides =
125 generate_tuple([&](auto I) { return inStrides[I]; }, Number<NumInputDim>{});
126
127 const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
128
129 const auto in_grid_desc_m_k = [&]() {
130 if constexpr(reduceAllDim)
131 {
132 const auto one_dim_inDesc = transform_tensor_descriptor(
133 inDesc,
134 make_tuple(make_merge_transform(tupleSrcLengths)),
137
138 return transform_tensor_descriptor(one_dim_inDesc,
140 1, one_dim_inDesc.GetLength(Number<0>{})))),
143 }
144 else
145 {
146 using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
148
149 const auto reduceDimLengths = generate_tuple(
150 [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
151 const auto invariantDimLengths =
152 generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
153
155 inDesc,
156 make_tuple(make_merge_transform(invariantDimLengths),
157 make_merge_transform(reduceDimLengths)),
158 make_tuple(InvariantDims{}, ReduceDims{}),
160 }
161 }();
162
163 const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
164 const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
165
166 const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
167 const auto inPad_M =
168 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
169 const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
170
171 auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
172 in_grid_desc_m_k,
173 make_tuple(make_right_pad_transform(invariantLength, inPad_M),
174 make_right_pad_transform(reduceLength, inPad_K)),
177
178 return (in_grid_desc_m_k_padded);
179 };
180
181 static auto MakeDst1dDescriptor(const std::array<index_t, NumOutputDim>& outLengths,
182 const std::array<index_t, NumOutputDim>& outStrides)
183 {
184 const auto tupleDstLengths =
185 generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumOutputDim>{});
186 const auto tupleDstStrides =
187 generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumOutputDim>{});
188
189 auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
190
191 auto out_grid_desc_m = transform_tensor_descriptor(
192 outDesc,
193 make_tuple(make_merge_transform(tupleDstLengths)),
196
197 const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
198
199 const auto outPad =
200 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
201
202 auto out_grid_desc_m_padded = transform_tensor_descriptor(
203 out_grid_desc_m,
204 make_tuple(make_right_pad_transform(invariantLength, outPad)),
207 return (out_grid_desc_m_padded);
208 };
209
211 {
212 return generate_tuple(
213 [&](auto I) {
214 (void)I;
215 return MakeDst1dDescriptor(std::array<index_t, NumOutputDim>{},
216 std::array<index_t, NumOutputDim>{});
217 },
219 };
220
222 std::array<index_t, NumInputDim>{}, std::array<index_t, NumInputDim>{}, 1, 1));
224
225 static auto MakeDst1dDescriptorForBufferSet(const std::array<index_t, NumOutputDim>& outLengths,
226 const std::array<index_t, NumOutputDim>& outStrides)
227 {
228 const auto tupleDstLengths =
229 generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumOutputDim>{});
230 const auto tupleDstStrides =
231 generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumOutputDim>{});
232
233 auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
234
235 auto out_grid_desc_m = transform_tensor_descriptor(
236 outDesc,
237 make_tuple(make_merge_transform(tupleDstLengths)),
240
241 const auto length = out_grid_desc_m.GetLength(Number<0>{});
242
243 const auto pad = math::integer_least_multiple(length, BlockSize) - length;
244
245 auto out_grid_desc_m_padded =
246 transform_tensor_descriptor(out_grid_desc_m,
250 return (out_grid_desc_m_padded);
251 };
252
254 {
255 return generate_tuple(
256 [&](auto I) {
257 (void)I;
258 return MakeDst1dDescriptorForBufferSet(std::array<index_t, NumOutputDim>{},
259 std::array<index_t, NumOutputDim>{});
260 },
262 };
263
265
266 struct Argument : public BaseArgument
267 {
268 Argument(const std::array<index_t, NumInputDim>& inLengths,
269 const std::array<index_t, NumInputDim>& inStrides,
270 const std::array<index_t, NumOutputDim>& outLengths,
271 const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
272 const std::array<int, NumReduceDim>& reduceDims,
273 const std::array<double, NumReduction>& alphas,
274 const std::array<double, NumReduction>& betas,
275 const void* in_dev,
276 const std::array<void*, NumReduction>& out_dev_buffers,
277 const InElementwiseOperationTuple in_elementwise_op_tuple,
278 const AccElementwiseOperationTuple acc_elementwise_op_tuple)
279 : outLengths_{outLengths},
280 outStridesArray_{outStridesArray},
281 in_elementwise_op_tuple_{in_elementwise_op_tuple},
282 acc_elementwise_op_tuple_{acc_elementwise_op_tuple}
283 {
286
287 for(size_t i = 0; i < NumReduction; i++)
288 {
289 alpha_values_(i) = static_cast<AccDataType>(alphas[i]);
290 beta_values_(i) = static_cast<AccDataType>(betas[i]);
291 };
292
293 in_dev_ = static_cast<const InDataType*>(in_dev);
294
296 [&](auto iR) {
297 using OutDataTypePointer =
300 return static_cast<OutDataType*>(out_dev_buffers[iR]);
301 },
303
306
307 if constexpr(use_multiblock)
308 {
309
310 int iterations = 1;
311 while(true)
312 {
313 int testBlkGroupSize =
314 (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
315 (K_BlockTileSize * iterations);
316
317 // we want the blkGroupSize be not more than 128
318 if(testBlkGroupSize <= 128)
319 break;
320
321 iterations++;
322 };
323
324 blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
325 (K_BlockTileSize * iterations);
326
327 numBlockTileIteration = iterations;
328 }
329 else
330 {
331 blkGroupSize = 1;
334 };
335
338
340 [&](auto I) { return MakeDst1dDescriptor(outLengths, outStridesArray[I]); },
342
344 [&](auto I) {
345 return MakeDst1dDescriptorForBufferSet(outLengths, outStridesArray[I]);
346 },
348
351
354 }
355
356 std::array<index_t, NumInputDim> inLengths_;
357 std::array<index_t, NumInputDim> inStrides_;
358
359 std::array<index_t, NumOutputDim> outLengths_;
360 std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray_;
361
364
365 const InDataType* in_dev_;
367
371
372 InElementwiseOperationTuple in_elementwise_op_tuple_;
373 AccElementwiseOperationTuple acc_elementwise_op_tuple_;
374
377
380 size_t gridSize;
381
383 };
384
385 struct Invoker : public BaseInvoker
386 {
387 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
388 {
389 using GridwiseMultipleReduce =
391 InDataType,
393 AccDataType,
396 ReduceOperation,
397 InElementwiseOperationTuple,
398 AccElementwiseOperationTuple,
399 OutMemoryDataOperation,
400 PropagateNan,
401 BlockSize,
402 MThreadClusterSize,
403 KThreadClusterSize,
404 MThreadSliceSize,
405 KThreadSliceSize,
406 InSrcVectorDim,
407 InSrcVectorSize,
408 OutDstVectorSizeSeq>;
409
410 const auto kernel_main =
411 kernel_multiple_reduce_multiblock<GridwiseMultipleReduce,
412 NumReduction,
413 InDataType,
415 AccDataType,
418 InElementwiseOperationTuple,
419 AccElementwiseOperationTuple>;
420
421 float avg_time = 0;
422
423 if constexpr(use_multiblock)
424 {
425 auto identity_values = generate_tuple(
426 [&](auto iR) {
427 using OutDataType = remove_cvref_t<decltype(OutDataTypeTuple{}[iR])>;
429 OutMemoryDataOperation);
430 },
432
434 NumReduction,
435 BlockSize,
437 OutDataTypeTuple>;
438
439 avg_time += launch_and_time_kernel(stream_config,
440 kernel_pre,
441 dim3(arg.gridSize_pre),
442 dim3(BlockSize),
443 0,
446 identity_values);
447 };
448
449 avg_time += launch_and_time_kernel(stream_config,
450 kernel_main,
451 dim3(arg.gridSize),
452 dim3(BlockSize),
453 0,
458 arg.blkGroupSize,
460 arg.alpha_values_,
461 arg.in_dev_,
462 arg.beta_values_,
463 arg.out_dev_buffers_);
464
465 return (avg_time);
466 };
467
468 float Run(const BaseArgument* p_arg,
469 const StreamConfig& stream_config = StreamConfig{}) override
470 {
471 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
472 };
473 };
474
475 bool IsSupportedArgument(const BaseArgument* p_arg) override
476 {
477 const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
478
479 if constexpr(use_multiblock)
480 {
481 for(size_t i = 0; i < pArg->beta_values_.Size(); i++)
482 if(pArg->beta_values_[i] != 0.0f)
483 return (false);
484 };
485
486 if constexpr(InSrcVectorDim == 0)
487 {
488 if constexpr(NumInvariantDim == 0)
489 {
490 return (false);
491 }
492 else
493 {
494 if(pArg->inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
495 return (false);
496
497 if(pArg->inLengths_[NumInvariantDim - 1] % InSrcVectorSize != 0)
498 return (false);
499 };
500 }
501 else
502 {
503 if(pArg->inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1)
504 return (false);
505
506 if(pArg->inLengths_[Rank - 1] % InSrcVectorSize != 0)
507 return (false);
508 };
509 // To improve
510 bool valid = true;
511 static_for<0, NumReduction, 1>{}([&](auto I) {
512 if(pArg->outStridesArray_[I.value][NumOutputDim - 1] != 1 &&
513 OutDstVectorSizeSeq::At(I) != 1)
514 valid = false;
515
516 if(pArg->outLengths_[NumOutputDim - 1] % OutDstVectorSizeSeq::At(I) != 0)
517 valid = false;
518 });
519
520 if(!valid)
521 return (false);
522
523 if constexpr(use_multiblock)
524 {
525 // blkGroupSize of 1 should be handled by Blockwise path using
526 // InMemoryDataOperationEnum::Set
527 if(pArg->blkGroupSize == 1)
528 return (false);
529
530 // This is very strong restriction, but needed to avoid some failure
531 if(pArg->outLengths_[NumOutputDim - 1] % M_BlockTileSize != 0)
532 return (false);
533 }
534 else
535 {
536 // cases with very small reduce_total_length should be handled by ThreadWise kernel
537 if(pArg->reduce_total_length / KThreadSliceSize < 2)
538 return (false);
539 };
540
541 return (true);
542 };
543
544 std::unique_ptr<BaseArgument> MakeArgumentPointer(
545 const std::array<index_t, NumInputDim> inLengths,
546 const std::array<index_t, NumInputDim> inStrides,
547 const std::array<index_t, NumOutputDim> outLengths,
548 const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
549 const std::array<int, NumReduceDim> reduceDims,
550 const std::array<double, NumReduction> alphas,
551 const std::array<double, NumReduction> betas,
552 const void* in_dev,
553 const std::array<void*, NumReduction> out_dev_buffers,
554 const InElementwiseOperationTuple in_elementwise_op_tuple,
555 const AccElementwiseOperationTuple acc_elementwise_op_tuple) override
556 {
557 return std::make_unique<Argument>(inLengths,
558 inStrides,
559 outLengths,
560 outStridesArray,
561 reduceDims,
562 alphas,
563 betas,
564 in_dev,
565 out_dev_buffers,
566 in_elementwise_op_tuple,
567 acc_elementwise_op_tuple);
568 };
569
570 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
571 {
572 return std::make_unique<Invoker>();
573 };
574
575 std::string GetTypeString() const override
576 {
577 auto str = std::stringstream();
578
579 // clang-format off
580 str << (OutMemoryDataOperation == InMemoryDataOperationEnum::Set? "DeviceMultipleReduceBlockWise<" : "DeviceMultipleReduceMultiBlock<") << BlockSize << ",";
581 str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
582 str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
583 str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << ",";
584 str << "OutDstVectorSize";
585 static_for<0, OutDstVectorSizeSeq::Size(), 1>{}([&](auto I) {str << "_" << OutDstVectorSizeSeq::At(I); });
586 str << ">";
587 // clang-format on
588
589 return str.str();
590 }
591};
592
593} // namespace device
594} // namespace tensor_operation
595} // namespace ck
auto pad(ck::index_t mpb, ck::index_t npb, ck::index_t kpb, ck::tensor_operation::device::GemmSpecialization gemm, CDesc_MRaw_NRaw conv)
Definition helper.hpp:70
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
Definition reduction_operator.hpp:473
Definition convolution_backward_data_specialization.hpp:8
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_multiple_buffer_set_value(const Grid1dBufferDescTuple grid_1d_buffer_desc_tuple, DataTypePointerTuple p_global_tuple, DataTypeTuple value_tuple)
Definition gridwise_set_multiple_buffer_value.hpp:17
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_multiple_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M_Tuple out_grid_desc_m_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple, index_t block_group_size, index_t num_k_block_tile_iteration, Array< AccDataType, NumReduction > alpha_values, const InDataType *const __restrict__ p_in_value_global, Array< AccDataType, NumReduction > beta_values, OutDataTypePointerTuple p_out_value_global_tuple)
Definition gridwise_2d_multiple_reduction_multiblock.hpp:26
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr bool sequence_all_of(Seq, F f)
Definition utility/sequence.hpp:912
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition ck/stream_config.hpp:10
Definition utility/array.hpp:14
__host__ static __device__ constexpr index_t Size()
Definition utility/array.hpp:20
Definition gridwise_2d_multiple_reduction_multiblock.hpp:69
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition reduction_operator.hpp:485
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_multiple_reduce.hpp:25
Definition device_multiple_reduce_multiblock.hpp:267
OutGridDesc_M_Tuple_2 out_grid_desc_m_tuple_2
Definition device_multiple_reduce_multiblock.hpp:370
std::array< index_t, NumInputDim > inLengths_
Definition device_multiple_reduce_multiblock.hpp:356
InGridDesc_M_K in_grid_desc_m_k
Definition device_multiple_reduce_multiblock.hpp:368
long_index_t invariant_total_length
Definition device_multiple_reduce_multiblock.hpp:375
Array< AccDataType, NumReduction > beta_values_
Definition device_multiple_reduce_multiblock.hpp:363
long_index_t reduce_total_length
Definition device_multiple_reduce_multiblock.hpp:376
size_t gridSize_pre
Definition device_multiple_reduce_multiblock.hpp:382
int blkGroupSize
Definition device_multiple_reduce_multiblock.hpp:378
int numBlockTileIteration
Definition device_multiple_reduce_multiblock.hpp:379
const InDataType * in_dev_
Definition device_multiple_reduce_multiblock.hpp:365
size_t gridSize
Definition device_multiple_reduce_multiblock.hpp:380
OutGridDesc_M_Tuple out_grid_desc_m_tuple
Definition device_multiple_reduce_multiblock.hpp:369
OutDataTypePointerTuple out_dev_buffers_
Definition device_multiple_reduce_multiblock.hpp:366
Argument(const std::array< index_t, NumInputDim > &inLengths, const std::array< index_t, NumInputDim > &inStrides, const std::array< index_t, NumOutputDim > &outLengths, const std::array< std::array< index_t, NumOutputDim >, NumReduction > &outStridesArray, const std::array< int, NumReduceDim > &reduceDims, const std::array< double, NumReduction > &alphas, const std::array< double, NumReduction > &betas, const void *in_dev, const std::array< void *, NumReduction > &out_dev_buffers, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple)
Definition device_multiple_reduce_multiblock.hpp:268
std::array< std::array< index_t, NumOutputDim >, NumReduction > outStridesArray_
Definition device_multiple_reduce_multiblock.hpp:360
std::array< index_t, NumOutputDim > outLengths_
Definition device_multiple_reduce_multiblock.hpp:359
std::array< index_t, NumInputDim > inStrides_
Definition device_multiple_reduce_multiblock.hpp:357
InElementwiseOperationTuple in_elementwise_op_tuple_
Definition device_multiple_reduce_multiblock.hpp:372
Array< AccDataType, NumReduction > alpha_values_
Definition device_multiple_reduce_multiblock.hpp:362
AccElementwiseOperationTuple acc_elementwise_op_tuple_
Definition device_multiple_reduce_multiblock.hpp:373
Definition device_multiple_reduce_multiblock.hpp:386
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_multiple_reduce_multiblock.hpp:387
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_multiple_reduce_multiblock.hpp:468
Definition device_multiple_reduce_multiblock.hpp:48
static constexpr index_t NumInvariantDim
Definition device_multiple_reduce_multiblock.hpp:86
decltype(GenerateOutGrid1dDescTuple_2()) OutGridDesc_M_Tuple_2
Definition device_multiple_reduce_multiblock.hpp:264
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_multiple_reduce_multiblock.hpp:475
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_multiple_reduce_multiblock.hpp:570
static constexpr bool use_multiblock
Definition device_multiple_reduce_multiblock.hpp:94
static auto GenerateOutGrid1dDescTuple()
Definition device_multiple_reduce_multiblock.hpp:210
static auto GenerateOutGrid1dDescTuple_2()
Definition device_multiple_reduce_multiblock.hpp:253
static constexpr bool CheckDataTypeTuple()
Definition device_multiple_reduce_multiblock.hpp:69
static auto MakeDst1dDescriptor(const std::array< index_t, NumOutputDim > &outLengths, const std::array< index_t, NumOutputDim > &outStrides)
Definition device_multiple_reduce_multiblock.hpp:181
std::string GetTypeString() const override
Definition device_multiple_reduce_multiblock.hpp:575
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, NumInputDim > inLengths, const std::array< index_t, NumInputDim > inStrides, const std::array< index_t, NumOutputDim > outLengths, const std::array< std::array< index_t, NumOutputDim >, NumReduction > outStridesArray, const std::array< int, NumReduceDim > reduceDims, const std::array< double, NumReduction > alphas, const std::array< double, NumReduction > betas, const void *in_dev, const std::array< void *, NumReduction > out_dev_buffers, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple) override
Definition device_multiple_reduce_multiblock.hpp:544
decltype(GenerateOutGrid1dDescTuple()) OutGridDesc_M_Tuple
Definition device_multiple_reduce_multiblock.hpp:223
static constexpr index_t K_BlockTileSize
Definition device_multiple_reduce_multiblock.hpp:102
static constexpr bool reduceAllDim
Definition device_multiple_reduce_multiblock.hpp:90
static auto MakeSrc2dDescriptor(const std::array< index_t, NumInputDim > &inLengths, const std::array< index_t, NumInputDim > &inStrides, int blkGroupSize, int numBlockTileIteration)
Definition device_multiple_reduce_multiblock.hpp:117
static auto GenerateOutDataTypePointerTuple()
Definition device_multiple_reduce_multiblock.hpp:104
static constexpr index_t NumInputDim
Definition device_multiple_reduce_multiblock.hpp:88
static constexpr index_t M_BlockTileSize
Definition device_multiple_reduce_multiblock.hpp:101
decltype(GenerateOutDataTypePointerTuple()) OutDataTypePointerTuple
Definition device_multiple_reduce_multiblock.hpp:115
static constexpr index_t NumOutputDim
Definition device_multiple_reduce_multiblock.hpp:89
static auto MakeDst1dDescriptorForBufferSet(const std::array< index_t, NumOutputDim > &outLengths, const std::array< index_t, NumOutputDim > &outStrides)
Definition device_multiple_reduce_multiblock.hpp:225
decltype(MakeSrc2dDescriptor( std::array< index_t, NumInputDim >{}, std::array< index_t, NumInputDim >{}, 1, 1)) InGridDesc_M_K
Definition device_multiple_reduce_multiblock.hpp:221