reference_batched_contraction.hpp Source File

reference_batched_contraction.hpp Source File#

Composable Kernel: reference_batched_contraction.hpp Source File
reference_batched_contraction.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <cstdlib>
7#include <functional>
8#include <numeric>
9#include <thread>
10
11#include "ck_tile/core.hpp"
13
14namespace ck_tile {
15
16template <typename ADataType,
17 typename BDataType,
18 typename DDataType,
19 typename EDataType,
20 typename AccDataType,
21 typename CDEElementWise>
22
24 const ck_tile::HostTensor<ADataType>& a_full_dims,
25 const ck_tile::HostTensor<BDataType>& b_full_dims,
26 const std::vector<ck_tile::HostTensor<DDataType>>& ds_full_dims_host,
27 ck_tile::HostTensor<EDataType>& e_full_dims_host_ref,
28 ck_tile::index_t G_total,
29 ck_tile::index_t M_total,
30 ck_tile::index_t N_total,
31 ck_tile::index_t K_total,
32 const CDEElementWise& cde_elementwise)
33{
34 std::cout << "Calculating reference using optimized flat indexing with parallel processing..."
35 << std::endl;
36
37 // Parallel computation over G and M dimensions using pattern from reference_batched_gemm.hpp
38 auto f_gm = [&](auto g_flat, auto m_flat) {
39 for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat)
40 {
41 AccDataType sum = 0;
42
43 // Compute dot product over K dimension
44 for(ck_tile::index_t k_flat = 0; k_flat < K_total; ++k_flat)
45 {
46 auto a_val =
47 a_full_dims.mData[g_flat * M_total * K_total + m_flat * K_total + k_flat];
48 auto b_val =
49 b_full_dims.mData[g_flat * N_total * K_total + n_flat * K_total + k_flat];
50 sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
51 }
52
53 // Apply elementwise operation with D tensors
54 EDataType result = static_cast<EDataType>(sum);
55 if(ds_full_dims_host.size() == 0)
56 {
57 ;
58 }
59 else if(ds_full_dims_host.size() == 1)
60 {
61 cde_elementwise(result,
64 ds_full_dims_host[0].mData[g_flat * M_total * N_total +
65 m_flat * N_total + n_flat]));
66 }
67 else if(ds_full_dims_host.size() == 2)
68 {
69 cde_elementwise(
70 result,
73 ds_full_dims_host[0]
74 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
76 ds_full_dims_host[1]
77 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
78 }
79 else if(ds_full_dims_host.size() == 3)
80 {
81 cde_elementwise(
82 result,
85 ds_full_dims_host[0]
86 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
88 ds_full_dims_host[1]
89 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
91 ds_full_dims_host[2]
92 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
93 }
94 else if(ds_full_dims_host.size() == 4)
95 {
96 cde_elementwise(
97 result,
100 ds_full_dims_host[0]
101 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
103 ds_full_dims_host[1]
104 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
106 ds_full_dims_host[2]
107 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
109 ds_full_dims_host[3]
110 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
111 }
112 else
113 {
114 throw std::runtime_error("Unsupported NumDTensor for reference calculation");
115 }
116
117 // Store result
118 e_full_dims_host_ref.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] =
119 static_cast<EDataType>(result);
120 }
121 };
122
123 // Execute parallel computation using hardware concurrency
124 // Parallelize over G_total and M_total dimensions for optimal CPU utilization
125 make_ParallelTensorFunctor(f_gm, G_total, M_total)(std::thread::hardware_concurrency());
126}
127
128template <typename ADataType,
129 typename BDataType,
130 typename DDataType,
131 typename EDataType,
132 typename AccDataType,
133 typename CDEElementWise>
135 const HostTensor<ADataType>& a_full_dims,
136 const HostTensor<BDataType>& b_full_dims,
137 const std::vector<HostTensor<DDataType>>& ds_full_dims_host,
138 HostTensor<EDataType>& e_full_dims_host_ref,
139 const std::vector<index_t>& G_dims,
140 const std::vector<index_t>& M_dims,
141 const std::vector<index_t>& N_dims,
142 const std::vector<index_t>& K_dims,
143 const std::vector<index_t>& A_dims,
144 const std::vector<index_t>& B_dims,
145 const std::vector<index_t>& E_dims,
146 const CDEElementWise& cde_elementwise)
147{
148 std::cout << "Calculating reference using multi-dimensional indexing..." << std::endl;
149
150 std::vector<std::size_t> g_idx(G_dims.size());
151 std::vector<std::size_t> m_idx(M_dims.size());
152 std::vector<std::size_t> n_idx(N_dims.size());
153 std::vector<std::size_t> k_idx(K_dims.size());
154 std::vector<std::size_t> a_idx, b_idx, e_idx;
155
156 a_idx.reserve(A_dims.size());
157 b_idx.reserve(B_dims.size());
158 e_idx.reserve(E_dims.size());
159
160 auto calculate_total_elements = [](const std::vector<ck_tile::index_t>& dims) {
161 return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<ck_tile::index_t>());
162 };
163
164 for(ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat)
165 {
166 ck_tile::index_t temp = g_flat;
167 for(int i = G_dims.size() - 1; i >= 0; --i)
168 {
169 g_idx[i] = temp % G_dims[i];
170 temp /= G_dims[i];
171 }
172
173 for(ck_tile::index_t m_flat = 0; m_flat < calculate_total_elements(M_dims); ++m_flat)
174 {
175 temp = m_flat;
176 for(int i = M_dims.size() - 1; i >= 0; --i)
177 {
178 m_idx[i] = temp % M_dims[i];
179 temp /= M_dims[i];
180 }
181
182 for(ck_tile::index_t n_flat = 0; n_flat < calculate_total_elements(N_dims); ++n_flat)
183 {
184 temp = n_flat;
185 for(int i = N_dims.size() - 1; i >= 0; --i)
186 {
187 n_idx[i] = temp % N_dims[i];
188 temp /= N_dims[i];
189 }
190
191 AccDataType sum = 0;
192
193 for(ck_tile::index_t k_flat = 0; k_flat < calculate_total_elements(K_dims);
194 ++k_flat)
195 {
196 temp = k_flat;
197 for(int i = K_dims.size() - 1; i >= 0; --i)
198 {
199 k_idx[i] = temp % K_dims[i];
200 temp /= K_dims[i];
201 }
202
203 a_idx.clear();
204 b_idx.clear();
205
206 a_idx.insert(a_idx.end(), g_idx.begin(), g_idx.end());
207 a_idx.insert(a_idx.end(), m_idx.begin(), m_idx.end());
208 a_idx.insert(a_idx.end(), k_idx.begin(), k_idx.end());
209
210 b_idx.insert(b_idx.end(), g_idx.begin(), g_idx.end());
211 b_idx.insert(b_idx.end(), n_idx.begin(), n_idx.end());
212 b_idx.insert(b_idx.end(), k_idx.begin(), k_idx.end());
213
214 auto a_val = a_full_dims(a_idx);
215 auto b_val = b_full_dims(b_idx);
216
217 sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
218 }
219
220 e_idx.clear();
221 e_idx.insert(e_idx.end(), g_idx.begin(), g_idx.end());
222 e_idx.insert(e_idx.end(), m_idx.begin(), m_idx.end());
223 e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end());
224
225 EDataType result = static_cast<EDataType>(sum);
226 if(ds_full_dims_host.size() == 0)
227 {
228 ;
229 }
230 else if(ds_full_dims_host.size() == 1)
231 {
232 cde_elementwise(result,
234 ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)));
235 }
236 else if(ds_full_dims_host.size() == 2)
237 {
238 cde_elementwise(result,
240 ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
241 ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)));
242 }
243 else if(ds_full_dims_host.size() == 3)
244 {
245 cde_elementwise(result,
247 ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
248 ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
249 ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)));
250 }
251 else if(ds_full_dims_host.size() == 4)
252 {
253 cde_elementwise(result,
255 ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
256 ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
257 ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)),
258 ck_tile::type_convert<float>(ds_full_dims_host[3](e_idx)));
259 }
260 else
261 {
262 throw std::runtime_error("Unsupported NumDTensor for reference calculation");
263 }
264
265 e_full_dims_host_ref(e_idx) = static_cast<EDataType>(result);
266 }
267 }
268 }
269}
270
271} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
void calculate_reference_multi_dimensional(const HostTensor< ADataType > &a_full_dims, const HostTensor< BDataType > &b_full_dims, const std::vector< HostTensor< DDataType > > &ds_full_dims_host, HostTensor< EDataType > &e_full_dims_host_ref, const std::vector< index_t > &G_dims, const std::vector< index_t > &M_dims, const std::vector< index_t > &N_dims, const std::vector< index_t > &K_dims, const std::vector< index_t > &A_dims, const std::vector< index_t > &B_dims, const std::vector< index_t > &E_dims, const CDEElementWise &cde_elementwise)
Definition reference_batched_contraction.hpp:134
int32_t index_t
Definition integer.hpp:9
void calculate_reference_flat_indexing(const ck_tile::HostTensor< ADataType > &a_full_dims, const ck_tile::HostTensor< BDataType > &b_full_dims, const std::vector< ck_tile::HostTensor< DDataType > > &ds_full_dims_host, ck_tile::HostTensor< EDataType > &e_full_dims_host_ref, ck_tile::index_t G_total, ck_tile::index_t M_total, ck_tile::index_t N_total, ck_tile::index_t K_total, const CDEElementWise &cde_elementwise)
Definition reference_batched_contraction.hpp:23
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
Definition tile/host/host_tensor.hpp:336
Data mData
Definition tile/host/host_tensor.hpp:801