mxfp_convert.hpp Source File

mxfp_convert.hpp Source File#

Composable Kernel: mxfp_convert.hpp Source File
mxfp_convert.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6namespace ck_tile {
7// modify from include/ck/utility/mxfp_utils.hpp
8
9template <typename T>
11{
12
15 using raw_type = typename traits::bitwise_type;
16
17 static constexpr int exp_mask = (1 << traits::exp) - 1;
18
19 static constexpr raw_type get_exponent(raw_type x)
20 {
21 // TODO: check if repeated calls are optimized.
22 return (x >> traits::mant) & exp_mask;
23 }
24 static constexpr raw_type get_exponent(const T& x)
25 {
27 }
28 static constexpr bool is_positive(raw_type x)
29 {
30 return (x >> (traits::exp + traits::mant)) == _numeric::binary_zero;
31 }
32 static constexpr bool is_subnormal(raw_type x)
33 {
34 return get_exponent(x) == _numeric::binary_zero;
35 }
36 // TODO: replace double with template arg?
37 static constexpr double get_mantissa(raw_type x)
38 {
39 double mantissa = is_subnormal(x) ? 0.0f : 1.0f;
40 for(raw_type i = 0; i < traits::mant; ++i)
41 {
42 mantissa += std::ldexp(static_cast<float>(x & 0b1), -(traits::mant - i));
43 x >>= 1;
44 }
45 return mantissa;
46 }
47};
48
49template <typename T>
50CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, float scale = 1.f)
51{
52 using utils = numeric_utils<T>;
53 float sign = utils::is_positive(data) ? 1.0 : -1.0;
54 int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias;
55 float mant = utils::get_mantissa(data);
56
57 return std::ldexp(sign * mant * scale, exp);
58}
59
60template <typename T>
61CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value, float scale = 1.f)
62{
63 using bitwise_type = typename numeric_traits<T>::bitwise_type;
64
65 value /= scale;
66
67 if(std::abs(value) > float(numeric<T>::max()))
68 {
69 float max_value = numeric<T>::max();
70
71 // cppcheck-suppress redundantAssignment
72 uint32_t max_bitwise = bit_cast<uint32_t>(max_value);
73
74 // cppcheck-suppress redundantAssignment
75 bitwise_type sign =
77 bitwise_type exp =
80 bitwise_type mantissa =
82
84 mant_prev &= ((1 << numeric_traits<T>::mant) - 1);
85 mant_prev--;
86
88 uint32_t prev_bit =
90 mant_prev;
91
92 float prev_val = bit_cast<float>(prev_bit);
93 float diff = max_value - prev_val;
94
95 float actual_max = max_value + (diff / 2);
96
97 if(std::abs(value) < actual_max)
98 {
100 (exp << numeric_traits<T>::mant) | mantissa;
101 }
102 else
103 {
104 if constexpr(!numeric<T>::has_inf())
105 {
106
108 }
109 else
110 {
111 exp++;
114 }
115 }
116 }
117 const int mfmt = numeric_traits<float>::mant;
118 uint32_t x;
120
121 uint32_t head, mantissa;
122 int32_t exponent, bias;
123 uint32_t sign;
124
130
131 if(x == 0)
132 {
133 return 0b0;
134 }
135
136 const int mini_bias = numeric_traits<T>::bias;
137 const int mini_denormal_act_exponent = 1 - mini_bias;
138
139 int act_exponent, out_exponent, exponent_diff;
140
141 bool is_subnorm = false;
142
143 if(exponent == 0)
144 {
145 act_exponent = exponent - bias + 1;
146 exponent_diff = mini_denormal_act_exponent - act_exponent;
147 is_subnorm = true;
148 }
149 else
150 {
151 act_exponent = exponent - bias;
152 if(act_exponent <= mini_denormal_act_exponent)
153 {
154 exponent_diff = mini_denormal_act_exponent - act_exponent;
155 is_subnorm = true;
156 }
157 else
158 {
159 exponent_diff = 0;
160 }
161 mantissa += (1UL << mfmt);
162 }
163
164 auto shift_amount = (mfmt - numeric_traits<T>::mant + exponent_diff);
165 shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
166 bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
167
168 float min_subnorm = float(numeric<T>::epsilon()) * (sign ? -1 : 1);
169
170 if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
171 {
172 // closer to 0
173 if(std::abs(value) <= std::abs(min_subnorm - value))
175 else
176 return 1 | (sign << (numeric_traits<T>::exp + numeric_traits<T>::mant));
177 }
178
179 if(exponent_diff > 0)
180 mantissa >>= exponent_diff;
181 else if(exponent_diff == -1)
182 mantissa <<= -exponent_diff;
183 bool implicit_one = mantissa & (1 << mfmt);
184 out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
185
186 uint32_t drop_mask = (1UL << (mfmt - numeric_traits<T>::mant)) - 1;
187 bool odd = mantissa & (1UL << (mfmt - numeric_traits<T>::mant));
188 mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
189
190 if(out_exponent == 0)
191 {
192 if((1UL << mfmt) & mantissa)
193 {
194 out_exponent = 1;
195 }
196 }
197 else
198 {
199 if((1UL << (mfmt + 1)) & mantissa)
200 {
201 mantissa >>= 1;
202 out_exponent++;
203 }
204 }
205
206 mantissa >>= (mfmt - numeric_traits<T>::mant);
207
208 if(out_exponent == 0 && mantissa == 0)
209 {
211 }
212
213 mantissa &= (1UL << numeric_traits<T>::mant) - 1;
215 (out_exponent << numeric_traits<T>::mant) | mantissa;
216}
217
218} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, float scale=1.f)
Definition mxfp_convert.hpp:50
CK_TILE_HOST_DEVICE T::raw_type convert_to_type(float value, float scale=1.f)
Definition mxfp_convert.hpp:61
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
int32_t int32_t
Definition integer.hpp:10
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
unsigned int uint32_t
Definition stdint.h:126
Definition tile/core/numeric/numeric.hpp:81
Definition mxfp_convert.hpp:11
static constexpr bool is_positive(raw_type x)
Definition mxfp_convert.hpp:28
static constexpr raw_type get_exponent(raw_type x)
Definition mxfp_convert.hpp:19
static constexpr double get_mantissa(raw_type x)
Definition mxfp_convert.hpp:37
static constexpr int exp_mask
Definition mxfp_convert.hpp:17
numeric_traits< T > traits
Definition mxfp_convert.hpp:13
typename traits::bitwise_type raw_type
Definition mxfp_convert.hpp:15
numeric< T > _numeric
Definition mxfp_convert.hpp:14
static constexpr bool is_subnormal(raw_type x)
Definition mxfp_convert.hpp:32
static constexpr raw_type get_exponent(const T &x)
Definition mxfp_convert.hpp:24
Definition tile/core/numeric/numeric.hpp:18
static CK_TILE_HOST_DEVICE constexpr T max()
Definition tile/core/numeric/numeric.hpp:26
static CK_TILE_HOST_DEVICE constexpr T epsilon()
Definition tile/core/numeric/numeric.hpp:29