FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference#
Classes |
Public Types |
Public Member Functions |
Static Public Member Functions |
Static Public Attributes |
List of all members
ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference
#include <fmha_fwd_v3_kernel.hpp>
Classes | |
| struct | FmhaFwdEmptyKargs |
| struct | FmhaFwdCommonKargs |
| struct | FmhaFwdMaskKargs |
| struct | FmhaFwdCommonLSEKargs |
| struct | FmhaFwdBatchModeKargs |
| struct | FmhaFwdGroupModeKargs |
Public Types | |
| using | FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_> |
| using | EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_> |
| using | QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType> |
| using | KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType> |
| using | VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType> |
| using | LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType> |
| using | ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType> |
| using | SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType> |
| using | FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask> |
| using | Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs> |
Public Member Functions | |
| CK_TILE_DEVICE void | operator() (Kargs kargs) const |
Static Public Member Functions | |
| template<bool Cond = !kIsGroupMode> | |
| static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > | MakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, const ck_tile::index_t *cu_seqlen_q_ptr=nullptr, const ck_tile::index_t *cu_seqlen_kv_ptr=nullptr) |
| template<bool Cond = kIsGroupMode> | |
| static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > | MakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, const void *seqstart_padded_q_ptr=nullptr, const void *seqstart_padded_k_ptr=nullptr) |
| static CK_TILE_HOST constexpr auto | GridSize (ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_) |
| static CK_TILE_DEVICE constexpr auto | RemapTileIndices (int32_t tg_idx, int32_t tg_idy, int32_t remap_option) |
| static CK_TILE_DEVICE constexpr auto | GetTileIndex (const Kargs &) |
| static CK_TILE_HOST constexpr auto | BlockSize () |
| static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t | GetSmemSize () |
Static Public Attributes | |
| static constexpr ck_tile::index_t | kBlockSize = FmhaPipeline::kBlockSize |
| static constexpr ck_tile::index_t | kBlockPerCu = FmhaPipeline::kBlockPerCu |
| static constexpr bool | kIsGroupMode = FmhaPipeline::kIsGroupMode |
| static constexpr bool | kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ |
| static constexpr bool | kPadSeqLenK = FmhaPipeline::kPadSeqLenK |
| static constexpr bool | kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ |
| static constexpr bool | kPadHeadDimV = FmhaPipeline::kPadHeadDimV |
| static constexpr bool | kStoreLSE = FmhaPipeline::kStoreLSE |
| static constexpr bool | kHasMask = FmhaMask::IsMasking |
Member Typedef Documentation
◆ EpiloguePipeline
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_> |
◆ FmhaMask
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask> |
◆ FmhaPipeline
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_> |
◆ Kargs
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs> |
◆ KDataType
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType> |
◆ LSEDataType
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType> |
◆ ODataType
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType> |
◆ QDataType
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType> |
◆ SaccDataType
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType> |
◆ VDataType
template<typename FmhaPipeline_, typename EpiloguePipeline_>
| using ck_tile::FmhaFwdV3Kernel< FmhaPipeline_, EpiloguePipeline_ >::VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType> |
Member Function Documentation
◆ BlockSize()
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
inlinestaticconstexpr |
◆ GetSmemSize()
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
inlinestaticconstexpr |
◆ GetTileIndex()
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
inlinestaticconstexpr |
◆ GridSize()
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
inlinestaticconstexpr |
◆ MakeKargs() [1/2]
template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = !kIsGroupMode>
|
inlinestaticconstexpr |
◆ MakeKargs() [2/2]
template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = kIsGroupMode>
|
inlinestaticconstexpr |
◆ operator()()
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
inline |
◆ RemapTileIndices()
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
inlinestaticconstexpr |
Member Data Documentation
◆ kBlockPerCu
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kBlockSize
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kHasMask
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kIsGroupMode
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kPadHeadDimQ
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kPadHeadDimV
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kPadSeqLenK
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kPadSeqLenQ
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
◆ kStoreLSE
template<typename FmhaPipeline_, typename EpiloguePipeline_>
|
staticconstexpr |
The documentation for this struct was generated from the following file: