Home | History | Annotate | Download | only in fixedpoint
      1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // fixedpoint.h: fixed-point arithmetic, with basic operations and
     16 // a few math functions such as tanh.
     17 
     18 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
     19 #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_
     20 
     21 #include <cassert>
     22 #include <limits>
     23 
     24 #include "../internal/common.h"
     25 
     26 namespace gemmlowp {
     27 
     28 // Part 1: Low-level integer-arithmetic primitives.
     29 // The implementations here are generic implementations valid for
     30 // scalar types (e.g. std::int32_t). Architecture-specific SIMD types
     31 // (e.g. NEON int32x4_t) may be supported by providing
     32 // specializations for them in separate files.
     33 //
     34 // The purpose of these primitives is two-fold:
     35 //  - They will be used to implement higher-level fixed-point
     36 //    abstractions, namely the FixedPoint class and its arithmetic
     37 //    operators.
     38 //  - They will be directly used to implement some more involved
     39 //    fixed-point computations, e.g. the fixed-point implementation
     40 //    of math functions such as tanh.
     41 
     42 // Some compile-time traits around raw types to handle SIMD aspects:
     43 // number of lanes, underlying scalar type.
     44 template <typename tIntegerType>
     45 struct FixedPointRawTypeTraits {};
     46 
     47 template <>
     48 struct FixedPointRawTypeTraits<std::int32_t> {
     49   typedef std::int32_t ScalarRawType;
     50   static const int kLanes = 1;
     51 };
     52 
     53 // Returns a SIMD value duplicating a scalar value across all lanes.
     54 template <typename tRawType>
     55 tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
     56   return x;
     57 }
     58 
     59 // Plain bit-wise AND
     60 template <typename tIntegerType>
     61 tIntegerType BitAnd(tIntegerType a, tIntegerType b) {
     62   return a & b;
     63 }
     64 
     65 // Plain bit-wise OR
     66 template <typename tIntegerType>
     67 tIntegerType BitOr(tIntegerType a, tIntegerType b) {
     68   return a | b;
     69 }
     70 
     71 // Plain bit-wise XOR
     72 template <typename tIntegerType>
     73 tIntegerType BitXor(tIntegerType a, tIntegerType b) {
     74   return a ^ b;
     75 }
     76 
     77 // Plain bit-wise NOT
     78 template <typename tIntegerType>
     79 tIntegerType BitNot(tIntegerType a) {
     80   return ~a;
     81 }
     82 
     83 // Integer addition. Not saturating. Overflow is undefined behavior.
     84 template <typename tIntegerType>
     85 tIntegerType Add(tIntegerType a, tIntegerType b) {
     86   return a + b;
     87 }
     88 
     89 // Integer subtraction. Not saturating. Overflow is undefined behavior.
     90 template <typename tIntegerType>
     91 tIntegerType Mul(tIntegerType a, tIntegerType b) {
     92   return a * b;
     93 }
     94 
     95 template <typename tIntegerType>
     96 tIntegerType Sub(tIntegerType a, tIntegerType b) {
     97   return a - b;
     98 }
     99 
    100 // Integer unary negative. Not saturating. Overflow is undefined behavior.
    101 template <typename tIntegerType>
    102 tIntegerType Neg(tIntegerType a) {
    103   return -a;
    104 }
    105 
    106 // Integer arithmetic left-shift, equivalent to multiplying with a
    107 // power of two. Not saturating. Overflow is undefined behavior.
    108 template <typename tIntegerType>
    109 tIntegerType ShiftLeft(tIntegerType a, int offset) {
    110   return a << offset;
    111 }
    112 
    113 // Integer arithmetic right-shift. Not rounding.
    114 // Relying on implementation-defined, but in-practice-consistent,
    115 // C++ compiler behavior.
    116 template <typename tIntegerType>
    117 tIntegerType ShiftRight(tIntegerType a, int offset) {
    118   return a >> offset;
    119 }
    120 
    121 // Each bit of the result is set to the corresponding bit of either then_val or
    122 // else_val depending on whether the corresponding bit of if_mask is set.
    123 // Equivalent to the VBSL instruction in ARM NEON.
    124 template <typename tIntegerType>
    125 tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val,
    126                              tIntegerType else_val) {
    127   return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val));
    128 }
    129 
    130 // For each input scalar, the corresponding bits of the result are set if the
    131 // input scalar is non-zero.
    132 template <typename tIntegerType>
    133 tIntegerType MaskIfNonZero(tIntegerType a) {
    134   static const tIntegerType zero = 0;
    135   return a ? BitNot(zero) : zero;
    136 }
    137 
    138 // For each input scalar, the corresponding bits of the result are set if the
    139 // input scalar is zero.
    140 template <typename tIntegerType>
    141 tIntegerType MaskIfZero(tIntegerType a) {
    142   return MaskIfNonZero<tIntegerType>(!a);
    143 }
    144 
    145 // For each pair of input scalars, the corresponding bits of the result are
    146 // set if the input scalars are equal.
    147 template <typename tIntegerType>
    148 tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) {
    149   return MaskIfNonZero<tIntegerType>(a == b);
    150 }
    151 
    152 // For each pair of input scalars, the corresponding bits of the result are
    153 // set if the input scalars are not equal.
    154 template <typename tIntegerType>
    155 tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) {
    156   return MaskIfNonZero<tIntegerType>(a != b);
    157 }
    158 
    159 // For each pair of input scalars, the corresponding bits of the result are
    160 // set if the input scalars a, b satisfy a > b.
    161 template <typename tIntegerType>
    162 tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) {
    163   return MaskIfNonZero<tIntegerType>(a > b);
    164 }
    165 
    166 // For each pair of input scalars, the corresponding bits of the result are
    167 // set if the input scalars a, b satisfy a >= b.
    168 template <typename tIntegerType>
    169 tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) {
    170   return MaskIfNonZero<tIntegerType>(a >= b);
    171 }
    172 
    173 // For each pair of input scalars, the corresponding bits of the result are
    174 // set if the input scalars a, b satisfy a < b.
    175 template <typename tIntegerType>
    176 tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) {
    177   return MaskIfNonZero<tIntegerType>(a < b);
    178 }
    179 
    180 // For each pair of input scalars, the corresponding bits of the result are
    181 // set if the input scalars a, b satisfy a <= b.
    182 template <typename tIntegerType>
    183 tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) {
    184   return MaskIfNonZero<tIntegerType>(a <= b);
    185 }
    186 
    187 // Returns true if all of the input scalars are nonzero.
    188 // This function may currently assume that each of the input scalars has either
    189 // all or none of its bits set. Otherwise, its behavior is currently undefined.
    190 template <typename tIntegerType>
    191 bool All(tIntegerType a) {
    192   return a;
    193 }
    194 
    195 // Returns true if any of the input scalars are nonzero.
    196 // This function may currently assume that each of the input scalars has either
    197 // all or none of its bits set. Otherwise, its behavior is currently undefined.
    198 template <typename tIntegerType>
    199 bool Any(tIntegerType a) {
    200   return a;
    201 }
    202 
    203 // Returns (a+b)/2, rounded to the nearest integer.
    204 // Equivalent to VRHADD in the ARM NEON instruction set.
    205 template <typename IntegerType>
    206 IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
    207   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
    208   return a;
    209 }
    210 
    211 template <>
    212 inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) {
    213   std::int64_t a64 = a;
    214   std::int64_t b64 = b;
    215   std::int64_t sum = a64 + b64;
    216   std::int64_t sign = sum >= 0 ? 1 : -1;
    217   return static_cast<std::int32_t>((sum + sign) / 2);
    218 }
    219 
    220 // Returns the integer that represents the product of two fixed-point
    221 // numbers, interpreting all integers as fixed-point values in the
    222 // interval [-1, 1), rounding to the nearest value, and saturating
    223 // -1 * -1 to the maximum value (since 1 is not in the half-open
    224 // interval [-1, 1)).
    225 //
    226 // [The explanation below specializes to std::int32_t for example purpose.]
    227 //
    228 // The mapping between IntegerType and the interval [-1, 1) is unique and
    229 // implied by IntegerType, which is assumed to be signed. For example,
    230 // for IntegerType==std::int32_t, the mapping is
    231 //   real_value = integer_value / 2^31.
    232 // So in this case, and leaving aside rounding and saturating, this
    233 // function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to
    234 //   (a * b) / 2^31.
    235 //
    236 // The 'doubling' part in the name of this function comes from the fact that
    237 // this operation is very close to a "multiply-high" operation, keeping only
    238 // the top half bits, except that that would be effectively computing
    239 //   (a * b) / 2^32,
    240 // so here we are computing 2x that, since
    241 //   1/2^31 = 2 * 1/2^32.
    242 // The idea is to use all of the available 32 bits in the destination int32
    243 // value.
    244 //
    245 // [End of the explanation specializing to int32.]
    246 //
    247 // This is equivalent to the VQRDMULH instruction in ARM NEON.
    248 template <typename IntegerType>
    249 IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
    250   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
    251   return a;
    252 }
    253 
    254 // This function implements the same computation as the ARMv7 NEON VQRDMULH
    255 // instruction.
    256 template <>
    257 inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a,
    258                                                       std::int32_t b) {
    259   bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
    260   std::int64_t a_64(a);
    261   std::int64_t b_64(b);
    262   std::int64_t ab_64 = a_64 * b_64;
    263   std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
    264   std::int32_t ab_x2_high32 =
    265       static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
    266   return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
    267 }
    268 
    269 // Correctly-rounded-to-nearest division by a power-of-two.
    270 // Also known as a rounding arithmetic right shift.
    271 template <typename IntegerType>
    272 inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) {
    273   using ScalarIntegerType =
    274       typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
    275   static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value,
    276                 "Currently only supporting int32 scalar and SIMD types");
    277   assert(exponent >= 0);
    278   assert(exponent <= 31);
    279   const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
    280   const IntegerType zero = Dup<IntegerType>(0);
    281   const IntegerType one = Dup<IntegerType>(1);
    282   const IntegerType remainder = BitAnd(x, mask);
    283   const IntegerType threshold =
    284       Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one));
    285   return Add(ShiftRight(x, exponent),
    286              BitAnd(MaskIfGreaterThan(remainder, threshold), one));
    287 }
    288 
    289 // Returns the product of a run-time integer value by a compile-time power
    290 // of two, with either a positive exponent (equivalent to an arithmetic
    291 // left shift, saturating) or a negative exponent (equivalent to an arithmetic
    292 // right shift, rounding to nearest).
    293 template <int Exponent, typename IntegerType,
    294           int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
    295 struct ImplSaturatingRoundingMultiplyByPOT {};
    296 
    297 template <int Exponent, typename IntegerType>
    298 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> {
    299   static IntegerType eval(IntegerType x) { return x; }
    300 };
    301 
    302 template <int Exponent, typename IntegerType>
    303 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> {
    304   static IntegerType eval(IntegerType x) {
    305     using ScalarIntegerType =
    306         typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
    307     static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value,
    308                   "Currently only supporting int32 scalar and SIMD types");
    309     const IntegerType min =
    310         Dup<IntegerType>(std::numeric_limits<std::int32_t>::min());
    311     const IntegerType max =
    312         Dup<IntegerType>(std::numeric_limits<std::int32_t>::max());
    313 
    314     const std::int32_t threshold = ((1 << (31 - Exponent)) - 1);
    315     const IntegerType positive_mask =
    316         MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
    317     const IntegerType negative_mask =
    318         MaskIfLessThan(x, Dup<IntegerType>(-threshold));
    319 
    320     IntegerType result = ShiftLeft(x, Exponent);
    321     result = SelectUsingMask(positive_mask, max, result);
    322     result = SelectUsingMask(negative_mask, min, result);
    323     return result;
    324   }
    325 };
    326 
    327 template <int Exponent, typename IntegerType>
    328 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> {
    329   static IntegerType eval(IntegerType x) {
    330     return RoundingDivideByPOT<IntegerType>(x, -Exponent);
    331   }
    332 };
    333 
    334 template <int Exponent, typename IntegerType>
    335 IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) {
    336   return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x);
    337 }
    338 
    339 // Part 2: the FixedPoint class.
    340 
    341 // A FixedPoint object represents a fixed-point value stored in the underlying
    342 // integer type tRawType, if tRawType is a plain scalar integer type.
    343 // Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which
    344 // case a FixedPoint object represents a corresponding SIMD vector of fixed
    345 // point values.
    346 //
    347 // tIntegerBits describes the range of the fixed-point format: if
    348 // tIntegerBits == m then the range of representable values is the half-open
    349 // interval [-2^m; 2^m) where the open boundary on the right side means that
    350 // 2^m is not representable (how close the maximum representable value is to
    351 // it, depends on bit-depth of tRawType).
    352 //
    353 // In "Q format notation",
    354 //   https://en.wikipedia.org/wiki/Q_(number_format)
    355 // we are describing the format
    356 //   Qm.n
    357 // where
    358 //   m = tIntegerBits
    359 // and
    360 //   n = NumberOfBits(tRawType) - (m + 1)
    361 // Note that the (m + 1) in the above line is because we adopt the convention
    362 // that we count the integer bits exclusively of the sign bit; so (m + 1) is
    363 // the total number of integer bits inclusive of the sign bit.
    364 //
    365 // Accordingly, the number of integral representable values in our range
    366 //   [-2^m ; 2^m)
    367 // is equal to 2^(m+1).
    368 template <typename tRawType, int tIntegerBits>
    369 class FixedPoint {
    370  public:
    371   typedef tRawType RawType;
    372 
    373   typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
    374   typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
    375 
    376   static const int kTotalBits = 8 * sizeof(ScalarRawType);
    377   static const int kIntegerBits = tIntegerBits;
    378   static const int kFractionalBits = kTotalBits - 1 - kIntegerBits;
    379   static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
    380                 "bad IntegerBits");
    381 
    382   typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType;
    383 
    384   static const ScalarRawType ScalarRawMin() {
    385     return std::numeric_limits<ScalarRawType>::min();
    386   }
    387 
    388   static const ScalarRawType ScalarRawMax() {
    389     return std::numeric_limits<ScalarRawType>::max();
    390   }
    391 
    392   static const ScalarRawType RawMin() {
    393     return VectorFromScalar(ScalarRawMin());
    394   }
    395 
    396   static const ScalarRawType RawMax() {
    397     return VectorFromScalar(ScalarRawMax());
    398   }
    399 
    400   static FixedPoint FromRaw(RawType x) {
    401     FixedPoint retval;
    402     retval.raw() = x;
    403     return retval;
    404   }
    405 
    406   static FixedPoint FromScalarRaw(ScalarRawType x) {
    407     FixedPoint retval;
    408     retval.raw() = Dup<RawType>(x);
    409     return retval;
    410   }
    411 
    412   static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) {
    413     return FromScalarRaw(x.raw());
    414   }
    415 
    416   template <int Exponent>
    417   static FixedPoint ConstantPOT() {
    418     static const int kOffset = kFractionalBits + Exponent;
    419     static_assert(
    420         kOffset < 31,
    421         "Constant not exactly representable in this fixed-point format");
    422     return FromScalarRaw(ScalarRawType(1) << kOffset);
    423   }
    424 
    425   static FixedPoint Zero() { return FromScalarRaw(0); }
    426 
    427   static FixedPoint One() {
    428     return FromScalarRaw(kIntegerBits == 0
    429                              ? ScalarRawMax()
    430                              : (ScalarRawType(1) << kFractionalBits));
    431   }
    432 
    433   static FixedPoint FromDouble(double x) {
    434     const double min_bound = static_cast<double>(ScalarRawMin());
    435     const double max_bound = static_cast<double>(ScalarRawMax());
    436     return FromScalarRaw(static_cast<std::int32_t>(std::min(
    437         std::max(round(x * static_cast<double>(1ll << kFractionalBits)),
    438                  min_bound),
    439         max_bound)));
    440   }
    441 
    442   RawType raw() const { return i_; }
    443   RawType& raw() { return i_; }
    444 
    445  private:
    446   RawType i_;
    447 };
    448 
    449 // Part 3: implementation of arithmetic operators for the
    450 // FixedPoint class, and a few related functions.
    451 
    452 // A FixedPoint multiplication is just a
    453 // SaturatingRoundingDoublingHighMul operation on the underlying
    454 // raw integer values. The IntegerBits simply add up, as is obvious
    455 // from the fact that the range is [-2^IntegerBits, 2^IntegerBits).
    456 template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b>
    457 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(
    458     FixedPoint<tRawType, tIntegerBits_a> a,
    459     FixedPoint<tRawType, tIntegerBits_b> b) {
    460   FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c;
    461   c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw());
    462   return c;
    463 }
    464 
    465 // Tweaking IntegerBits gives exact multiplication by a power of two.
    466 template <int tExponent, typename tRawType, int tIntegerBits>
    467 FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(
    468     FixedPoint<tRawType, tIntegerBits> a) {
    469   FixedPoint<tRawType, tExponent + tIntegerBits> c;
    470   c.raw() = a.raw();
    471   return c;
    472 }
    473 
    474 // If we want to leave IntegerBits fixed, then multiplication
    475 // by a power of two has to be saturating/rounding, not exact anymore.
    476 template <int tExponent, typename tRawType, int tIntegerBits>
    477 FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(
    478     FixedPoint<tRawType, tIntegerBits> a) {
    479   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
    480       SaturatingRoundingMultiplyByPOT<tExponent>(a.raw()));
    481 }
    482 
    483 // Generic arithmetic operators.
    484 
    485 #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName)                     \
    486   template <typename tRawType, int tIntegerBits>                               \
    487   FixedPoint<tRawType, tIntegerBits> FuncName(                                 \
    488       FixedPoint<tRawType, tIntegerBits> a) {                                  \
    489     return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \
    490   }
    491 
    492 #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \
    493   template <typename tRawType, int tIntegerBits>            \
    494   FixedPoint<tRawType, tIntegerBits> FuncName(              \
    495       FixedPoint<tRawType, tIntegerBits> a,                 \
    496       FixedPoint<tRawType, tIntegerBits> b) {               \
    497     return FixedPoint<tRawType, tIntegerBits>::FromRaw(     \
    498         ImplFuncName(a.raw(), b.raw()));                    \
    499   }
    500 
    501 MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg)
    502 MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot)
    503 MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add)
    504 MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub)
    505 MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd)
    506 MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor)
    507 MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr)
    508 MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum)
    509 
    510 #undef MAKE_FIXEDPOINT_UNARY_FUNC
    511 #undef MAKE_FIXEDPOINT_BINARY_FUNC
    512 
    513 #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName)  \
    514   template <typename tRawType, int tIntegerBits>            \
    515   tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
    516     return FuncName(a.raw());                               \
    517   }
    518 
    519 #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \
    520   template <typename tRawType, int tIntegerBits>            \
    521   tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a,   \
    522                     FixedPoint<tRawType, tIntegerBits> b) { \
    523     return FuncName(a.raw(), b.raw());                      \
    524   }
    525 
    526 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero)
    527 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero)
    528 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual)
    529 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual)
    530 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan)
    531 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual)
    532 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan)
    533 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual)
    534 
    535 #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW
    536 #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW
    537 
    538 template <typename tRawType, int tIntegerBits>
    539 FixedPoint<tRawType, tIntegerBits> SelectUsingMask(
    540     tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val,
    541     FixedPoint<tRawType, tIntegerBits> else_val) {
    542   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
    543       SelectUsingMask(if_mask, then_val.raw(), else_val.raw()));
    544 }
    545 
    546 template <typename tRawType, int tIntegerBits>
    547 bool operator==(FixedPoint<tRawType, tIntegerBits> a,
    548                 FixedPoint<tRawType, tIntegerBits> b) {
    549   return All(MaskIfEqual(a.raw(), b.raw()));
    550 }
    551 
    552 template <typename tRawType, int tIntegerBits>
    553 bool operator!=(FixedPoint<tRawType, tIntegerBits> a,
    554                 FixedPoint<tRawType, tIntegerBits> b) {
    555   return !(a == b);
    556 }
    557 
    558 // Conversion to floating-point.
    559 template <typename tRawType, int tIntegerBits>
    560 double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
    561   static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1,
    562                 "not applicable to SIMD types");
    563   typedef FixedPoint<tRawType, tIntegerBits> F;
    564   return x.raw() / static_cast<double>(1ll << F::kFractionalBits);
    565 }
    566 
    567 // Rescale changes the number of IntegerBits and updates the underlying
    568 // raw integer value accordingly.
    569 template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
    570 FixedPoint<tRawType, tIntegerBitsDst> Rescale(
    571     FixedPoint<tRawType, tIntegerBitsSrc> x) {
    572   static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
    573   FixedPoint<tRawType, tIntegerBitsDst> result;
    574   result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
    575   return result;
    576 }
    577 
    578 // CheckedFixedPointConstant allows to specify fixed-point constants
    579 // initialized as real numbers, in a way that does not compile floating-point
    580 // arithmetic in production code, yet still checks agreement with the
    581 // floating-point expressions when asserts are enabled.
    582 #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
    583 template <typename FixedPointType>
    584 FixedPointType CheckedFixedPointConstant(
    585     typename FixedPointType::ScalarRawType raw_value, double double_value) {
    586   typedef typename FixedPointType::RawType RawType;
    587   const FixedPointType result = FixedPointType::FromScalarRaw(raw_value);
    588   assert(result == FixedPointType::FromDouble(double_value));
    589   return result;
    590 }
    591 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \
    592                                              DoubleValue)                    \
    593   (CheckedFixedPointConstant<FixedPointType>(ScalarRawValue, DoubleValue))
    594 
    595 #else
    596 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \
    597                                              DoubleValue)                    \
    598   (FixedPointType::FromScalarRaw(ScalarRawValue))
    599 #endif
    600 
    601 // Implementation of exponential function.
    602 
    603 // Returns exp(x) for x in [-1/4, 0).
    604 template <typename tRawType>
    605 FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(
    606     FixedPoint<tRawType, 0> a) {
    607   typedef FixedPoint<tRawType, 0> F;
    608   const F constant_term =
    609       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0));
    610   const F constant_1_over_3 =
    611       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0);
    612   // We're evaluating a Taylor expansion around -1/8, so we do the change of
    613   // variable: x = a + 1/8.
    614   // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
    615   F x = a + F::template ConstantPOT<-3>();
    616   F x2 = x * x;
    617   F x3 = x2 * x;
    618   F x4 = x2 * x2;
    619   F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4);
    620   F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
    621       SaturatingRoundingMultiplyByPOT<-1>(
    622           ((x4_over_4 + x3) * constant_1_over_3) + x2);
    623   return constant_term +
    624          constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2);
    625 }
    626 
    627 // Returns exp(x) for x < 0.
    628 template <typename tRawType, int tIntegerBits>
    629 FixedPoint<tRawType, 0> exp_on_negative_values(
    630     FixedPoint<tRawType, tIntegerBits> a) {
    631   typedef FixedPoint<tRawType, tIntegerBits> InputF;
    632   typedef FixedPoint<tRawType, 0> ResultF;
    633   static const int kFractionalBits = InputF::kFractionalBits;
    634   static const int kIntegerBits = InputF::kIntegerBits;
    635   static const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
    636   InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
    637   InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
    638   ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
    639       Rescale<0>(a_mod_quarter_minus_one_quarter));
    640   tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw();
    641 
    642 #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier)         \
    643   if (kIntegerBits > Exponent) {                                            \
    644     const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(       \
    645         ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \
    646     static constexpr int kShiftAmount =                                     \
    647         kIntegerBits > Exponent ? kFractionalBits + Exponent : 0;           \
    648     result = SelectUsingMask(                                               \
    649         MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \
    650         result * kMultiplier, result);                                      \
    651   }
    652 
    653   GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);
    654   GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);
    655   GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);
    656   GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);
    657   GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);
    658   GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);
    659   GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);
    660 
    661 #undef GEMMLOWP_EXP_BARREL_SHIFTER
    662 
    663   if (kIntegerBits > 5) {
    664     static const int b = kIntegerBits > 5 ? kFractionalBits + 5 : 0;
    665     const InputF clamp =
    666         GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0);
    667     result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
    668   }
    669 
    670   result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result);
    671   return result;
    672 }
    673 
    674 // Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)).
    675 
    676 // Returns (1 - x) / (1 + x) for x in (0, 1).
    677 template <typename tRawType>
    678 FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1(
    679     FixedPoint<tRawType, 0> a) {
    680   typedef FixedPoint<tRawType, 0> F0;
    681   typedef FixedPoint<tRawType, 2> F2;
    682   F0 half_denominator = RoundingHalfSum(a, F0::One());
    683   // Newton-Raphson division
    684   // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
    685   // Refer to that page for the logic behind the 48/17 and 32/17 constants.
    686   const F2 constant_48_over_17 =
    687       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
    688   const F2 constant_neg_32_over_17 =
    689       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
    690   F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
    691   for (int i = 0; i < 3; i++) {
    692     F2 half_denominator_times_x = half_denominator * x;
    693     F2 one_minus_half_denominator_times_x =
    694         F2::One() - half_denominator_times_x;
    695     x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
    696   }
    697   return Rescale<0>(x - F2::One());
    698 }
    699 
    700 // Returns -tanh(x) for x < 0.
    701 template <typename tRawType, int tIntegerBits>
    702 FixedPoint<tRawType, 0> neg_tanh_on_negative_values(
    703     FixedPoint<tRawType, tIntegerBits> a) {
    704   return one_minus_x_over_one_plus_x_for_x_in_0_1(
    705       exp_on_negative_values(ExactMulByPot<1>(a)));
    706 }
    707 
    708 // Returns tanh(x) for any x.
    709 template <typename tRawType, int tIntegerBits>
    710 FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) {
    711   typedef FixedPoint<tRawType, tIntegerBits> InputF;
    712   typedef FixedPoint<tRawType, 0> ResultF;
    713   tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero());
    714   tRawType mask_if_zero = MaskIfZero(a);
    715   InputF n = SelectUsingMask(mask_if_negative, a, -a);
    716   ResultF t = neg_tanh_on_negative_values(n);
    717   return SelectUsingMask(mask_if_zero, ResultF::Zero(),
    718                          SelectUsingMask(mask_if_negative, -t, t));
    719 }
    720 
    721 // Implementation of logistic function.
    722 
    723 // Returns 1 / (1 + x) for x in (0, 1).
    724 template <typename tRawType>
    725 FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1(
    726     FixedPoint<tRawType, 0> a) {
    727   typedef FixedPoint<tRawType, 0> F0;
    728   typedef FixedPoint<tRawType, 2> F2;
    729   F0 half_denominator = RoundingHalfSum(a, F0::One());
    730   // Newton-Raphson division
    731   // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
    732   // Refer to that page for the logic behind the 48/17 and 32/17 constants.
    733   const F2 constant_48_over_17 =
    734       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
    735   const F2 constant_neg_32_over_17 =
    736       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
    737   F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
    738   for (int i = 0; i < 3; i++) {
    739     F2 half_denominator_times_x = half_denominator * x;
    740     F2 one_minus_half_denominator_times_x =
    741         F2::One() - half_denominator_times_x;
    742     x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
    743   }
    744   return Rescale<0>(ExactMulByPot<-1>(x));
    745 }
    746 
    747 // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0.
    748 template <typename tRawType, int tIntegerBits>
    749 FixedPoint<tRawType, 0> logistic_on_positive_values(
    750     FixedPoint<tRawType, tIntegerBits> a) {
    751   return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a));
    752 }
    753 
    754 // Returns logistic(x) = 1 / (1 + exp(-x)) for any x.
    755 template <typename tRawType, int tIntegerBits>
    756 FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
    757   typedef FixedPoint<tRawType, tIntegerBits> InputF;
    758   typedef FixedPoint<tRawType, 0> ResultF;
    759   tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero());
    760   tRawType mask_if_zero = MaskIfZero(a);
    761   InputF abs_input = SelectUsingMask(mask_if_positive, a, -a);
    762   ResultF result_if_positive = logistic_on_positive_values(abs_input);
    763   ResultF result_if_negative = ResultF::One() - result_if_positive;
    764   const ResultF one_half =
    765       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5);
    766   return SelectUsingMask(mask_if_zero, one_half,
    767                          SelectUsingMask(mask_if_positive, result_if_positive,
    768                                          result_if_negative));
    769 }
    770 
    771 }  // end namespace gemmlowp
    772 
    773 #ifdef GEMMLOWP_NEON
    774 #include "./fixedpoint_neon.h"
    775 #elif defined(GEMMLOWP_SSE4)
    776 #include "./fixedpoint_sse.h"
    777 #endif
    778 
    779 #endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_H_
    780