fused_moegemm_pipeline_flatmm_uk.hpp Source File

fused_moegemm_pipeline_flatmm_uk.hpp Source File#

Composable Kernel: fused_moegemm_pipeline_flatmm_uk.hpp Source File
fused_moegemm_pipeline_flatmm_uk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12/*
13This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
14we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
15
16 <----- gemm-N ------>
17 +----+----+----+----+
18 | w0 | w1 | w2 | w3 | gemm-m
19 +----+----+----+----+
20*/
21template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
23{
26
27 using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape
28
29 using ADataType = typename Problem::ADataType;
30 using GDataType = typename Problem::GDataType;
31 using DDataType = typename Problem::DDataType;
32 using AccDataType = typename Problem::AccDataType;
33 using ODataType = typename Problem::ODataType;
34 using AScaleDataType = typename Problem::AScaleDataType;
35 using GScaleDataType = typename Problem::GScaleDataType;
36 using DScaleDataType = typename Problem::DScaleDataType;
37 using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
38 using TopkWeightDataType = typename Problem::TopkWeightDataType;
39 using IndexDataType = typename Problem::IndexDataType;
40 using YDataType = typename Problem::YDataType;
41
42 using Traits = typename Problem::Traits;
43
44 static constexpr bool IsGateOnly = Traits::IsGateOnly;
45 static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
46 static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
47 static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
48
49 static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
50 static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
51 static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
52 static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
53
58
59 static constexpr index_t kBlockPerCu = []() {
60 if constexpr(Problem::kBlockPerCu != -1)
61 return Problem::kBlockPerCu;
62 else
63 {
64 // minimize occupancy
65 return 2;
66 }
67 }();
68
69 static constexpr const char* name = "flatmm_uk";
70
72 {
73#if 1
74 constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
75 constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
76 constexpr index_t smem_bridge =
77 BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
78 return max(smem_0 + smem_1, smem_bridge);
79#else
80 // keep it here purposely in case we have regression
81 return 65536;
82#endif
83 }
84
85 // this is the thread-offset along row/col
87 {
88 constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
89 const auto a_coord = a_dist.calculate_index();
90 return a_coord;
91 }
92
93 // this is the thread-offset along row/col
95 {
96 constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
97 const auto o_coord = o_dist.calculate_index();
98 return o_coord;
99 }
100
102 {
103 constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
104 constexpr index_t MLans = BlockShape::BlockSize / KLans;
105 constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
106
107 return MRepeat;
108 }
109
110 // TODO: properlly support scatter/gather
112 {
113 constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
114 constexpr index_t MLans = BlockShape::BlockSize / KLans;
115 constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
116
117 auto base_coord = threadIdx.x / KLans + base_offset;
118
120 static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
121
122 return coords;
123 }
124
125 template <typename ROW_COORDS>
126 CK_TILE_DEVICE auto GetRowID(const ROW_COORDS coords, const IndexDataType* sorted_token_ids_ptr)
127 {
128 constexpr index_t n_size = coords.size();
129
131 static_for<0, n_size, 1>{}([&](auto i) {
132 row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans;
133#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
134 row_ids.at(i) &= 0xffffff;
135#endif
136 });
137
138 return row_ids;
139 }
140
141 template <typename ROW_COORDS>
142 CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords,
143 const TopkWeightDataType* sorted_weight_ptr)
144 {
145 constexpr index_t n_size = coords.size();
146
148 static_for<0, n_size, 1>{}([&](auto i) {
149 w.at(i) = sorted_weight_ptr[coords[i]]; // base_coord + i * MLans;
150 });
151
152 return w;
153 }
154
155 // TODO: this row id is before shuffle atomic, need use acc distribution
157 {
158 constexpr index_t MLanes = BlockShape::Warp_M1;
159 constexpr index_t Repeat_M = BlockShape::Repeat_M1;
160
161 auto base_coord = threadIdx.x % MLanes + base_offset;
162
164 static_for<0, Repeat_M, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLanes; });
165
166 return coords;
167 }
168
169 template <typename Karg>
170 CK_TILE_DEVICE auto operator()(const Karg& kargs,
171 CK_TILE_LDS_ADDR void* smem,
172 index_t sorted_tile_id,
173 index_t intermediate_tile_id)
174 {
175 constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
176 ck_tile::index_t shared_intermediate_size_0 =
177 kargs.intermediate_size * hidden_radio_0; // total gate+up
178 ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size;
179
180 // after weight shuffling, gate-only: [nr0, kr0, w0], gate+up: [nr0_gate + nr0_up, kr0, w0]
181
182 index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
183 index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W
184 index_t nr_1 = kargs.hidden_size / BlockShape::Warp_N1;
185 index_t kr_1 = shared_intermediate_size_1 / BlockShape::Warp_K1;
186
187 const IndexDataType expert_id = amd_wave_read_first_lane(
188 reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
189 index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
190 index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size;
191
192 // nr*kr*w
193 index_t interm_idx_nr0 = amd_wave_read_first_lane(
194 intermediate_tile_id *
195 BlockShape::Block_Nr0); // intermediate_tile_id * Block_N / (N in W)
196
197 index_t interm_idx_kr1 = amd_wave_read_first_lane(
198 intermediate_tile_id *
199 BlockShape::Block_Kr1); // intermediate_tile_id * Block_N / (N in W)
200
201 auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0);
202 auto row_ids_a = GetRowID(
203 row_coords_a, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
204 auto a_coords = generate_tuple(
205 [&](auto i) {
206 return row_ids_a[i] * kargs.stride_token +
207 threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
208 },
209 number<row_ids_a.size()>{});
210
211 auto a_res =
212 make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
213 kargs.num_tokens * kargs.stride_token * sizeof(ADataType),
214 std::true_type{});
215
216 auto make_gu_win = [&](const auto* ptr_) {
218 ptr_,
220 make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
222 number<1>{});
223
224 auto win_ = make_tile_window_linear_raw(
225 view_,
229 {0, 0, 0},
230 Policy::template MakeGlobalTileDistribution_G<Problem>(),
232 return win_;
233 };
234
235 const GDataType* gu_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
236 static_cast<long_index_t>(expert_id) * expert_stride_0 +
237 interm_idx_nr0 * kr_0 * BlockShape::Block_W0;
238
239 auto g_win = make_gu_win(gu_ptr);
240 // Note: gu swizzled, [nr_u+nr_g, kr, w], hence base offset to up is just interm*hidden
241 auto u_win = make_gu_win(gu_ptr + kargs.intermediate_size * kargs.hidden_size);
242
243 auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
244 auto u_res = u_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
245 auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
246 number<decltype(g_win)::NumAccess_NonLinear>{});
247
248 const auto d_win = [&]() {
249 const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
250 static_cast<long_index_t>(expert_id) * expert_stride_1 +
251 interm_idx_kr1 * BlockShape::Block_W1;
252 // note interm_idx_nr0 is along the gemm-k dim of 2nd gemm
253
255 d_ptr,
256 make_tuple(nr_1, kr_1, BlockShape::Block_W1),
257 make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
259 number<1>{});
260
261 const auto d_window_ = make_tile_window_linear_raw(
262 d_view_,
266 {0, 0, 0},
267 Policy::template MakeGlobalTileDistribution_D<Problem>(),
269 return d_window_;
270 }();
271 auto d_res = d_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
272
273 // TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255
274 // block-k=512, block-n=128
275 // wg |<----- W_ ----->|
276 // Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue
277 // y p y y p p y
278 // 1 2 0(imm)
279 auto d_coords = [&]() {
280 constexpr index_t Nr_ = 2;
281 constexpr index_t Nw_ = 4;
282 constexpr index_t Kr0_ = 4;
283 constexpr index_t Kr1_ = 4;
284 constexpr index_t Kl_ = 4;
285 constexpr index_t Nl_ = 16;
286 constexpr index_t Kv_ = 8;
287 constexpr index_t W_ = Kl_ * Nl_ * Kv_;
288 constexpr index_t num_offsets_ = Nr_ * Kr0_;
289 index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) *
290 shared_intermediate_size_1 *
291 Nl_; // Kr0_ * Kr1_ * W_;
292 return generate_tuple(
293 [&](auto i) {
294 constexpr auto i_nr_ = number<i % Nr_>{};
295 constexpr auto i_kr0_ = number<i / Nr_>{};
296
297 return i_nr_ * shared_intermediate_size_1 * Nw_ * Nl_ + i_kr0_ * Kr1_ * W_ +
298 base_os_;
299 },
301 }();
302
303 auto o_coords = generate_tuple(
304 [&](auto i) {
305 return row_ids_a[i] * kargs.stride_token +
306 threadIdx.x % (BlockShape::Block_N1 / kAlignmentO) * kAlignmentO;
307 },
308 number<row_ids_a.size()>{});
309
310 auto o_flags =
311 generate_tuple([&](auto i) { return cmp_lt_to_exec(row_ids_a[i], kargs.num_tokens); },
312 number<row_ids_a.size()>{});
313
314 auto bridge_sst_win = [&]() {
315 constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>();
316 constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist();
318 reinterpret_cast<YDataType*>(smem), desc_),
319 desc_.get_lengths(),
320 {0, 0},
321 dist_);
322 }();
323
324 auto o_res =
325 make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
326 kargs.num_tokens * kargs.stride_token * sizeof(ODataType),
327 std::true_type{});
328 auto row_coords_o = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0);
329 auto w_scale = GetWeightScale(
330 row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
331
332 auto uk_0 = Policy::template GetUK_0<Problem>();
333
334 auto y_pre = [&]() {
335 if constexpr(IsGateOnly)
336 {
337 auto acc_0 = uk_0(a_res,
338 a_coords,
339 g_res,
340 g_coords,
341 smem,
342 kargs.hidden_size,
343 BlockShape::Block_K0, // tile offset for B matrix each unroll
344 BlockShape::Block_Kr0 *
345 BlockShape::Block_W0); // tile offset for B matrix each unroll
346
348 acc_0,
349 [&](auto idx0, auto idx1) {
350 fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
351 typename Problem::GateActivation{}(v_, v_);
352 acc_0(idx0) = v_.x;
353 acc_0(idx1) = v_.y;
354 },
356
357 return cast_tile<YDataType>(acc_0);
358 }
359 else
360 {
361 uint32x8_t gu_res;
362 gu_res[0] = g_res[0];
363 gu_res[1] = g_res[1];
364 gu_res[2] = g_res[2];
365 gu_res[3] = g_res[3];
366 gu_res[4] = u_res[0];
367 gu_res[5] = u_res[1];
368 gu_res[6] = u_res[2];
369 gu_res[7] = u_res[3];
370
371 auto acc_0 = uk_0(a_res,
372 a_coords,
373 gu_res,
374 g_coords,
375 smem,
376 kargs.hidden_size,
377 BlockShape::Block_K0, // tile offset for B matrix each unroll
378 BlockShape::Block_Kr0 * BlockShape::Block_W0,
379 bool_constant<true>{}); // tile offset for B matrix each unroll
380
382 acc_0.at(number<0>{}),
383 [&](auto idx0, auto idx1) {
384 fp32x2_t v_{acc_0.at(number<0>{})(idx0), acc_0.at(number<0>{})(idx1)};
385 typename Problem::GateActivation{}(v_, v_);
386 acc_0.at(number<0>{})(idx0) = v_.x;
387 acc_0.at(number<0>{})(idx1) = v_.y;
388 },
390
391 auto reduced_acc_0 =
392 tile_elementwise_in([&](const auto& a_, const auto& b_) { return a_ * b_; },
393 acc_0.at(number<0>{}),
394 acc_0.at(number<1>{}));
395
396 return cast_tile<YDataType>(reduced_acc_0);
397 }
398 }();
399
401
402 store_tile(bridge_sst_win, y_pre);
404
405 auto uk_1 = Policy::template GetUK_1<Problem>();
406 uk_1(d_res,
407 d_coords,
408 o_res,
409 o_coords,
410 o_flags,
411 smem,
412 kargs.hidden_size, // total n number
413 w_scale,
414 BlockShape::Block_Nr1 * kr_1 * BlockShape::Block_W1, // along N
415 BlockShape::Block_N1); // along N
416 }
417};
418
419} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_LDS_ADDR
Definition config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE auto cmp_lt_to_exec(const X &x, const Y &y)
Definition utility.hpp:133
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
@ GST_O
Definition fused_moegemm_traits.hpp:48
@ GLD_B
Definition fused_moegemm_traits.hpp:45
@ SLD_A
Definition fused_moegemm_traits.hpp:42
@ GLD_A
Definition fused_moegemm_traits.hpp:44
uint32_t uint32x8_t
Definition vector_type.hpp:165
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F &f, UnpacksPerXDim={})
Definition sweep_tile.hpp:231
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
int64_t long_index_t
Definition integer.hpp:11
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_DEVICE auto make_tile_window_linear_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition tile_window_linear.hpp:1029
CK_TILE_DEVICE constexpr auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition tile_window_linear.hpp:993
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr, uint32_t size=0xffffffff, ForceSGPR={})
Definition tile/core/arch/amd_buffer_addressing.hpp:97
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
float fp32x2_t
Definition pk_fp4.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition fused_moegemm_pipeline_flatmm_uk.hpp:23
typename Problem::IndexDataType IndexDataType
Definition fused_moegemm_pipeline_flatmm_uk.hpp:39
typename Problem::ADataType ADataType
Definition fused_moegemm_pipeline_flatmm_uk.hpp:29
static constexpr index_t kAlignmentO
Definition fused_moegemm_pipeline_flatmm_uk.hpp:52
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition fused_moegemm_pipeline_flatmm_uk.hpp:71
typename Problem::DDataType DDataType
Definition fused_moegemm_pipeline_flatmm_uk.hpp:31
static constexpr bool PadIntermediateSize
Definition fused_moegemm_pipeline_flatmm_uk.hpp:47
CK_TILE_DEVICE auto operator()(const Karg &kargs, CK_TILE_LDS_ADDR void *smem, index_t sorted_tile_id, index_t intermediate_tile_id)
Definition fused_moegemm_pipeline_flatmm_uk.hpp:170
static constexpr const char * name
Definition fused_moegemm_pipeline_flatmm_uk.hpp:69
CK_TILE_DEVICE auto GetRowCoords_A(index_t base_offset)
Definition fused_moegemm_pipeline_flatmm_uk.hpp:111
static CK_TILE_HOST_DEVICE auto GetOCoord()
Definition fused_moegemm_pipeline_flatmm_uk.hpp:94
typename Problem::DScaleDataType DScaleDataType
Definition fused_moegemm_pipeline_flatmm_uk.hpp:36
CK_TILE_DEVICE constexpr auto GetNumRowCoords_A()
Definition fused_moegemm_pipeline_flatmm_uk.hpp:101
static constexpr index_t kAlignmentD
Definition fused_moegemm_pipeline_flatmm_uk.hpp:51
static constexpr index_t kAlignmentA
Definition fused_moegemm_pipeline_flatmm_uk.hpp:49
static constexpr index_t GLD_B
Definition fused_moegemm_pipeline_flatmm_uk.hpp:56
static constexpr index_t kBlockPerCu
Definition fused_moegemm_pipeline_flatmm_uk.hpp:59
typename Problem::GScaleDataType GScaleDataType
Definition fused_moegemm_pipeline_flatmm_uk.hpp:35
typename Problem::TopkWeightDataType TopkWeightDataType
Definition fused_moegemm_pipeline_flatmm_uk.hpp:38
static constexpr bool IsGateOnly
Definition fused_moegemm_pipeline_flatmm_uk.hpp:44
typename Problem::ODataType ODataType
Definition fused_moegemm_pipeline_flatmm_uk.hpp:33
typename Problem::BlockShape BlockShape
Definition fused_moegemm_pipeline_flatmm_uk.hpp:27
static constexpr index_t kAlignmentG
Definition fused_moegemm_pipeline_flatmm_uk.hpp:50
static constexpr bool PadHiddenSize
Definition fused_moegemm_pipeline_flatmm_uk.hpp:46
static constexpr index_t GST_O
Definition fused_moegemm_pipeline_flatmm_uk.hpp:57
static constexpr index_t SLD_A
Definition fused_moegemm_pipeline_flatmm_uk.hpp:54
CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords, const TopkWeightDataType *sorted_weight_ptr)
Definition fused_moegemm_pipeline_flatmm_uk.hpp:142
static constexpr index_t GLD_A
Definition fused_moegemm_pipeline_flatmm_uk.hpp:55
remove_cvref_t< Problem_ > Problem
Definition fused_moegemm_pipeline_flatmm_uk.hpp:24
typename Problem::YSmoothScaleDataType YSmoothScaleDataType
Definition fused_moegemm_pipeline_flatmm_uk.hpp:37
CK_TILE_DEVICE auto GetRowID(const ROW_COORDS coords, const IndexDataType *sorted_token_ids_ptr)
Definition fused_moegemm_pipeline_flatmm_uk.hpp:126
typename Problem::GDataType GDataType
Definition fused_moegemm_pipeline_flatmm_uk.hpp:30
typename Problem::YDataType YDataType
Definition fused_moegemm_pipeline_flatmm_uk.hpp:40
CK_TILE_DEVICE auto GetRowCoords_O(index_t base_offset)
Definition fused_moegemm_pipeline_flatmm_uk.hpp:156
remove_cvref_t< Policy_ > Policy
Definition fused_moegemm_pipeline_flatmm_uk.hpp:25
static constexpr bool UseSmoothQuant
Definition fused_moegemm_pipeline_flatmm_uk.hpp:45
typename Problem::Traits Traits
Definition fused_moegemm_pipeline_flatmm_uk.hpp:42
typename Problem::AccDataType AccDataType
Definition fused_moegemm_pipeline_flatmm_uk.hpp:32
typename Problem::AScaleDataType AScaleDataType
Definition fused_moegemm_pipeline_flatmm_uk.hpp:34
static CK_TILE_HOST_DEVICE auto GetACoord()
Definition fused_moegemm_pipeline_flatmm_uk.hpp:86
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
CK_TILE_HOST_DEVICE constexpr auto & at(index_t i)
Definition tile/core/container/array.hpp:110
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43