29template <
typename Shape,
typename UnrolledDescriptorType>
33using is_tuple =
decltype(std::declval<T&>().IsTuple());
43template <
typename... Ts>
44__host__ __device__
constexpr static auto
45GenerateColumnMajorPackedStrides(
const Tuple<Ts...>&
shape)
50 if constexpr(i.value == 0)
56 return TupleReduce<Number<0>{}.value, i.value>([](
auto x,
auto y) {
return x * y; },
60 Number<
decltype(unrolled_shape)::Size()>{});
70template <
typename LayoutShape,
typename LayoutStr
ides>
71__host__ __device__
constexpr auto MakeUnrolledDescriptor(
const LayoutShape&
shape,
72 const LayoutStrides& strides)
75 if constexpr(is_same_v<LayoutStrides, Tuple<>>)
78 const auto unrolled_strides = GenerateColumnMajorPackedStrides(unrolled_shape);
79 static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
80 "Size of strides and shape are not consistent.");
86 static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
87 "Size of strides and shape are not consistent.");
104template <
typename Shape,
typename Str
ides>
107 using UnrolledDescriptorType =
decltype(detail::MakeUnrolledDescriptor(Shape{}, Strides{}));
109 detail::MakeUnrolledDescriptor(
shape, strides));
119template <
typename Shape>
122 using UnrolledDescriptorType =
decltype(detail::MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
124 detail::MakeUnrolledDescriptor(
shape, Tuple<>{}));
136__host__ __device__ T
constexpr get(
const T& dim)
148template <
index_t idx,
typename... Dims>
149__host__ __device__
constexpr auto get(
const Tuple<Dims...>& tuple)
151 return tuple.At(Number<idx>{});
161template <index_t
idx,
typename Shape,
typename UnrolledDesc>
165 const auto new_shape = get<idx>(
shape);
166 static_assert(is_detected<is_tuple,
decltype(new_shape)>
::value,
167 "Shape of sub layout must be tuple");
169 constexpr auto old_shape_dims =
decltype(UnrollNestedTuple(
shape))::Size();
170 constexpr auto new_shape_dims =
decltype(UnrollNestedTuple(new_shape))::Size();
171 constexpr auto shape_offset =
decltype(UnrollNestedTuple(TupleSlice<0, idx>(
shape)))::Size();
173 const auto unrolled_shape = UnrollNestedTuple(
shape);
174 const auto transforms = generate_tuple(
177 if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
180 return make_freeze_transform(Number<0>{});
184 return make_pass_through_transform(unrolled_shape.At(i));
187 Number<old_shape_dims>{});
189 const auto lower_dims =
190 generate_tuple([&](
auto i) {
return Sequence<i.value>{}; }, Number<old_shape_dims>{});
191 const auto upper_dims = generate_tuple(
193 if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
198 return Sequence<i.value - shape_offset>{};
201 Number<old_shape_dims>{});
203 const auto& flatten_desc =
layout.GetUnrolledDescriptor();
204 auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
216__host__ __device__
constexpr auto get(
const T& elem)
218 return get<Idxs...>(get<Idx>(elem));
230__host__ __device__ T
constexpr size(
const T& dim)
242template <index_t
idx,
typename Shape,
typename UnrolledDescriptorType>
245 return layout.template GetLength<idx>();
254template <
typename... ShapeDims>
255__host__ __device__
constexpr auto size(
const Tuple<ShapeDims...>&
shape)
257 const auto unrolled_shape = UnrollNestedTuple(
shape);
258 return TupleReduce<0, unrolled_shape.Size()>([](
auto x,
auto y) {
return x * y; },
268template <
typename Shape,
typename UnrolledDescriptorType>
271 return layout.GetLengths();
281template <
index_t idx,
typename... Ts>
282__host__ __device__
constexpr auto size(
const Tuple<Ts...>& tuple)
284 return size(tuple.At(Number<idx>{}));
296__host__ __device__
constexpr auto size(
const T& elem)
298 return size(get<Idx, Idxs...>(elem));
308template <
typename Shape,
typename UnrolledDescriptorType>
309__host__ __device__
constexpr auto
312 return Shape::Size();
322template <
typename... Dims>
323__host__ __device__
constexpr auto rank([[maybe_unused]]
const Tuple<Dims...>& tuple)
325 return Tuple<Dims...>::Size();
335template <index_t IDim>
336__host__ __device__
constexpr index_t rank([[maybe_unused]]
const Number<IDim>& dim)
348__host__ __device__
constexpr index_t rank([[maybe_unused]]
const index_t& dim) {
return 1; }
357template <
index_t... Idxs,
typename T>
358__host__ __device__
constexpr auto rank(
const T& elem)
360 return rank(get<Idxs...>(elem));
370template <
typename Shape,
typename UnrolledDescriptorType>
374 return TupleDepth(
shape);
383template <
typename... Dims>
384__host__ __device__
constexpr auto depth(
const Tuple<Dims...>& tuple)
386 return TupleDepth(tuple);
396template <index_t IDim>
397__host__ __device__
constexpr index_t depth([[maybe_unused]]
const Number<IDim>& dim)
409__host__ __device__
constexpr index_t depth([[maybe_unused]]
const index_t& dim) {
return 0; }
418template <
index_t... Idxs,
typename T>
419__host__ __device__
constexpr auto depth(
const T& elem)
421 return depth(get<Idxs...>(elem));
430template <
typename LayoutType>
431__host__ __device__
constexpr const auto&
shape(
const LayoutType&
layout)
445template <
typename Shape,
typename UnrolledDesc,
typename TileLengths>
447 const TileLengths& tile_lengths)
449 auto& unrolled_desc =
layout.GetUnrolledDescriptor();
451 constexpr auto do_pads_seq =
452 generate_sequence_v2([](
auto) {
return Number<1>{}; }, Number<Shape::Size()>{});
455 tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
457 const auto padded_shape = generate_tuple(
458 [&](
auto i) {
return padded_desc.GetLength(Number<i>{}); }, Number<TileLengths::Size()>{});
473template <index_t Idx,
typename Shape,
typename UnrolledDesc,
typename NewLengths,
typename NewIdxs>
475 const NewLengths& new_lengths,
476 [[maybe_unused]]
const NewIdxs& new_indexes)
479 auto& unrolled_desc =
layout.GetUnrolledDescriptor();
480 constexpr auto dims = Shape::Size();
482 const auto transforms = generate_tuple(
484 if constexpr(i == Idx)
486 return make_unmerge_transform(new_lengths);
490 return make_pass_through_transform(layout_shape.At(i));
495 constexpr auto lower_dims =
496 generate_tuple([&](
auto i) {
return Sequence<i.value>{}; }, Number<dims>{});
497 constexpr auto upper_dims = generate_tuple(
499 if constexpr(is_detected<is_tuple, tuple_element_t<i.value, NewIdxs>>
::value)
501 constexpr auto idxs_tuple = tuple_element_t<i.value, NewIdxs>{};
502 return to_sequence(idxs_tuple);
506 constexpr index_t index = tuple_element_t<i.value, NewIdxs>{};
507 return Sequence<index>{};
512 const auto unmerged_desc =
513 transform_tensor_descriptor(unrolled_desc, transforms, lower_dims, upper_dims);
514 const auto unmerged_shape =
515 generate_tuple([&](
auto i) {
return unmerged_desc.GetLength(Number<i>{}); },
516 Number<
decltype(unmerged_desc)::GetNumOfVisibleDimension()>{});
auto pad(ck::index_t mpb, ck::index_t npb, ck::index_t kpb, ck::tensor_operation::device::GemmSpecialization gemm, CDesc_MRaw_NRaw conv)
Definition helper.hpp:70
__host__ __device__ constexpr const auto & shape(const LayoutType &layout)
Get Layout shape.
Definition layout_utils.hpp:431
__host__ __device__ constexpr auto depth(const Layout< Shape, UnrolledDescriptorType > &layout)
Get depth of the layout shape (return 0 if scalar).
Definition layout_utils.hpp:371
__host__ __device__ constexpr auto unmerge(const Layout< Shape, UnrolledDesc > &layout, const NewLengths &new_lengths, const NewIdxs &new_indexes)
Unmerge selected dim in layout.
Definition layout_utils.hpp:474
__host__ __device__ constexpr auto make_layout(const Shape &shape, const Strides &strides)
Make layout function.
Definition layout_utils.hpp:105
__host__ __device__ constexpr auto rank(const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition layout_utils.hpp:310
decltype(std::declval< T & >().IsTuple()) is_tuple
Definition device_grouped_conv_fwd_multiple_abd.hpp:23
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<> &element)
Definition tuple_helper.hpp:120
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Layout wrapper that performs the tensor descriptor logic.
Definition layout.hpp:24
__host__ __device__ constexpr const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition tensor_utils.hpp:162