Home | History | Annotate | Download | only in fixedpoint
      1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // fixedpoint.h: fixed-point arithmetic, with basic operations and
     16 // a few math functions such as tanh.
     17 
     18 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
     19 #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_
     20 
     21 #include <cassert>
     22 #include <limits>
     23 
     24 #include "../internal/common.h"
     25 
     26 namespace gemmlowp {
     27 
     28 // Part 1: Low-level integer-arithmetic primitives.
     29 // The implementations here are generic implementations valid for
     30 // scalar types (e.g. std::int32_t). Architecture-specific SIMD types
     31 // (e.g. NEON int32x4_t) may be supported by providing
     32 // specializations for them in separate files.
     33 //
     34 // The purpose of these primitives is two-fold:
     35 //  - They will be used to implement higher-level fixed-point
     36 //    abstractions, namely the FixedPoint class and its arithmetic
     37 //    operators.
     38 //  - They will be directly used to implement some more involved
     39 //    fixed-point computations, e.g. the fixed-point implementation
     40 //    of math functions such as tanh.
     41 
     42 // Some compile-time traits around raw types to handle SIMD aspects:
     43 // number of lanes, underlying scalar type.
     44 template <typename tIntegerType>
     45 struct FixedPointRawTypeTraits {};
     46 
     47 template <>
     48 struct FixedPointRawTypeTraits<std::int32_t> {
     49   typedef std::int32_t ScalarRawType;
     50   static const int kLanes = 1;
     51 };
     52 
     53 template <>
     54 struct FixedPointRawTypeTraits<std::int16_t> {
     55   typedef std::int16_t ScalarRawType;
     56   static const int kLanes = 1;
     57 };
     58 
     59 // Returns a SIMD value duplicating a scalar value across all lanes.
     60 template <typename tRawType>
     61 tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
     62   return x;
     63 }
     64 
     65 // Plain bit-wise AND
     66 template <typename tIntegerType>
     67 tIntegerType BitAnd(tIntegerType a, tIntegerType b) {
     68   return a & b;
     69 }
     70 
     71 // Plain bit-wise OR
     72 template <typename tIntegerType>
     73 tIntegerType BitOr(tIntegerType a, tIntegerType b) {
     74   return a | b;
     75 }
     76 
     77 // Plain bit-wise XOR
     78 template <typename tIntegerType>
     79 tIntegerType BitXor(tIntegerType a, tIntegerType b) {
     80   return a ^ b;
     81 }
     82 
     83 // Plain bit-wise NOT
     84 template <typename tIntegerType>
     85 tIntegerType BitNot(tIntegerType a) {
     86   return ~a;
     87 }
     88 
     89 // Integer addition. Not saturating. Overflow is undefined behavior.
     90 template <typename tIntegerType>
     91 tIntegerType Add(tIntegerType a, tIntegerType b) {
     92   return a + b;
     93 }
     94 
     95 // Integer subtraction. Not saturating. Overflow is undefined behavior.
     96 template <typename tIntegerType>
     97 tIntegerType Mul(tIntegerType a, tIntegerType b) {
     98   return a * b;
     99 }
    100 
    101 template <typename tIntegerType>
    102 tIntegerType Sub(tIntegerType a, tIntegerType b) {
    103   return a - b;
    104 }
    105 
    106 // Integer unary negative. Not saturating. Overflow is undefined behavior.
    107 template <typename tIntegerType>
    108 tIntegerType Neg(tIntegerType a) {
    109   return -a;
    110 }
    111 
    112 // Integer arithmetic left-shift, equivalent to multiplying with a
    113 // power of two. Not saturating. Overflow is undefined behavior.
    114 template <typename tIntegerType>
    115 tIntegerType ShiftLeft(tIntegerType a, int offset) {
    116   return a << offset;
    117 }
    118 
    119 // Integer arithmetic right-shift. Not rounding.
    120 // Relying on implementation-defined, but in-practice-consistent,
    121 // C++ compiler behavior.
    122 template <typename tIntegerType>
    123 tIntegerType ShiftRight(tIntegerType a, int offset) {
    124   return a >> offset;
    125 }
    126 
    127 // Each bit of the result is set to the corresponding bit of either then_val or
    128 // else_val depending on whether the corresponding bit of if_mask is set.
    129 // Equivalent to the VBSL instruction in ARM NEON.
    130 template <typename tIntegerType>
    131 tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val,
    132                              tIntegerType else_val) {
    133   return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val));
    134 }
    135 
    136 // For each input scalar, the corresponding bits of the result are set if the
    137 // input scalar is non-zero.
    138 template <typename tIntegerType>
    139 tIntegerType MaskIfNonZero(tIntegerType a) {
    140   static const tIntegerType zero = 0;
    141   return a ? BitNot(zero) : zero;
    142 }
    143 
    144 // For each input scalar, the corresponding bits of the result are set if the
    145 // input scalar is zero.
    146 template <typename tIntegerType>
    147 tIntegerType MaskIfZero(tIntegerType a) {
    148   return MaskIfNonZero<tIntegerType>(!a);
    149 }
    150 
    151 // For each pair of input scalars, the corresponding bits of the result are
    152 // set if the input scalars are equal.
    153 template <typename tIntegerType>
    154 tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) {
    155   return MaskIfNonZero<tIntegerType>(a == b);
    156 }
    157 
    158 // For each pair of input scalars, the corresponding bits of the result are
    159 // set if the input scalars are not equal.
    160 template <typename tIntegerType>
    161 tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) {
    162   return MaskIfNonZero<tIntegerType>(a != b);
    163 }
    164 
    165 // For each pair of input scalars, the corresponding bits of the result are
    166 // set if the input scalars a, b satisfy a > b.
    167 template <typename tIntegerType>
    168 tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) {
    169   return MaskIfNonZero<tIntegerType>(a > b);
    170 }
    171 
    172 // For each pair of input scalars, the corresponding bits of the result are
    173 // set if the input scalars a, b satisfy a >= b.
    174 template <typename tIntegerType>
    175 tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) {
    176   return MaskIfNonZero<tIntegerType>(a >= b);
    177 }
    178 
    179 // For each pair of input scalars, the corresponding bits of the result are
    180 // set if the input scalars a, b satisfy a < b.
    181 template <typename tIntegerType>
    182 tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) {
    183   return MaskIfNonZero<tIntegerType>(a < b);
    184 }
    185 
    186 // For each pair of input scalars, the corresponding bits of the result are
    187 // set if the input scalars a, b satisfy a <= b.
    188 template <typename tIntegerType>
    189 tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) {
    190   return MaskIfNonZero<tIntegerType>(a <= b);
    191 }
    192 
    193 // Returns true if all of the input scalars are nonzero.
    194 // This function may currently assume that each of the input scalars has either
    195 // all or none of its bits set. Otherwise, its behavior is currently undefined.
    196 template <typename tIntegerType>
    197 bool All(tIntegerType a) {
    198   return a;
    199 }
    200 
    201 // Returns true if any of the input scalars are nonzero.
    202 // This function may currently assume that each of the input scalars has either
    203 // all or none of its bits set. Otherwise, its behavior is currently undefined.
    204 template <typename tIntegerType>
    205 bool Any(tIntegerType a) {
    206   return a;
    207 }
    208 
    209 // Returns (a+b)/2, rounded to the nearest integer.
    210 // Equivalent to VRHADD in the ARM NEON instruction set.
    211 template <typename IntegerType>
    212 IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
    213   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
    214   return a;
    215 }
    216 
    217 template <>
    218 inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) {
    219   std::int64_t a64 = a;
    220   std::int64_t b64 = b;
    221   std::int64_t sum = a64 + b64;
    222   std::int64_t sign = sum >= 0 ? 1 : -1;
    223   return static_cast<std::int32_t>((sum + sign) / 2);
    224 }
    225 
    226 template <>
    227 inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) {
    228   std::int32_t a32 = a;
    229   std::int32_t b32 = b;
    230   std::int32_t sum = a32 + b32;
    231   std::int32_t sign = sum >= 0 ? 1 : -1;
    232   return static_cast<std::int16_t>((sum + sign) / 2);
    233 }
    234 
    235 template <typename IntegerType>
    236 IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
    237   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
    238   return a;
    239 }
    240 
    241 // So far this is only needed for int16.
    242 template <>
    243 inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
    244   std::int32_t a32 = a;
    245   std::int32_t b32 = b;
    246   std::int32_t sum = a32 + b32;
    247   return static_cast<std::int16_t>(std::min(32767, std::max(-32768, sum)));
    248 }
    249 
    250 // Returns a+b, saturating if the integers are 16bit or narrower,
    251 // otherwise just a plain addition.
    252 template <typename IntegerType, bool Is16Bit>
    253 struct AddSaturatingIf16BitImpl {
    254   static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); }
    255 };
    256 template <typename IntegerType>
    257 struct AddSaturatingIf16BitImpl<IntegerType, true> {
    258   static IntegerType Run(IntegerType a, IntegerType b) {
    259     return SaturatingAdd(a, b);
    260   }
    261 };
    262 template <typename IntegerType>
    263 IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
    264   using ScalarType =
    265       typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
    266   return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a,
    267                                                                              b);
    268 }
    269 
    270 // Returns the integer that represents the product of two fixed-point
    271 // numbers, interpreting all integers as fixed-point values in the
    272 // interval [-1, 1), rounding to the nearest value, and saturating
    273 // -1 * -1 to the maximum value (since 1 is not in the half-open
    274 // interval [-1, 1)).
    275 //
    276 // [The explanation below specializes to std::int32_t for example purpose.]
    277 //
    278 // The mapping between IntegerType and the interval [-1, 1) is unique and
    279 // implied by IntegerType, which is assumed to be signed. For example,
    280 // for IntegerType==std::int32_t, the mapping is
    281 //   real_value = integer_value / 2^31.
    282 // So in this case, and leaving aside rounding and saturating, this
    283 // function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to
    284 //   (a * b) / 2^31.
    285 //
    286 // The 'doubling' part in the name of this function comes from the fact that
    287 // this operation is very close to a "multiply-high" operation, keeping only
    288 // the top half bits, except that that would be effectively computing
    289 //   (a * b) / 2^32,
    290 // so here we are computing 2x that, since
    291 //   1/2^31 = 2 * 1/2^32.
    292 // The idea is to use all of the available 32 bits in the destination int32
    293 // value.
    294 //
    295 // [End of the explanation specializing to int32.]
    296 //
    297 // This is equivalent to the VQRDMULH instruction in ARM NEON.
    298 template <typename IntegerType>
    299 IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
    300   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
    301   return a;
    302 }
    303 
    304 // This function implements the same computation as the ARMv7 NEON VQRDMULH
    305 // instruction.
    306 template <>
    307 inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a,
    308                                                       std::int32_t b) {
    309   bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
    310   std::int64_t a_64(a);
    311   std::int64_t b_64(b);
    312   std::int64_t ab_64 = a_64 * b_64;
    313   std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
    314   std::int32_t ab_x2_high32 =
    315       static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
    316   return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
    317 }
    318 
    319 template <>
    320 inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a,
    321                                                       std::int16_t b) {
    322   bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
    323   std::int32_t a_32(a);
    324   std::int32_t b_32(b);
    325   std::int32_t ab_32 = a_32 * b_32;
    326   std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14));
    327   std::int16_t ab_x2_high16 =
    328       static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15));
    329   return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
    330 }
    331 
    332 // Correctly-rounded-to-nearest division by a power-of-two.
    333 // Also known as a rounding arithmetic right shift.
    334 template <typename IntegerType>
    335 inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) {
    336   assert(exponent >= 0);
    337   assert(exponent <= 31);
    338   const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
    339   const IntegerType zero = Dup<IntegerType>(0);
    340   const IntegerType one = Dup<IntegerType>(1);
    341   const IntegerType remainder = BitAnd(x, mask);
    342   const IntegerType threshold =
    343       Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one));
    344   return Add(ShiftRight(x, exponent),
    345              BitAnd(MaskIfGreaterThan(remainder, threshold), one));
    346 }
    347 
    348 // Returns the product of a run-time integer value by a compile-time power
    349 // of two, with either a positive exponent (equivalent to an arithmetic
    350 // left shift, saturating) or a negative exponent (equivalent to an arithmetic
    351 // right shift, rounding to nearest).
    352 template <int Exponent, typename IntegerType,
    353           int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
    354 struct ImplSaturatingRoundingMultiplyByPOT {};
    355 
    356 template <int Exponent, typename IntegerType>
    357 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> {
    358   static IntegerType eval(IntegerType x) { return x; }
    359 };
    360 
    361 template <int Exponent, typename IntegerType>
    362 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> {
    363   static IntegerType eval(IntegerType x) {
    364     using ScalarIntegerType =
    365         typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
    366     const IntegerType min =
    367         Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
    368     const IntegerType max =
    369         Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
    370     const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
    371 
    372     const std::int32_t threshold =
    373         ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1);
    374     const IntegerType positive_mask =
    375         MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
    376     const IntegerType negative_mask =
    377         MaskIfLessThan(x, Dup<IntegerType>(-threshold));
    378 
    379     IntegerType result = ShiftLeft(x, Exponent);
    380     result = SelectUsingMask(positive_mask, max, result);
    381     result = SelectUsingMask(negative_mask, min, result);
    382     return result;
    383   }
    384 };
    385 
    386 template <int Exponent, typename IntegerType>
    387 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> {
    388   static IntegerType eval(IntegerType x) {
    389     return RoundingDivideByPOT<IntegerType>(x, -Exponent);
    390   }
    391 };
    392 
    393 template <int Exponent, typename IntegerType>
    394 IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) {
    395   return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x);
    396 }
    397 
    398 // Part 2: the FixedPoint class.
    399 
    400 // A FixedPoint object represents a fixed-point value stored in the underlying
    401 // integer type tRawType, if tRawType is a plain scalar integer type.
    402 // Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which
    403 // case a FixedPoint object represents a corresponding SIMD vector of fixed
    404 // point values.
    405 //
    406 // tIntegerBits describes the range of the fixed-point format: if
    407 // tIntegerBits == m then the range of representable values is the half-open
    408 // interval [-2^m; 2^m) where the open boundary on the right side means that
    409 // 2^m is not representable (how close the maximum representable value is to
    410 // it, depends on bit-depth of tRawType).
    411 //
    412 // In "Q format notation",
    413 //   https://en.wikipedia.org/wiki/Q_(number_format)
    414 // we are describing the format
    415 //   Qm.n
    416 // where
    417 //   m = tIntegerBits
    418 // and
    419 //   n = NumberOfBits(tRawType) - (m + 1)
    420 // Note that the (m + 1) in the above line is because we adopt the convention
    421 // that we count the integer bits exclusively of the sign bit; so (m + 1) is
    422 // the total number of integer bits inclusive of the sign bit.
    423 //
    424 // Accordingly, the number of integral representable values in our range
    425 //   [-2^m ; 2^m)
    426 // is equal to 2^(m+1).
    427 template <typename tRawType, int tIntegerBits>
    428 class FixedPoint {
    429  public:
    430   typedef tRawType RawType;
    431 
    432   typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
    433   typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
    434 
    435   static const int kTotalBits = 8 * sizeof(ScalarRawType);
    436   static const int kIntegerBits = tIntegerBits;
    437   static const int kFractionalBits = kTotalBits - 1 - kIntegerBits;
    438   static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
    439                 "bad IntegerBits");
    440 
    441   typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType;
    442 
    443   static const ScalarRawType ScalarRawMin() {
    444     return std::numeric_limits<ScalarRawType>::min();
    445   }
    446 
    447   static const ScalarRawType ScalarRawMax() {
    448     return std::numeric_limits<ScalarRawType>::max();
    449   }
    450 
    451   static const ScalarRawType RawMin() {
    452     return VectorFromScalar(ScalarRawMin());
    453   }
    454 
    455   static const ScalarRawType RawMax() {
    456     return VectorFromScalar(ScalarRawMax());
    457   }
    458 
    459   static FixedPoint FromRaw(RawType x) {
    460     FixedPoint retval;
    461     retval.raw() = x;
    462     return retval;
    463   }
    464 
    465   static FixedPoint FromScalarRaw(ScalarRawType x) {
    466     FixedPoint retval;
    467     retval.raw() = Dup<RawType>(x);
    468     return retval;
    469   }
    470 
    471   static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) {
    472     return FromScalarRaw(x.raw());
    473   }
    474 
    475   template <int Exponent>
    476   static FixedPoint ConstantPOT() {
    477     static const int kOffset = kFractionalBits + Exponent;
    478     static_assert(
    479         kOffset < 31,
    480         "Constant not exactly representable in this fixed-point format");
    481     return FromScalarRaw(ScalarRawType(1) << kOffset);
    482   }
    483 
    484   static FixedPoint Zero() { return FromScalarRaw(0); }
    485 
    486   static FixedPoint One() {
    487     return FromScalarRaw(
    488         kIntegerBits == 0
    489             ? ScalarRawMax()
    490             : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits)));
    491   }
    492 
    493   static FixedPoint FromDouble(double x) {
    494     const double min_bound = static_cast<double>(ScalarRawMin());
    495     const double max_bound = static_cast<double>(ScalarRawMax());
    496     return FromScalarRaw(static_cast<ScalarRawType>(std::min(
    497         std::max(round(x * static_cast<double>(1ll << kFractionalBits)),
    498                  min_bound),
    499         max_bound)));
    500   }
    501 
    502   RawType raw() const { return i_; }
    503   RawType& raw() { return i_; }
    504 
    505  private:
    506   RawType i_;
    507 };
    508 
    509 // Part 3: implementation of arithmetic operators for the
    510 // FixedPoint class, and a few related functions.
    511 
    512 // A FixedPoint multiplication is just a
    513 // SaturatingRoundingDoublingHighMul operation on the underlying
    514 // raw integer values. The IntegerBits simply add up, as is obvious
    515 // from the fact that the range is [-2^IntegerBits, 2^IntegerBits).
    516 template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b>
    517 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(
    518     FixedPoint<tRawType, tIntegerBits_a> a,
    519     FixedPoint<tRawType, tIntegerBits_b> b) {
    520   FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c;
    521   c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw());
    522   return c;
    523 }
    524 
    525 // Tweaking IntegerBits gives exact multiplication by a power of two.
    526 template <int tExponent, typename tRawType, int tIntegerBits>
    527 FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(
    528     FixedPoint<tRawType, tIntegerBits> a) {
    529   FixedPoint<tRawType, tExponent + tIntegerBits> c;
    530   c.raw() = a.raw();
    531   return c;
    532 }
    533 
    534 // If we want to leave IntegerBits fixed, then multiplication
    535 // by a power of two has to be saturating/rounding, not exact anymore.
    536 template <int tExponent, typename tRawType, int tIntegerBits>
    537 FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(
    538     FixedPoint<tRawType, tIntegerBits> a) {
    539   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
    540       SaturatingRoundingMultiplyByPOT<tExponent>(a.raw()));
    541 }
    542 
    543 // Generic arithmetic operators.
    544 
    545 #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName)                     \
    546   template <typename tRawType, int tIntegerBits>                               \
    547   FixedPoint<tRawType, tIntegerBits> FuncName(                                 \
    548       FixedPoint<tRawType, tIntegerBits> a) {                                  \
    549     return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \
    550   }
    551 
    552 #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \
    553   template <typename tRawType, int tIntegerBits>            \
    554   FixedPoint<tRawType, tIntegerBits> FuncName(              \
    555       FixedPoint<tRawType, tIntegerBits> a,                 \
    556       FixedPoint<tRawType, tIntegerBits> b) {               \
    557     return FixedPoint<tRawType, tIntegerBits>::FromRaw(     \
    558         ImplFuncName(a.raw(), b.raw()));                    \
    559   }
    560 
    561 MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg)
    562 MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot)
    563 MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add)
    564 MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub)
    565 MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd)
    566 MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor)
    567 MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr)
    568 MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum)
    569 
    570 #undef MAKE_FIXEDPOINT_UNARY_FUNC
    571 #undef MAKE_FIXEDPOINT_BINARY_FUNC
    572 
    573 #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName)  \
    574   template <typename tRawType, int tIntegerBits>            \
    575   tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
    576     return FuncName(a.raw());                               \
    577   }
    578 
    579 #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \
    580   template <typename tRawType, int tIntegerBits>            \
    581   tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a,   \
    582                     FixedPoint<tRawType, tIntegerBits> b) { \
    583     return FuncName(a.raw(), b.raw());                      \
    584   }
    585 
    586 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero)
    587 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero)
    588 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual)
    589 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual)
    590 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan)
    591 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual)
    592 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan)
    593 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual)
    594 
    595 #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW
    596 #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW
    597 
    598 template <typename tRawType, int tIntegerBits>
    599 FixedPoint<tRawType, tIntegerBits> SelectUsingMask(
    600     tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val,
    601     FixedPoint<tRawType, tIntegerBits> else_val) {
    602   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
    603       SelectUsingMask(if_mask, then_val.raw(), else_val.raw()));
    604 }
    605 
    606 template <typename tRawType, int tIntegerBits>
    607 bool operator==(FixedPoint<tRawType, tIntegerBits> a,
    608                 FixedPoint<tRawType, tIntegerBits> b) {
    609   return All(MaskIfEqual(a.raw(), b.raw()));
    610 }
    611 
    612 template <typename tRawType, int tIntegerBits>
    613 bool operator!=(FixedPoint<tRawType, tIntegerBits> a,
    614                 FixedPoint<tRawType, tIntegerBits> b) {
    615   return !(a == b);
    616 }
    617 
    618 template <typename tRawType, int tIntegerBits>
    619 FixedPoint<tRawType, tIntegerBits> SaturatingAdd(
    620     FixedPoint<tRawType, tIntegerBits> a,
    621     FixedPoint<tRawType, tIntegerBits> b) {
    622   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
    623       SaturatingAdd(a.raw(), b.raw()));
    624 }
    625 
    626 template <typename tRawType, int tIntegerBits>
    627 FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(
    628     FixedPoint<tRawType, tIntegerBits> a,
    629     FixedPoint<tRawType, tIntegerBits> b) {
    630   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
    631       AddSaturatingIf16Bit(a.raw(), b.raw()));
    632 }
    633 
    634 // Conversion to floating-point.
    635 template <typename tRawType, int tIntegerBits>
    636 double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
    637   static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1,
    638                 "not applicable to SIMD types");
    639   typedef FixedPoint<tRawType, tIntegerBits> F;
    640   return x.raw() / static_cast<double>(1ll << F::kFractionalBits);
    641 }
    642 
    643 // Rescale changes the number of IntegerBits and updates the underlying
    644 // raw integer value accordingly.
    645 template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
    646 FixedPoint<tRawType, tIntegerBitsDst> Rescale(
    647     FixedPoint<tRawType, tIntegerBitsSrc> x) {
    648   static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
    649   FixedPoint<tRawType, tIntegerBitsDst> result;
    650   result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
    651   return result;
    652 }
    653 
    654 // CheckedFixedPointConstant allows to specify fixed-point constants
    655 // initialized as real numbers, in a way that does not compile floating-point
    656 // arithmetic in production code, yet still checks agreement with the
    657 // floating-point expressions when asserts are enabled.
    658 //
    659 // The raw integer value provided is always a int32, encoding a 32-bit
    660 // fixed-point value, regardless of the actual Scalar type. This allows
    661 // writing generic code that applies just as well to the 32-bit and 16-bit
    662 // cases. In the 16-bit case, the raw integer value is internally
    663 // rounding-shifted by 16 bits to the right.
    664 template <typename FixedPointType>
    665 inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(
    666     std::int32_t int32_value) {
    667   typedef typename FixedPointType::ScalarRawType ScalarRawType;
    668   static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType);
    669   return static_cast<ScalarRawType>(
    670       RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits));
    671 }
    672 #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
    673 template <typename FixedPointType>
    674 FixedPointType CheckedFixedPointConstant(std::int32_t raw_value,
    675                                          double double_value) {
    676   const FixedPointType result = FixedPointType::FromScalarRaw(raw_value);
    677   assert(result == FixedPointType::FromDouble(double_value));
    678   return result;
    679 }
    680 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType,                   \
    681                                              ScalarRawInt32Value, DoubleValue) \
    682   (gemmlowp::CheckedFixedPointConstant<FixedPointType>(                        \
    683       gemmlowp::RescaleConstantInitializer<FixedPointType>(                    \
    684           ScalarRawInt32Value),                                                \
    685       DoubleValue))
    686 
    687 #else
    688 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType,                   \
    689                                              ScalarRawInt32Value, DoubleValue) \
    690   (FixedPointType::FromScalarRaw(                                              \
    691       gemmlowp::RescaleConstantInitializer<FixedPointType>(                    \
    692           ScalarRawInt32Value)))
    693 #endif
    694 
    695 // Implementation of exponential function.
    696 
    697 // Returns exp(x) for x in [-1/4, 0).
    698 template <typename tRawType>
    699 FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(
    700     FixedPoint<tRawType, 0> a) {
    701   typedef FixedPoint<tRawType, 0> F;
    702   const F constant_term =
    703       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0));
    704   const F constant_1_over_3 =
    705       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0);
    706   // We're evaluating a Taylor expansion around -1/8, so we do the change of
    707   // variable: x = a + 1/8.
    708   // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
    709   F x = a + F::template ConstantPOT<-3>();
    710   F x2 = x * x;
    711   F x3 = x2 * x;
    712   F x4 = x2 * x2;
    713   F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4);
    714   F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
    715       SaturatingRoundingMultiplyByPOT<-1>(
    716           ((x4_over_4 + x3) * constant_1_over_3) + x2);
    717   return AddSaturatingIf16Bit(
    718       constant_term,
    719       constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
    720 }
    721 
    722 // Returns exp(x) for x < 0.
    723 template <typename tRawType, int tIntegerBits>
    724 FixedPoint<tRawType, 0> exp_on_negative_values(
    725     FixedPoint<tRawType, tIntegerBits> a) {
    726   typedef FixedPoint<tRawType, tIntegerBits> InputF;
    727   typedef FixedPoint<tRawType, 0> ResultF;
    728   static const int kFractionalBits = InputF::kFractionalBits;
    729   static const int kIntegerBits = InputF::kIntegerBits;
    730   static const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
    731   InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
    732   InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
    733   ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
    734       Rescale<0>(a_mod_quarter_minus_one_quarter));
    735   tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw();
    736 
    737 #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier)         \
    738   if (kIntegerBits > Exponent) {                                            \
    739     const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(       \
    740         ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \
    741     static constexpr int kShiftAmount =                                     \
    742         kIntegerBits > Exponent ? kFractionalBits + Exponent : 0;           \
    743     result = SelectUsingMask(                                               \
    744         MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \
    745         result * kMultiplier, result);                                      \
    746   }
    747 
    748   GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);
    749   GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);
    750   GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);
    751   GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);
    752   GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);
    753   GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);
    754   GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);
    755 
    756 #undef GEMMLOWP_EXP_BARREL_SHIFTER
    757 
    758   if (kIntegerBits > 5) {
    759     static const int b = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
    760     const InputF clamp =
    761         GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0);
    762     result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
    763   }
    764 
    765   result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result);
    766   return result;
    767 }
    768 
    769 // Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)).
    770 
    771 // Returns (1 - x) / (1 + x) for x in (0, 1).
    772 template <typename tRawType>
    773 FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1(
    774     FixedPoint<tRawType, 0> a) {
    775   typedef FixedPoint<tRawType, 0> F0;
    776   typedef FixedPoint<tRawType, 2> F2;
    777   F0 half_denominator = RoundingHalfSum(a, F0::One());
    778   // Newton-Raphson division
    779   // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
    780   // Refer to that page for the logic behind the 48/17 and 32/17 constants.
    781   const F2 constant_48_over_17 =
    782       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
    783   const F2 constant_neg_32_over_17 =
    784       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
    785   F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
    786   for (int i = 0; i < 3; i++) {
    787     F2 half_denominator_times_x = half_denominator * x;
    788     F2 one_minus_half_denominator_times_x =
    789         F2::One() - half_denominator_times_x;
    790     x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
    791   }
    792   return Rescale<0>(x - F2::One());
    793 }
    794 
    795 // Returns -tanh(x) for x < 0.
    796 template <typename tRawType, int tIntegerBits>
    797 FixedPoint<tRawType, 0> neg_tanh_on_negative_values(
    798     FixedPoint<tRawType, tIntegerBits> a) {
    799   return one_minus_x_over_one_plus_x_for_x_in_0_1(
    800       exp_on_negative_values(ExactMulByPot<1>(a)));
    801 }
    802 
    803 // Returns tanh(x) for any x.
    804 template <typename tRawType, int tIntegerBits>
    805 FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) {
    806   typedef FixedPoint<tRawType, tIntegerBits> InputF;
    807   typedef FixedPoint<tRawType, 0> ResultF;
    808   tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero());
    809   tRawType mask_if_zero = MaskIfZero(a);
    810   InputF n = SelectUsingMask(mask_if_negative, a, -a);
    811   ResultF t = neg_tanh_on_negative_values(n);
    812   return SelectUsingMask(mask_if_zero, ResultF::Zero(),
    813                          SelectUsingMask(mask_if_negative, -t, t));
    814 }
    815 
    816 // Implementation of logistic function.
    817 
    818 // Returns 1 / (1 + x) for x in (0, 1).
    819 template <typename tRawType>
    820 FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1(
    821     FixedPoint<tRawType, 0> a) {
    822   typedef FixedPoint<tRawType, 0> F0;
    823   typedef FixedPoint<tRawType, 2> F2;
    824   F0 half_denominator = RoundingHalfSum(a, F0::One());
    825   // Newton-Raphson division
    826   // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
    827   // Refer to that page for the logic behind the 48/17 and 32/17 constants.
    828   const F2 constant_48_over_17 =
    829       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
    830   const F2 constant_neg_32_over_17 =
    831       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
    832   F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
    833   for (int i = 0; i < 3; i++) {
    834     F2 half_denominator_times_x = half_denominator * x;
    835     F2 one_minus_half_denominator_times_x =
    836         F2::One() - half_denominator_times_x;
    837     x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
    838   }
    839   return Rescale<0>(ExactMulByPot<-1>(x));
    840 }
    841 
    842 // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0.
    843 template <typename tRawType, int tIntegerBits>
    844 FixedPoint<tRawType, 0> logistic_on_positive_values(
    845     FixedPoint<tRawType, tIntegerBits> a) {
    846   return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a));
    847 }
    848 
    849 // Returns logistic(x) = 1 / (1 + exp(-x)) for any x.
    850 template <typename tRawType, int tIntegerBits>
    851 FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
    852   typedef FixedPoint<tRawType, tIntegerBits> InputF;
    853   typedef FixedPoint<tRawType, 0> ResultF;
    854   tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero());
    855   tRawType mask_if_zero = MaskIfZero(a);
    856   InputF abs_input = SelectUsingMask(mask_if_positive, a, -a);
    857   ResultF result_if_positive = logistic_on_positive_values(abs_input);
    858   ResultF result_if_negative = ResultF::One() - result_if_positive;
    859   const ResultF one_half =
    860       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5);
    861   return SelectUsingMask(mask_if_zero, one_half,
    862                          SelectUsingMask(mask_if_positive, result_if_positive,
    863                                          result_if_negative));
    864 }
    865 
    866 }  // end namespace gemmlowp
    867 
    868 #ifdef GEMMLOWP_NEON
    869 #include "./fixedpoint_neon.h"
    870 #elif defined(GEMMLOWP_SSE4)
    871 #include "./fixedpoint_sse.h"
    872 #elif defined(GEMMLOWP_MSA)
    873 #include "./fixedpoint_msa.h"
    874 #endif
    875 
    876 #endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_H_
    877