Home | History | Annotate | Download | only in internal
      1 // Copyright 2015 Google Inc. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // fixedpoint.h: fixed-point arithmetic, with basic operations and
     16 // a few math functions such as tanh.
     17 
     18 // This is only used in output.h
     19 // for some specific output pipeline stages (tanh); most of gemmlowp
     20 // uses only plain integer arithmetic, not fixed-point arithmetic.
     21 // At the most basic level, we distinguish between plain integer
     22 // arithmetic and fixed-point arithmetic by the type of multiplication
     23 // that is used: plain integer arithmetic uses plain (overflowing)
     24 // integer multiplication, whereas fixed-point arithmetic uses
     25 // "multiply-high" instructions, which means using only the most
     26 // significant bits of the product, or equivalently, multiplying
     27 // fixed-point numbers in the [-1 .. +1] interval.
     28 
     29 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
     30 #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_
     31 
     32 #include "common.h"
     33 
     34 #include <limits>
     35 #include <cassert>
     36 
     37 namespace gemmlowp {
     38 
     39 template <typename tIntegerType>
     40 tIntegerType BitAnd(tIntegerType a, tIntegerType b) {
     41   return a & b;
     42 }
     43 
     44 template <typename tIntegerType>
     45 tIntegerType BitOr(tIntegerType a, tIntegerType b) {
     46   return a | b;
     47 }
     48 
     49 template <typename tIntegerType>
     50 tIntegerType BitXor(tIntegerType a, tIntegerType b) {
     51   return a ^ b;
     52 }
     53 
     54 template <typename tIntegerType>
     55 tIntegerType BitNot(tIntegerType a) {
     56   return ~a;
     57 }
     58 
     59 template <typename tIntegerType>
     60 tIntegerType Add(tIntegerType a, tIntegerType b) {
     61   return a + b;
     62 }
     63 
     64 template <typename tIntegerType>
     65 tIntegerType Sub(tIntegerType a, tIntegerType b) {
     66   return a - b;
     67 }
     68 
     69 template <typename tIntegerType>
     70 tIntegerType Neg(tIntegerType a) {
     71   return -a;
     72 }
     73 
     74 template <typename tIntegerType>
     75 tIntegerType ShiftLeft(tIntegerType a, int offset) {
     76   return a * (1 << offset);
     77 }
     78 
     79 template <typename tIntegerType>
     80 tIntegerType ShiftRight(tIntegerType a, int offset) {
     81   return a / (1 << offset);
     82 }
     83 
     84 template <typename tIntegerType>
     85 tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val,
     86                              tIntegerType else_val) {
     87   return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val));
     88 }
     89 
     90 template <typename tIntegerType>
     91 tIntegerType MaskIfNonZero(tIntegerType a) {
     92   static const tIntegerType zero = 0;
     93   return a ? BitNot(zero) : zero;
     94 }
     95 
     96 template <typename tIntegerType>
     97 tIntegerType MaskIfZero(tIntegerType a) {
     98   return MaskIfNonZero<tIntegerType>(!a);
     99 }
    100 
    101 template <typename tIntegerType>
    102 tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) {
    103   return MaskIfNonZero<tIntegerType>(a == b);
    104 }
    105 
    106 template <typename tIntegerType>
    107 tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) {
    108   return MaskIfNonZero<tIntegerType>(a != b);
    109 }
    110 
    111 template <typename tIntegerType>
    112 tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) {
    113   return MaskIfNonZero<tIntegerType>(a > b);
    114 }
    115 
    116 template <typename tIntegerType>
    117 tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) {
    118   return MaskIfNonZero<tIntegerType>(a >= b);
    119 }
    120 
    121 template <typename tIntegerType>
    122 tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) {
    123   return MaskIfNonZero<tIntegerType>(a < b);
    124 }
    125 
    126 template <typename tIntegerType>
    127 tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) {
    128   return MaskIfNonZero<tIntegerType>(a <= b);
    129 }
    130 
    131 template <typename tIntegerType>
    132 bool All(tIntegerType a) {
    133   return a;
    134 }
    135 
    136 template <typename tIntegerType>
    137 bool Any(tIntegerType a) {
    138   return a;
    139 }
    140 
    141 template <typename IntegerType>
    142 IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
    143   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
    144   return a;
    145 }
    146 
    147 template <>
    148 inline int32_t RoundingHalfSum(int32_t a, int32_t b) {
    149   int64_t a64 = a;
    150   int64_t b64 = b;
    151   int64_t sum = a64 + b64;
    152   int64_t sign = sum >= 0 ? 1 : -1;
    153   return static_cast<int32_t>((sum + sign) / 2);
    154 }
    155 
    156 template <typename IntegerType>
    157 IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
    158   static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
    159   return a;
    160 }
    161 
    162 // This function implements the same computation as the ARMv7 NEON VQRDMULH
    163 // instruction.
    164 template <>
    165 inline int32_t SaturatingRoundingDoublingHighMul(int32_t a, int32_t b) {
    166   bool overflow = a == b && a == std::numeric_limits<int32_t>::min();
    167   int64_t a_64(a);
    168   int64_t b_64(b);
    169   int64_t ab_64 = a_64 * b_64;
    170   int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
    171   int32_t ab_x2_high32 = static_cast<int32_t>((ab_64 + nudge) / (1ll << 31));
    172   return overflow ? std::numeric_limits<int32_t>::max() : ab_x2_high32;
    173 }
    174 
    175 template <int Exponent, typename IntegerType,
    176           int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
    177 struct ImplSaturatingRoundingMultiplyByPOT {};
    178 
    179 template <int Exponent, typename IntegerType>
    180 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> {
    181   static IntegerType eval(IntegerType x) { return x; }
    182 };
    183 
    184 template <int Exponent>
    185 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32_t, 1> {
    186   static int32_t eval(int32_t x) {
    187     const int64_t min = std::numeric_limits<int32_t>::min();
    188     const int64_t max = std::numeric_limits<int32_t>::max();
    189     return x >= (1 << (31 - Exponent)) ? max : x <= -(1 << (31 - Exponent))
    190                                                    ? min
    191                                                    : x * (1 << Exponent);
    192   }
    193 };
    194 
    195 template <int Exponent>
    196 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32_t, -1> {
    197   static int32_t eval(int32_t x) {
    198     int32_t b = (std::abs(x) & (1 << (-Exponent - 1))) >> (-Exponent - 1);
    199     int32_t nudge = x >= 0 ? b : -b;
    200     return x / (1 << -Exponent) + nudge;
    201   }
    202 };
    203 
    204 template <int Exponent, typename IntegerType>
    205 IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) {
    206   return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x);
    207 }
    208 
    209 template <typename tIntegerType>
    210 struct FixedPointRawTypeTraits {};
    211 
    212 template <>
    213 struct FixedPointRawTypeTraits<int32_t> {
    214   typedef int32_t ScalarRawType;
    215   static const int kLanes = 1;
    216 };
    217 
    218 template <typename tRawType>
    219 tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
    220   return x;
    221 }
    222 
    223 template <typename tRawType, int tIntegerBits>
    224 class FixedPoint {
    225  public:
    226   typedef tRawType RawType;
    227 
    228   typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
    229   typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
    230 
    231   static const int kTotalBits = 8 * sizeof(ScalarRawType);
    232   static const int kIntegerBits = tIntegerBits;
    233   static const int kFractionalBits = kTotalBits - 1 - kIntegerBits;
    234   static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
    235                 "bad IntegerBits");
    236 
    237   typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType;
    238 
    239   static const ScalarRawType ScalarRawMin() {
    240     return std::numeric_limits<ScalarRawType>::min();
    241   }
    242 
    243   static const ScalarRawType ScalarRawMax() {
    244     return std::numeric_limits<ScalarRawType>::max();
    245   }
    246 
    247   static const ScalarRawType RawMin() {
    248     return VectorFromScalar(ScalarRawMin());
    249   }
    250 
    251   static const ScalarRawType RawMax() {
    252     return VectorFromScalar(ScalarRawMax());
    253   }
    254 
    255   static FixedPoint FromRaw(RawType x) {
    256     FixedPoint retval;
    257     retval.raw() = x;
    258     return retval;
    259   }
    260 
    261   static FixedPoint FromScalarRaw(ScalarRawType x) {
    262     FixedPoint retval;
    263     retval.raw() = Dup<RawType>(x);
    264     return retval;
    265   }
    266 
    267   static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) {
    268     return FromScalarRaw(x.raw());
    269   }
    270 
    271   template <int Exponent>
    272   static FixedPoint ConstantPOT() {
    273     static const int kOffset = kFractionalBits + Exponent;
    274     static_assert(
    275         kOffset < 31,
    276         "Constant not exactly representable in this fixed-point format");
    277     return FromScalarRaw(ScalarRawType(1) << kOffset);
    278   }
    279 
    280   static FixedPoint Zero() { return FromScalarRaw(0); }
    281 
    282   static FixedPoint One() {
    283     return FromScalarRaw(kIntegerBits == 0
    284                              ? ScalarRawMax()
    285                              : (ScalarRawType(1) << kFractionalBits));
    286   }
    287 
    288   RawType raw() const { return i_; }
    289   RawType& raw() { return i_; }
    290 
    291  private:
    292   RawType i_;
    293 };
    294 
    295 template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b>
    296 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(
    297     FixedPoint<tRawType, tIntegerBits_a> a,
    298     FixedPoint<tRawType, tIntegerBits_b> b) {
    299   FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c;
    300   c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw());
    301   return c;
    302 }
    303 
    304 template <int tExponent, typename tRawType, int tIntegerBits>
    305 FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(
    306     FixedPoint<tRawType, tIntegerBits> a) {
    307   FixedPoint<tRawType, tExponent + tIntegerBits> c;
    308   c.raw() = a.raw();
    309   return c;
    310 }
    311 
    312 template <int tExponent, typename tRawType, int tIntegerBits>
    313 FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(
    314     FixedPoint<tRawType, tIntegerBits> a) {
    315   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
    316       SaturatingRoundingMultiplyByPOT<tExponent>(a.raw()));
    317 }
    318 
    319 #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName)                     \
    320   template <typename tRawType, int tIntegerBits>                               \
    321   FixedPoint<tRawType, tIntegerBits> FuncName(                                 \
    322       FixedPoint<tRawType, tIntegerBits> a) {                                  \
    323     return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \
    324   }
    325 
    326 #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \
    327   template <typename tRawType, int tIntegerBits>            \
    328   FixedPoint<tRawType, tIntegerBits> FuncName(              \
    329       FixedPoint<tRawType, tIntegerBits> a,                 \
    330       FixedPoint<tRawType, tIntegerBits> b) {               \
    331     return FixedPoint<tRawType, tIntegerBits>::FromRaw(     \
    332         ImplFuncName(a.raw(), b.raw()));                    \
    333   }
    334 
    335 MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg)
    336 MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot)
    337 MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add)
    338 MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub)
    339 MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd)
    340 MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor)
    341 MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr)
    342 MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum)
    343 
    344 #undef MAKE_FIXEDPOINT_UNARY_FUNC
    345 #undef MAKE_FIXEDPOINT_BINARY_FUNC
    346 
    347 #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName)  \
    348   template <typename tRawType, int tIntegerBits>            \
    349   tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
    350     return FuncName(a.raw());                               \
    351   }
    352 
    353 #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \
    354   template <typename tRawType, int tIntegerBits>            \
    355   tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a,   \
    356                     FixedPoint<tRawType, tIntegerBits> b) { \
    357     return FuncName(a.raw(), b.raw());                      \
    358   }
    359 
    360 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero)
    361 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero)
    362 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual)
    363 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual)
    364 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan)
    365 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual)
    366 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan)
    367 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual)
    368 
    369 #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW
    370 #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW
    371 
    372 template <typename tRawType, int tIntegerBits>
    373 FixedPoint<tRawType, tIntegerBits> SelectUsingMask(
    374     tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val,
    375     FixedPoint<tRawType, tIntegerBits> else_val) {
    376   return FixedPoint<tRawType, tIntegerBits>::FromRaw(
    377       SelectUsingMask(if_mask, then_val.raw(), else_val.raw()));
    378 }
    379 
    380 template <typename tRawType, int tIntegerBits>
    381 bool operator==(FixedPoint<tRawType, tIntegerBits> a,
    382                 FixedPoint<tRawType, tIntegerBits> b) {
    383   return All(MaskIfEqual(a.raw(), b.raw()));
    384 }
    385 
    386 template <typename tRawType, int tIntegerBits>
    387 bool operator!=(FixedPoint<tRawType, tIntegerBits> a,
    388                 FixedPoint<tRawType, tIntegerBits> b) {
    389   return !(a == b);
    390 }
    391 
    392 template <typename tRawType, int tIntegerBits>
    393 double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
    394   static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1,
    395                 "not applicable to SIMD types");
    396   typedef FixedPoint<tRawType, tIntegerBits> F;
    397   return x.raw() / double(1ll << F::kFractionalBits);
    398 }
    399 
    400 template <typename tRawType, int tIntegerBits>
    401 FixedPoint<tRawType, tIntegerBits> ToFixedPoint(double x) {
    402   typedef FixedPoint<tRawType, tIntegerBits> F;
    403   return F::FromScalarRaw(static_cast<int32_t>(
    404       std::min(std::max(round(x * double(1ll << F::kFractionalBits)),
    405                         double(F::ScalarRawMin())),
    406                double(F::ScalarRawMax()))));
    407 }
    408 
    409 template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
    410 FixedPoint<tRawType, tIntegerBitsDst> Rescale(
    411     FixedPoint<tRawType, tIntegerBitsSrc> x) {
    412   static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
    413   FixedPoint<tRawType, tIntegerBitsDst> result;
    414   result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
    415   return result;
    416 }
    417 
    418 #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
    419 template <typename FixedPointType>
    420 FixedPointType CheckedFixedPointConstant(
    421     typename FixedPointType::ScalarRawType raw_value, double double_value) {
    422   typedef typename FixedPointType::RawType RawType;
    423   static const int kIntegerBits = FixedPointType::kIntegerBits;
    424   FixedPointType ref = FixedPointType::FromScalarRaw(raw_value);
    425   FixedPointType check = ToFixedPoint<RawType, kIntegerBits>(double_value);
    426   assert(ref == check);
    427   return ref;
    428 }
    429 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \
    430                                              DoubleValue)                    \
    431   (CheckedFixedPointConstant<FixedPointType>(ScalarRawValue, DoubleValue))
    432 
    433 #else
    434 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \
    435                                              DoubleValue)                    \
    436   (FixedPointType::FromScalarRaw(ScalarRawValue))
    437 #endif
    438 
    439 template <typename tRawType>
    440 FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(
    441     FixedPoint<tRawType, 0> a) {
    442   typedef FixedPoint<tRawType, 0> F;
    443   const F constant_term =
    444       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0));
    445   const F constant_1_over_3 =
    446       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0);
    447   // We're evaluating a Taylor expansion around -1/8, so we do the change of
    448   // variable: x = a + 1/8.
    449   // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
    450   F x = a + F::template ConstantPOT<-3>();
    451   F x2 = x * x;
    452   F x3 = x2 * x;
    453   F x4 = x2 * x2;
    454   F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4);
    455   F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
    456       SaturatingRoundingMultiplyByPOT<-1>(
    457           ((x4_over_4 + x3) * constant_1_over_3) + x2);
    458   return constant_term +
    459          constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2);
    460 }
    461 
    462 template <typename tRawType, int tIntegerBits>
    463 FixedPoint<tRawType, 0> exp_on_negative_values(
    464     FixedPoint<tRawType, tIntegerBits> a) {
    465   typedef FixedPoint<tRawType, tIntegerBits> InputF;
    466   typedef FixedPoint<tRawType, 0> ResultF;
    467   static const int kFractionalBits = InputF::kFractionalBits;
    468   static const int kIntegerBits = InputF::kIntegerBits;
    469   static const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
    470   InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
    471   InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
    472   ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
    473       Rescale<0>(a_mod_quarter_minus_one_quarter));
    474   tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw();
    475 
    476 #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier)         \
    477   if (kIntegerBits > Exponent) {                                            \
    478     const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(       \
    479         ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \
    480     result = SelectUsingMask(                                               \
    481         MaskIfNonZero(BitAnd(                                               \
    482             remainder, Dup<tRawType>(1 << (kFractionalBits + Exponent)))),  \
    483         result * kMultiplier, result);                                      \
    484   }
    485 
    486   GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);
    487   GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);
    488   GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);
    489   GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);
    490   GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);
    491   GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);
    492   GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);
    493 
    494 #undef GEMMLOWP_EXP_BARREL_SHIFTER
    495 
    496   if (kIntegerBits > 5) {
    497     static const int b = kIntegerBits > 5 ? kFractionalBits + 5 : 0;
    498     const InputF clamp =
    499         GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0);
    500     result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
    501   }
    502 
    503   result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result);
    504   return result;
    505 }
    506 
    507 template <typename tRawType>
    508 FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1(
    509     FixedPoint<tRawType, 0> a) {
    510   typedef FixedPoint<tRawType, 0> F0;
    511   typedef FixedPoint<tRawType, 2> F2;
    512   F0 half_denominator = RoundingHalfSum(a, F0::One());
    513   const F2 constant_48_over_17 =
    514       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
    515   const F2 constant_neg_32_over_17 =
    516       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
    517   F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
    518   for (int i = 0; i < 3; i++) {
    519     F2 half_denominator_times_x = half_denominator * x;
    520     F2 one_minus_half_denominator_times_x =
    521         F2::One() - half_denominator_times_x;
    522     x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
    523   }
    524   return Rescale<0>(x - F2::One());
    525 }
    526 
    527 template <typename tRawType, int tIntegerBits>
    528 FixedPoint<tRawType, 0> neg_tanh_on_negative_values(
    529     FixedPoint<tRawType, tIntegerBits> a) {
    530   return one_minus_x_over_one_plus_x_for_x_in_0_1(
    531       exp_on_negative_values(ExactMulByPot<1>(a)));
    532 }
    533 
    534 template <typename tRawType, int tIntegerBits>
    535 FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) {
    536   typedef FixedPoint<tRawType, tIntegerBits> InputF;
    537   typedef FixedPoint<tRawType, 0> ResultF;
    538   tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero());
    539   tRawType mask_if_zero = MaskIfZero(a);
    540   InputF n = SelectUsingMask(mask_if_negative, a, -a);
    541   ResultF t = neg_tanh_on_negative_values(n);
    542   return SelectUsingMask(mask_if_zero, ResultF::Zero(),
    543                          SelectUsingMask(mask_if_negative, -t, t));
    544 }
    545 
    546 }  // end namespace gemmlowp
    547 
    548 #ifdef GEMMLOWP_NEON
    549 #include "fixedpoint_neon.h"
    550 #endif
    551 
    552 #endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_H_
    553