streamk_common.hpp Source File

streamk_common.hpp Source File#

Composable Kernel: streamk_common.hpp Source File
streamk_common.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7
8namespace ck_tile {
14
25template <ck_tile::StreamKReductionStrategy ReductionStrategy>
27estimate_num_wgs_per_tile(index_t sk_ctas, index_t iters_per_sk_cta, index_t iters_per_tile)
28{
29 // In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
30 // writing final results to a given macro tile in C.
31 int num_wgs_per_tile = 1;
32
33 // Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
34 if(sk_ctas > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
35 {
36 // Estimate the number of workgroups per macro tile.
37 num_wgs_per_tile =
38 (iters_per_tile / iters_per_sk_cta) + ((iters_per_tile % iters_per_sk_cta) != 0);
39 }
40
41 return std::max(num_wgs_per_tile, 1);
42}
43} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
StreamKReductionStrategy
Definition streamk_common.hpp:10
@ Atomic
Definition streamk_common.hpp:11
@ Reduction
Definition streamk_common.hpp:12
ck_tile::index_t estimate_num_wgs_per_tile(index_t sk_ctas, index_t iters_per_sk_cta, index_t iters_per_tile)
Estimates the number of Stream-K workgroups per macro tile in the C tensor.
Definition streamk_common.hpp:27
int32_t index_t
Definition integer.hpp:9
unsigned int uint32_t
Definition stdint.h:126