reference_batched_dropout_randval.hpp Source File

reference_batched_dropout_randval.hpp Source File#

Composable Kernel: reference_batched_dropout_randval.hpp Source File
reference_batched_dropout_randval.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <thread>
9
10namespace ck_tile {
11
12template <typename RandValOutputDataType>
13CK_TILE_HOST void
15 index_t batch,
16 uint64_t drop_seed,
17 uint64_t drop_offset)
18{
19 const index_t nhead = randval_b_m_n.mDesc.get_lengths()[0];
20 const index_t real_seqlen_q = randval_b_m_n.mDesc.get_lengths()[1];
21 const index_t real_seqlen_k = randval_b_m_n.mDesc.get_lengths()[2];
22
23 static_assert(std::is_same_v<RandValOutputDataType, uint8_t>);
24
25 // BlockDropout generates random numbers by 32x32 tiles. Even when warp gemm 16x16 is used, the
26 // order of values in the bigger 32x32 tile must be the same because fwd and bwd may use
27 // different warp gemms (16x16 or 32x32).
28 // To compute 32x32 tiles, WarpGemmMfmaF16F16F32M32N32K16SwizzleA is used. It is
29 // WarpGemmAttributeMfmaImplF16F16F32M32N32K8 with SFactor = 2 (swizzling factor).
30 // Matrix element to register mapping for WarpGemmAttributeMfmaImplF16F16F32M32N32K8:
31 // C i: (8 * floor(GPR_num / 4) % 32) + 4 * floor(lane / 32) + (GPR_num % 4)
32 // C j: (lane % 32)
33 // With SFactor = 2 it becomes:
34 // C i: (16 * floor(GPR_num / 8) % 32) + 8 * floor(lane / 32) + (GPR_num % 8)
35 // C j: (lane % 32)
36 // See ck_tile/ops/fmha/block/block_dropout.hpp for more details.
37
38 // The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values
39 constexpr index_t philox_per_tile = 64;
40 constexpr index_t warp_gemm_mn = 32;
41
42 const index_t rows = integer_divide_ceil(real_seqlen_q, warp_gemm_mn);
43 const index_t cols = integer_divide_ceil(real_seqlen_k, warp_gemm_mn);
44
45 auto f = [&](index_t i_h, index_t row, index_t col) {
46 uint2 rowcol = make_uint2(row, col);
47 for(index_t lane = 0; lane < philox_per_tile; lane++)
48 {
49 const uint64_t ph_head_offset = drop_offset + (batch * nhead + i_h) * philox_per_tile;
50 const index_t ph_offset = lane;
51 philox ph(drop_seed, ph_head_offset + ph_offset);
52
53 uint8_t random_uint8_t[16];
54 ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
55
56 for(auto r = 0; r < 16; r++)
57 {
58 index_t i = (16 * (r / 8) % 32) + 8 * (lane / 32) + (r % 8);
59 index_t j = (lane % 32);
60 index_t m = row * warp_gemm_mn + i;
61 index_t n = col * warp_gemm_mn + j;
62
63 if(m < real_seqlen_q && n < real_seqlen_k)
64 {
65 randval_b_m_n(i_h, m, n) = random_uint8_t[r];
66 }
67 }
68 }
69 };
70
71 make_ParallelTensorFunctor(f, nhead, rows, cols)(std::thread::hardware_concurrency());
72}
73
74} // namespace ck_tile
Definition philox_rand.hpp:12
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t *out, const unsigned long long subsequence) const
Definition philox_rand.hpp:42
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
CK_TILE_HOST void reference_batched_dropout_randval(HostTensor< RandValOutputDataType > &randval_b_m_n, index_t batch, uint64_t drop_seed, uint64_t drop_offset)
Definition reference_batched_dropout_randval.hpp:14
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
int32_t index_t
Definition integer.hpp:9
unsigned char uint8_t
Definition stdint.h:124
unsigned __int64 uint64_t
Definition stdint.h:136
const std::vector< std::size_t > & get_lengths() const
Definition tile/host/host_tensor.hpp:198
Definition tile/host/host_tensor.hpp:336
Descriptor mDesc
Definition tile/host/host_tensor.hpp:800