1 // Copyright 2015 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // fixedpoint.h: fixed-point arithmetic, with basic operations and 16 // a few math functions such as tanh. 17 18 // This is only used in output.h 19 // for some specific output pipeline stages (tanh); most of gemmlowp 20 // uses only plain integer arithmetic, not fixed-point arithmetic. 21 // At the most basic level, we distinguish between plain integer 22 // arithmetic and fixed-point arithmetic by the type of multiplication 23 // that is used: plain integer arithmetic uses plain (overflowing) 24 // integer multiplication, whereas fixed-point arithmetic uses 25 // "multiply-high" instructions, which means using only the most 26 // significant bits of the product, or equivalently, multiplying 27 // fixed-point numbers in the [-1 .. +1] interval. 28 29 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 30 #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 31 32 #include "common.h" 33 34 #include <limits> 35 #include <cassert> 36 37 namespace gemmlowp { 38 39 template <typename tIntegerType> 40 tIntegerType BitAnd(tIntegerType a, tIntegerType b) { 41 return a & b; 42 } 43 44 template <typename tIntegerType> 45 tIntegerType BitOr(tIntegerType a, tIntegerType b) { 46 return a | b; 47 } 48 49 template <typename tIntegerType> 50 tIntegerType BitXor(tIntegerType a, tIntegerType b) { 51 return a ^ b; 52 } 53 54 template <typename tIntegerType> 55 tIntegerType BitNot(tIntegerType a) { 56 return ~a; 57 } 58 59 template <typename tIntegerType> 60 tIntegerType Add(tIntegerType a, tIntegerType b) { 61 return a + b; 62 } 63 64 template <typename tIntegerType> 65 tIntegerType Sub(tIntegerType a, tIntegerType b) { 66 return a - b; 67 } 68 69 template <typename tIntegerType> 70 tIntegerType Neg(tIntegerType a) { 71 return -a; 72 } 73 74 template <typename tIntegerType> 75 tIntegerType ShiftLeft(tIntegerType a, int offset) { 76 return a * (1 << offset); 77 } 78 79 template <typename tIntegerType> 80 tIntegerType ShiftRight(tIntegerType a, int offset) { 81 return a / (1 << offset); 82 } 83 84 template <typename tIntegerType> 85 tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, 86 tIntegerType else_val) { 87 return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val)); 88 } 89 90 template <typename tIntegerType> 91 tIntegerType MaskIfNonZero(tIntegerType a) { 92 static const tIntegerType zero = 0; 93 return a ? BitNot(zero) : zero; 94 } 95 96 template <typename tIntegerType> 97 tIntegerType MaskIfZero(tIntegerType a) { 98 return MaskIfNonZero<tIntegerType>(!a); 99 } 100 101 template <typename tIntegerType> 102 tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) { 103 return MaskIfNonZero<tIntegerType>(a == b); 104 } 105 106 template <typename tIntegerType> 107 tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) { 108 return MaskIfNonZero<tIntegerType>(a != b); 109 } 110 111 template <typename tIntegerType> 112 tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) { 113 return MaskIfNonZero<tIntegerType>(a > b); 114 } 115 116 template <typename tIntegerType> 117 tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) { 118 return MaskIfNonZero<tIntegerType>(a >= b); 119 } 120 121 template <typename tIntegerType> 122 tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) { 123 return MaskIfNonZero<tIntegerType>(a < b); 124 } 125 126 template <typename tIntegerType> 127 tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) { 128 return MaskIfNonZero<tIntegerType>(a <= b); 129 } 130 131 template <typename tIntegerType> 132 bool All(tIntegerType a) { 133 return a; 134 } 135 136 template <typename tIntegerType> 137 bool Any(tIntegerType a) { 138 return a; 139 } 140 141 template <typename IntegerType> 142 IntegerType RoundingHalfSum(IntegerType a, IntegerType b) { 143 static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 144 return a; 145 } 146 147 template <> 148 inline int32_t RoundingHalfSum(int32_t a, int32_t b) { 149 int64_t a64 = a; 150 int64_t b64 = b; 151 int64_t sum = a64 + b64; 152 int64_t sign = sum >= 0 ? 1 : -1; 153 return static_cast<int32_t>((sum + sign) / 2); 154 } 155 156 template <typename IntegerType> 157 IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) { 158 static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 159 return a; 160 } 161 162 // This function implements the same computation as the ARMv7 NEON VQRDMULH 163 // instruction. 164 template <> 165 inline int32_t SaturatingRoundingDoublingHighMul(int32_t a, int32_t b) { 166 bool overflow = a == b && a == std::numeric_limits<int32_t>::min(); 167 int64_t a_64(a); 168 int64_t b_64(b); 169 int64_t ab_64 = a_64 * b_64; 170 int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); 171 int32_t ab_x2_high32 = static_cast<int32_t>((ab_64 + nudge) / (1ll << 31)); 172 return overflow ? std::numeric_limits<int32_t>::max() : ab_x2_high32; 173 } 174 175 template <int Exponent, typename IntegerType, 176 int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)> 177 struct ImplSaturatingRoundingMultiplyByPOT {}; 178 179 template <int Exponent, typename IntegerType> 180 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> { 181 static IntegerType eval(IntegerType x) { return x; } 182 }; 183 184 template <int Exponent> 185 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32_t, 1> { 186 static int32_t eval(int32_t x) { 187 const int64_t min = std::numeric_limits<int32_t>::min(); 188 const int64_t max = std::numeric_limits<int32_t>::max(); 189 return x >= (1 << (31 - Exponent)) ? max : x <= -(1 << (31 - Exponent)) 190 ? min 191 : x * (1 << Exponent); 192 } 193 }; 194 195 template <int Exponent> 196 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32_t, -1> { 197 static int32_t eval(int32_t x) { 198 int32_t b = (std::abs(x) & (1 << (-Exponent - 1))) >> (-Exponent - 1); 199 int32_t nudge = x >= 0 ? b : -b; 200 return x / (1 << -Exponent) + nudge; 201 } 202 }; 203 204 template <int Exponent, typename IntegerType> 205 IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) { 206 return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x); 207 } 208 209 template <typename tIntegerType> 210 struct FixedPointRawTypeTraits {}; 211 212 template <> 213 struct FixedPointRawTypeTraits<int32_t> { 214 typedef int32_t ScalarRawType; 215 static const int kLanes = 1; 216 }; 217 218 template <typename tRawType> 219 tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) { 220 return x; 221 } 222 223 template <typename tRawType, int tIntegerBits> 224 class FixedPoint { 225 public: 226 typedef tRawType RawType; 227 228 typedef FixedPointRawTypeTraits<RawType> RawTypeTraits; 229 typedef typename RawTypeTraits::ScalarRawType ScalarRawType; 230 231 static const int kTotalBits = 8 * sizeof(ScalarRawType); 232 static const int kIntegerBits = tIntegerBits; 233 static const int kFractionalBits = kTotalBits - 1 - kIntegerBits; 234 static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, 235 "bad IntegerBits"); 236 237 typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType; 238 239 static const ScalarRawType ScalarRawMin() { 240 return std::numeric_limits<ScalarRawType>::min(); 241 } 242 243 static const ScalarRawType ScalarRawMax() { 244 return std::numeric_limits<ScalarRawType>::max(); 245 } 246 247 static const ScalarRawType RawMin() { 248 return VectorFromScalar(ScalarRawMin()); 249 } 250 251 static const ScalarRawType RawMax() { 252 return VectorFromScalar(ScalarRawMax()); 253 } 254 255 static FixedPoint FromRaw(RawType x) { 256 FixedPoint retval; 257 retval.raw() = x; 258 return retval; 259 } 260 261 static FixedPoint FromScalarRaw(ScalarRawType x) { 262 FixedPoint retval; 263 retval.raw() = Dup<RawType>(x); 264 return retval; 265 } 266 267 static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) { 268 return FromScalarRaw(x.raw()); 269 } 270 271 template <int Exponent> 272 static FixedPoint ConstantPOT() { 273 static const int kOffset = kFractionalBits + Exponent; 274 static_assert( 275 kOffset < 31, 276 "Constant not exactly representable in this fixed-point format"); 277 return FromScalarRaw(ScalarRawType(1) << kOffset); 278 } 279 280 static FixedPoint Zero() { return FromScalarRaw(0); } 281 282 static FixedPoint One() { 283 return FromScalarRaw(kIntegerBits == 0 284 ? ScalarRawMax() 285 : (ScalarRawType(1) << kFractionalBits)); 286 } 287 288 RawType raw() const { return i_; } 289 RawType& raw() { return i_; } 290 291 private: 292 RawType i_; 293 }; 294 295 template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b> 296 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*( 297 FixedPoint<tRawType, tIntegerBits_a> a, 298 FixedPoint<tRawType, tIntegerBits_b> b) { 299 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c; 300 c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw()); 301 return c; 302 } 303 304 template <int tExponent, typename tRawType, int tIntegerBits> 305 FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot( 306 FixedPoint<tRawType, tIntegerBits> a) { 307 FixedPoint<tRawType, tExponent + tIntegerBits> c; 308 c.raw() = a.raw(); 309 return c; 310 } 311 312 template <int tExponent, typename tRawType, int tIntegerBits> 313 FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT( 314 FixedPoint<tRawType, tIntegerBits> a) { 315 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 316 SaturatingRoundingMultiplyByPOT<tExponent>(a.raw())); 317 } 318 319 #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \ 320 template <typename tRawType, int tIntegerBits> \ 321 FixedPoint<tRawType, tIntegerBits> FuncName( \ 322 FixedPoint<tRawType, tIntegerBits> a) { \ 323 return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \ 324 } 325 326 #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \ 327 template <typename tRawType, int tIntegerBits> \ 328 FixedPoint<tRawType, tIntegerBits> FuncName( \ 329 FixedPoint<tRawType, tIntegerBits> a, \ 330 FixedPoint<tRawType, tIntegerBits> b) { \ 331 return FixedPoint<tRawType, tIntegerBits>::FromRaw( \ 332 ImplFuncName(a.raw(), b.raw())); \ 333 } 334 335 MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg) 336 MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot) 337 MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add) 338 MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub) 339 MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd) 340 MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor) 341 MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr) 342 MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum) 343 344 #undef MAKE_FIXEDPOINT_UNARY_FUNC 345 #undef MAKE_FIXEDPOINT_BINARY_FUNC 346 347 #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \ 348 template <typename tRawType, int tIntegerBits> \ 349 tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \ 350 return FuncName(a.raw()); \ 351 } 352 353 #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \ 354 template <typename tRawType, int tIntegerBits> \ 355 tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a, \ 356 FixedPoint<tRawType, tIntegerBits> b) { \ 357 return FuncName(a.raw(), b.raw()); \ 358 } 359 360 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero) 361 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero) 362 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual) 363 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual) 364 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan) 365 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual) 366 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan) 367 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual) 368 369 #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW 370 #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW 371 372 template <typename tRawType, int tIntegerBits> 373 FixedPoint<tRawType, tIntegerBits> SelectUsingMask( 374 tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val, 375 FixedPoint<tRawType, tIntegerBits> else_val) { 376 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 377 SelectUsingMask(if_mask, then_val.raw(), else_val.raw())); 378 } 379 380 template <typename tRawType, int tIntegerBits> 381 bool operator==(FixedPoint<tRawType, tIntegerBits> a, 382 FixedPoint<tRawType, tIntegerBits> b) { 383 return All(MaskIfEqual(a.raw(), b.raw())); 384 } 385 386 template <typename tRawType, int tIntegerBits> 387 bool operator!=(FixedPoint<tRawType, tIntegerBits> a, 388 FixedPoint<tRawType, tIntegerBits> b) { 389 return !(a == b); 390 } 391 392 template <typename tRawType, int tIntegerBits> 393 double ToDouble(FixedPoint<tRawType, tIntegerBits> x) { 394 static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1, 395 "not applicable to SIMD types"); 396 typedef FixedPoint<tRawType, tIntegerBits> F; 397 return x.raw() / double(1ll << F::kFractionalBits); 398 } 399 400 template <typename tRawType, int tIntegerBits> 401 FixedPoint<tRawType, tIntegerBits> ToFixedPoint(double x) { 402 typedef FixedPoint<tRawType, tIntegerBits> F; 403 return F::FromScalarRaw(static_cast<int32_t>( 404 std::min(std::max(round(x * double(1ll << F::kFractionalBits)), 405 double(F::ScalarRawMin())), 406 double(F::ScalarRawMax())))); 407 } 408 409 template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc> 410 FixedPoint<tRawType, tIntegerBitsDst> Rescale( 411 FixedPoint<tRawType, tIntegerBitsSrc> x) { 412 static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst; 413 FixedPoint<tRawType, tIntegerBitsDst> result; 414 result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw()); 415 return result; 416 } 417 418 #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS 419 template <typename FixedPointType> 420 FixedPointType CheckedFixedPointConstant( 421 typename FixedPointType::ScalarRawType raw_value, double double_value) { 422 typedef typename FixedPointType::RawType RawType; 423 static const int kIntegerBits = FixedPointType::kIntegerBits; 424 FixedPointType ref = FixedPointType::FromScalarRaw(raw_value); 425 FixedPointType check = ToFixedPoint<RawType, kIntegerBits>(double_value); 426 assert(ref == check); 427 return ref; 428 } 429 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \ 430 DoubleValue) \ 431 (CheckedFixedPointConstant<FixedPointType>(ScalarRawValue, DoubleValue)) 432 433 #else 434 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \ 435 DoubleValue) \ 436 (FixedPointType::FromScalarRaw(ScalarRawValue)) 437 #endif 438 439 template <typename tRawType> 440 FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl( 441 FixedPoint<tRawType, 0> a) { 442 typedef FixedPoint<tRawType, 0> F; 443 const F constant_term = 444 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0)); 445 const F constant_1_over_3 = 446 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0); 447 // We're evaluating a Taylor expansion around -1/8, so we do the change of 448 // variable: x = a + 1/8. 449 // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28. 450 F x = a + F::template ConstantPOT<-3>(); 451 F x2 = x * x; 452 F x3 = x2 * x; 453 F x4 = x2 * x2; 454 F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4); 455 F x4_over_24_plus_x3_over_6_plus_x2_over_2 = 456 SaturatingRoundingMultiplyByPOT<-1>( 457 ((x4_over_4 + x3) * constant_1_over_3) + x2); 458 return constant_term + 459 constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2); 460 } 461 462 template <typename tRawType, int tIntegerBits> 463 FixedPoint<tRawType, 0> exp_on_negative_values( 464 FixedPoint<tRawType, tIntegerBits> a) { 465 typedef FixedPoint<tRawType, tIntegerBits> InputF; 466 typedef FixedPoint<tRawType, 0> ResultF; 467 static const int kFractionalBits = InputF::kFractionalBits; 468 static const int kIntegerBits = InputF::kIntegerBits; 469 static const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); 470 InputF mask = kOneQuarter - InputF::FromScalarRaw(1); 471 InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter; 472 ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl( 473 Rescale<0>(a_mod_quarter_minus_one_quarter)); 474 tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw(); 475 476 #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \ 477 if (kIntegerBits > Exponent) { \ 478 const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( \ 479 ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \ 480 result = SelectUsingMask( \ 481 MaskIfNonZero(BitAnd( \ 482 remainder, Dup<tRawType>(1 << (kFractionalBits + Exponent)))), \ 483 result * kMultiplier, result); \ 484 } 485 486 GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); 487 GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); 488 GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); 489 GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); 490 GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); 491 GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); 492 GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); 493 494 #undef GEMMLOWP_EXP_BARREL_SHIFTER 495 496 if (kIntegerBits > 5) { 497 static const int b = kIntegerBits > 5 ? kFractionalBits + 5 : 0; 498 const InputF clamp = 499 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0); 500 result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result); 501 } 502 503 result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result); 504 return result; 505 } 506 507 template <typename tRawType> 508 FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1( 509 FixedPoint<tRawType, 0> a) { 510 typedef FixedPoint<tRawType, 0> F0; 511 typedef FixedPoint<tRawType, 2> F2; 512 F0 half_denominator = RoundingHalfSum(a, F0::One()); 513 const F2 constant_48_over_17 = 514 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); 515 const F2 constant_neg_32_over_17 = 516 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); 517 F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; 518 for (int i = 0; i < 3; i++) { 519 F2 half_denominator_times_x = half_denominator * x; 520 F2 one_minus_half_denominator_times_x = 521 F2::One() - half_denominator_times_x; 522 x = x + Rescale<2>(x * one_minus_half_denominator_times_x); 523 } 524 return Rescale<0>(x - F2::One()); 525 } 526 527 template <typename tRawType, int tIntegerBits> 528 FixedPoint<tRawType, 0> neg_tanh_on_negative_values( 529 FixedPoint<tRawType, tIntegerBits> a) { 530 return one_minus_x_over_one_plus_x_for_x_in_0_1( 531 exp_on_negative_values(ExactMulByPot<1>(a))); 532 } 533 534 template <typename tRawType, int tIntegerBits> 535 FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) { 536 typedef FixedPoint<tRawType, tIntegerBits> InputF; 537 typedef FixedPoint<tRawType, 0> ResultF; 538 tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero()); 539 tRawType mask_if_zero = MaskIfZero(a); 540 InputF n = SelectUsingMask(mask_if_negative, a, -a); 541 ResultF t = neg_tanh_on_negative_values(n); 542 return SelectUsingMask(mask_if_zero, ResultF::Zero(), 543 SelectUsingMask(mask_if_negative, -t, t)); 544 } 545 546 } // end namespace gemmlowp 547 548 #ifdef GEMMLOWP_NEON 549 #include "fixedpoint_neon.h" 550 #endif 551 552 #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 553