gemm_aquant_pipeline_ag_bg_cr_policy.hpp Source File

gemm_aquant_pipeline_ag_bg_cr_policy.hpp Source File#

Composable Kernel: gemm_aquant_pipeline_ag_bg_cr_policy.hpp Source File
gemm_aquant_pipeline_ag_bg_cr_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8
9namespace ck_tile {
10
12{
14 using Base::I0;
15 using Base::I1;
16 using Base::I2;
17
18 template <typename Problem>
19 CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ()
20 {
23 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
24 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
25 constexpr index_t KPerBlockAQ = KPerBlock / Problem::QuantGroupSize::kK;
26
27 static_assert(std::is_same_v<AQLayout, ck_tile::tensor_layout::gemm::RowMajor>);
28 return GetABQGlobalVectorLoadSize<Problem, AQDataType, MPerBlock, KPerBlockAQ>();
29 }
30
31 template <typename Problem>
33 {
35 using BlockGemmShape = typename Problem::BlockGemmShape;
36
37 constexpr index_t BlockSize = Problem::kBlockSize;
38 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
39 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
40 constexpr index_t KPerBlockAQ = KPerBlock / Problem::QuantGroupSize::kK;
41 constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
42 constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
43 using WarpTile = typename Problem::BlockGemmShape::WarpTile;
44 using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
45 typename Problem::ComputeDataType,
46 typename Problem::CDataType,
47 WarpTile::at(I0),
48 WarpTile::at(I1),
49 WarpTile::at(I2),
50 Problem::TransposeC>;
51
52 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
53 if constexpr(PreshuffleQuant)
54 {
55 using TileEncodingPattern = tile_distribution_encoding_pattern_aq<
56 BlockGemmShape,
57 WarpGemm,
58 BlockSize,
59 MPerBlock / WarpGemm::kM,
60 ck_tile::integer_least_multiple(WarpGemm::kM * KPerBlockAQ, get_warp_size()),
61 KPerBlockAQ,
62 VecLoadSize,
63 PreshuffleQuant>;
64
65 return TileEncodingPattern::make_2d_static_tile_distribution();
66 }
67 else
68 {
69 if constexpr(Problem::TransposeC)
70 {
71 using TileEncodingPatternTransposeC =
73 WarpGemm,
74 BlockSize,
75 MPerBlock,
76 KPerBlockAQ,
77 VecLoadSize>;
78 return TileEncodingPatternTransposeC::make_2d_static_tile_distribution();
79 }
80 else
81 {
82 using TileEncodingPattern = tile_distribution_encoding_pattern_aq<BlockGemmShape,
83 WarpGemm,
84 BlockSize,
85 MPerBlock,
86 KPerBlockAQ,
87 KPerBlockAQ,
88 VecLoadSize,
89 PreshuffleQuant>;
90
91 return TileEncodingPattern::make_2d_static_tile_distribution();
92 }
93 }
94 }
95
96 template <typename Problem>
97 CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
98 {
99 using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
100 using WarpTile = typename Problem::BlockGemmShape::WarpTile;
101
102 static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0,
103 "KPerWarpGemm must be a multiple of QuantGroupSize::kK!");
104
105 using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
106 typename Problem::ComputeDataType,
107 typename Problem::CDataType,
108 WarpTile::at(I0),
109 WarpTile::at(I1),
110 WarpTile::at(I2),
111 Problem::TransposeC>;
112 static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
113 std::is_same_v<typename Problem::ComputeDataType, bf8_t>);
114 static_assert(std::is_same_v<typename Problem::CDataType, float>);
115 using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
116 typename Problem::BDataType,
117 typename Problem::CDataType,
118 BlockWarps,
119 WarpGemm>;
121 }
122};
123
124} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
Definition tile/core/numeric/math.hpp:155
int32_t index_t
Definition integer.hpp:9
Definition block_universal_gemm_as_aquant_bs_cr.hpp:56
Definition block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp:18
Definition gemm_aquant_pipeline_ag_bg_cr_policy.hpp:12
UniversalGemmPipelineAgBgCrPolicy Base
Definition gemm_aquant_pipeline_ag_bg_cr_policy.hpp:13
static CK_TILE_HOST_DEVICE constexpr auto GetBlockGemm()
Definition gemm_aquant_pipeline_ag_bg_cr_policy.hpp:97
static CK_TILE_HOST_DEVICE constexpr auto MakeAQDramTileDistribution()
Definition gemm_aquant_pipeline_ag_bg_cr_policy.hpp:32
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeAQ()
Definition gemm_aquant_pipeline_ag_bg_cr_policy.hpp:19
static constexpr auto I1
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:50
static constexpr auto I2
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:51
static constexpr auto I0
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:49
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:693
Definition gemm_group_quant_utils.hpp:124
Definition gemm_group_quant_utils.hpp:57