1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_CWISE_OPS_H_ 17 #define TENSORFLOW_KERNELS_CWISE_OPS_H_ 18 19 #include <cmath> 20 #include <functional> 21 #include <type_traits> 22 23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 24 25 #include "tensorflow/core/framework/numeric_types.h" 26 #include "tensorflow/core/framework/tensor_types.h" 27 #include "tensorflow/core/kernels/bounds_check.h" 28 29 namespace Eigen { 30 namespace numext { 31 #if GOOGLE_CUDA 32 template <> 33 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::complex<float> exp( 34 const std::complex<float>& x) { 35 auto com = ::expf(x.real()); 36 auto res_real = com * ::cosf(x.imag()); 37 auto res_imag = com * ::sinf(x.imag()); 38 return std::complex<float>(res_real, res_imag); 39 } 40 template <> 41 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::complex<double> exp( 42 const std::complex<double>& x) { 43 auto com = ::exp(x.real()); 44 auto res_real = com * ::cos(x.imag()); 45 auto res_imag = com * ::sin(x.imag()); 46 return std::complex<double>(res_real, res_imag); 47 } 48 #endif 49 } // namespace numext 50 51 namespace internal { 52 53 template <typename T> 54 struct scalar_asinh_op { 55 EIGEN_EMPTY_STRUCT_CTOR(scalar_asinh_op) 56 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const { 57 #if EIGEN_HAS_CXX11_MATH 58 return numext::asinh(a); 59 #else 60 return std::asinh(a); 61 #endif // EIGEN_HAS_CXX11_MATH 62 } 63 }; 64 template <typename T> 65 struct functor_traits<scalar_asinh_op<T>> { 66 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; 67 }; 68 69 template <typename T> 70 struct scalar_acosh_op { 71 EIGEN_EMPTY_STRUCT_CTOR(scalar_acosh_op) 72 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const { 73 #if EIGEN_HAS_CXX11_MATH 74 return numext::acosh(a); 75 #else 76 return std::acosh(a); 77 #endif // EIGEN_HAS_CXX11_MATH 78 } 79 }; 80 template <typename T> 81 struct functor_traits<scalar_acosh_op<T>> { 82 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; 83 }; 84 85 template <typename T> 86 struct scalar_atanh_op { 87 EIGEN_EMPTY_STRUCT_CTOR(scalar_atanh_op) 88 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const { 89 #if EIGEN_HAS_CXX11_MATH 90 return numext::atanh(a); 91 #else 92 return std::atanh(a); 93 #endif // EIGEN_HAS_CXX11_MATH 94 } 95 }; 96 template <typename T> 97 struct functor_traits<scalar_atanh_op<T>> { 98 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; 99 }; 100 101 // TODO(rmlarsen): This is a workaround for upstream change 102 // https://bitbucket.org/eigen/eigen/commits/f339468d04d0f87caeb6cab9aef568627e9f6ea9 103 // that renamed scalar_binary_pow_op to scalar_pow_op and deleted the unary 104 // version of the latter. Remove once we upgrade to Eigen 3.3. 105 template <typename Scalar, typename Exponent> 106 struct scalar_binary_pow_op_google { 107 EIGEN_EMPTY_STRUCT_CTOR(scalar_binary_pow_op_google) 108 EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a, 109 const Exponent& b) const { 110 return numext::pow(a, b); 111 } 112 }; 113 114 template <typename Scalar, typename Exponent> 115 struct functor_traits<scalar_binary_pow_op_google<Scalar, Exponent>> { 116 enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false }; 117 }; 118 119 template <typename Scalar, typename Exponent> 120 struct safe_scalar_binary_pow_op { 121 static_assert(std::is_integral<Scalar>::value, "Integer type expected"); 122 static_assert(std::is_integral<Exponent>::value && 123 std::is_signed<Exponent>::value, 124 "Signed integer type expected"); 125 126 bool* const error; 127 128 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error) 129 : error(error) {} 130 131 EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a, 132 const Exponent& b) const { 133 const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b); 134 if (TF_PREDICT_TRUE(safe_b >= 0)) { 135 return numext::pow(a, safe_b); 136 } else { 137 *error = true; 138 return 0; 139 } 140 } 141 }; 142 143 template <typename Scalar, typename Exponent> 144 struct functor_traits<safe_scalar_binary_pow_op<Scalar, Exponent>> { 145 enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false }; 146 }; 147 148 template <typename T, typename DivOrMod> 149 struct safe_div_or_mod_op { 150 static_assert(std::is_integral<T>::value, "Integer type expected"); 151 152 bool* const error; 153 154 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_div_or_mod_op(bool* error) 155 : error(error) {} 156 157 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a, 158 const T& b) const { 159 const T safe_b = tensorflow::internal::SubtleMustCopy(b); 160 if (TF_PREDICT_TRUE(safe_b != 0)) { 161 return DivOrMod()(a, safe_b); 162 } else { 163 *error = true; 164 return 0; 165 } 166 } 167 }; 168 169 template <typename T, typename DivOrMod> 170 struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> { 171 enum { 172 Cost = functor_traits<DivOrMod>::Cost + NumTraits<T>::AddCost, 173 PacketAccess = false, 174 }; 175 }; 176 177 // scalar_left and scalar_right are template helpers to partially 178 // apply a binary function. 179 // 180 // Suppose Binary is a binary functor f(x, y), scalar_left<> is a 181 // unary functor g_x(y) = f(x, y), where x is provided via the 182 // constructor. Similarly, scalar_right<> is a unary functor g_y(x) = 183 // f(x, y). 184 185 template <typename Tout, typename Tin, typename Binary> 186 struct scalar_left : private Binary { 187 typedef Tout result_type; 188 const Tin* left; 189 190 EIGEN_DEVICE_FUNC inline scalar_left(const scalar_left& other) = default; 191 192 template <typename... Args> 193 EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c, Args... args) 194 : Binary(args...), left(c) {} 195 196 EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const { 197 return Binary::operator()(*left, right); 198 } 199 200 template <typename Packet> 201 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const { 202 const Packet left_packet = Eigen::internal::pset1<Packet>(*left); 203 return Binary::packetOp(left_packet, right_packet); 204 } 205 }; 206 207 template <typename Tout, typename Tin, typename Binary> 208 struct functor_traits<scalar_left<Tout, Tin, Binary>> { 209 enum { 210 Cost = functor_traits<Binary>::Cost, 211 PacketAccess = functor_traits<Binary>::PacketAccess, 212 }; 213 }; 214 215 template <typename Tout, typename Tin, typename Binary> 216 struct scalar_right : private Binary { 217 typedef Tout result_type; 218 const Tin* right; 219 220 EIGEN_DEVICE_FUNC inline scalar_right(const scalar_right& other) = default; 221 222 template <typename... Args> 223 EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c, Args... args) 224 : Binary(args...), right(c) {} 225 226 EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const { 227 return Binary::operator()(left, *right); 228 } 229 230 template <typename Packet> 231 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const { 232 const Packet right_packet = Eigen::internal::pset1<Packet>(*right); 233 return Binary::packetOp(left_packet, right_packet); 234 } 235 }; 236 237 template <typename Tout, typename Tin, typename Binary> 238 struct functor_traits<scalar_right<Tout, Tin, Binary>> { 239 enum { 240 Cost = functor_traits<Binary>::Cost, 241 PacketAccess = functor_traits<Binary>::PacketAccess, 242 }; 243 }; 244 245 // similar to std::equal_to, but with the DEVICE_FUNC qualifier 246 template <class T> 247 struct equal_to : std::binary_function<T, T, bool> { 248 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 249 const T& y) const { 250 return x == y; 251 } 252 }; 253 254 // similar to std::not_equal_to, but with the DEVICE_FUNC qualifier 255 template <class T> 256 struct not_equal_to : std::binary_function<T, T, bool> { 257 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 258 const T& y) const { 259 return x != y; 260 } 261 }; 262 263 // similar to std::greater, but with the DEVICE_FUNC qualifier 264 template <class T> 265 struct greater : std::binary_function<T, T, bool> { 266 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 267 const T& y) const { 268 return x > y; 269 } 270 }; 271 272 // similar to std::less, but with the DEVICE_FUNC qualifier 273 template <class T> 274 struct less : std::binary_function<T, T, bool> { 275 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 276 const T& y) const { 277 return x < y; 278 } 279 }; 280 281 // similar to std::greater_equal, but with the DEVICE_FUNC qualifier 282 template <class T> 283 struct greater_equal : std::binary_function<T, T, bool> { 284 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 285 const T& y) const { 286 return x >= y; 287 } 288 }; 289 290 // similar to std::less_equal, but with the DEVICE_FUNC qualifier 291 template <class T> 292 struct less_equal : std::binary_function<T, T, bool> { 293 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, 294 const T& y) const { 295 return x <= y; 296 } 297 }; 298 299 // Functor that enables composition of multiple Eigen functors. 300 template <typename Scalar, typename UnaryFunctor, typename BinaryFunctor> 301 struct scalar_compose_op { 302 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar 303 operator()(const Scalar& a, const Scalar& b) const { 304 return UnaryFunctor()(BinaryFunctor()(a, b)); 305 } 306 template <typename Packet> 307 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 308 packetOp(const Packet& a, const Packet& b) const { 309 return UnaryFunctor().packetOp(BinaryFunctor().packetOp(a, b)); 310 } 311 }; 312 313 template <typename Scalar, typename UnaryFunctor, typename BinaryFunctor> 314 struct functor_traits<scalar_compose_op<Scalar, UnaryFunctor, BinaryFunctor>> { 315 enum { 316 Cost = functor_traits<UnaryFunctor>::Cost + 317 functor_traits<BinaryFunctor>::Cost, 318 PacketAccess = functor_traits<UnaryFunctor>::PacketAccess && 319 functor_traits<BinaryFunctor>::PacketAccess 320 }; 321 }; 322 323 // TODO(b/32239616): This kernel should be moved into Eigen and vectorized. 324 template <typename T, typename Enable = void> 325 struct google_floor_div { 326 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 327 const T& y) const { 328 if ((x < T(0)) != (y < T(0))) { 329 T abs_x = std::abs(x); 330 T abs_y = std::abs(y); 331 return -(abs_x + abs_y - 1) / abs_y; 332 } else { 333 return x / y; 334 } 335 } 336 }; 337 338 template <typename T> 339 struct google_floor_div< 340 T, typename std::enable_if<std::is_unsigned<T>::value>::type> { 341 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 342 const T& y) const { 343 return x / y; 344 } 345 }; 346 347 template <typename Scalar> 348 struct functor_traits<google_floor_div<Scalar>> { 349 enum { 350 Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value + 351 2 * NumTraits<Scalar>::AddCost, 352 PacketAccess = false 353 }; 354 }; 355 356 // TODO(b/32239616): This kernel should be moved into Eigen and vectorized. 357 template <typename T, typename Enable = void> 358 struct google_floor_div_real { 359 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 360 const T& y) const { 361 return Eigen::numext::floor(x / y); 362 } 363 }; 364 365 template <typename Scalar> 366 struct functor_traits<google_floor_div_real<Scalar>> { 367 enum { 368 Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value + 369 2 * NumTraits<Scalar>::AddCost, 370 PacketAccess = false 371 }; 372 }; 373 374 // TODO(b//32239616): This kernel should be moved into Eigen and vectorized. 375 template <typename T> 376 struct google_floor_fmod { 377 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 378 const T& y) const { 379 // EIGEN_STATIC_ASSERT(NUMERIC_TYPE_MUST_BE_REAL); 380 T trunc_mod = std::fmod(x, y); 381 return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y); 382 } 383 }; 384 385 template <typename Scalar> 386 struct functor_traits<google_floor_fmod<Scalar>> { 387 enum { 388 Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value + 389 2 * NumTraits<Scalar>::AddCost, 390 PacketAccess = false 391 }; 392 }; 393 394 // TODO(b/32239616): This kernel should be moved into Eigen and vectorized. 395 template <typename T> 396 struct google_floor_mod { 397 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 398 const T& y) const { 399 // EIGEN_STATIC_ASSERT(!NUMERIC_TYPE_MUST_BE_REAL); 400 T trunc_mod = x % y; 401 return (x < T(0)) == (y < T(0)) ? trunc_mod : (trunc_mod + y) % y; 402 } 403 }; 404 405 template <typename Scalar> 406 struct functor_traits<google_floor_mod<Scalar>> { 407 enum { 408 Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value + 409 2 * NumTraits<Scalar>::AddCost, 410 PacketAccess = false 411 }; 412 }; 413 414 #if EIGEN_COMP_GNUC && __cplusplus > 199711L 415 #define DISABLE_FLOAT_EQUALITY_WARNING \ 416 _Pragma("GCC diagnostic push") \ 417 _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"") 418 #define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop") 419 #else 420 #define DISABLE_FLOAT_EQUALITY_WARNING 421 #define ENABLE_FLOAT_EQUALITY_WARNING 422 #endif 423 424 template <typename Scalar> 425 struct scalar_round_op_google { 426 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar 427 operator()(const Scalar& x) const { 428 EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), 429 NUMERIC_TYPE_MUST_BE_REAL) 430 431 Scalar round_val = Eigen::numext::floor(x); 432 const Scalar fraction = x - round_val; 433 if (fraction > Scalar(.5)) { 434 round_val += Scalar(1.0); 435 } else if (fraction == Scalar(.5)) { 436 const Scalar nearest_even_int = 437 round_val - Scalar(2) * Eigen::numext::floor(Scalar(.5) * x); 438 bool is_odd = (nearest_even_int == Scalar(1)); 439 if (is_odd) { 440 round_val += Scalar(1); 441 } 442 } 443 return round_val; 444 } 445 }; 446 447 template <typename Scalar> 448 struct functor_traits<scalar_round_op_google<Scalar>> { 449 enum { Cost = 4 * NumTraits<Scalar>::AddCost, PacketAccess = false }; 450 }; 451 452 #undef ENABLE_FLOAT_EQUALITY_WARNING 453 #undef DISABLE_FLOAT_EQUALITY_WARNING 454 455 template <typename Scalar> 456 struct bitwise_xor_op { 457 EIGEN_EMPTY_STRUCT_CTOR(bitwise_xor_op) 458 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar 459 operator()(const Scalar& x, const Scalar& y) const { 460 return x ^ y; 461 } 462 typedef typename Eigen::internal::packet_traits<Scalar>::type Packet; 463 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, 464 const Packet& b) const { 465 return Eigen::internal::pxor(a, b); 466 } 467 }; 468 469 template <typename Scalar> 470 struct functor_traits<bitwise_xor_op<Scalar>> { 471 enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true }; 472 }; 473 474 } // end namespace internal 475 } // end namespace Eigen 476 477 namespace tensorflow { 478 namespace functor { 479 480 //////////////////////////////////////////////////////////////////////////////// 481 // Helpers 482 //////////////////////////////////////////////////////////////////////////////// 483 484 // Base template for functors whose input scalar type is T and 485 // output scalar type is R. 486 template <typename T, typename F, typename R = T> 487 struct base { 488 // func defines operator() and its vectorized version packetOp(). 489 typedef F func; 490 491 // If true, the functor's corresponding binary op will instantiate 492 // specialized kernels to perform an optimized broadcast 493 // operation. Each functor for which this is enabled increases the 494 // code size, so by default this is disabled for binary functors and 495 // is enabled on a per-op basis as needed. 496 static const bool use_bcast_optimization = false; 497 498 // operator() has the signature: 499 // out_type operator()(in_type in0, in_type in1 ...) 500 typedef R out_type; 501 typedef T in_type; 502 503 // TensorFlow provides tensor-ized version of "func". Roughly 504 // speaking, the tensorflow operation has the signature: 505 // tout_type op(tin_type in0) 506 // tout_type op(tin_type in0, tin_type in1) 507 // tout_type op(tin_type in0, in_type scalar) 508 typedef typename TTypes<out_type>::Flat tout_type; 509 typedef typename TTypes<in_type>::ConstFlat tin_type; 510 typedef typename TTypes<in_type>::ConstScalar tscalar_type; 511 512 // Whether the functor can error out. Currently applies only to integer 513 // div and mod. 514 static const bool has_errors = false; 515 }; 516 517 // For now, we only apply certain speed optimization for 518 // float/double's broadcast binary op. 519 template <typename T> 520 struct use_bcast_optimization { 521 static const bool value = false; 522 }; 523 524 template <> 525 struct use_bcast_optimization<float> { 526 static const bool value = true; 527 }; 528 529 template <> 530 struct use_bcast_optimization<double> { 531 static const bool value = true; 532 }; 533 534 //////////////////////////////////////////////////////////////////////////////// 535 // Unary functors 536 //////////////////////////////////////////////////////////////////////////////// 537 538 // abs(x) = |x| 539 // neg(x) = - x 540 // inverse(x) = 1 / x 541 // square(x) = x^2 542 // sqrt(x) = x^(1/2) 543 // rsqrt(x) = x^(-1/2) 544 // exp(x) = e^x 545 // expm1(x) = e^x - 1 546 // log(x) = natural logarithm of x 547 // log1p(x) = natural logarithm of 1 + x 548 // tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) 549 // sigmoid = 1 / (1 + exp(-x)) // a.k.a, logistic 550 // 551 // NOTE: We may eventually implement common functions used in NN 552 // here. E.g., rectifier, softplus, derivatives of tanh, sigmod, etc. 553 // For reference, see speech/lstm/eigen_functors.h. 554 555 template <typename T> 556 struct abs : base<T, Eigen::internal::scalar_abs_op<T>, 557 typename Eigen::internal::scalar_abs_op<T>::result_type> {}; 558 559 template <typename T> 560 struct neg : base<T, Eigen::internal::scalar_opposite_op<T>> {}; 561 562 template <typename T> 563 struct inverse : base<T, Eigen::internal::scalar_inverse_op<T>> {}; 564 565 template <typename T> 566 struct square : base<T, Eigen::internal::scalar_square_op<T>> {}; 567 568 template <typename T> 569 struct sqrt : base<T, Eigen::internal::scalar_sqrt_op<T>> {}; 570 571 template <typename T> 572 struct rsqrt : base<T, Eigen::internal::scalar_rsqrt_op<T>> {}; 573 574 template <typename T> 575 struct exp : base<T, Eigen::internal::scalar_exp_op<T>> {}; 576 577 template <typename T> 578 struct expm1 : base<T, Eigen::internal::scalar_expm1_op<T>> {}; 579 580 template <typename T> 581 struct log : base<T, Eigen::internal::scalar_log_op<T>> {}; 582 583 template <typename T> 584 struct log1p : base<T, Eigen::internal::scalar_log1p_op<T>> {}; 585 586 template <typename T> 587 struct sign : base<T, Eigen::internal::scalar_sign_op<T>> {}; 588 589 template <typename T> 590 struct sinh : base<T, Eigen::internal::scalar_sinh_op<T>> {}; 591 592 template <typename T> 593 struct cosh : base<T, Eigen::internal::scalar_cosh_op<T>> {}; 594 595 template <typename T> 596 struct tanh : base<T, Eigen::internal::scalar_tanh_op<T>> {}; 597 598 template <typename T> 599 struct asinh : base<T, Eigen::internal::scalar_asinh_op<T>> {}; 600 601 template <typename T> 602 struct acosh : base<T, Eigen::internal::scalar_acosh_op<T>> {}; 603 604 template <typename T> 605 struct atanh : base<T, Eigen::internal::scalar_atanh_op<T>> {}; 606 607 template <typename T> 608 struct lgamma : base<T, Eigen::internal::scalar_lgamma_op<T>> {}; 609 610 template <typename T> 611 struct digamma : base<T, Eigen::internal::scalar_digamma_op<T>> {}; 612 613 template <typename T> 614 struct erf : base<T, Eigen::internal::scalar_erf_op<T>> {}; 615 616 template <typename T> 617 struct erfc : base<T, Eigen::internal::scalar_erfc_op<T>> {}; 618 619 template <typename T> 620 struct sigmoid : base<T, Eigen::internal::scalar_sigmoid_op<T>> {}; 621 622 template <typename T> 623 struct sin : base<T, Eigen::internal::scalar_sin_op<T>> {}; 624 625 template <typename T> 626 struct cos : base<T, Eigen::internal::scalar_cos_op<T>> {}; 627 628 template <typename T> 629 struct tan : base<T, Eigen::internal::scalar_tan_op<T>> {}; 630 631 template <typename T> 632 struct asin : base<T, Eigen::internal::scalar_asin_op<T>> {}; 633 634 template <typename T> 635 struct acos : base<T, Eigen::internal::scalar_acos_op<T>> {}; 636 637 template <typename T> 638 struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {}; 639 640 struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> { 641 }; 642 643 // Flip all bits. Named invert to be consistent with numpy. 644 template <typename T> 645 struct invert_op { 646 EIGEN_EMPTY_STRUCT_CTOR(invert_op) 647 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const { 648 return ~a; 649 } 650 }; 651 652 template <typename T> 653 struct invert : base<T, invert_op<T>> {}; 654 655 // NOTE: std::isinf, std::isnan, std::isfinite are plain function. 656 // Therefore we need to wrap them in functors to be used with Eigen's 657 // type system. 658 template <typename T> 659 struct isinf : base<T, Eigen::internal::scalar_isinf_op<T>, bool> {}; 660 661 template <typename T> 662 struct isnan : base<T, Eigen::internal::scalar_isnan_op<T>, bool> {}; 663 664 template <typename T> 665 struct isfinite : base<T, Eigen::internal::scalar_isfinite_op<T>, bool> {}; 666 667 template <typename T> 668 struct floor : base<T, Eigen::internal::scalar_floor_op<T>> {}; 669 670 template <typename T> 671 struct round : base<T, Eigen::internal::scalar_round_op_google<T>> {}; 672 673 template <typename T> 674 struct ceil : base<T, Eigen::internal::scalar_ceil_op<T>> {}; 675 676 /** this should go in Eigen 677 * \brief Template functor to compute the round to int value of a scalar 678 */ 679 template <typename Scalar> 680 struct scalar_rint_op { 681 EIGEN_EMPTY_STRUCT_CTOR(scalar_rint_op) 682 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar 683 operator()(const Scalar& a) const { 684 #if defined(__CUDACC__) 685 return ::rint(a); 686 #elif defined(__ANDROID__) 687 return rint(a); 688 #else 689 return std::rint(a); 690 #endif 691 } 692 }; 693 694 template <typename T> 695 struct rint : base<T, scalar_rint_op<T>> {}; 696 697 //////////////////////////////////////////////////////////////////////////////// 698 // Binary functors 699 //////////////////////////////////////////////////////////////////////////////// 700 701 // Binary functors: 702 // 703 // add(x, y) = x + y 704 // sub(x, y) = x - y 705 // mul(x, y) = x * y 706 // div(x, y) = x / y 707 // mod(x, y) = x % y (int32 and int64 only) 708 // fmod(x, y) = fmod(x, y) (float and double only) 709 // pow(x, y) = x ^ y 710 // maximum(x, y) = x > y ? x : y 711 // minimum(x, y) = x < y ? x : y 712 // squared_difference(x, y) = (x - y) * (x - y) 713 714 template <typename T> 715 struct add : base<T, Eigen::internal::scalar_sum_op<T>> { 716 static const bool use_bcast_optimization = true; 717 }; 718 719 template <typename T> 720 struct sub : base<T, Eigen::internal::scalar_difference_op<T>> { 721 static const bool use_bcast_optimization = true; 722 }; 723 724 template <typename T> 725 struct mul : base<T, Eigen::internal::scalar_product_op<T>> { 726 static const bool use_bcast_optimization = true; 727 }; 728 729 template <typename T> 730 struct div : base<T, Eigen::internal::scalar_quotient_op<T>> {}; 731 732 template <typename T> 733 struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op< 734 T, Eigen::internal::scalar_quotient_op<T>>> { 735 static const bool has_errors = true; 736 }; 737 738 template <typename T> 739 struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {}; 740 741 template <typename T> 742 struct mod : base<T, Eigen::internal::scalar_mod2_op<T>> {}; 743 744 template <typename T> 745 struct safe_mod : base<T, Eigen::internal::safe_div_or_mod_op< 746 T, Eigen::internal::scalar_mod2_op<T>>> { 747 static const bool has_errors = true; 748 }; 749 750 template <typename T> 751 struct floor_fmod : base<T, Eigen::internal::google_floor_fmod<T>> {}; 752 753 template <typename T> 754 struct safe_floor_mod : base<T, Eigen::internal::safe_div_or_mod_op< 755 T, Eigen::internal::google_floor_mod<T>>> { 756 static const bool has_errors = true; 757 }; 758 759 template <typename T> 760 struct floor_div : base<T, Eigen::internal::google_floor_div<T>> {}; 761 762 template <typename T> 763 struct safe_floor_div : base<T, Eigen::internal::safe_div_or_mod_op< 764 T, Eigen::internal::google_floor_div<T>>> { 765 static const bool has_errors = true; 766 }; 767 768 template <typename T> 769 struct floor_div_real : base<T, Eigen::internal::google_floor_div_real<T>> {}; 770 771 template <typename T> 772 struct pow : base<T, Eigen::internal::scalar_binary_pow_op_google<T, T>> {}; 773 774 template <typename T> 775 struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> { 776 static const bool has_errors = true; 777 }; 778 779 template <typename T> 780 struct maximum : base<T, Eigen::internal::scalar_max_op<T>> {}; 781 782 template <typename T> 783 struct minimum : base<T, Eigen::internal::scalar_min_op<T>> {}; 784 785 template <typename T> 786 struct igamma : base<T, Eigen::internal::scalar_igamma_op<T>> {}; 787 788 template <typename T> 789 struct igammac : base<T, Eigen::internal::scalar_igammac_op<T>> {}; 790 791 template <typename T> 792 struct zeta : base<T, Eigen::internal::scalar_zeta_op<T>> {}; 793 794 template <typename T> 795 struct polygamma : base<T, Eigen::internal::scalar_polygamma_op<T>> {}; 796 797 template <typename Scalar> 798 struct scalar_atan2_op { 799 EIGEN_EMPTY_STRUCT_CTOR(scalar_atan2_op) 800 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar 801 operator()(const Scalar& y, const Scalar& x) const { 802 #if GOOGLE_CUDA 803 return ::atan2(y, x); 804 #else 805 return std::atan2(y, x); 806 #endif 807 } 808 }; 809 810 template <typename T> 811 struct atan2 : base<T, scalar_atan2_op<T>> {}; 812 813 template <typename T> 814 struct squared_difference 815 : base<T, Eigen::internal::scalar_compose_op< 816 T, Eigen::internal::scalar_square_op<T>, 817 Eigen::internal::scalar_difference_op<T>>> {}; 818 819 template <typename T> 820 struct less : base<T, Eigen::internal::less<T>, bool> {}; 821 822 template <typename T> 823 struct less_equal : base<T, Eigen::internal::less_equal<T>, bool> {}; 824 825 template <typename T> 826 struct greater : base<T, Eigen::internal::greater<T>, bool> {}; 827 828 template <typename T> 829 struct greater_equal : base<T, Eigen::internal::greater_equal<T>, bool> {}; 830 831 template <typename T> 832 struct equal_to : base<T, Eigen::internal::equal_to<T>, bool> {}; 833 834 template <typename T> 835 struct not_equal_to : base<T, Eigen::internal::not_equal_to<T>, bool> {}; 836 837 struct logical_and : base<bool, Eigen::internal::scalar_boolean_and_op> {}; 838 839 struct logical_or : base<bool, Eigen::internal::scalar_boolean_or_op> {}; 840 841 template <typename T> 842 struct bitwise_and_op { 843 EIGEN_EMPTY_STRUCT_CTOR(bitwise_and_op) 844 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 845 const T& y) const { 846 return x & y; 847 } 848 }; 849 850 template <typename T> 851 struct bitwise_or_op { 852 EIGEN_EMPTY_STRUCT_CTOR(bitwise_or_op) 853 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 854 const T& y) const { 855 return x | y; 856 } 857 }; 858 859 template <typename T> 860 struct bitwise_and : base<T, bitwise_and_op<T>> {}; 861 862 template <typename T> 863 struct bitwise_or : base<T, bitwise_or_op<T>> {}; 864 865 template <typename T> 866 struct bitwise_xor : base<T, Eigen::internal::bitwise_xor_op<T>> {}; 867 868 template <typename T> 869 struct left_shift_op { 870 EIGEN_EMPTY_STRUCT_CTOR(left_shift_op) 871 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 872 const T& y) const { 873 // Avoids UB: don't shift by larger than the bitwidth of T, and 874 // performs left shifts as unsigned shifts. 875 T y_clamped = y; 876 if (y_clamped < 0) { 877 y_clamped = 0; 878 } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) { 879 y_clamped = sizeof(T) * CHAR_BIT - 1; 880 } 881 using U = typename std::make_unsigned<T>::type; 882 return static_cast<T>(static_cast<U>(x) << static_cast<U>(y_clamped)); 883 } 884 }; 885 886 template <typename T> 887 struct right_shift_op { 888 EIGEN_EMPTY_STRUCT_CTOR(right_shift_op) 889 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, 890 const T& y) const { 891 // Avoids UB: don't shift by larger than the bitwidth of T. 892 T y_clamped = y; 893 if (y_clamped < 0) { 894 y_clamped = 0; 895 } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) { 896 y_clamped = sizeof(T) * CHAR_BIT - 1; 897 } 898 // Technically right shifts of signed integers are not necessarily 899 // arithmetic shifts according to the C++ standard. However in practice most 900 // implementations are arithmetic shifts. If this proves to be a problem in 901 // practice, we may need to use an alternative implementation. 902 return x >> y_clamped; 903 } 904 }; 905 906 template <typename T> 907 struct left_shift : base<T, left_shift_op<T>> {}; 908 909 template <typename T> 910 struct right_shift : base<T, right_shift_op<T>> {}; 911 912 template <typename T> 913 struct make_complex_func { 914 typedef std::complex<T> result_type; 915 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(T real, 916 T imag) const { 917 return std::complex<T>(real, imag); 918 } 919 }; 920 921 template <typename T> 922 struct make_complex : base<T, make_complex_func<T>, std::complex<T>> {}; 923 924 template <typename T> 925 struct get_real 926 : base<T, Eigen::internal::scalar_real_op<T>, typename T::value_type> {}; 927 928 template <typename T> 929 struct get_imag 930 : base<T, Eigen::internal::scalar_imag_op<T>, typename T::value_type> {}; 931 932 template <typename T> 933 struct get_angle 934 : base<T, Eigen::internal::scalar_arg_op<T>, typename T::value_type> {}; 935 936 template <typename T> 937 struct conj : base<T, Eigen::internal::scalar_conjugate_op<T>> {}; 938 939 //////////////////////////////////////////////////////////////////////////////// 940 // Functors takes 1 or 2 tensors, computes the base functor on 941 // coefficient of the input tensors and puts the results in the output 942 // tensor. 943 //////////////////////////////////////////////////////////////////////////////// 944 template <typename Device, typename Functor> 945 struct UnaryFunctor { 946 // Computes on device "d": out[i] = Functor(in[i]) 947 void operator()(const Device& d, typename Functor::tout_type out, 948 typename Functor::tin_type in); 949 }; 950 951 template <typename Device, typename Functor, int NDIMS, 952 bool has_errors = Functor::has_errors> 953 struct BinaryFunctor { 954 // Computes on device "d": out[i] = Functor(in0[i], in1[i]) 955 void operator()(const Device& d, typename Functor::tout_type out, 956 typename Functor::tin_type in0, 957 typename Functor::tin_type in1, bool* error); 958 959 // Computes on device "d": out[i] = Functor(scalar[0], in[i]) 960 void Left(const Device& d, typename Functor::tout_type out, 961 typename Functor::tscalar_type scalar, 962 typename Functor::tin_type in, bool* error); 963 964 // Computes on device "d": out[i] = Functor(in[i], scalar[0]) 965 void Right(const Device& d, typename Functor::tout_type out, 966 typename Functor::tin_type in, 967 typename Functor::tscalar_type scalar, bool* error); 968 969 // Computes on device "d": 970 // out = Functor(in0.broadcast(bcast0), in1.broadcast(bcast1)) 971 // 972 // TODO(zhifengc): makes BCast a template member function on NDIMS 973 // instead making BinaryFunctor templates on NDIMS. 974 void BCast(const Device& d, 975 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out, 976 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0, 977 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0, 978 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1, 979 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1, 980 bool* error); 981 }; 982 983 template <typename Device, typename T> 984 struct ApproximateEqual { 985 void operator()(const Device& d, typename TTypes<T>::ConstFlat x, 986 typename TTypes<T>::ConstFlat y, T tolerance, 987 typename TTypes<bool>::Flat z); 988 }; 989 990 template <int NDIMS> 991 bool AllOne(const typename Eigen::array<Eigen::DenseIndex, NDIMS>& a) { 992 for (size_t i = 0; i < a.size(); ++i) { 993 if (a[i] != 1) return false; 994 } 995 return true; 996 } 997 998 template <typename Device, typename T> 999 struct SelectFunctor { 1000 void operator()(const Device& d, typename TTypes<T>::Flat out, 1001 typename TTypes<bool>::ConstFlat cond_flat, 1002 typename TTypes<T>::ConstFlat then_flat, 1003 typename TTypes<T>::ConstFlat else_flat); 1004 }; 1005 1006 template <typename Device, typename T> 1007 struct SelectScalarFunctor { 1008 void operator()(const Device& d, typename TTypes<T>::Flat out, 1009 typename TTypes<bool>::ConstScalar cond, 1010 typename TTypes<T>::ConstFlat then_flat, 1011 typename TTypes<T>::ConstFlat else_flat); 1012 }; 1013 1014 template <typename Device, typename T> 1015 struct BatchSelectFunctor { 1016 void operator()(const Device& d, 1017 typename TTypes<T>::Matrix output_flat_outer_dims, 1018 TTypes<bool>::ConstVec cond_vec, 1019 typename TTypes<T>::ConstMatrix then_flat_outer_dims, 1020 typename TTypes<T>::ConstMatrix else_flat_outer_dims); 1021 }; 1022 1023 } // end namespace functor 1024 } // end namespace tensorflow 1025 1026 #endif // TENSORFLOW_KERNELS_CWISE_OPS_H_ 1027