Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_CWISE_OPS_H_
     17 #define TENSORFLOW_KERNELS_CWISE_OPS_H_
     18 
     19 #include <cmath>
     20 #include <functional>
     21 #include <type_traits>
     22 
     23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     24 
     25 #include "tensorflow/core/framework/numeric_types.h"
     26 #include "tensorflow/core/framework/tensor_types.h"
     27 #include "tensorflow/core/kernels/bounds_check.h"
     28 
     29 namespace Eigen {
     30 namespace numext {
     31 #if GOOGLE_CUDA
     32 template <>
     33 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::complex<float> exp(
     34     const std::complex<float>& x) {
     35   auto com = ::expf(x.real());
     36   auto res_real = com * ::cosf(x.imag());
     37   auto res_imag = com * ::sinf(x.imag());
     38   return std::complex<float>(res_real, res_imag);
     39 }
     40 template <>
     41 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE std::complex<double> exp(
     42     const std::complex<double>& x) {
     43   auto com = ::exp(x.real());
     44   auto res_real = com * ::cos(x.imag());
     45   auto res_imag = com * ::sin(x.imag());
     46   return std::complex<double>(res_real, res_imag);
     47 }
     48 #endif
     49 }  // namespace numext
     50 
     51 namespace internal {
     52 
     53 template <typename T>
     54 struct scalar_asinh_op {
     55   EIGEN_EMPTY_STRUCT_CTOR(scalar_asinh_op)
     56   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const {
     57 #if EIGEN_HAS_CXX11_MATH
     58     return numext::asinh(a);
     59 #else
     60     return std::asinh(a);
     61 #endif  // EIGEN_HAS_CXX11_MATH
     62   }
     63 };
     64 template <typename T>
     65 struct functor_traits<scalar_asinh_op<T>> {
     66   enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
     67 };
     68 
     69 template <typename T>
     70 struct scalar_acosh_op {
     71   EIGEN_EMPTY_STRUCT_CTOR(scalar_acosh_op)
     72   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const {
     73 #if EIGEN_HAS_CXX11_MATH
     74     return numext::acosh(a);
     75 #else
     76     return std::acosh(a);
     77 #endif  // EIGEN_HAS_CXX11_MATH
     78   }
     79 };
     80 template <typename T>
     81 struct functor_traits<scalar_acosh_op<T>> {
     82   enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
     83 };
     84 
     85 template <typename T>
     86 struct scalar_atanh_op {
     87   EIGEN_EMPTY_STRUCT_CTOR(scalar_atanh_op)
     88   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const {
     89 #if EIGEN_HAS_CXX11_MATH
     90     return numext::atanh(a);
     91 #else
     92     return std::atanh(a);
     93 #endif  // EIGEN_HAS_CXX11_MATH
     94   }
     95 };
     96 template <typename T>
     97 struct functor_traits<scalar_atanh_op<T>> {
     98   enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
     99 };
    100 
    101 // TODO(rmlarsen): This is a workaround for upstream change
    102 // https://bitbucket.org/eigen/eigen/commits/f339468d04d0f87caeb6cab9aef568627e9f6ea9
    103 // that renamed scalar_binary_pow_op to scalar_pow_op and deleted the unary
    104 // version of the latter. Remove once we upgrade to Eigen 3.3.
    105 template <typename Scalar, typename Exponent>
    106 struct scalar_binary_pow_op_google {
    107   EIGEN_EMPTY_STRUCT_CTOR(scalar_binary_pow_op_google)
    108   EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a,
    109                                              const Exponent& b) const {
    110     return numext::pow(a, b);
    111   }
    112 };
    113 
    114 template <typename Scalar, typename Exponent>
    115 struct functor_traits<scalar_binary_pow_op_google<Scalar, Exponent>> {
    116   enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
    117 };
    118 
    119 template <typename Scalar, typename Exponent>
    120 struct safe_scalar_binary_pow_op {
    121   static_assert(std::is_integral<Scalar>::value, "Integer type expected");
    122   static_assert(std::is_integral<Exponent>::value &&
    123                     std::is_signed<Exponent>::value,
    124                 "Signed integer type expected");
    125 
    126   bool* const error;
    127 
    128   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error)
    129       : error(error) {}
    130 
    131   EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a,
    132                                              const Exponent& b) const {
    133     const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b);
    134     if (TF_PREDICT_TRUE(safe_b >= 0)) {
    135       return numext::pow(a, safe_b);
    136     } else {
    137       *error = true;
    138       return 0;
    139     }
    140   }
    141 };
    142 
    143 template <typename Scalar, typename Exponent>
    144 struct functor_traits<safe_scalar_binary_pow_op<Scalar, Exponent>> {
    145   enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
    146 };
    147 
    148 template <typename T, typename DivOrMod>
    149 struct safe_div_or_mod_op {
    150   static_assert(std::is_integral<T>::value, "Integer type expected");
    151 
    152   bool* const error;
    153 
    154   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_div_or_mod_op(bool* error)
    155       : error(error) {}
    156 
    157   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a,
    158                                                            const T& b) const {
    159     const T safe_b = tensorflow::internal::SubtleMustCopy(b);
    160     if (TF_PREDICT_TRUE(safe_b != 0)) {
    161       return DivOrMod()(a, safe_b);
    162     } else {
    163       *error = true;
    164       return 0;
    165     }
    166   }
    167 };
    168 
    169 template <typename T, typename DivOrMod>
    170 struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> {
    171   enum {
    172     Cost = functor_traits<DivOrMod>::Cost + NumTraits<T>::AddCost,
    173     PacketAccess = false,
    174   };
    175 };
    176 
    177 // scalar_left and scalar_right are template helpers to partially
    178 // apply a binary function.
    179 //
    180 // Suppose Binary is a binary functor f(x, y), scalar_left<> is a
    181 // unary functor g_x(y) = f(x, y), where x is provided via the
    182 // constructor. Similarly, scalar_right<> is a unary functor g_y(x) =
    183 // f(x, y).
    184 
    185 template <typename Tout, typename Tin, typename Binary>
    186 struct scalar_left : private Binary {
    187   typedef Tout result_type;
    188   const Tin* left;
    189 
    190   EIGEN_DEVICE_FUNC inline scalar_left(const scalar_left& other) = default;
    191 
    192   template <typename... Args>
    193   EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c, Args... args)
    194       : Binary(args...), left(c) {}
    195 
    196   EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const {
    197     return Binary::operator()(*left, right);
    198   }
    199 
    200   template <typename Packet>
    201   EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const {
    202     const Packet left_packet = Eigen::internal::pset1<Packet>(*left);
    203     return Binary::packetOp(left_packet, right_packet);
    204   }
    205 };
    206 
    207 template <typename Tout, typename Tin, typename Binary>
    208 struct functor_traits<scalar_left<Tout, Tin, Binary>> {
    209   enum {
    210     Cost = functor_traits<Binary>::Cost,
    211     PacketAccess = functor_traits<Binary>::PacketAccess,
    212   };
    213 };
    214 
    215 template <typename Tout, typename Tin, typename Binary>
    216 struct scalar_right : private Binary {
    217   typedef Tout result_type;
    218   const Tin* right;
    219 
    220   EIGEN_DEVICE_FUNC inline scalar_right(const scalar_right& other) = default;
    221 
    222   template <typename... Args>
    223   EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c, Args... args)
    224       : Binary(args...), right(c) {}
    225 
    226   EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const {
    227     return Binary::operator()(left, *right);
    228   }
    229 
    230   template <typename Packet>
    231   EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const {
    232     const Packet right_packet = Eigen::internal::pset1<Packet>(*right);
    233     return Binary::packetOp(left_packet, right_packet);
    234   }
    235 };
    236 
    237 template <typename Tout, typename Tin, typename Binary>
    238 struct functor_traits<scalar_right<Tout, Tin, Binary>> {
    239   enum {
    240     Cost = functor_traits<Binary>::Cost,
    241     PacketAccess = functor_traits<Binary>::PacketAccess,
    242   };
    243 };
    244 
    245 // similar to std::equal_to, but with the DEVICE_FUNC qualifier
    246 template <class T>
    247 struct equal_to : std::binary_function<T, T, bool> {
    248   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
    249                                                         const T& y) const {
    250     return x == y;
    251   }
    252 };
    253 
    254 // similar to std::not_equal_to, but with the DEVICE_FUNC qualifier
    255 template <class T>
    256 struct not_equal_to : std::binary_function<T, T, bool> {
    257   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
    258                                                         const T& y) const {
    259     return x != y;
    260   }
    261 };
    262 
    263 // similar to std::greater, but with the DEVICE_FUNC qualifier
    264 template <class T>
    265 struct greater : std::binary_function<T, T, bool> {
    266   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
    267                                                         const T& y) const {
    268     return x > y;
    269   }
    270 };
    271 
    272 // similar to std::less, but with the DEVICE_FUNC qualifier
    273 template <class T>
    274 struct less : std::binary_function<T, T, bool> {
    275   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
    276                                                         const T& y) const {
    277     return x < y;
    278   }
    279 };
    280 
    281 // similar to std::greater_equal, but with the DEVICE_FUNC qualifier
    282 template <class T>
    283 struct greater_equal : std::binary_function<T, T, bool> {
    284   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
    285                                                         const T& y) const {
    286     return x >= y;
    287   }
    288 };
    289 
    290 // similar to std::less_equal, but with the DEVICE_FUNC qualifier
    291 template <class T>
    292 struct less_equal : std::binary_function<T, T, bool> {
    293   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
    294                                                         const T& y) const {
    295     return x <= y;
    296   }
    297 };
    298 
    299 // Functor that enables composition of multiple Eigen functors.
    300 template <typename Scalar, typename UnaryFunctor, typename BinaryFunctor>
    301 struct scalar_compose_op {
    302   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
    303   operator()(const Scalar& a, const Scalar& b) const {
    304     return UnaryFunctor()(BinaryFunctor()(a, b));
    305   }
    306   template <typename Packet>
    307   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
    308   packetOp(const Packet& a, const Packet& b) const {
    309     return UnaryFunctor().packetOp(BinaryFunctor().packetOp(a, b));
    310   }
    311 };
    312 
    313 template <typename Scalar, typename UnaryFunctor, typename BinaryFunctor>
    314 struct functor_traits<scalar_compose_op<Scalar, UnaryFunctor, BinaryFunctor>> {
    315   enum {
    316     Cost = functor_traits<UnaryFunctor>::Cost +
    317            functor_traits<BinaryFunctor>::Cost,
    318     PacketAccess = functor_traits<UnaryFunctor>::PacketAccess &&
    319                    functor_traits<BinaryFunctor>::PacketAccess
    320   };
    321 };
    322 
    323 // TODO(b/32239616): This kernel should be moved into Eigen and vectorized.
    324 template <typename T, typename Enable = void>
    325 struct google_floor_div {
    326   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
    327                                                            const T& y) const {
    328     if ((x < T(0)) != (y < T(0))) {
    329       T abs_x = std::abs(x);
    330       T abs_y = std::abs(y);
    331       return -(abs_x + abs_y - 1) / abs_y;
    332     } else {
    333       return x / y;
    334     }
    335   }
    336 };
    337 
    338 template <typename T>
    339 struct google_floor_div<
    340     T, typename std::enable_if<std::is_unsigned<T>::value>::type> {
    341   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
    342                                                            const T& y) const {
    343     return x / y;
    344   }
    345 };
    346 
    347 template <typename Scalar>
    348 struct functor_traits<google_floor_div<Scalar>> {
    349   enum {
    350     Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value +
    351            2 * NumTraits<Scalar>::AddCost,
    352     PacketAccess = false
    353   };
    354 };
    355 
    356 // TODO(b/32239616): This kernel should be moved into Eigen and vectorized.
    357 template <typename T, typename Enable = void>
    358 struct google_floor_div_real {
    359   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
    360                                                            const T& y) const {
    361     return Eigen::numext::floor(x / y);
    362   }
    363 };
    364 
    365 template <typename Scalar>
    366 struct functor_traits<google_floor_div_real<Scalar>> {
    367   enum {
    368     Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value +
    369            2 * NumTraits<Scalar>::AddCost,
    370     PacketAccess = false
    371   };
    372 };
    373 
    374 // TODO(b//32239616): This kernel should be moved into Eigen and vectorized.
    375 template <typename T>
    376 struct google_floor_fmod {
    377   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
    378                                                            const T& y) const {
    379     // EIGEN_STATIC_ASSERT(NUMERIC_TYPE_MUST_BE_REAL);
    380     T trunc_mod = std::fmod(x, y);
    381     return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y);
    382   }
    383 };
    384 
    385 template <typename Scalar>
    386 struct functor_traits<google_floor_fmod<Scalar>> {
    387   enum {
    388     Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value +
    389            2 * NumTraits<Scalar>::AddCost,
    390     PacketAccess = false
    391   };
    392 };
    393 
    394 // TODO(b/32239616): This kernel should be moved into Eigen and vectorized.
    395 template <typename T>
    396 struct google_floor_mod {
    397   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
    398                                                            const T& y) const {
    399     // EIGEN_STATIC_ASSERT(!NUMERIC_TYPE_MUST_BE_REAL);
    400     T trunc_mod = x % y;
    401     return (x < T(0)) == (y < T(0)) ? trunc_mod : (trunc_mod + y) % y;
    402   }
    403 };
    404 
    405 template <typename Scalar>
    406 struct functor_traits<google_floor_mod<Scalar>> {
    407   enum {
    408     Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value +
    409            2 * NumTraits<Scalar>::AddCost,
    410     PacketAccess = false
    411   };
    412 };
    413 
    414 #if EIGEN_COMP_GNUC && __cplusplus > 199711L
    415 #define DISABLE_FLOAT_EQUALITY_WARNING \
    416   _Pragma("GCC diagnostic push")       \
    417       _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
    418 #define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
    419 #else
    420 #define DISABLE_FLOAT_EQUALITY_WARNING
    421 #define ENABLE_FLOAT_EQUALITY_WARNING
    422 #endif
    423 
    424 template <typename Scalar>
    425 struct scalar_round_op_google {
    426   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
    427   operator()(const Scalar& x) const {
    428     EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex),
    429                         NUMERIC_TYPE_MUST_BE_REAL)
    430 
    431     Scalar round_val = Eigen::numext::floor(x);
    432     const Scalar fraction = x - round_val;
    433     if (fraction > Scalar(.5)) {
    434       round_val += Scalar(1.0);
    435     } else if (fraction == Scalar(.5)) {
    436       const Scalar nearest_even_int =
    437           round_val - Scalar(2) * Eigen::numext::floor(Scalar(.5) * x);
    438       bool is_odd = (nearest_even_int == Scalar(1));
    439       if (is_odd) {
    440         round_val += Scalar(1);
    441       }
    442     }
    443     return round_val;
    444   }
    445 };
    446 
    447 template <typename Scalar>
    448 struct functor_traits<scalar_round_op_google<Scalar>> {
    449   enum { Cost = 4 * NumTraits<Scalar>::AddCost, PacketAccess = false };
    450 };
    451 
    452 #undef ENABLE_FLOAT_EQUALITY_WARNING
    453 #undef DISABLE_FLOAT_EQUALITY_WARNING
    454 
    455 template <typename Scalar>
    456 struct bitwise_xor_op {
    457   EIGEN_EMPTY_STRUCT_CTOR(bitwise_xor_op)
    458   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
    459   operator()(const Scalar& x, const Scalar& y) const {
    460     return x ^ y;
    461   }
    462   typedef typename Eigen::internal::packet_traits<Scalar>::type Packet;
    463   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a,
    464                                                         const Packet& b) const {
    465     return Eigen::internal::pxor(a, b);
    466   }
    467 };
    468 
    469 template <typename Scalar>
    470 struct functor_traits<bitwise_xor_op<Scalar>> {
    471   enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true };
    472 };
    473 
    474 }  // end namespace internal
    475 }  // end namespace Eigen
    476 
    477 namespace tensorflow {
    478 namespace functor {
    479 
    480 ////////////////////////////////////////////////////////////////////////////////
    481 // Helpers
    482 ////////////////////////////////////////////////////////////////////////////////
    483 
    484 // Base template for functors whose input scalar type is T and
    485 // output scalar type is R.
    486 template <typename T, typename F, typename R = T>
    487 struct base {
    488   // func defines operator() and its vectorized version packetOp().
    489   typedef F func;
    490 
    491   // If true, the functor's corresponding binary op will instantiate
    492   // specialized kernels to perform an optimized broadcast
    493   // operation. Each functor for which this is enabled increases the
    494   // code size, so by default this is disabled for binary functors and
    495   // is enabled on a per-op basis as needed.
    496   static const bool use_bcast_optimization = false;
    497 
    498   // operator() has the signature:
    499   //  out_type operator()(in_type in0, in_type in1 ...)
    500   typedef R out_type;
    501   typedef T in_type;
    502 
    503   // TensorFlow provides tensor-ized version of "func". Roughly
    504   // speaking, the tensorflow operation has the signature:
    505   //   tout_type op(tin_type in0)
    506   //   tout_type op(tin_type in0, tin_type in1)
    507   //   tout_type op(tin_type in0, in_type scalar)
    508   typedef typename TTypes<out_type>::Flat tout_type;
    509   typedef typename TTypes<in_type>::ConstFlat tin_type;
    510   typedef typename TTypes<in_type>::ConstScalar tscalar_type;
    511 
    512   // Whether the functor can error out.  Currently applies only to integer
    513   // div and mod.
    514   static const bool has_errors = false;
    515 };
    516 
    517 // For now, we only apply certain speed optimization for
    518 // float/double's broadcast binary op.
    519 template <typename T>
    520 struct use_bcast_optimization {
    521   static const bool value = false;
    522 };
    523 
    524 template <>
    525 struct use_bcast_optimization<float> {
    526   static const bool value = true;
    527 };
    528 
    529 template <>
    530 struct use_bcast_optimization<double> {
    531   static const bool value = true;
    532 };
    533 
    534 ////////////////////////////////////////////////////////////////////////////////
    535 // Unary functors
    536 ////////////////////////////////////////////////////////////////////////////////
    537 
    538 // abs(x) = |x|
    539 // neg(x) = - x
    540 // inverse(x) = 1 / x
    541 // square(x) = x^2
    542 // sqrt(x) = x^(1/2)
    543 // rsqrt(x) = x^(-1/2)
    544 // exp(x) = e^x
    545 // expm1(x) = e^x - 1
    546 // log(x) = natural logarithm of x
    547 // log1p(x) = natural logarithm of 1 + x
    548 // tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
    549 // sigmoid = 1 / (1 + exp(-x))  // a.k.a, logistic
    550 //
    551 // NOTE: We may eventually implement common functions used in NN
    552 // here. E.g., rectifier, softplus, derivatives of tanh, sigmod, etc.
    553 // For reference, see speech/lstm/eigen_functors.h.
    554 
    555 template <typename T>
    556 struct abs : base<T, Eigen::internal::scalar_abs_op<T>,
    557                   typename Eigen::internal::scalar_abs_op<T>::result_type> {};
    558 
    559 template <typename T>
    560 struct neg : base<T, Eigen::internal::scalar_opposite_op<T>> {};
    561 
    562 template <typename T>
    563 struct inverse : base<T, Eigen::internal::scalar_inverse_op<T>> {};
    564 
    565 template <typename T>
    566 struct square : base<T, Eigen::internal::scalar_square_op<T>> {};
    567 
    568 template <typename T>
    569 struct sqrt : base<T, Eigen::internal::scalar_sqrt_op<T>> {};
    570 
    571 template <typename T>
    572 struct rsqrt : base<T, Eigen::internal::scalar_rsqrt_op<T>> {};
    573 
    574 template <typename T>
    575 struct exp : base<T, Eigen::internal::scalar_exp_op<T>> {};
    576 
    577 template <typename T>
    578 struct expm1 : base<T, Eigen::internal::scalar_expm1_op<T>> {};
    579 
    580 template <typename T>
    581 struct log : base<T, Eigen::internal::scalar_log_op<T>> {};
    582 
    583 template <typename T>
    584 struct log1p : base<T, Eigen::internal::scalar_log1p_op<T>> {};
    585 
    586 template <typename T>
    587 struct sign : base<T, Eigen::internal::scalar_sign_op<T>> {};
    588 
    589 template <typename T>
    590 struct sinh : base<T, Eigen::internal::scalar_sinh_op<T>> {};
    591 
    592 template <typename T>
    593 struct cosh : base<T, Eigen::internal::scalar_cosh_op<T>> {};
    594 
    595 template <typename T>
    596 struct tanh : base<T, Eigen::internal::scalar_tanh_op<T>> {};
    597 
    598 template <typename T>
    599 struct asinh : base<T, Eigen::internal::scalar_asinh_op<T>> {};
    600 
    601 template <typename T>
    602 struct acosh : base<T, Eigen::internal::scalar_acosh_op<T>> {};
    603 
    604 template <typename T>
    605 struct atanh : base<T, Eigen::internal::scalar_atanh_op<T>> {};
    606 
    607 template <typename T>
    608 struct lgamma : base<T, Eigen::internal::scalar_lgamma_op<T>> {};
    609 
    610 template <typename T>
    611 struct digamma : base<T, Eigen::internal::scalar_digamma_op<T>> {};
    612 
    613 template <typename T>
    614 struct erf : base<T, Eigen::internal::scalar_erf_op<T>> {};
    615 
    616 template <typename T>
    617 struct erfc : base<T, Eigen::internal::scalar_erfc_op<T>> {};
    618 
    619 template <typename T>
    620 struct sigmoid : base<T, Eigen::internal::scalar_sigmoid_op<T>> {};
    621 
    622 template <typename T>
    623 struct sin : base<T, Eigen::internal::scalar_sin_op<T>> {};
    624 
    625 template <typename T>
    626 struct cos : base<T, Eigen::internal::scalar_cos_op<T>> {};
    627 
    628 template <typename T>
    629 struct tan : base<T, Eigen::internal::scalar_tan_op<T>> {};
    630 
    631 template <typename T>
    632 struct asin : base<T, Eigen::internal::scalar_asin_op<T>> {};
    633 
    634 template <typename T>
    635 struct acos : base<T, Eigen::internal::scalar_acos_op<T>> {};
    636 
    637 template <typename T>
    638 struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {};
    639 
    640 struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> {
    641 };
    642 
    643 // Flip all bits. Named invert to be consistent with numpy.
    644 template <typename T>
    645 struct invert_op {
    646   EIGEN_EMPTY_STRUCT_CTOR(invert_op)
    647   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const {
    648     return ~a;
    649   }
    650 };
    651 
    652 template <typename T>
    653 struct invert : base<T, invert_op<T>> {};
    654 
    655 // NOTE: std::isinf, std::isnan, std::isfinite are plain function.
    656 // Therefore we need to wrap them in functors to be used with Eigen's
    657 // type system.
    658 template <typename T>
    659 struct isinf : base<T, Eigen::internal::scalar_isinf_op<T>, bool> {};
    660 
    661 template <typename T>
    662 struct isnan : base<T, Eigen::internal::scalar_isnan_op<T>, bool> {};
    663 
    664 template <typename T>
    665 struct isfinite : base<T, Eigen::internal::scalar_isfinite_op<T>, bool> {};
    666 
    667 template <typename T>
    668 struct floor : base<T, Eigen::internal::scalar_floor_op<T>> {};
    669 
    670 template <typename T>
    671 struct round : base<T, Eigen::internal::scalar_round_op_google<T>> {};
    672 
    673 template <typename T>
    674 struct ceil : base<T, Eigen::internal::scalar_ceil_op<T>> {};
    675 
    676 /** this should go in Eigen
    677  * \brief Template functor to compute the round to int value of a scalar
    678  */
    679 template <typename Scalar>
    680 struct scalar_rint_op {
    681   EIGEN_EMPTY_STRUCT_CTOR(scalar_rint_op)
    682   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
    683   operator()(const Scalar& a) const {
    684 #if defined(__CUDACC__)
    685     return ::rint(a);
    686 #elif defined(__ANDROID__)
    687     return rint(a);
    688 #else
    689     return std::rint(a);
    690 #endif
    691   }
    692 };
    693 
    694 template <typename T>
    695 struct rint : base<T, scalar_rint_op<T>> {};
    696 
    697 ////////////////////////////////////////////////////////////////////////////////
    698 // Binary functors
    699 ////////////////////////////////////////////////////////////////////////////////
    700 
    701 // Binary functors:
    702 //
    703 // add(x, y) = x + y
    704 // sub(x, y) = x - y
    705 // mul(x, y) = x * y
    706 // div(x, y) = x / y
    707 // mod(x, y) = x % y         (int32 and int64 only)
    708 // fmod(x, y) = fmod(x, y)   (float and double only)
    709 // pow(x, y) = x ^ y
    710 // maximum(x, y) = x > y ? x : y
    711 // minimum(x, y) = x < y ? x : y
    712 // squared_difference(x, y) = (x - y) * (x - y)
    713 
    714 template <typename T>
    715 struct add : base<T, Eigen::internal::scalar_sum_op<T>> {
    716   static const bool use_bcast_optimization = true;
    717 };
    718 
    719 template <typename T>
    720 struct sub : base<T, Eigen::internal::scalar_difference_op<T>> {
    721   static const bool use_bcast_optimization = true;
    722 };
    723 
    724 template <typename T>
    725 struct mul : base<T, Eigen::internal::scalar_product_op<T>> {
    726   static const bool use_bcast_optimization = true;
    727 };
    728 
    729 template <typename T>
    730 struct div : base<T, Eigen::internal::scalar_quotient_op<T>> {};
    731 
    732 template <typename T>
    733 struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op<
    734                               T, Eigen::internal::scalar_quotient_op<T>>> {
    735   static const bool has_errors = true;
    736 };
    737 
    738 template <typename T>
    739 struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {};
    740 
    741 template <typename T>
    742 struct mod : base<T, Eigen::internal::scalar_mod2_op<T>> {};
    743 
    744 template <typename T>
    745 struct safe_mod : base<T, Eigen::internal::safe_div_or_mod_op<
    746                               T, Eigen::internal::scalar_mod2_op<T>>> {
    747   static const bool has_errors = true;
    748 };
    749 
    750 template <typename T>
    751 struct floor_fmod : base<T, Eigen::internal::google_floor_fmod<T>> {};
    752 
    753 template <typename T>
    754 struct safe_floor_mod : base<T, Eigen::internal::safe_div_or_mod_op<
    755                                     T, Eigen::internal::google_floor_mod<T>>> {
    756   static const bool has_errors = true;
    757 };
    758 
    759 template <typename T>
    760 struct floor_div : base<T, Eigen::internal::google_floor_div<T>> {};
    761 
    762 template <typename T>
    763 struct safe_floor_div : base<T, Eigen::internal::safe_div_or_mod_op<
    764                                     T, Eigen::internal::google_floor_div<T>>> {
    765   static const bool has_errors = true;
    766 };
    767 
    768 template <typename T>
    769 struct floor_div_real : base<T, Eigen::internal::google_floor_div_real<T>> {};
    770 
    771 template <typename T>
    772 struct pow : base<T, Eigen::internal::scalar_binary_pow_op_google<T, T>> {};
    773 
    774 template <typename T>
    775 struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> {
    776   static const bool has_errors = true;
    777 };
    778 
    779 template <typename T>
    780 struct maximum : base<T, Eigen::internal::scalar_max_op<T>> {};
    781 
    782 template <typename T>
    783 struct minimum : base<T, Eigen::internal::scalar_min_op<T>> {};
    784 
    785 template <typename T>
    786 struct igamma : base<T, Eigen::internal::scalar_igamma_op<T>> {};
    787 
    788 template <typename T>
    789 struct igammac : base<T, Eigen::internal::scalar_igammac_op<T>> {};
    790 
    791 template <typename T>
    792 struct zeta : base<T, Eigen::internal::scalar_zeta_op<T>> {};
    793 
    794 template <typename T>
    795 struct polygamma : base<T, Eigen::internal::scalar_polygamma_op<T>> {};
    796 
    797 template <typename Scalar>
    798 struct scalar_atan2_op {
    799   EIGEN_EMPTY_STRUCT_CTOR(scalar_atan2_op)
    800   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
    801   operator()(const Scalar& y, const Scalar& x) const {
    802 #if GOOGLE_CUDA
    803     return ::atan2(y, x);
    804 #else
    805     return std::atan2(y, x);
    806 #endif
    807   }
    808 };
    809 
    810 template <typename T>
    811 struct atan2 : base<T, scalar_atan2_op<T>> {};
    812 
    813 template <typename T>
    814 struct squared_difference
    815     : base<T, Eigen::internal::scalar_compose_op<
    816                   T, Eigen::internal::scalar_square_op<T>,
    817                   Eigen::internal::scalar_difference_op<T>>> {};
    818 
    819 template <typename T>
    820 struct less : base<T, Eigen::internal::less<T>, bool> {};
    821 
    822 template <typename T>
    823 struct less_equal : base<T, Eigen::internal::less_equal<T>, bool> {};
    824 
    825 template <typename T>
    826 struct greater : base<T, Eigen::internal::greater<T>, bool> {};
    827 
    828 template <typename T>
    829 struct greater_equal : base<T, Eigen::internal::greater_equal<T>, bool> {};
    830 
    831 template <typename T>
    832 struct equal_to : base<T, Eigen::internal::equal_to<T>, bool> {};
    833 
    834 template <typename T>
    835 struct not_equal_to : base<T, Eigen::internal::not_equal_to<T>, bool> {};
    836 
    837 struct logical_and : base<bool, Eigen::internal::scalar_boolean_and_op> {};
    838 
    839 struct logical_or : base<bool, Eigen::internal::scalar_boolean_or_op> {};
    840 
    841 template <typename T>
    842 struct bitwise_and_op {
    843   EIGEN_EMPTY_STRUCT_CTOR(bitwise_and_op)
    844   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
    845                                                            const T& y) const {
    846     return x & y;
    847   }
    848 };
    849 
    850 template <typename T>
    851 struct bitwise_or_op {
    852   EIGEN_EMPTY_STRUCT_CTOR(bitwise_or_op)
    853   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
    854                                                            const T& y) const {
    855     return x | y;
    856   }
    857 };
    858 
    859 template <typename T>
    860 struct bitwise_and : base<T, bitwise_and_op<T>> {};
    861 
    862 template <typename T>
    863 struct bitwise_or : base<T, bitwise_or_op<T>> {};
    864 
    865 template <typename T>
    866 struct bitwise_xor : base<T, Eigen::internal::bitwise_xor_op<T>> {};
    867 
    868 template <typename T>
    869 struct left_shift_op {
    870   EIGEN_EMPTY_STRUCT_CTOR(left_shift_op)
    871   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
    872                                                            const T& y) const {
    873     // Avoids UB: don't shift by larger than the bitwidth of T, and
    874     // performs left shifts as unsigned shifts.
    875     T y_clamped = y;
    876     if (y_clamped < 0) {
    877       y_clamped = 0;
    878     } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) {
    879       y_clamped = sizeof(T) * CHAR_BIT - 1;
    880     }
    881     using U = typename std::make_unsigned<T>::type;
    882     return static_cast<T>(static_cast<U>(x) << static_cast<U>(y_clamped));
    883   }
    884 };
    885 
    886 template <typename T>
    887 struct right_shift_op {
    888   EIGEN_EMPTY_STRUCT_CTOR(right_shift_op)
    889   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
    890                                                            const T& y) const {
    891     // Avoids UB: don't shift by larger than the bitwidth of T.
    892     T y_clamped = y;
    893     if (y_clamped < 0) {
    894       y_clamped = 0;
    895     } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) {
    896       y_clamped = sizeof(T) * CHAR_BIT - 1;
    897     }
    898     // Technically right shifts of signed integers are not necessarily
    899     // arithmetic shifts according to the C++ standard. However in practice most
    900     // implementations are arithmetic shifts. If this proves to be a problem in
    901     // practice, we may need to use an alternative implementation.
    902     return x >> y_clamped;
    903   }
    904 };
    905 
    906 template <typename T>
    907 struct left_shift : base<T, left_shift_op<T>> {};
    908 
    909 template <typename T>
    910 struct right_shift : base<T, right_shift_op<T>> {};
    911 
    912 template <typename T>
    913 struct make_complex_func {
    914   typedef std::complex<T> result_type;
    915   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(T real,
    916                                                                T imag) const {
    917     return std::complex<T>(real, imag);
    918   }
    919 };
    920 
    921 template <typename T>
    922 struct make_complex : base<T, make_complex_func<T>, std::complex<T>> {};
    923 
    924 template <typename T>
    925 struct get_real
    926     : base<T, Eigen::internal::scalar_real_op<T>, typename T::value_type> {};
    927 
    928 template <typename T>
    929 struct get_imag
    930     : base<T, Eigen::internal::scalar_imag_op<T>, typename T::value_type> {};
    931 
    932 template <typename T>
    933 struct get_angle
    934     : base<T, Eigen::internal::scalar_arg_op<T>, typename T::value_type> {};
    935 
    936 template <typename T>
    937 struct conj : base<T, Eigen::internal::scalar_conjugate_op<T>> {};
    938 
    939 ////////////////////////////////////////////////////////////////////////////////
    940 // Functors takes 1 or 2 tensors, computes the base functor on
    941 // coefficient of the input tensors and puts the results in the output
    942 // tensor.
    943 ////////////////////////////////////////////////////////////////////////////////
    944 template <typename Device, typename Functor>
    945 struct UnaryFunctor {
    946   // Computes on device "d": out[i] = Functor(in[i])
    947   void operator()(const Device& d, typename Functor::tout_type out,
    948                   typename Functor::tin_type in);
    949 };
    950 
    951 template <typename Device, typename Functor, int NDIMS,
    952           bool has_errors = Functor::has_errors>
    953 struct BinaryFunctor {
    954   // Computes on device "d": out[i] = Functor(in0[i], in1[i])
    955   void operator()(const Device& d, typename Functor::tout_type out,
    956                   typename Functor::tin_type in0,
    957                   typename Functor::tin_type in1, bool* error);
    958 
    959   // Computes on device "d": out[i] = Functor(scalar[0], in[i])
    960   void Left(const Device& d, typename Functor::tout_type out,
    961             typename Functor::tscalar_type scalar,
    962             typename Functor::tin_type in, bool* error);
    963 
    964   // Computes on device "d": out[i] = Functor(in[i], scalar[0])
    965   void Right(const Device& d, typename Functor::tout_type out,
    966              typename Functor::tin_type in,
    967              typename Functor::tscalar_type scalar, bool* error);
    968 
    969   // Computes on device "d":
    970   //   out = Functor(in0.broadcast(bcast0), in1.broadcast(bcast1))
    971   //
    972   // TODO(zhifengc): makes BCast a template member function on NDIMS
    973   // instead making BinaryFunctor templates on NDIMS.
    974   void BCast(const Device& d,
    975              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
    976              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
    977              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
    978              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
    979              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
    980              bool* error);
    981 };
    982 
    983 template <typename Device, typename T>
    984 struct ApproximateEqual {
    985   void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
    986                   typename TTypes<T>::ConstFlat y, T tolerance,
    987                   typename TTypes<bool>::Flat z);
    988 };
    989 
    990 template <int NDIMS>
    991 bool AllOne(const typename Eigen::array<Eigen::DenseIndex, NDIMS>& a) {
    992   for (size_t i = 0; i < a.size(); ++i) {
    993     if (a[i] != 1) return false;
    994   }
    995   return true;
    996 }
    997 
    998 template <typename Device, typename T>
    999 struct SelectFunctor {
   1000   void operator()(const Device& d, typename TTypes<T>::Flat out,
   1001                   typename TTypes<bool>::ConstFlat cond_flat,
   1002                   typename TTypes<T>::ConstFlat then_flat,
   1003                   typename TTypes<T>::ConstFlat else_flat);
   1004 };
   1005 
   1006 template <typename Device, typename T>
   1007 struct SelectScalarFunctor {
   1008   void operator()(const Device& d, typename TTypes<T>::Flat out,
   1009                   typename TTypes<bool>::ConstScalar cond,
   1010                   typename TTypes<T>::ConstFlat then_flat,
   1011                   typename TTypes<T>::ConstFlat else_flat);
   1012 };
   1013 
   1014 template <typename Device, typename T>
   1015 struct BatchSelectFunctor {
   1016   void operator()(const Device& d,
   1017                   typename TTypes<T>::Matrix output_flat_outer_dims,
   1018                   TTypes<bool>::ConstVec cond_vec,
   1019                   typename TTypes<T>::ConstMatrix then_flat_outer_dims,
   1020                   typename TTypes<T>::ConstMatrix else_flat_outer_dims);
   1021 };
   1022 
   1023 }  // end namespace functor
   1024 }  // end namespace tensorflow
   1025 
   1026 #endif  // TENSORFLOW_KERNELS_CWISE_OPS_H_
   1027