1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // fixedpoint.h: fixed-point arithmetic, with basic operations and 16 // a few math functions such as tanh. 17 18 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 19 #define GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 20 21 #include <cassert> 22 #include <limits> 23 24 #include "../internal/common.h" 25 26 namespace gemmlowp { 27 28 // Part 1: Low-level integer-arithmetic primitives. 29 // The implementations here are generic implementations valid for 30 // scalar types (e.g. std::int32_t). Architecture-specific SIMD types 31 // (e.g. NEON int32x4_t) may be supported by providing 32 // specializations for them in separate files. 33 // 34 // The purpose of these primitives is two-fold: 35 // - They will be used to implement higher-level fixed-point 36 // abstractions, namely the FixedPoint class and its arithmetic 37 // operators. 38 // - They will be directly used to implement some more involved 39 // fixed-point computations, e.g. the fixed-point implementation 40 // of math functions such as tanh. 41 42 // Some compile-time traits around raw types to handle SIMD aspects: 43 // number of lanes, underlying scalar type. 44 template <typename tIntegerType> 45 struct FixedPointRawTypeTraits {}; 46 47 template <> 48 struct FixedPointRawTypeTraits<std::int32_t> { 49 typedef std::int32_t ScalarRawType; 50 static const int kLanes = 1; 51 }; 52 53 // Returns a SIMD value duplicating a scalar value across all lanes. 54 template <typename tRawType> 55 tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) { 56 return x; 57 } 58 59 // Plain bit-wise AND 60 template <typename tIntegerType> 61 tIntegerType BitAnd(tIntegerType a, tIntegerType b) { 62 return a & b; 63 } 64 65 // Plain bit-wise OR 66 template <typename tIntegerType> 67 tIntegerType BitOr(tIntegerType a, tIntegerType b) { 68 return a | b; 69 } 70 71 // Plain bit-wise XOR 72 template <typename tIntegerType> 73 tIntegerType BitXor(tIntegerType a, tIntegerType b) { 74 return a ^ b; 75 } 76 77 // Plain bit-wise NOT 78 template <typename tIntegerType> 79 tIntegerType BitNot(tIntegerType a) { 80 return ~a; 81 } 82 83 // Integer addition. Not saturating. Overflow is undefined behavior. 84 template <typename tIntegerType> 85 tIntegerType Add(tIntegerType a, tIntegerType b) { 86 return a + b; 87 } 88 89 // Integer subtraction. Not saturating. Overflow is undefined behavior. 90 template <typename tIntegerType> 91 tIntegerType Mul(tIntegerType a, tIntegerType b) { 92 return a * b; 93 } 94 95 template <typename tIntegerType> 96 tIntegerType Sub(tIntegerType a, tIntegerType b) { 97 return a - b; 98 } 99 100 // Integer unary negative. Not saturating. Overflow is undefined behavior. 101 template <typename tIntegerType> 102 tIntegerType Neg(tIntegerType a) { 103 return -a; 104 } 105 106 // Integer arithmetic left-shift, equivalent to multiplying with a 107 // power of two. Not saturating. Overflow is undefined behavior. 108 template <typename tIntegerType> 109 tIntegerType ShiftLeft(tIntegerType a, int offset) { 110 return a << offset; 111 } 112 113 // Integer arithmetic right-shift. Not rounding. 114 // Relying on implementation-defined, but in-practice-consistent, 115 // C++ compiler behavior. 116 template <typename tIntegerType> 117 tIntegerType ShiftRight(tIntegerType a, int offset) { 118 return a >> offset; 119 } 120 121 // Each bit of the result is set to the corresponding bit of either then_val or 122 // else_val depending on whether the corresponding bit of if_mask is set. 123 // Equivalent to the VBSL instruction in ARM NEON. 124 template <typename tIntegerType> 125 tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, 126 tIntegerType else_val) { 127 return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val)); 128 } 129 130 // For each input scalar, the corresponding bits of the result are set if the 131 // input scalar is non-zero. 132 template <typename tIntegerType> 133 tIntegerType MaskIfNonZero(tIntegerType a) { 134 static const tIntegerType zero = 0; 135 return a ? BitNot(zero) : zero; 136 } 137 138 // For each input scalar, the corresponding bits of the result are set if the 139 // input scalar is zero. 140 template <typename tIntegerType> 141 tIntegerType MaskIfZero(tIntegerType a) { 142 return MaskIfNonZero<tIntegerType>(!a); 143 } 144 145 // For each pair of input scalars, the corresponding bits of the result are 146 // set if the input scalars are equal. 147 template <typename tIntegerType> 148 tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) { 149 return MaskIfNonZero<tIntegerType>(a == b); 150 } 151 152 // For each pair of input scalars, the corresponding bits of the result are 153 // set if the input scalars are not equal. 154 template <typename tIntegerType> 155 tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) { 156 return MaskIfNonZero<tIntegerType>(a != b); 157 } 158 159 // For each pair of input scalars, the corresponding bits of the result are 160 // set if the input scalars a, b satisfy a > b. 161 template <typename tIntegerType> 162 tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) { 163 return MaskIfNonZero<tIntegerType>(a > b); 164 } 165 166 // For each pair of input scalars, the corresponding bits of the result are 167 // set if the input scalars a, b satisfy a >= b. 168 template <typename tIntegerType> 169 tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) { 170 return MaskIfNonZero<tIntegerType>(a >= b); 171 } 172 173 // For each pair of input scalars, the corresponding bits of the result are 174 // set if the input scalars a, b satisfy a < b. 175 template <typename tIntegerType> 176 tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) { 177 return MaskIfNonZero<tIntegerType>(a < b); 178 } 179 180 // For each pair of input scalars, the corresponding bits of the result are 181 // set if the input scalars a, b satisfy a <= b. 182 template <typename tIntegerType> 183 tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) { 184 return MaskIfNonZero<tIntegerType>(a <= b); 185 } 186 187 // Returns true if all of the input scalars are nonzero. 188 // This function may currently assume that each of the input scalars has either 189 // all or none of its bits set. Otherwise, its behavior is currently undefined. 190 template <typename tIntegerType> 191 bool All(tIntegerType a) { 192 return a; 193 } 194 195 // Returns true if any of the input scalars are nonzero. 196 // This function may currently assume that each of the input scalars has either 197 // all or none of its bits set. Otherwise, its behavior is currently undefined. 198 template <typename tIntegerType> 199 bool Any(tIntegerType a) { 200 return a; 201 } 202 203 // Returns (a+b)/2, rounded to the nearest integer. 204 // Equivalent to VRHADD in the ARM NEON instruction set. 205 template <typename IntegerType> 206 IntegerType RoundingHalfSum(IntegerType a, IntegerType b) { 207 static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 208 return a; 209 } 210 211 template <> 212 inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) { 213 std::int64_t a64 = a; 214 std::int64_t b64 = b; 215 std::int64_t sum = a64 + b64; 216 std::int64_t sign = sum >= 0 ? 1 : -1; 217 return static_cast<std::int32_t>((sum + sign) / 2); 218 } 219 220 // Returns the integer that represents the product of two fixed-point 221 // numbers, interpreting all integers as fixed-point values in the 222 // interval [-1, 1), rounding to the nearest value, and saturating 223 // -1 * -1 to the maximum value (since 1 is not in the half-open 224 // interval [-1, 1)). 225 // 226 // [The explanation below specializes to std::int32_t for example purpose.] 227 // 228 // The mapping between IntegerType and the interval [-1, 1) is unique and 229 // implied by IntegerType, which is assumed to be signed. For example, 230 // for IntegerType==std::int32_t, the mapping is 231 // real_value = integer_value / 2^31. 232 // So in this case, and leaving aside rounding and saturating, this 233 // function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to 234 // (a * b) / 2^31. 235 // 236 // The 'doubling' part in the name of this function comes from the fact that 237 // this operation is very close to a "multiply-high" operation, keeping only 238 // the top half bits, except that that would be effectively computing 239 // (a * b) / 2^32, 240 // so here we are computing 2x that, since 241 // 1/2^31 = 2 * 1/2^32. 242 // The idea is to use all of the available 32 bits in the destination int32 243 // value. 244 // 245 // [End of the explanation specializing to int32.] 246 // 247 // This is equivalent to the VQRDMULH instruction in ARM NEON. 248 template <typename IntegerType> 249 IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) { 250 static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 251 return a; 252 } 253 254 // This function implements the same computation as the ARMv7 NEON VQRDMULH 255 // instruction. 256 template <> 257 inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, 258 std::int32_t b) { 259 bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min(); 260 std::int64_t a_64(a); 261 std::int64_t b_64(b); 262 std::int64_t ab_64 = a_64 * b_64; 263 std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); 264 std::int32_t ab_x2_high32 = 265 static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31)); 266 return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32; 267 } 268 269 // Correctly-rounded-to-nearest division by a power-of-two. 270 // Also known as a rounding arithmetic right shift. 271 template <typename IntegerType> 272 inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) { 273 using ScalarIntegerType = 274 typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; 275 static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value, 276 "Currently only supporting int32 scalar and SIMD types"); 277 assert(exponent >= 0); 278 assert(exponent <= 31); 279 const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1); 280 const IntegerType zero = Dup<IntegerType>(0); 281 const IntegerType one = Dup<IntegerType>(1); 282 const IntegerType remainder = BitAnd(x, mask); 283 const IntegerType threshold = 284 Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one)); 285 return Add(ShiftRight(x, exponent), 286 BitAnd(MaskIfGreaterThan(remainder, threshold), one)); 287 } 288 289 // Returns the product of a run-time integer value by a compile-time power 290 // of two, with either a positive exponent (equivalent to an arithmetic 291 // left shift, saturating) or a negative exponent (equivalent to an arithmetic 292 // right shift, rounding to nearest). 293 template <int Exponent, typename IntegerType, 294 int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)> 295 struct ImplSaturatingRoundingMultiplyByPOT {}; 296 297 template <int Exponent, typename IntegerType> 298 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> { 299 static IntegerType eval(IntegerType x) { return x; } 300 }; 301 302 template <int Exponent, typename IntegerType> 303 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> { 304 static IntegerType eval(IntegerType x) { 305 using ScalarIntegerType = 306 typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; 307 static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value, 308 "Currently only supporting int32 scalar and SIMD types"); 309 const IntegerType min = 310 Dup<IntegerType>(std::numeric_limits<std::int32_t>::min()); 311 const IntegerType max = 312 Dup<IntegerType>(std::numeric_limits<std::int32_t>::max()); 313 314 const std::int32_t threshold = ((1 << (31 - Exponent)) - 1); 315 const IntegerType positive_mask = 316 MaskIfGreaterThan(x, Dup<IntegerType>(threshold)); 317 const IntegerType negative_mask = 318 MaskIfLessThan(x, Dup<IntegerType>(-threshold)); 319 320 IntegerType result = ShiftLeft(x, Exponent); 321 result = SelectUsingMask(positive_mask, max, result); 322 result = SelectUsingMask(negative_mask, min, result); 323 return result; 324 } 325 }; 326 327 template <int Exponent, typename IntegerType> 328 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> { 329 static IntegerType eval(IntegerType x) { 330 return RoundingDivideByPOT<IntegerType>(x, -Exponent); 331 } 332 }; 333 334 template <int Exponent, typename IntegerType> 335 IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) { 336 return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x); 337 } 338 339 // Part 2: the FixedPoint class. 340 341 // A FixedPoint object represents a fixed-point value stored in the underlying 342 // integer type tRawType, if tRawType is a plain scalar integer type. 343 // Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which 344 // case a FixedPoint object represents a corresponding SIMD vector of fixed 345 // point values. 346 // 347 // tIntegerBits describes the range of the fixed-point format: if 348 // tIntegerBits == m then the range of representable values is the half-open 349 // interval [-2^m; 2^m) where the open boundary on the right side means that 350 // 2^m is not representable (how close the maximum representable value is to 351 // it, depends on bit-depth of tRawType). 352 // 353 // In "Q format notation", 354 // https://en.wikipedia.org/wiki/Q_(number_format) 355 // we are describing the format 356 // Qm.n 357 // where 358 // m = tIntegerBits 359 // and 360 // n = NumberOfBits(tRawType) - (m + 1) 361 // Note that the (m + 1) in the above line is because we adopt the convention 362 // that we count the integer bits exclusively of the sign bit; so (m + 1) is 363 // the total number of integer bits inclusive of the sign bit. 364 // 365 // Accordingly, the number of integral representable values in our range 366 // [-2^m ; 2^m) 367 // is equal to 2^(m+1). 368 template <typename tRawType, int tIntegerBits> 369 class FixedPoint { 370 public: 371 typedef tRawType RawType; 372 373 typedef FixedPointRawTypeTraits<RawType> RawTypeTraits; 374 typedef typename RawTypeTraits::ScalarRawType ScalarRawType; 375 376 static const int kTotalBits = 8 * sizeof(ScalarRawType); 377 static const int kIntegerBits = tIntegerBits; 378 static const int kFractionalBits = kTotalBits - 1 - kIntegerBits; 379 static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, 380 "bad IntegerBits"); 381 382 typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType; 383 384 static const ScalarRawType ScalarRawMin() { 385 return std::numeric_limits<ScalarRawType>::min(); 386 } 387 388 static const ScalarRawType ScalarRawMax() { 389 return std::numeric_limits<ScalarRawType>::max(); 390 } 391 392 static const ScalarRawType RawMin() { 393 return VectorFromScalar(ScalarRawMin()); 394 } 395 396 static const ScalarRawType RawMax() { 397 return VectorFromScalar(ScalarRawMax()); 398 } 399 400 static FixedPoint FromRaw(RawType x) { 401 FixedPoint retval; 402 retval.raw() = x; 403 return retval; 404 } 405 406 static FixedPoint FromScalarRaw(ScalarRawType x) { 407 FixedPoint retval; 408 retval.raw() = Dup<RawType>(x); 409 return retval; 410 } 411 412 static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) { 413 return FromScalarRaw(x.raw()); 414 } 415 416 template <int Exponent> 417 static FixedPoint ConstantPOT() { 418 static const int kOffset = kFractionalBits + Exponent; 419 static_assert( 420 kOffset < 31, 421 "Constant not exactly representable in this fixed-point format"); 422 return FromScalarRaw(ScalarRawType(1) << kOffset); 423 } 424 425 static FixedPoint Zero() { return FromScalarRaw(0); } 426 427 static FixedPoint One() { 428 return FromScalarRaw(kIntegerBits == 0 429 ? ScalarRawMax() 430 : (ScalarRawType(1) << kFractionalBits)); 431 } 432 433 static FixedPoint FromDouble(double x) { 434 const double min_bound = static_cast<double>(ScalarRawMin()); 435 const double max_bound = static_cast<double>(ScalarRawMax()); 436 return FromScalarRaw(static_cast<std::int32_t>(std::min( 437 std::max(round(x * static_cast<double>(1ll << kFractionalBits)), 438 min_bound), 439 max_bound))); 440 } 441 442 RawType raw() const { return i_; } 443 RawType& raw() { return i_; } 444 445 private: 446 RawType i_; 447 }; 448 449 // Part 3: implementation of arithmetic operators for the 450 // FixedPoint class, and a few related functions. 451 452 // A FixedPoint multiplication is just a 453 // SaturatingRoundingDoublingHighMul operation on the underlying 454 // raw integer values. The IntegerBits simply add up, as is obvious 455 // from the fact that the range is [-2^IntegerBits, 2^IntegerBits). 456 template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b> 457 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*( 458 FixedPoint<tRawType, tIntegerBits_a> a, 459 FixedPoint<tRawType, tIntegerBits_b> b) { 460 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c; 461 c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw()); 462 return c; 463 } 464 465 // Tweaking IntegerBits gives exact multiplication by a power of two. 466 template <int tExponent, typename tRawType, int tIntegerBits> 467 FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot( 468 FixedPoint<tRawType, tIntegerBits> a) { 469 FixedPoint<tRawType, tExponent + tIntegerBits> c; 470 c.raw() = a.raw(); 471 return c; 472 } 473 474 // If we want to leave IntegerBits fixed, then multiplication 475 // by a power of two has to be saturating/rounding, not exact anymore. 476 template <int tExponent, typename tRawType, int tIntegerBits> 477 FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT( 478 FixedPoint<tRawType, tIntegerBits> a) { 479 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 480 SaturatingRoundingMultiplyByPOT<tExponent>(a.raw())); 481 } 482 483 // Generic arithmetic operators. 484 485 #define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \ 486 template <typename tRawType, int tIntegerBits> \ 487 FixedPoint<tRawType, tIntegerBits> FuncName( \ 488 FixedPoint<tRawType, tIntegerBits> a) { \ 489 return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \ 490 } 491 492 #define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \ 493 template <typename tRawType, int tIntegerBits> \ 494 FixedPoint<tRawType, tIntegerBits> FuncName( \ 495 FixedPoint<tRawType, tIntegerBits> a, \ 496 FixedPoint<tRawType, tIntegerBits> b) { \ 497 return FixedPoint<tRawType, tIntegerBits>::FromRaw( \ 498 ImplFuncName(a.raw(), b.raw())); \ 499 } 500 501 MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg) 502 MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot) 503 MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add) 504 MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub) 505 MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd) 506 MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor) 507 MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr) 508 MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum) 509 510 #undef MAKE_FIXEDPOINT_UNARY_FUNC 511 #undef MAKE_FIXEDPOINT_BINARY_FUNC 512 513 #define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \ 514 template <typename tRawType, int tIntegerBits> \ 515 tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \ 516 return FuncName(a.raw()); \ 517 } 518 519 #define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \ 520 template <typename tRawType, int tIntegerBits> \ 521 tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a, \ 522 FixedPoint<tRawType, tIntegerBits> b) { \ 523 return FuncName(a.raw(), b.raw()); \ 524 } 525 526 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero) 527 MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero) 528 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual) 529 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual) 530 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan) 531 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual) 532 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan) 533 MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual) 534 535 #undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW 536 #undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW 537 538 template <typename tRawType, int tIntegerBits> 539 FixedPoint<tRawType, tIntegerBits> SelectUsingMask( 540 tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val, 541 FixedPoint<tRawType, tIntegerBits> else_val) { 542 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 543 SelectUsingMask(if_mask, then_val.raw(), else_val.raw())); 544 } 545 546 template <typename tRawType, int tIntegerBits> 547 bool operator==(FixedPoint<tRawType, tIntegerBits> a, 548 FixedPoint<tRawType, tIntegerBits> b) { 549 return All(MaskIfEqual(a.raw(), b.raw())); 550 } 551 552 template <typename tRawType, int tIntegerBits> 553 bool operator!=(FixedPoint<tRawType, tIntegerBits> a, 554 FixedPoint<tRawType, tIntegerBits> b) { 555 return !(a == b); 556 } 557 558 // Conversion to floating-point. 559 template <typename tRawType, int tIntegerBits> 560 double ToDouble(FixedPoint<tRawType, tIntegerBits> x) { 561 static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1, 562 "not applicable to SIMD types"); 563 typedef FixedPoint<tRawType, tIntegerBits> F; 564 return x.raw() / static_cast<double>(1ll << F::kFractionalBits); 565 } 566 567 // Rescale changes the number of IntegerBits and updates the underlying 568 // raw integer value accordingly. 569 template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc> 570 FixedPoint<tRawType, tIntegerBitsDst> Rescale( 571 FixedPoint<tRawType, tIntegerBitsSrc> x) { 572 static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst; 573 FixedPoint<tRawType, tIntegerBitsDst> result; 574 result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw()); 575 return result; 576 } 577 578 // CheckedFixedPointConstant allows to specify fixed-point constants 579 // initialized as real numbers, in a way that does not compile floating-point 580 // arithmetic in production code, yet still checks agreement with the 581 // floating-point expressions when asserts are enabled. 582 #ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS 583 template <typename FixedPointType> 584 FixedPointType CheckedFixedPointConstant( 585 typename FixedPointType::ScalarRawType raw_value, double double_value) { 586 typedef typename FixedPointType::RawType RawType; 587 const FixedPointType result = FixedPointType::FromScalarRaw(raw_value); 588 assert(result == FixedPointType::FromDouble(double_value)); 589 return result; 590 } 591 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \ 592 DoubleValue) \ 593 (CheckedFixedPointConstant<FixedPointType>(ScalarRawValue, DoubleValue)) 594 595 #else 596 #define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \ 597 DoubleValue) \ 598 (FixedPointType::FromScalarRaw(ScalarRawValue)) 599 #endif 600 601 // Implementation of exponential function. 602 603 // Returns exp(x) for x in [-1/4, 0). 604 template <typename tRawType> 605 FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl( 606 FixedPoint<tRawType, 0> a) { 607 typedef FixedPoint<tRawType, 0> F; 608 const F constant_term = 609 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0)); 610 const F constant_1_over_3 = 611 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0); 612 // We're evaluating a Taylor expansion around -1/8, so we do the change of 613 // variable: x = a + 1/8. 614 // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28. 615 F x = a + F::template ConstantPOT<-3>(); 616 F x2 = x * x; 617 F x3 = x2 * x; 618 F x4 = x2 * x2; 619 F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4); 620 F x4_over_24_plus_x3_over_6_plus_x2_over_2 = 621 SaturatingRoundingMultiplyByPOT<-1>( 622 ((x4_over_4 + x3) * constant_1_over_3) + x2); 623 return constant_term + 624 constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2); 625 } 626 627 // Returns exp(x) for x < 0. 628 template <typename tRawType, int tIntegerBits> 629 FixedPoint<tRawType, 0> exp_on_negative_values( 630 FixedPoint<tRawType, tIntegerBits> a) { 631 typedef FixedPoint<tRawType, tIntegerBits> InputF; 632 typedef FixedPoint<tRawType, 0> ResultF; 633 static const int kFractionalBits = InputF::kFractionalBits; 634 static const int kIntegerBits = InputF::kIntegerBits; 635 static const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); 636 InputF mask = kOneQuarter - InputF::FromScalarRaw(1); 637 InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter; 638 ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl( 639 Rescale<0>(a_mod_quarter_minus_one_quarter)); 640 tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw(); 641 642 #define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \ 643 if (kIntegerBits > Exponent) { \ 644 const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( \ 645 ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \ 646 static constexpr int kShiftAmount = \ 647 kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \ 648 result = SelectUsingMask( \ 649 MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \ 650 result * kMultiplier, result); \ 651 } 652 653 GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); 654 GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); 655 GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); 656 GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); 657 GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); 658 GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); 659 GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); 660 661 #undef GEMMLOWP_EXP_BARREL_SHIFTER 662 663 if (kIntegerBits > 5) { 664 static const int b = kIntegerBits > 5 ? kFractionalBits + 5 : 0; 665 const InputF clamp = 666 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0); 667 result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result); 668 } 669 670 result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result); 671 return result; 672 } 673 674 // Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)). 675 676 // Returns (1 - x) / (1 + x) for x in (0, 1). 677 template <typename tRawType> 678 FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1( 679 FixedPoint<tRawType, 0> a) { 680 typedef FixedPoint<tRawType, 0> F0; 681 typedef FixedPoint<tRawType, 2> F2; 682 F0 half_denominator = RoundingHalfSum(a, F0::One()); 683 // Newton-Raphson division 684 // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division 685 // Refer to that page for the logic behind the 48/17 and 32/17 constants. 686 const F2 constant_48_over_17 = 687 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); 688 const F2 constant_neg_32_over_17 = 689 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); 690 F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; 691 for (int i = 0; i < 3; i++) { 692 F2 half_denominator_times_x = half_denominator * x; 693 F2 one_minus_half_denominator_times_x = 694 F2::One() - half_denominator_times_x; 695 x = x + Rescale<2>(x * one_minus_half_denominator_times_x); 696 } 697 return Rescale<0>(x - F2::One()); 698 } 699 700 // Returns -tanh(x) for x < 0. 701 template <typename tRawType, int tIntegerBits> 702 FixedPoint<tRawType, 0> neg_tanh_on_negative_values( 703 FixedPoint<tRawType, tIntegerBits> a) { 704 return one_minus_x_over_one_plus_x_for_x_in_0_1( 705 exp_on_negative_values(ExactMulByPot<1>(a))); 706 } 707 708 // Returns tanh(x) for any x. 709 template <typename tRawType, int tIntegerBits> 710 FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) { 711 typedef FixedPoint<tRawType, tIntegerBits> InputF; 712 typedef FixedPoint<tRawType, 0> ResultF; 713 tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero()); 714 tRawType mask_if_zero = MaskIfZero(a); 715 InputF n = SelectUsingMask(mask_if_negative, a, -a); 716 ResultF t = neg_tanh_on_negative_values(n); 717 return SelectUsingMask(mask_if_zero, ResultF::Zero(), 718 SelectUsingMask(mask_if_negative, -t, t)); 719 } 720 721 // Implementation of logistic function. 722 723 // Returns 1 / (1 + x) for x in (0, 1). 724 template <typename tRawType> 725 FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1( 726 FixedPoint<tRawType, 0> a) { 727 typedef FixedPoint<tRawType, 0> F0; 728 typedef FixedPoint<tRawType, 2> F2; 729 F0 half_denominator = RoundingHalfSum(a, F0::One()); 730 // Newton-Raphson division 731 // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division 732 // Refer to that page for the logic behind the 48/17 and 32/17 constants. 733 const F2 constant_48_over_17 = 734 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); 735 const F2 constant_neg_32_over_17 = 736 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); 737 F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; 738 for (int i = 0; i < 3; i++) { 739 F2 half_denominator_times_x = half_denominator * x; 740 F2 one_minus_half_denominator_times_x = 741 F2::One() - half_denominator_times_x; 742 x = x + Rescale<2>(x * one_minus_half_denominator_times_x); 743 } 744 return Rescale<0>(ExactMulByPot<-1>(x)); 745 } 746 747 // Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0. 748 template <typename tRawType, int tIntegerBits> 749 FixedPoint<tRawType, 0> logistic_on_positive_values( 750 FixedPoint<tRawType, tIntegerBits> a) { 751 return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a)); 752 } 753 754 // Returns logistic(x) = 1 / (1 + exp(-x)) for any x. 755 template <typename tRawType, int tIntegerBits> 756 FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) { 757 typedef FixedPoint<tRawType, tIntegerBits> InputF; 758 typedef FixedPoint<tRawType, 0> ResultF; 759 tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero()); 760 tRawType mask_if_zero = MaskIfZero(a); 761 InputF abs_input = SelectUsingMask(mask_if_positive, a, -a); 762 ResultF result_if_positive = logistic_on_positive_values(abs_input); 763 ResultF result_if_negative = ResultF::One() - result_if_positive; 764 const ResultF one_half = 765 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5); 766 return SelectUsingMask(mask_if_zero, one_half, 767 SelectUsingMask(mask_if_positive, result_if_positive, 768 result_if_negative)); 769 } 770 771 } // end namespace gemmlowp 772 773 #ifdef GEMMLOWP_NEON 774 #include "./fixedpoint_neon.h" 775 #elif defined(GEMMLOWP_SSE4) 776 #include "./fixedpoint_sse.h" 777 #endif 778 779 #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 780