Home | History | Annotate | Download | only in SPIRV
      1 // Copyright (c) 2015-2016 The Khronos Group Inc.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #ifndef LIBSPIRV_UTIL_HEX_FLOAT_H_
     16 #define LIBSPIRV_UTIL_HEX_FLOAT_H_
     17 
     18 #include <cassert>
     19 #include <cctype>
     20 #include <cmath>
     21 #include <cstdint>
     22 #include <iomanip>
     23 #include <limits>
     24 #include <sstream>
     25 
     26 #if defined(_MSC_VER) && _MSC_VER < 1800
     27 namespace std {
     28 bool isnan(double f)
     29 {
     30   return ::_isnan(f) != 0;
     31 }
     32 bool isinf(double f)
     33 {
     34   return ::_finite(f) == 0;
     35 }
     36 }
     37 #endif
     38 
     39 #include "bitutils.h"
     40 
     41 namespace spvutils {
     42 
     43 class Float16 {
     44  public:
     45   Float16(uint16_t v) : val(v) {}
     46   Float16() {}
     47   static bool isNan(const Float16& val) {
     48     return ((val.val & 0x7C00) == 0x7C00) && ((val.val & 0x3FF) != 0);
     49   }
     50   // Returns true if the given value is any kind of infinity.
     51   static bool isInfinity(const Float16& val) {
     52     return ((val.val & 0x7C00) == 0x7C00) && ((val.val & 0x3FF) == 0);
     53   }
     54   Float16(const Float16& other) { val = other.val; }
     55   uint16_t get_value() const { return val; }
     56 
     57   // Returns the maximum normal value.
     58   static Float16 max() { return Float16(0x7bff); }
     59   // Returns the lowest normal value.
     60   static Float16 lowest() { return Float16(0xfbff); }
     61 
     62  private:
     63   uint16_t val;
     64 };
     65 
     66 // To specialize this type, you must override uint_type to define
     67 // an unsigned integer that can fit your floating point type.
     68 // You must also add a isNan function that returns true if
     69 // a value is Nan.
     70 template <typename T>
     71 struct FloatProxyTraits {
     72   typedef void uint_type;
     73 };
     74 
     75 template <>
     76 struct FloatProxyTraits<float> {
     77   typedef uint32_t uint_type;
     78   static bool isNan(float f) { return std::isnan(f); }
     79   // Returns true if the given value is any kind of infinity.
     80   static bool isInfinity(float f) { return std::isinf(f); }
     81   // Returns the maximum normal value.
     82   static float max() { return std::numeric_limits<float>::max(); }
     83   // Returns the lowest normal value.
     84   static float lowest() { return std::numeric_limits<float>::lowest(); }
     85 };
     86 
     87 template <>
     88 struct FloatProxyTraits<double> {
     89   typedef uint64_t uint_type;
     90   static bool isNan(double f) { return std::isnan(f); }
     91   // Returns true if the given value is any kind of infinity.
     92   static bool isInfinity(double f) { return std::isinf(f); }
     93   // Returns the maximum normal value.
     94   static double max() { return std::numeric_limits<double>::max(); }
     95   // Returns the lowest normal value.
     96   static double lowest() { return std::numeric_limits<double>::lowest(); }
     97 };
     98 
     99 template <>
    100 struct FloatProxyTraits<Float16> {
    101   typedef uint16_t uint_type;
    102   static bool isNan(Float16 f) { return Float16::isNan(f); }
    103   // Returns true if the given value is any kind of infinity.
    104   static bool isInfinity(Float16 f) { return Float16::isInfinity(f); }
    105   // Returns the maximum normal value.
    106   static Float16 max() { return Float16::max(); }
    107   // Returns the lowest normal value.
    108   static Float16 lowest() { return Float16::lowest(); }
    109 };
    110 
    111 // Since copying a floating point number (especially if it is NaN)
    112 // does not guarantee that bits are preserved, this class lets us
    113 // store the type and use it as a float when necessary.
    114 template <typename T>
    115 class FloatProxy {
    116  public:
    117   typedef typename FloatProxyTraits<T>::uint_type uint_type;
    118 
    119   // Since this is to act similar to the normal floats,
    120   // do not initialize the data by default.
    121   FloatProxy() {}
    122 
    123   // Intentionally non-explicit. This is a proxy type so
    124   // implicit conversions allow us to use it more transparently.
    125   FloatProxy(T val) { data_ = BitwiseCast<uint_type>(val); }
    126 
    127   // Intentionally non-explicit. This is a proxy type so
    128   // implicit conversions allow us to use it more transparently.
    129   FloatProxy(uint_type val) { data_ = val; }
    130 
    131   // This is helpful to have and is guaranteed not to stomp bits.
    132   FloatProxy<T> operator-() const {
    133     return static_cast<uint_type>(data_ ^
    134                                   (uint_type(0x1) << (sizeof(T) * 8 - 1)));
    135   }
    136 
    137   // Returns the data as a floating point value.
    138   T getAsFloat() const { return BitwiseCast<T>(data_); }
    139 
    140   // Returns the raw data.
    141   uint_type data() const { return data_; }
    142 
    143   // Returns true if the value represents any type of NaN.
    144   bool isNan() { return FloatProxyTraits<T>::isNan(getAsFloat()); }
    145   // Returns true if the value represents any type of infinity.
    146   bool isInfinity() { return FloatProxyTraits<T>::isInfinity(getAsFloat()); }
    147 
    148   // Returns the maximum normal value.
    149   static FloatProxy<T> max() {
    150     return FloatProxy<T>(FloatProxyTraits<T>::max());
    151   }
    152   // Returns the lowest normal value.
    153   static FloatProxy<T> lowest() {
    154     return FloatProxy<T>(FloatProxyTraits<T>::lowest());
    155   }
    156 
    157  private:
    158   uint_type data_;
    159 };
    160 
    161 template <typename T>
    162 bool operator==(const FloatProxy<T>& first, const FloatProxy<T>& second) {
    163   return first.data() == second.data();
    164 }
    165 
    166 // Reads a FloatProxy value as a normal float from a stream.
    167 template <typename T>
    168 std::istream& operator>>(std::istream& is, FloatProxy<T>& value) {
    169   T float_val;
    170   is >> float_val;
    171   value = FloatProxy<T>(float_val);
    172   return is;
    173 }
    174 
    175 // This is an example traits. It is not meant to be used in practice, but will
    176 // be the default for any non-specialized type.
    177 template <typename T>
    178 struct HexFloatTraits {
    179   // Integer type that can store this hex-float.
    180   typedef void uint_type;
    181   // Signed integer type that can store this hex-float.
    182   typedef void int_type;
    183   // The numerical type that this HexFloat represents.
    184   typedef void underlying_type;
    185   // The type needed to construct the underlying type.
    186   typedef void native_type;
    187   // The number of bits that are actually relevant in the uint_type.
    188   // This allows us to deal with, for example, 24-bit values in a 32-bit
    189   // integer.
    190   static const uint32_t num_used_bits = 0;
    191   // Number of bits that represent the exponent.
    192   static const uint32_t num_exponent_bits = 0;
    193   // Number of bits that represent the fractional part.
    194   static const uint32_t num_fraction_bits = 0;
    195   // The bias of the exponent. (How much we need to subtract from the stored
    196   // value to get the correct value.)
    197   static const uint32_t exponent_bias = 0;
    198 };
    199 
    200 // Traits for IEEE float.
    201 // 1 sign bit, 8 exponent bits, 23 fractional bits.
    202 template <>
    203 struct HexFloatTraits<FloatProxy<float>> {
    204   typedef uint32_t uint_type;
    205   typedef int32_t int_type;
    206   typedef FloatProxy<float> underlying_type;
    207   typedef float native_type;
    208   static const uint_type num_used_bits = 32;
    209   static const uint_type num_exponent_bits = 8;
    210   static const uint_type num_fraction_bits = 23;
    211   static const uint_type exponent_bias = 127;
    212 };
    213 
    214 // Traits for IEEE double.
    215 // 1 sign bit, 11 exponent bits, 52 fractional bits.
    216 template <>
    217 struct HexFloatTraits<FloatProxy<double>> {
    218   typedef uint64_t uint_type;
    219   typedef int64_t int_type;
    220   typedef FloatProxy<double> underlying_type;
    221   typedef double native_type;
    222   static const uint_type num_used_bits = 64;
    223   static const uint_type num_exponent_bits = 11;
    224   static const uint_type num_fraction_bits = 52;
    225   static const uint_type exponent_bias = 1023;
    226 };
    227 
    228 // Traits for IEEE half.
    229 // 1 sign bit, 5 exponent bits, 10 fractional bits.
    230 template <>
    231 struct HexFloatTraits<FloatProxy<Float16>> {
    232   typedef uint16_t uint_type;
    233   typedef int16_t int_type;
    234   typedef uint16_t underlying_type;
    235   typedef uint16_t native_type;
    236   static const uint_type num_used_bits = 16;
    237   static const uint_type num_exponent_bits = 5;
    238   static const uint_type num_fraction_bits = 10;
    239   static const uint_type exponent_bias = 15;
    240 };
    241 
    242 enum round_direction {
    243   kRoundToZero,
    244   kRoundToNearestEven,
    245   kRoundToPositiveInfinity,
    246   kRoundToNegativeInfinity
    247 };
    248 
    249 // Template class that houses a floating pointer number.
    250 // It exposes a number of constants based on the provided traits to
    251 // assist in interpreting the bits of the value.
    252 template <typename T, typename Traits = HexFloatTraits<T>>
    253 class HexFloat {
    254  public:
    255   typedef typename Traits::uint_type uint_type;
    256   typedef typename Traits::int_type int_type;
    257   typedef typename Traits::underlying_type underlying_type;
    258   typedef typename Traits::native_type native_type;
    259 
    260   explicit HexFloat(T f) : value_(f) {}
    261 
    262   T value() const { return value_; }
    263   void set_value(T f) { value_ = f; }
    264 
    265   // These are all written like this because it is convenient to have
    266   // compile-time constants for all of these values.
    267 
    268   // Pass-through values to save typing.
    269   static const uint32_t num_used_bits = Traits::num_used_bits;
    270   static const uint32_t exponent_bias = Traits::exponent_bias;
    271   static const uint32_t num_exponent_bits = Traits::num_exponent_bits;
    272   static const uint32_t num_fraction_bits = Traits::num_fraction_bits;
    273 
    274   // Number of bits to shift left to set the highest relevant bit.
    275   static const uint32_t top_bit_left_shift = num_used_bits - 1;
    276   // How many nibbles (hex characters) the fractional part takes up.
    277   static const uint32_t fraction_nibbles = (num_fraction_bits + 3) / 4;
    278   // If the fractional part does not fit evenly into a hex character (4-bits)
    279   // then we have to left-shift to get rid of leading 0s. This is the amount
    280   // we have to shift (might be 0).
    281   static const uint32_t num_overflow_bits =
    282       fraction_nibbles * 4 - num_fraction_bits;
    283 
    284   // The representation of the fraction, not the actual bits. This
    285   // includes the leading bit that is usually implicit.
    286   static const uint_type fraction_represent_mask =
    287       spvutils::SetBits<uint_type, 0,
    288                         num_fraction_bits + num_overflow_bits>::get;
    289 
    290   // The topmost bit in the nibble-aligned fraction.
    291   static const uint_type fraction_top_bit =
    292       uint_type(1) << (num_fraction_bits + num_overflow_bits - 1);
    293 
    294   // The least significant bit in the exponent, which is also the bit
    295   // immediately to the left of the significand.
    296   static const uint_type first_exponent_bit = uint_type(1)
    297                                               << (num_fraction_bits);
    298 
    299   // The mask for the encoded fraction. It does not include the
    300   // implicit bit.
    301   static const uint_type fraction_encode_mask =
    302       spvutils::SetBits<uint_type, 0, num_fraction_bits>::get;
    303 
    304   // The bit that is used as a sign.
    305   static const uint_type sign_mask = uint_type(1) << top_bit_left_shift;
    306 
    307   // The bits that represent the exponent.
    308   static const uint_type exponent_mask =
    309       spvutils::SetBits<uint_type, num_fraction_bits, num_exponent_bits>::get;
    310 
    311   // How far left the exponent is shifted.
    312   static const uint32_t exponent_left_shift = num_fraction_bits;
    313 
    314   // How far from the right edge the fraction is shifted.
    315   static const uint32_t fraction_right_shift =
    316       static_cast<uint32_t>(sizeof(uint_type) * 8) - num_fraction_bits;
    317 
    318   // The maximum representable unbiased exponent.
    319   static const int_type max_exponent =
    320       (exponent_mask >> num_fraction_bits) - exponent_bias;
    321   // The minimum representable exponent for normalized numbers.
    322   static const int_type min_exponent = -static_cast<int_type>(exponent_bias);
    323 
    324   // Returns the bits associated with the value.
    325   uint_type getBits() const { return spvutils::BitwiseCast<uint_type>(value_); }
    326 
    327   // Returns the bits associated with the value, without the leading sign bit.
    328   uint_type getUnsignedBits() const {
    329     return static_cast<uint_type>(spvutils::BitwiseCast<uint_type>(value_) &
    330                                   ~sign_mask);
    331   }
    332 
    333   // Returns the bits associated with the exponent, shifted to start at the
    334   // lsb of the type.
    335   const uint_type getExponentBits() const {
    336     return static_cast<uint_type>((getBits() & exponent_mask) >>
    337                                   num_fraction_bits);
    338   }
    339 
    340   // Returns the exponent in unbiased form. This is the exponent in the
    341   // human-friendly form.
    342   const int_type getUnbiasedExponent() const {
    343     return static_cast<int_type>(getExponentBits() - exponent_bias);
    344   }
    345 
    346   // Returns just the significand bits from the value.
    347   const uint_type getSignificandBits() const {
    348     return getBits() & fraction_encode_mask;
    349   }
    350 
    351   // If the number was normalized, returns the unbiased exponent.
    352   // If the number was denormal, normalize the exponent first.
    353   const int_type getUnbiasedNormalizedExponent() const {
    354     if ((getBits() & ~sign_mask) == 0) {  // special case if everything is 0
    355       return 0;
    356     }
    357     int_type exp = getUnbiasedExponent();
    358     if (exp == min_exponent) {  // We are in denorm land.
    359       uint_type significand_bits = getSignificandBits();
    360       while ((significand_bits & (first_exponent_bit >> 1)) == 0) {
    361         significand_bits = static_cast<uint_type>(significand_bits << 1);
    362         exp = static_cast<int_type>(exp - 1);
    363       }
    364       significand_bits &= fraction_encode_mask;
    365     }
    366     return exp;
    367   }
    368 
    369   // Returns the signficand after it has been normalized.
    370   const uint_type getNormalizedSignificand() const {
    371     int_type unbiased_exponent = getUnbiasedNormalizedExponent();
    372     uint_type significand = getSignificandBits();
    373     for (int_type i = unbiased_exponent; i <= min_exponent; ++i) {
    374       significand = static_cast<uint_type>(significand << 1);
    375     }
    376     significand &= fraction_encode_mask;
    377     return significand;
    378   }
    379 
    380   // Returns true if this number represents a negative value.
    381   bool isNegative() const { return (getBits() & sign_mask) != 0; }
    382 
    383   // Sets this HexFloat from the individual components.
    384   // Note this assumes EVERY significand is normalized, and has an implicit
    385   // leading one. This means that the only way that this method will set 0,
    386   // is if you set a number so denormalized that it underflows.
    387   // Do not use this method with raw bits extracted from a subnormal number,
    388   // since subnormals do not have an implicit leading 1 in the significand.
    389   // The significand is also expected to be in the
    390   // lowest-most num_fraction_bits of the uint_type.
    391   // The exponent is expected to be unbiased, meaning an exponent of
    392   // 0 actually means 0.
    393   // If underflow_round_up is set, then on underflow, if a number is non-0
    394   // and would underflow, we round up to the smallest denorm.
    395   void setFromSignUnbiasedExponentAndNormalizedSignificand(
    396       bool negative, int_type exponent, uint_type significand,
    397       bool round_denorm_up) {
    398     bool significand_is_zero = significand == 0;
    399 
    400     if (exponent <= min_exponent) {
    401       // If this was denormalized, then we have to shift the bit on, meaning
    402       // the significand is not zero.
    403       significand_is_zero = false;
    404       significand |= first_exponent_bit;
    405       significand = static_cast<uint_type>(significand >> 1);
    406     }
    407 
    408     while (exponent < min_exponent) {
    409       significand = static_cast<uint_type>(significand >> 1);
    410       ++exponent;
    411     }
    412 
    413     if (exponent == min_exponent) {
    414       if (significand == 0 && !significand_is_zero && round_denorm_up) {
    415         significand = static_cast<uint_type>(0x1);
    416       }
    417     }
    418 
    419     uint_type new_value = 0;
    420     if (negative) {
    421       new_value = static_cast<uint_type>(new_value | sign_mask);
    422     }
    423     exponent = static_cast<int_type>(exponent + exponent_bias);
    424     assert(exponent >= 0);
    425 
    426     // put it all together
    427     exponent = static_cast<uint_type>((exponent << exponent_left_shift) &
    428                                       exponent_mask);
    429     significand = static_cast<uint_type>(significand & fraction_encode_mask);
    430     new_value = static_cast<uint_type>(new_value | (exponent | significand));
    431     value_ = BitwiseCast<T>(new_value);
    432   }
    433 
    434   // Increments the significand of this number by the given amount.
    435   // If this would spill the significand into the implicit bit,
    436   // carry is set to true and the significand is shifted to fit into
    437   // the correct location, otherwise carry is set to false.
    438   // All significands and to_increment are assumed to be within the bounds
    439   // for a valid significand.
    440   static uint_type incrementSignificand(uint_type significand,
    441                                         uint_type to_increment, bool* carry) {
    442     significand = static_cast<uint_type>(significand + to_increment);
    443     *carry = false;
    444     if (significand & first_exponent_bit) {
    445       *carry = true;
    446       // The implicit 1-bit will have carried, so we should zero-out the
    447       // top bit and shift back.
    448       significand = static_cast<uint_type>(significand & ~first_exponent_bit);
    449       significand = static_cast<uint_type>(significand >> 1);
    450     }
    451     return significand;
    452   }
    453 
    454   // These exist because MSVC throws warnings on negative right-shifts
    455   // even if they are not going to be executed. Eg:
    456   // constant_number < 0? 0: constant_number
    457   // These convert the negative left-shifts into right shifts.
    458 
    459   template <typename int_type>
    460   uint_type negatable_left_shift(int_type N, uint_type val)
    461   {
    462     if(N >= 0)
    463       return val << N;
    464 
    465     return val >> -N;
    466   }
    467 
    468   template <typename int_type>
    469   uint_type negatable_right_shift(int_type N, uint_type val)
    470   {
    471     if(N >= 0)
    472       return val >> N;
    473 
    474     return val << -N;
    475   }
    476 
    477   // Returns the significand, rounded to fit in a significand in
    478   // other_T. This is shifted so that the most significant
    479   // bit of the rounded number lines up with the most significant bit
    480   // of the returned significand.
    481   template <typename other_T>
    482   typename other_T::uint_type getRoundedNormalizedSignificand(
    483       round_direction dir, bool* carry_bit) {
    484     typedef typename other_T::uint_type other_uint_type;
    485     static const int_type num_throwaway_bits =
    486         static_cast<int_type>(num_fraction_bits) -
    487         static_cast<int_type>(other_T::num_fraction_bits);
    488 
    489     static const uint_type last_significant_bit =
    490         (num_throwaway_bits < 0)
    491             ? 0
    492             : negatable_left_shift(num_throwaway_bits, 1u);
    493     static const uint_type first_rounded_bit =
    494         (num_throwaway_bits < 1)
    495             ? 0
    496             : negatable_left_shift(num_throwaway_bits - 1, 1u);
    497 
    498     static const uint_type throwaway_mask_bits =
    499         num_throwaway_bits > 0 ? num_throwaway_bits : 0;
    500     static const uint_type throwaway_mask =
    501         spvutils::SetBits<uint_type, 0, throwaway_mask_bits>::get;
    502 
    503     *carry_bit = false;
    504     other_uint_type out_val = 0;
    505     uint_type significand = getNormalizedSignificand();
    506     // If we are up-casting, then we just have to shift to the right location.
    507     if (num_throwaway_bits <= 0) {
    508       out_val = static_cast<other_uint_type>(significand);
    509       uint_type shift_amount = static_cast<uint_type>(-num_throwaway_bits);
    510       out_val = static_cast<other_uint_type>(out_val << shift_amount);
    511       return out_val;
    512     }
    513 
    514     // If every non-representable bit is 0, then we don't have any casting to
    515     // do.
    516     if ((significand & throwaway_mask) == 0) {
    517       return static_cast<other_uint_type>(
    518           negatable_right_shift(num_throwaway_bits, significand));
    519     }
    520 
    521     bool round_away_from_zero = false;
    522     // We actually have to narrow the significand here, so we have to follow the
    523     // rounding rules.
    524     switch (dir) {
    525       case kRoundToZero:
    526         break;
    527       case kRoundToPositiveInfinity:
    528         round_away_from_zero = !isNegative();
    529         break;
    530       case kRoundToNegativeInfinity:
    531         round_away_from_zero = isNegative();
    532         break;
    533       case kRoundToNearestEven:
    534         // Have to round down, round bit is 0
    535         if ((first_rounded_bit & significand) == 0) {
    536           break;
    537         }
    538         if (((significand & throwaway_mask) & ~first_rounded_bit) != 0) {
    539           // If any subsequent bit of the rounded portion is non-0 then we round
    540           // up.
    541           round_away_from_zero = true;
    542           break;
    543         }
    544         // We are exactly half-way between 2 numbers, pick even.
    545         if ((significand & last_significant_bit) != 0) {
    546           // 1 for our last bit, round up.
    547           round_away_from_zero = true;
    548           break;
    549         }
    550         break;
    551     }
    552 
    553     if (round_away_from_zero) {
    554       return static_cast<other_uint_type>(
    555           negatable_right_shift(num_throwaway_bits, incrementSignificand(
    556               significand, last_significant_bit, carry_bit)));
    557     } else {
    558       return static_cast<other_uint_type>(
    559           negatable_right_shift(num_throwaway_bits, significand));
    560     }
    561   }
    562 
    563   // Casts this value to another HexFloat. If the cast is widening,
    564   // then round_dir is ignored. If the cast is narrowing, then
    565   // the result is rounded in the direction specified.
    566   // This number will retain Nan and Inf values.
    567   // It will also saturate to Inf if the number overflows, and
    568   // underflow to (0 or min depending on rounding) if the number underflows.
    569   template <typename other_T>
    570   void castTo(other_T& other, round_direction round_dir) {
    571     other = other_T(static_cast<typename other_T::native_type>(0));
    572     bool negate = isNegative();
    573     if (getUnsignedBits() == 0) {
    574       if (negate) {
    575         other.set_value(-other.value());
    576       }
    577       return;
    578     }
    579     uint_type significand = getSignificandBits();
    580     bool carried = false;
    581     typename other_T::uint_type rounded_significand =
    582         getRoundedNormalizedSignificand<other_T>(round_dir, &carried);
    583 
    584     int_type exponent = getUnbiasedExponent();
    585     if (exponent == min_exponent) {
    586       // If we are denormal, normalize the exponent, so that we can encode
    587       // easily.
    588       exponent = static_cast<int_type>(exponent + 1);
    589       for (uint_type check_bit = first_exponent_bit >> 1; check_bit != 0;
    590            check_bit = static_cast<uint_type>(check_bit >> 1)) {
    591         exponent = static_cast<int_type>(exponent - 1);
    592         if (check_bit & significand) break;
    593       }
    594     }
    595 
    596     bool is_nan =
    597         (getBits() & exponent_mask) == exponent_mask && significand != 0;
    598     bool is_inf =
    599         !is_nan &&
    600         ((exponent + carried) > static_cast<int_type>(other_T::exponent_bias) ||
    601          (significand == 0 && (getBits() & exponent_mask) == exponent_mask));
    602 
    603     // If we are Nan or Inf we should pass that through.
    604     if (is_inf) {
    605       other.set_value(BitwiseCast<typename other_T::underlying_type>(
    606           static_cast<typename other_T::uint_type>(
    607               (negate ? other_T::sign_mask : 0) | other_T::exponent_mask)));
    608       return;
    609     }
    610     if (is_nan) {
    611       typename other_T::uint_type shifted_significand;
    612       shifted_significand = static_cast<typename other_T::uint_type>(
    613           negatable_left_shift(
    614               static_cast<int_type>(other_T::num_fraction_bits) -
    615               static_cast<int_type>(num_fraction_bits), significand));
    616 
    617       // We are some sort of Nan. We try to keep the bit-pattern of the Nan
    618       // as close as possible. If we had to shift off bits so we are 0, then we
    619       // just set the last bit.
    620       other.set_value(BitwiseCast<typename other_T::underlying_type>(
    621           static_cast<typename other_T::uint_type>(
    622               (negate ? other_T::sign_mask : 0) | other_T::exponent_mask |
    623               (shifted_significand == 0 ? 0x1 : shifted_significand))));
    624       return;
    625     }
    626 
    627     bool round_underflow_up =
    628         isNegative() ? round_dir == kRoundToNegativeInfinity
    629                      : round_dir == kRoundToPositiveInfinity;
    630     typedef typename other_T::int_type other_int_type;
    631     // setFromSignUnbiasedExponentAndNormalizedSignificand will
    632     // zero out any underflowing value (but retain the sign).
    633     other.setFromSignUnbiasedExponentAndNormalizedSignificand(
    634         negate, static_cast<other_int_type>(exponent), rounded_significand,
    635         round_underflow_up);
    636     return;
    637   }
    638 
    639  private:
    640   T value_;
    641 
    642   static_assert(num_used_bits ==
    643                     Traits::num_exponent_bits + Traits::num_fraction_bits + 1,
    644                 "The number of bits do not fit");
    645   static_assert(sizeof(T) == sizeof(uint_type), "The type sizes do not match");
    646 };
    647 
    648 // Returns 4 bits represented by the hex character.
    649 inline uint8_t get_nibble_from_character(int character) {
    650   const char* dec = "0123456789";
    651   const char* lower = "abcdef";
    652   const char* upper = "ABCDEF";
    653   const char* p = nullptr;
    654   if ((p = strchr(dec, character))) {
    655     return static_cast<uint8_t>(p - dec);
    656   } else if ((p = strchr(lower, character))) {
    657     return static_cast<uint8_t>(p - lower + 0xa);
    658   } else if ((p = strchr(upper, character))) {
    659     return static_cast<uint8_t>(p - upper + 0xa);
    660   }
    661 
    662   assert(false && "This was called with a non-hex character");
    663   return 0;
    664 }
    665 
    666 // Outputs the given HexFloat to the stream.
    667 template <typename T, typename Traits>
    668 std::ostream& operator<<(std::ostream& os, const HexFloat<T, Traits>& value) {
    669   typedef HexFloat<T, Traits> HF;
    670   typedef typename HF::uint_type uint_type;
    671   typedef typename HF::int_type int_type;
    672 
    673   static_assert(HF::num_used_bits != 0,
    674                 "num_used_bits must be non-zero for a valid float");
    675   static_assert(HF::num_exponent_bits != 0,
    676                 "num_exponent_bits must be non-zero for a valid float");
    677   static_assert(HF::num_fraction_bits != 0,
    678                 "num_fractin_bits must be non-zero for a valid float");
    679 
    680   const uint_type bits = spvutils::BitwiseCast<uint_type>(value.value());
    681   const char* const sign = (bits & HF::sign_mask) ? "-" : "";
    682   const uint_type exponent = static_cast<uint_type>(
    683       (bits & HF::exponent_mask) >> HF::num_fraction_bits);
    684 
    685   uint_type fraction = static_cast<uint_type>((bits & HF::fraction_encode_mask)
    686                                               << HF::num_overflow_bits);
    687 
    688   const bool is_zero = exponent == 0 && fraction == 0;
    689   const bool is_denorm = exponent == 0 && !is_zero;
    690 
    691   // exponent contains the biased exponent we have to convert it back into
    692   // the normal range.
    693   int_type int_exponent = static_cast<int_type>(exponent - HF::exponent_bias);
    694   // If the number is all zeros, then we actually have to NOT shift the
    695   // exponent.
    696   int_exponent = is_zero ? 0 : int_exponent;
    697 
    698   // If we are denorm, then start shifting, and decreasing the exponent until
    699   // our leading bit is 1.
    700 
    701   if (is_denorm) {
    702     while ((fraction & HF::fraction_top_bit) == 0) {
    703       fraction = static_cast<uint_type>(fraction << 1);
    704       int_exponent = static_cast<int_type>(int_exponent - 1);
    705     }
    706     // Since this is denormalized, we have to consume the leading 1 since it
    707     // will end up being implicit.
    708     fraction = static_cast<uint_type>(fraction << 1);  // eat the leading 1
    709     fraction &= HF::fraction_represent_mask;
    710   }
    711 
    712   uint_type fraction_nibbles = HF::fraction_nibbles;
    713   // We do not have to display any trailing 0s, since this represents the
    714   // fractional part.
    715   while (fraction_nibbles > 0 && (fraction & 0xF) == 0) {
    716     // Shift off any trailing values;
    717     fraction = static_cast<uint_type>(fraction >> 4);
    718     --fraction_nibbles;
    719   }
    720 
    721   const auto saved_flags = os.flags();
    722   const auto saved_fill = os.fill();
    723 
    724   os << sign << "0x" << (is_zero ? '0' : '1');
    725   if (fraction_nibbles) {
    726     // Make sure to keep the leading 0s in place, since this is the fractional
    727     // part.
    728     os << "." << std::setw(static_cast<int>(fraction_nibbles))
    729        << std::setfill('0') << std::hex << fraction;
    730   }
    731   os << "p" << std::dec << (int_exponent >= 0 ? "+" : "") << int_exponent;
    732 
    733   os.flags(saved_flags);
    734   os.fill(saved_fill);
    735 
    736   return os;
    737 }
    738 
    739 // Returns true if negate_value is true and the next character on the
    740 // input stream is a plus or minus sign.  In that case we also set the fail bit
    741 // on the stream and set the value to the zero value for its type.
    742 template <typename T, typename Traits>
    743 inline bool RejectParseDueToLeadingSign(std::istream& is, bool negate_value,
    744                                         HexFloat<T, Traits>& value) {
    745   if (negate_value) {
    746     auto next_char = is.peek();
    747     if (next_char == '-' || next_char == '+') {
    748       // Fail the parse.  Emulate standard behaviour by setting the value to
    749       // the zero value, and set the fail bit on the stream.
    750       value = HexFloat<T, Traits>(typename HexFloat<T, Traits>::uint_type(0));
    751       is.setstate(std::ios_base::failbit);
    752       return true;
    753     }
    754   }
    755   return false;
    756 }
    757 
    758 // Parses a floating point number from the given stream and stores it into the
    759 // value parameter.
    760 // If negate_value is true then the number may not have a leading minus or
    761 // plus, and if it successfully parses, then the number is negated before
    762 // being stored into the value parameter.
    763 // If the value cannot be correctly parsed or overflows the target floating
    764 // point type, then set the fail bit on the stream.
    765 // TODO(dneto): Promise C++11 standard behavior in how the value is set in
    766 // the error case, but only after all target platforms implement it correctly.
    767 // In particular, the Microsoft C++ runtime appears to be out of spec.
    768 template <typename T, typename Traits>
    769 inline std::istream& ParseNormalFloat(std::istream& is, bool negate_value,
    770                                       HexFloat<T, Traits>& value) {
    771   if (RejectParseDueToLeadingSign(is, negate_value, value)) {
    772     return is;
    773   }
    774   T val;
    775   is >> val;
    776   if (negate_value) {
    777     val = -val;
    778   }
    779   value.set_value(val);
    780   // In the failure case, map -0.0 to 0.0.
    781   if (is.fail() && value.getUnsignedBits() == 0u) {
    782     value = HexFloat<T, Traits>(typename HexFloat<T, Traits>::uint_type(0));
    783   }
    784   if (val.isInfinity()) {
    785     // Fail the parse.  Emulate standard behaviour by setting the value to
    786     // the closest normal value, and set the fail bit on the stream.
    787     value.set_value((value.isNegative() | negate_value) ? T::lowest()
    788                                                         : T::max());
    789     is.setstate(std::ios_base::failbit);
    790   }
    791   return is;
    792 }
    793 
    794 // Specialization of ParseNormalFloat for FloatProxy<Float16> values.
    795 // This will parse the float as it were a 32-bit floating point number,
    796 // and then round it down to fit into a Float16 value.
    797 // The number is rounded towards zero.
    798 // If negate_value is true then the number may not have a leading minus or
    799 // plus, and if it successfully parses, then the number is negated before
    800 // being stored into the value parameter.
    801 // If the value cannot be correctly parsed or overflows the target floating
    802 // point type, then set the fail bit on the stream.
    803 // TODO(dneto): Promise C++11 standard behavior in how the value is set in
    804 // the error case, but only after all target platforms implement it correctly.
    805 // In particular, the Microsoft C++ runtime appears to be out of spec.
    806 template <>
    807 inline std::istream&
    808 ParseNormalFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>(
    809     std::istream& is, bool negate_value,
    810     HexFloat<FloatProxy<Float16>, HexFloatTraits<FloatProxy<Float16>>>& value) {
    811   // First parse as a 32-bit float.
    812   HexFloat<FloatProxy<float>> float_val(0.0f);
    813   ParseNormalFloat(is, negate_value, float_val);
    814 
    815   // Then convert to 16-bit float, saturating at infinities, and
    816   // rounding toward zero.
    817   float_val.castTo(value, kRoundToZero);
    818 
    819   // Overflow on 16-bit behaves the same as for 32- and 64-bit: set the
    820   // fail bit and set the lowest or highest value.
    821   if (Float16::isInfinity(value.value().getAsFloat())) {
    822     value.set_value(value.isNegative() ? Float16::lowest() : Float16::max());
    823     is.setstate(std::ios_base::failbit);
    824   }
    825   return is;
    826 }
    827 
    828 // Reads a HexFloat from the given stream.
    829 // If the float is not encoded as a hex-float then it will be parsed
    830 // as a regular float.
    831 // This may fail if your stream does not support at least one unget.
    832 // Nan values can be encoded with "0x1.<not zero>p+exponent_bias".
    833 // This would normally overflow a float and round to
    834 // infinity but this special pattern is the exact representation for a NaN,
    835 // and therefore is actually encoded as the correct NaN. To encode inf,
    836 // either 0x0p+exponent_bias can be specified or any exponent greater than
    837 // exponent_bias.
    838 // Examples using IEEE 32-bit float encoding.
    839 //    0x1.0p+128 (+inf)
    840 //    -0x1.0p-128 (-inf)
    841 //
    842 //    0x1.1p+128 (+Nan)
    843 //    -0x1.1p+128 (-Nan)
    844 //
    845 //    0x1p+129 (+inf)
    846 //    -0x1p+129 (-inf)
    847 template <typename T, typename Traits>
    848 std::istream& operator>>(std::istream& is, HexFloat<T, Traits>& value) {
    849   using HF = HexFloat<T, Traits>;
    850   using uint_type = typename HF::uint_type;
    851   using int_type = typename HF::int_type;
    852 
    853   value.set_value(static_cast<typename HF::native_type>(0.f));
    854 
    855   if (is.flags() & std::ios::skipws) {
    856     // If the user wants to skip whitespace , then we should obey that.
    857     while (std::isspace(is.peek())) {
    858       is.get();
    859     }
    860   }
    861 
    862   auto next_char = is.peek();
    863   bool negate_value = false;
    864 
    865   if (next_char != '-' && next_char != '0') {
    866     return ParseNormalFloat(is, negate_value, value);
    867   }
    868 
    869   if (next_char == '-') {
    870     negate_value = true;
    871     is.get();
    872     next_char = is.peek();
    873   }
    874 
    875   if (next_char == '0') {
    876     is.get();  // We may have to unget this.
    877     auto maybe_hex_start = is.peek();
    878     if (maybe_hex_start != 'x' && maybe_hex_start != 'X') {
    879       is.unget();
    880       return ParseNormalFloat(is, negate_value, value);
    881     } else {
    882       is.get();  // Throw away the 'x';
    883     }
    884   } else {
    885     return ParseNormalFloat(is, negate_value, value);
    886   }
    887 
    888   // This "looks" like a hex-float so treat it as one.
    889   bool seen_p = false;
    890   bool seen_dot = false;
    891   uint_type fraction_index = 0;
    892 
    893   uint_type fraction = 0;
    894   int_type exponent = HF::exponent_bias;
    895 
    896   // Strip off leading zeros so we don't have to special-case them later.
    897   while ((next_char = is.peek()) == '0') {
    898     is.get();
    899   }
    900 
    901   bool is_denorm =
    902       true;  // Assume denorm "representation" until we hear otherwise.
    903              // NB: This does not mean the value is actually denorm,
    904              // it just means that it was written 0.
    905   bool bits_written = false;  // Stays false until we write a bit.
    906   while (!seen_p && !seen_dot) {
    907     // Handle characters that are left of the fractional part.
    908     if (next_char == '.') {
    909       seen_dot = true;
    910     } else if (next_char == 'p') {
    911       seen_p = true;
    912     } else if (::isxdigit(next_char)) {
    913       // We know this is not denormalized since we have stripped all leading
    914       // zeroes and we are not a ".".
    915       is_denorm = false;
    916       int number = get_nibble_from_character(next_char);
    917       for (int i = 0; i < 4; ++i, number <<= 1) {
    918         uint_type write_bit = (number & 0x8) ? 0x1 : 0x0;
    919         if (bits_written) {
    920           // If we are here the bits represented belong in the fractional
    921           // part of the float, and we have to adjust the exponent accordingly.
    922           fraction = static_cast<uint_type>(
    923               fraction |
    924               static_cast<uint_type>(
    925                   write_bit << (HF::top_bit_left_shift - fraction_index++)));
    926           exponent = static_cast<int_type>(exponent + 1);
    927         }
    928         bits_written |= write_bit != 0;
    929       }
    930     } else {
    931       // We have not found our exponent yet, so we have to fail.
    932       is.setstate(std::ios::failbit);
    933       return is;
    934     }
    935     is.get();
    936     next_char = is.peek();
    937   }
    938   bits_written = false;
    939   while (seen_dot && !seen_p) {
    940     // Handle only fractional parts now.
    941     if (next_char == 'p') {
    942       seen_p = true;
    943     } else if (::isxdigit(next_char)) {
    944       int number = get_nibble_from_character(next_char);
    945       for (int i = 0; i < 4; ++i, number <<= 1) {
    946         uint_type write_bit = (number & 0x8) ? 0x01 : 0x00;
    947         bits_written |= write_bit != 0;
    948         if (is_denorm && !bits_written) {
    949           // Handle modifying the exponent here this way we can handle
    950           // an arbitrary number of hex values without overflowing our
    951           // integer.
    952           exponent = static_cast<int_type>(exponent - 1);
    953         } else {
    954           fraction = static_cast<uint_type>(
    955               fraction |
    956               static_cast<uint_type>(
    957                   write_bit << (HF::top_bit_left_shift - fraction_index++)));
    958         }
    959       }
    960     } else {
    961       // We still have not found our 'p' exponent yet, so this is not a valid
    962       // hex-float.
    963       is.setstate(std::ios::failbit);
    964       return is;
    965     }
    966     is.get();
    967     next_char = is.peek();
    968   }
    969 
    970   bool seen_sign = false;
    971   int8_t exponent_sign = 1;
    972   int_type written_exponent = 0;
    973   while (true) {
    974     if ((next_char == '-' || next_char == '+')) {
    975       if (seen_sign) {
    976         is.setstate(std::ios::failbit);
    977         return is;
    978       }
    979       seen_sign = true;
    980       exponent_sign = (next_char == '-') ? -1 : 1;
    981     } else if (::isdigit(next_char)) {
    982       // Hex-floats express their exponent as decimal.
    983       written_exponent = static_cast<int_type>(written_exponent * 10);
    984       written_exponent =
    985           static_cast<int_type>(written_exponent + (next_char - '0'));
    986     } else {
    987       break;
    988     }
    989     is.get();
    990     next_char = is.peek();
    991   }
    992 
    993   written_exponent = static_cast<int_type>(written_exponent * exponent_sign);
    994   exponent = static_cast<int_type>(exponent + written_exponent);
    995 
    996   bool is_zero = is_denorm && (fraction == 0);
    997   if (is_denorm && !is_zero) {
    998     fraction = static_cast<uint_type>(fraction << 1);
    999     exponent = static_cast<int_type>(exponent - 1);
   1000   } else if (is_zero) {
   1001     exponent = 0;
   1002   }
   1003 
   1004   if (exponent <= 0 && !is_zero) {
   1005     fraction = static_cast<uint_type>(fraction >> 1);
   1006     fraction |= static_cast<uint_type>(1) << HF::top_bit_left_shift;
   1007   }
   1008 
   1009   fraction = (fraction >> HF::fraction_right_shift) & HF::fraction_encode_mask;
   1010 
   1011   const int_type max_exponent =
   1012       SetBits<uint_type, 0, HF::num_exponent_bits>::get;
   1013 
   1014   // Handle actual denorm numbers
   1015   while (exponent < 0 && !is_zero) {
   1016     fraction = static_cast<uint_type>(fraction >> 1);
   1017     exponent = static_cast<int_type>(exponent + 1);
   1018 
   1019     fraction &= HF::fraction_encode_mask;
   1020     if (fraction == 0) {
   1021       // We have underflowed our fraction. We should clamp to zero.
   1022       is_zero = true;
   1023       exponent = 0;
   1024     }
   1025   }
   1026 
   1027   // We have overflowed so we should be inf/-inf.
   1028   if (exponent > max_exponent) {
   1029     exponent = max_exponent;
   1030     fraction = 0;
   1031   }
   1032 
   1033   uint_type output_bits = static_cast<uint_type>(
   1034       static_cast<uint_type>(negate_value ? 1 : 0) << HF::top_bit_left_shift);
   1035   output_bits |= fraction;
   1036 
   1037   uint_type shifted_exponent = static_cast<uint_type>(
   1038       static_cast<uint_type>(exponent << HF::exponent_left_shift) &
   1039       HF::exponent_mask);
   1040   output_bits |= shifted_exponent;
   1041 
   1042   T output_float = spvutils::BitwiseCast<T>(output_bits);
   1043   value.set_value(output_float);
   1044 
   1045   return is;
   1046 }
   1047 
   1048 // Writes a FloatProxy value to a stream.
   1049 // Zero and normal numbers are printed in the usual notation, but with
   1050 // enough digits to fully reproduce the value.  Other values (subnormal,
   1051 // NaN, and infinity) are printed as a hex float.
   1052 template <typename T>
   1053 std::ostream& operator<<(std::ostream& os, const FloatProxy<T>& value) {
   1054   auto float_val = value.getAsFloat();
   1055   switch (std::fpclassify(float_val)) {
   1056     case FP_ZERO:
   1057     case FP_NORMAL: {
   1058       auto saved_precision = os.precision();
   1059       os.precision(std::numeric_limits<T>::digits10);
   1060       os << float_val;
   1061       os.precision(saved_precision);
   1062     } break;
   1063     default:
   1064       os << HexFloat<FloatProxy<T>>(value);
   1065       break;
   1066   }
   1067   return os;
   1068 }
   1069 
   1070 template <>
   1071 inline std::ostream& operator<<<Float16>(std::ostream& os,
   1072                                          const FloatProxy<Float16>& value) {
   1073   os << HexFloat<FloatProxy<Float16>>(value);
   1074   return os;
   1075 }
   1076 }
   1077 
   1078 #endif  // LIBSPIRV_UTIL_HEX_FLOAT_H_
   1079