threadwise_contraction_dl.hpp Source File

threadwise_contraction_dl.hpp Source File#

Composable Kernel: threadwise_contraction_dl.hpp Source File
threadwise_contraction_dl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck/utility/math.hpp"
8
9namespace ck {
10
11// C[TM0, TM1, TN0, TN1] += A[TK, TM0, TM1] * B[TK, TN0, TN1]
12// Tensor element can be vectorized data
13// Assume:
14// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are
15// known at compile-time
16// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
17template <typename FloatA,
18 typename FloatB,
19 typename FloatC,
20 typename AThreadDesc_TK0_TM0_TM1_TK1,
21 typename BThreadDesc_TK0_TN0_TN1_TK1,
22 typename CThreadDesc_TM0_TM1_TN0_TN1,
23 typename TKLengths,
24 typename TMLengths,
25 typename TNLengths,
26 typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
27 BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
28 CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
29 bool>::type = false>
31{
33 {
34 static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
35 BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
36 CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
37 "wrong! Desc should be known at compile-time");
38
39 // TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1,
40 // CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths
41
42 // TODO remove this restriction
43 static_assert(TKLengths::Size() == 1 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
44 "wrong!");
45 }
46
47 template <typename ABuffer,
48 typename AOriginIdx,
49 typename BBuffer,
50 typename BOriginIdx,
51 typename CBuffer,
52 typename COriginIdx>
53 __device__ static void Run(const ABuffer& a_buf,
54 AOriginIdx,
55 const BBuffer& b_buf,
56 BOriginIdx,
57 CBuffer& c_buf,
58 COriginIdx)
59 {
63 "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
64
65 static_assert(
69 "wrong! inconsistent type");
70
71 constexpr auto I0 = Number<0>{};
72 constexpr auto I1 = Number<1>{};
73
74 constexpr auto TK = TKLengths{}[I0];
75 constexpr auto TM0 = TMLengths{}[I0];
76 constexpr auto TM1 = TMLengths{}[I1];
77 constexpr auto TN0 = TNLengths{}[I0];
78 constexpr auto TN1 = TNLengths{}[I1];
79
80 constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
81 constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
82 constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
83
84 static_for<0, TK, 1>{}([&](auto tk) {
85 static_for<0, TM0, 1>{}([&](auto tm0) {
86 static_for<0, TM1, 1>{}([&](auto tm1) {
87 static_for<0, TN0, 1>{}([&](auto tn0) {
88 static_for<0, TN1, 1>{}([&](auto tn1) {
89 constexpr index_t a_offset =
90 AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
91 a_origin_idx + make_multi_index(tk, tm0, tm1));
92 constexpr index_t b_offset =
93 BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
94 b_origin_idx + make_multi_index(tk, tn0, tn1));
95 constexpr index_t c_offset =
96 CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
97 c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
98
100 b_buf[Number<b_offset>{}],
101 c_buf(Number<c_offset>{}));
102 });
103 });
104 });
105 });
106 });
107 }
108};
109
110// C[TM0, TM1, TN0, TN1] += A[TK0, TM0, TM1, TK1] * B[TK0, TN0, TN1, TK1]
111// Tensor element can be vectorized data
112// Assume:
113// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are
114// known at compile-time
115// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
116template <typename FloatA,
117 typename FloatB,
118 typename FloatC,
119 typename AThreadDesc_TK0_TM0_TM1_TK1,
120 typename BThreadDesc_TK0_TN0_TN1_TK1,
121 typename CThreadDesc_TM0_TM1_TN0_TN1,
122 typename TKLengths,
123 typename TMLengths,
124 typename TNLengths,
125 typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
126 BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
127 CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
128 bool>::type = false>
130{
132 {
133 static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
134 BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
135 CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
136 "wrong! Desc should be known at compile-time");
137
138 // TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1,
139 // CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths
140
141 // TODO remove this restriction
142 static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
143 "wrong!");
144 }
145
146 template <typename ABuffer,
147 typename AOriginIdx,
148 typename BBuffer,
149 typename BOriginIdx,
150 typename CBuffer,
151 typename COriginIdx>
152 __device__ static void Run(const ABuffer& a_buf,
153 AOriginIdx,
154 const BBuffer& b_buf,
155 BOriginIdx,
156 CBuffer& c_buf,
157 COriginIdx)
158 {
162 "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
163
164 static_assert(
168 "wrong! inconsistent type");
169
170 constexpr auto I0 = Number<0>{};
171 constexpr auto I1 = Number<1>{};
172
173 constexpr index_t TK0 = TKLengths{}[I0];
174 constexpr index_t TK1 = TKLengths{}[I1];
175 constexpr index_t TM0 = TMLengths{}[I0];
176 constexpr index_t TM1 = TMLengths{}[I1];
177 constexpr index_t TN0 = TNLengths{}[I0];
178 constexpr index_t TN1 = TNLengths{}[I1];
179
180 constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
181 constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
182 constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
183
184 static_for<0, TK0, 1>{}([&](auto tk0) {
185 static_for<0, TM0, 1>{}([&](auto tm0) {
186 static_for<0, TM1, 1>{}([&](auto tm1) {
187 static_for<0, TN0, 1>{}([&](auto tn0) {
188 static_for<0, TN1, 1>{}([&](auto tn1) {
191
192 static_for<0, TK1, 1>{}([&](auto tk1) {
193 constexpr index_t a_offset =
194 AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
195 a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1));
196
197 constexpr index_t b_offset =
198 BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
199 b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1));
200
201 a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
202 b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
203 });
204
205 using a_vector_t = typename vector_type<FloatA, TK1>::type;
206 using b_vector_t = typename vector_type<FloatB, TK1>::type;
207
208 constexpr index_t c_offset =
209 CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
210 c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
211
213 a_vec.template AsType<a_vector_t>()[I0],
214 b_vec.template AsType<b_vector_t>()[I0],
215 c_buf(Number<c_offset>{}));
216 });
217 });
218 });
219 });
220 });
221 }
222};
223
224} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
__host__ __device__ constexpr auto to_multi_index(const T &x)
Definition array_multi_index.hpp:28
__device__ void inner_product(const TA &a, const TB &b, TC &c)
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
__device__ constexpr ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
Definition threadwise_contraction_dl.hpp:131
static __device__ void Run(const ABuffer &a_buf, AOriginIdx, const BBuffer &b_buf, BOriginIdx, CBuffer &c_buf, COriginIdx)
Definition threadwise_contraction_dl.hpp:152
static __device__ void Run(const ABuffer &a_buf, AOriginIdx, const BBuffer &b_buf, BOriginIdx, CBuffer &c_buf, COriginIdx)
Definition threadwise_contraction_dl.hpp:53
__device__ constexpr ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1()
Definition threadwise_contraction_dl.hpp:32
Definition is_known_at_compile_time.hpp:14
Definition type.hpp:177
Definition functional2.hpp:33
Definition dtype_vector.hpp:10