Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_
     17 #define TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_
     18 
     19 // See docs in ../ops/math_ops.cc.
     20 
     21 #define EIGEN_USE_THREADS
     22 
     23 #ifdef TENSORFLOW_USE_SYCL
     24 #include "tensorflow/core/kernels/cwise_ops_sycl_common.h"
     25 #endif
     26 
     27 #include "tensorflow/core/kernels/cwise_ops.h"
     28 #include "tensorflow/core/kernels/cwise_ops_gradients.h"
     29 
     30 #include "tensorflow/core/framework/op.h"
     31 #include "tensorflow/core/framework/op_kernel.h"
     32 #include "tensorflow/core/framework/tensor_types.h"
     33 #include "tensorflow/core/framework/variant_op_registry.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/util/bcast.h"
     36 
     37 namespace tensorflow {
     38 
     39 typedef Eigen::ThreadPoolDevice CPUDevice;
     40 typedef Eigen::GpuDevice GPUDevice;
     41 #ifdef TENSORFLOW_USE_SYCL
     42 typedef Eigen::SyclDevice SYCLDevice;
     43 #endif
     44 
     45 class BinaryOpShared : public OpKernel {
     46  public:
     47   explicit BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in);
     48 
     49  protected:
     50   struct BinaryOpState {
     51     // Sets up bcast with the shape of in0 and in1, ensures that the bcast
     52     // is valid, and if so, set out, either by allocating a new buffer using
     53     // ctx->output(...) or by creating an alias for an owned input buffer for
     54     // in-place computation.
     55     // Caller must check ctx->status() upon return for non-ok status.
     56     // If ctx->status().ok() is true, then out is guaranteed to be allocated.
     57     BinaryOpState(OpKernelContext* ctx);
     58 
     59     const Tensor& in0;
     60     const Tensor& in1;
     61 
     62     BCast bcast;
     63     Tensor* out = nullptr;
     64     int64 out_num_elements;
     65 
     66     int64 in0_num_elements;
     67     int64 in1_num_elements;
     68 
     69     int ndims;
     70   };
     71 
     72   void SetUnimplementedError(OpKernelContext* ctx);
     73   void SetComputeError(OpKernelContext* ctx);
     74 };
     75 
     76 // Coefficient-wise binary operations:
     77 //   Device: E.g., CPUDevice, GPUDevice.
     78 //   Functor: defined in cwise_ops.h. E.g., functor::add.
     79 template <typename Device, typename Functor>
     80 class BinaryOp : public BinaryOpShared {
     81  public:
     82   typedef typename Functor::in_type Tin;    // Input scalar data type.
     83   typedef typename Functor::out_type Tout;  // Output scalar data type.
     84 
     85   explicit BinaryOp(OpKernelConstruction* ctx)
     86       : BinaryOpShared(ctx, DataTypeToEnum<Tout>::v(),
     87                        DataTypeToEnum<Tin>::v()) {}
     88 
     89   void Compute(OpKernelContext* ctx) override {
     90     // 'state': Shared helper not dependent on T to reduce code size
     91     BinaryOpState state(ctx);
     92     if (!ctx->status().ok()) return;
     93     Tensor* out = state.out;
     94     BCast* bcast = &state.bcast;
     95     auto& in0 = state.in0;
     96     auto& in1 = state.in1;
     97     if (state.out_num_elements == 0) {
     98       return;
     99     }
    100     const int ndims = state.ndims;
    101     const Device& eigen_device = ctx->eigen_device<Device>();
    102     bool error = false;
    103     bool* const error_ptr = Functor::has_errors ? &error : nullptr;
    104     if (ndims <= 1) {
    105       auto out_flat = out->flat<Tout>();
    106       if (state.in1_num_elements == 1) {
    107         // tensor op scalar
    108         functor::BinaryFunctor<Device, Functor, 1>().Right(
    109             eigen_device, out_flat, in0.template flat<Tin>(),
    110             in1.template scalar<Tin>(), error_ptr);
    111       } else if (state.in0_num_elements == 1) {
    112         // scalar op tensor
    113         functor::BinaryFunctor<Device, Functor, 1>().Left(
    114             eigen_device, out_flat, in0.template scalar<Tin>(),
    115             in1.template flat<Tin>(), error_ptr);
    116       } else {
    117         functor::BinaryFunctor<Device, Functor, 1>()(
    118             eigen_device, out_flat, in0.template flat<Tin>(),
    119             in1.template flat<Tin>(), error_ptr);
    120       }
    121     } else if (ndims == 2) {
    122       functor::BinaryFunctor<Device, Functor, 2>().BCast(
    123           eigen_device, out->shaped<Tout, 2>(bcast->result_shape()),
    124           in0.template shaped<Tin, 2>(bcast->x_reshape()),
    125           BCast::ToIndexArray<2>(bcast->x_bcast()),
    126           in1.template shaped<Tin, 2>(bcast->y_reshape()),
    127           BCast::ToIndexArray<2>(bcast->y_bcast()), error_ptr);
    128     } else if (ndims == 3) {
    129       functor::BinaryFunctor<Device, Functor, 3>().BCast(
    130           eigen_device, out->shaped<Tout, 3>(bcast->result_shape()),
    131           in0.template shaped<Tin, 3>(bcast->x_reshape()),
    132           BCast::ToIndexArray<3>(bcast->x_bcast()),
    133           in1.template shaped<Tin, 3>(bcast->y_reshape()),
    134           BCast::ToIndexArray<3>(bcast->y_bcast()), error_ptr);
    135     } else if (ndims == 4) {
    136       functor::BinaryFunctor<Device, Functor, 4>().BCast(
    137           eigen_device, out->shaped<Tout, 4>(bcast->result_shape()),
    138           in0.template shaped<Tin, 4>(bcast->x_reshape()),
    139           BCast::ToIndexArray<4>(bcast->x_bcast()),
    140           in1.template shaped<Tin, 4>(bcast->y_reshape()),
    141           BCast::ToIndexArray<4>(bcast->y_bcast()), error_ptr);
    142     } else if (ndims == 5) {
    143       functor::BinaryFunctor<Device, Functor, 5>().BCast(
    144           eigen_device, out->shaped<Tout, 5>(bcast->result_shape()),
    145           in0.template shaped<Tin, 5>(bcast->x_reshape()),
    146           BCast::ToIndexArray<5>(bcast->x_bcast()),
    147           in1.template shaped<Tin, 5>(bcast->y_reshape()),
    148           BCast::ToIndexArray<5>(bcast->y_bcast()), error_ptr);
    149     } else {
    150       SetUnimplementedError(ctx);
    151     }
    152     if (Functor::has_errors && error) {
    153       SetComputeError(ctx);
    154     }
    155   }
    156 };
    157 
    158 template <typename Device, typename T>
    159 class ApproximateEqualOp : public OpKernel {
    160  public:
    161   explicit ApproximateEqualOp(OpKernelConstruction* context)
    162       : OpKernel(context) {
    163     float tolerance;
    164     OP_REQUIRES_OK(context, context->GetAttr("tolerance", &tolerance));
    165     tolerance_ = T(tolerance);
    166   }
    167   void Compute(OpKernelContext* context) override {
    168     const Tensor& x_input = context->input(0);
    169     const Tensor& y_input = context->input(1);
    170     OP_REQUIRES(
    171         context, x_input.shape() == y_input.shape(),
    172         errors::InvalidArgument("x and y must be of the same shape. ",
    173                                 "x shape: ", x_input.shape().DebugString(),
    174                                 ". y shape: ", y_input.shape().DebugString()));
    175     Tensor* z_output = nullptr;
    176     OP_REQUIRES_OK(context,
    177                    context->allocate_output(0, x_input.shape(), &z_output));
    178     const Device& d = context->eigen_device<Device>();
    179     typename TTypes<T>::ConstFlat x(x_input.flat<T>());
    180     typename TTypes<T>::ConstFlat y(y_input.flat<T>());
    181     typename TTypes<bool>::Flat z(z_output->flat<bool>());
    182     functor::ApproximateEqual<Device, T>()(d, x, y, tolerance_, z);
    183   }
    184 
    185  private:
    186   T tolerance_;
    187 };
    188 
    189 // Basic coefficient-wise binary operations that are known to not require
    190 // any broadcasting. This is the case for example of the gradients of
    191 // unary operations.
    192 //   Device: E.g., CPUDevice, GPUDevice.
    193 //   Functor: defined above. E.g., functor::tanh_grad.
    194 template <typename Device, typename Functor>
    195 class SimpleBinaryOp : public OpKernel {
    196  public:
    197   typedef typename Functor::in_type Tin;    // Input scalar data type.
    198   typedef typename Functor::out_type Tout;  // Output scalar data type.
    199 
    200   explicit SimpleBinaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    201 
    202   void Compute(OpKernelContext* ctx) override {
    203     const Tensor& in0 = ctx->input(0);
    204     const Tensor& in1 = ctx->input(1);
    205     auto in0_flat = in0.flat<Tin>();
    206     auto in1_flat = in1.flat<Tin>();
    207     const Device& eigen_device = ctx->eigen_device<Device>();
    208 
    209     Tensor* out = nullptr;
    210     if (std::is_same<Tin, Tout>::value) {
    211       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
    212                               {0, 1}, 0, in0.shape(), &out));
    213     } else {
    214       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out));
    215     }
    216     auto out_flat = out->flat<Tout>();
    217     functor::SimpleBinaryFunctor<Device, Functor>()(eigen_device, out_flat,
    218                                                     in0_flat, in1_flat);
    219   }
    220 };
    221 
    222 // Coefficient-wise unary operations:
    223 //   Device: E.g., CPUDevice, GPUDevice.
    224 //   Functor: defined in cwise_ops.h. E.g., functor::sqrt.
    225 template <typename Device, typename Functor>
    226 class UnaryOp : public OpKernel {
    227  public:
    228   typedef typename Functor::in_type Tin;    // Input scalar data type.
    229   typedef typename Functor::out_type Tout;  // Output scalar data type.
    230   // Tin may be different from Tout. E.g., abs: complex64 -> float
    231 
    232   explicit UnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    233     auto in = DataTypeToEnum<Tin>::v();
    234     auto out = DataTypeToEnum<Tout>::v();
    235     OP_REQUIRES_OK(ctx, ctx->MatchSignature({in}, {out}));
    236   }
    237 
    238   void Compute(OpKernelContext* ctx) override {
    239     const Tensor& inp = ctx->input(0);
    240     Tensor* out = nullptr;
    241     if (std::is_same<Tin, Tout>::value) {
    242       OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
    243                               {0}, 0, inp.shape(), &out));
    244     } else {
    245       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
    246     }
    247     functor::UnaryFunctor<Device, Functor>()(
    248         ctx->eigen_device<Device>(), out->flat<Tout>(), inp.flat<Tin>());
    249   }
    250 };
    251 
    252 template <typename Device, VariantUnaryOp OpEnum>
    253 class UnaryVariantOp : public OpKernel {
    254  public:
    255   explicit UnaryVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    256 
    257   void Compute(OpKernelContext* ctx) override {
    258     const Tensor& inp = ctx->input(0);
    259     OP_REQUIRES(
    260         ctx, TensorShapeUtils::IsScalar(inp.shape()),
    261         errors::InvalidArgument("Non-scalar variants are not supported."));
    262     const Variant& v = inp.scalar<Variant>()();
    263     Variant v_out;
    264     OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(ctx, OpEnum, v, &v_out));
    265     Tensor out(cpu_allocator(), DT_VARIANT, TensorShape());
    266     out.scalar<Variant>()() = std::move(v_out);
    267     ctx->set_output(0, std::move(out));
    268   }
    269 };
    270 
    271 namespace functor {
    272 
    273 template <typename D, typename Out, typename Rhs>
    274 void Assign(const D& d, Out out, Rhs rhs) {
    275   out.device(d) = rhs;
    276 }
    277 
    278 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, NDIMS>
    279 // for functors with with no error checking.
    280 template <typename Functor, int NDIMS>
    281 struct BinaryFunctor<CPUDevice, Functor, NDIMS, false> {
    282   void operator()(const CPUDevice& d, typename Functor::tout_type out,
    283                   typename Functor::tin_type in0,
    284                   typename Functor::tin_type in1, bool* error) {
    285     Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
    286   }
    287 
    288   void Left(const CPUDevice& d, typename Functor::tout_type out,
    289             typename Functor::tscalar_type scalar,
    290             typename Functor::tin_type in, bool* error) {
    291     typedef typename Functor::out_type Tout;
    292     typedef typename Functor::in_type Tin;
    293     typedef typename Functor::func Binary;
    294     typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary;
    295     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
    296   }
    297 
    298   void Right(const CPUDevice& d, typename Functor::tout_type out,
    299              typename Functor::tin_type in,
    300              typename Functor::tscalar_type scalar, bool* error) {
    301     typedef typename Functor::out_type Tout;
    302     typedef typename Functor::in_type Tin;
    303     typedef typename Functor::func Binary;
    304     typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary;
    305     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
    306   }
    307 
    308   void BCast(const CPUDevice& dev,
    309              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
    310              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
    311              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
    312              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
    313              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
    314              bool* error) {
    315     typename Functor::func func;
    316     if (AllOne<NDIMS>(bcast0) && AllOne<NDIMS>(bcast1)) {
    317       Assign(dev, out, in0.binaryExpr(in1, func));
    318     } else if (AllOne<NDIMS>(bcast0)) {
    319       auto rhs = in1.broadcast(bcast1);
    320       Assign(dev, out, in0.binaryExpr(rhs, func));
    321     } else if (AllOne<NDIMS>(bcast1)) {
    322       auto lhs = in0.broadcast(bcast0);
    323       Assign(dev, out, lhs.binaryExpr(in1, func));
    324     } else {
    325       auto lhs = in0.broadcast(bcast0);
    326       auto rhs = in1.broadcast(bcast1);
    327       Assign(dev, out, lhs.binaryExpr(rhs, func));
    328     }
    329   }
    330 };
    331 
    332 // Partial specialization of BinaryFunctor<Device=CPUDevice, Functor, 2>
    333 // for functors with with no error checking.
    334 template <typename Functor>
    335 struct BinaryFunctor<CPUDevice, Functor, 2, false> {
    336   enum { NDIMS = 2 };
    337 
    338   void operator()(const CPUDevice& d, typename Functor::tout_type out,
    339                   typename Functor::tin_type in0,
    340                   typename Functor::tin_type in1, bool* error) {
    341     Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
    342   }
    343 
    344   void Left(const CPUDevice& d, typename Functor::tout_type out,
    345             typename Functor::tscalar_type scalar,
    346             typename Functor::tin_type in, bool* error) {
    347     typedef typename Functor::out_type Tout;
    348     typedef typename Functor::in_type Tin;
    349     typedef typename Functor::func Binary;
    350     typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary;
    351     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
    352   }
    353 
    354   void Right(const CPUDevice& d, typename Functor::tout_type out,
    355              typename Functor::tin_type in,
    356              typename Functor::tscalar_type scalar, bool* error) {
    357     typedef typename Functor::out_type Tout;
    358     typedef typename Functor::in_type Tin;
    359     typedef typename Functor::func Binary;
    360     typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary;
    361     Assign(d, out, in.unaryExpr(Unary(scalar.data())));
    362   }
    363 
    364 #if !defined(EIGEN_HAS_INDEX_LIST)
    365   inline Eigen::DSizes<int, 2> NByOne(int n) {
    366     return Eigen::DSizes<int, 2>(n, 1);
    367   }
    368   inline Eigen::DSizes<int, 2> OneByM(int m) {
    369     return Eigen::DSizes<int, 2>(1, m);
    370   }
    371 #else
    372   inline Eigen::IndexList<int, Eigen::type2index<1>> NByOne(int n) {
    373     Eigen::IndexList<int, Eigen::type2index<1>> ret;
    374     ret.set(0, n);
    375     return ret;
    376   }
    377   inline Eigen::IndexList<Eigen::type2index<1>, int> OneByM(int m) {
    378     Eigen::IndexList<Eigen::type2index<1>, int> ret;
    379     ret.set(1, m);
    380     return ret;
    381   }
    382 #endif
    383 
    384   void BCast(const CPUDevice& dev,
    385              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
    386              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
    387              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
    388              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
    389              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
    390              bool* error) {
    391     typedef typename Functor::in_type T;
    392     typename Functor::func func;
    393     if (Functor::use_bcast_optimization && use_bcast_optimization<T>::value) {
    394       // Optimize for speed by using Eigen::type2index and avoid
    395       // .broadcast() when we know its a no-op.
    396       //
    397       // Here, we need to handle 6 cases depending on how many "1"
    398       // exist in in0 and in1's shapes (4 numbers in total). It's not
    399       // possible that two shapes have more than 2 1s because those
    400       // are simplified to NDIMS==1 case.
    401       //
    402       // Because this optimization increases the binary size for each
    403       // Functor (+, -, *, /, <, <=, etc.), type and ndim combination.
    404       // we only apply such optimization for selected ops/types/ndims.
    405       //
    406       // Because NDIMS, Functor::use_broadcast_optimization and
    407       // use_broadcast_optimization<T> are compile-time constant, gcc
    408       // does a decent job avoiding generating code when conditions
    409       // are not met.
    410       const int a = in0.dimension(0);  // in0 is shape [a, b]
    411       const int b = in0.dimension(1);
    412       const int c = in1.dimension(0);  // in1 is shape [c, d]
    413       const int d = in1.dimension(1);
    414       if ((a == 1) && (d == 1)) {
    415         auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
    416         auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
    417         Assign(dev, out, lhs.binaryExpr(rhs, func));
    418         return;
    419       }
    420       if ((b == 1) && (c == 1)) {
    421         auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
    422         auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
    423         Assign(dev, out, lhs.binaryExpr(rhs, func));
    424         return;
    425       }
    426       if (a == 1) {
    427         auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
    428         auto rhs = in1;
    429         Assign(dev, out, lhs.binaryExpr(rhs, func));
    430         return;
    431       }
    432       if (b == 1) {
    433         auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
    434         auto rhs = in1;
    435         Assign(dev, out, lhs.binaryExpr(rhs, func));
    436         return;
    437       }
    438       if (c == 1) {
    439         auto lhs = in0;
    440         auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
    441         Assign(dev, out, lhs.binaryExpr(rhs, func));
    442         return;
    443       }
    444       if (d == 1) {
    445         auto lhs = in0;
    446         auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
    447         Assign(dev, out, lhs.binaryExpr(rhs, func));
    448         return;
    449       }
    450 
    451       const bool bcast0_all_one = AllOne<NDIMS>(bcast0);
    452       const bool bcast1_all_one = AllOne<NDIMS>(bcast1);
    453       if (bcast0_all_one && !bcast1_all_one) {
    454         auto lhs = in0;  // No need to do broadcast for in0
    455         auto rhs = in1.broadcast(bcast1);
    456         Assign(dev, out, lhs.binaryExpr(rhs, func));
    457         return;
    458       }
    459 
    460       if (!bcast0_all_one && bcast1_all_one) {
    461         auto lhs = in0.broadcast(bcast0);
    462         auto rhs = in1;  // No need to do broadcast for in1
    463         Assign(dev, out, lhs.binaryExpr(rhs, func));
    464         return;
    465       }
    466     }
    467 
    468     // Fallback path. Always works and probably slower.
    469     auto lhs = in0.broadcast(bcast0);
    470     auto rhs = in1.broadcast(bcast1);
    471     Assign(dev, out, lhs.binaryExpr(rhs, func));
    472   }
    473 };
    474 
    475 // Version of BinaryFunctor with error handling.
    476 template <typename Functor, int NDIMS>
    477 struct BinaryFunctor<CPUDevice, Functor, NDIMS, true> {
    478   void operator()(const CPUDevice& d, typename Functor::tout_type out,
    479                   typename Functor::tin_type in0,
    480                   typename Functor::tin_type in1, bool* error) {
    481     Assign(d, out, in0.binaryExpr(in1, typename Functor::func(error)));
    482   }
    483 
    484   void Left(const CPUDevice& d, typename Functor::tout_type out,
    485             typename Functor::tscalar_type scalar,
    486             typename Functor::tin_type in, bool* error) {
    487     typedef typename Functor::out_type Tout;
    488     typedef typename Functor::in_type Tin;
    489     typedef typename Functor::func Binary;
    490     typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary;
    491     Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
    492   }
    493 
    494   void Right(const CPUDevice& d, typename Functor::tout_type out,
    495              typename Functor::tin_type in,
    496              typename Functor::tscalar_type scalar, bool* error) {
    497     typedef typename Functor::out_type Tout;
    498     typedef typename Functor::in_type Tin;
    499     typedef typename Functor::func Binary;
    500     typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary;
    501     Assign(d, out, in.unaryExpr(Unary(scalar.data(), error)));
    502   }
    503 
    504   void BCast(const CPUDevice& dev,
    505              typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
    506              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
    507              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
    508              typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
    509              typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
    510              bool* error) {
    511     typename Functor::func func(error);
    512     auto lhs = in0.broadcast(bcast0);
    513     auto rhs = in1.broadcast(bcast1);
    514     Assign(dev, out, lhs.binaryExpr(rhs, func));
    515   }
    516 };
    517 
    518 // Partial specialization of UnaryFunctor<Device=CPUDevice, Functor>.
    519 template <typename Functor>
    520 struct UnaryFunctor<CPUDevice, Functor> {
    521   void operator()(const CPUDevice& d, typename Functor::tout_type out,
    522                   typename Functor::tin_type in) {
    523     Assign(d, out, in.unaryExpr(typename Functor::func()));
    524   }
    525 };
    526 
    527 // Partial specialization of ApproximateEqual<Device=CPUDevice, T>.
    528 template <typename T>
    529 struct ApproximateEqual<CPUDevice, T> {
    530   void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat x,
    531                   typename TTypes<T>::ConstFlat y, T tolerance,
    532                   typename TTypes<bool>::Flat z) {
    533     auto diff = x - y;
    534     z.device(d) = diff.abs() <= tolerance;
    535   }
    536 };
    537 
    538 }  // end namespace functor
    539 
    540 #define REGISTER(OP, D, N, F, T)                                             \
    541   REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint<T>("T"), \
    542                           OP<D##Device, F<T>>);
    543 
    544 #define REGISTER_VARIANT(OP, D, N, ENUM)                       \
    545   REGISTER_KERNEL_BUILDER(                                     \
    546       Name(N).Device(DEVICE_##D).TypeConstraint<Variant>("T"), \
    547       OP<D##Device, ENUM>);
    548 
    549 // Macros to register kernels for multiple types (T0, T1, etc.)  on
    550 // device type "D" (CPU or GPU) for operation "N" (e.g., sqrt) using
    551 // the functor "F" (e.g., functor::sqrt).
    552 
    553 #if defined(__ANDROID_TYPES_SLIM__)
    554 // Note that __ANDROID_TYPES_SLIM__ is also checked in the cwise_ops*.cc files.
    555 // Normally Android TensorFlow is built with a reduced number of types (float).
    556 // Override on the command-line using "--copt=-D__ANDROID_TYPES_FULL__"
    557 // to generate a library with full type support with a consequent increase in
    558 // code size.
    559 #define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0)
    560 #define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0)
    561 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0)
    562 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) REGISTER(OP, D, N, F, T0)
    563 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0)
    564 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
    565   REGISTER(OP, D, N, F, T0)
    566 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
    567   REGISTER(OP, D, N, F, T0)
    568 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
    569   REGISTER(OP, D, N, F, T0)
    570 #else  // !defined(__ANDROID_TYPES_SLIM__)
    571 #define REGISTER2(OP, D, N, F, T0, T1) \
    572   REGISTER(OP, D, N, F, T0)            \
    573   REGISTER(OP, D, N, F, T1)
    574 #define REGISTER3(OP, D, N, F, T0, T1, T2) \
    575   REGISTER2(OP, D, N, F, T0, T1)           \
    576   REGISTER(OP, D, N, F, T2)
    577 #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
    578   REGISTER2(OP, D, N, F, T0, T1)               \
    579   REGISTER2(OP, D, N, F, T2, T3)
    580 #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \
    581   REGISTER3(OP, D, N, F, T0, T1, T2)               \
    582   REGISTER2(OP, D, N, F, T3, T4)
    583 #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) \
    584   REGISTER3(OP, D, N, F, T0, T1, T2)                   \
    585   REGISTER3(OP, D, N, F, T3, T4, T5)
    586 #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
    587   REGISTER4(OP, D, N, F, T0, T1, T2, T3)                   \
    588   REGISTER3(OP, D, N, F, T4, T5, T6)
    589 #define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \
    590   REGISTER4(OP, D, N, F, T0, T1, T2, T3)                       \
    591   REGISTER4(OP, D, N, F, T4, T5, T6, T7)
    592 #define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \
    593   REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4)                       \
    594   REGISTER4(OP, D, N, F, T5, T6, T7, T8)
    595 
    596 // Instead of adding REGISTER10, etc., shard the .cc files - see
    597 // cwise_op_equal_to_*.cc for an example.
    598 
    599 #endif  // defined(__ANDROID_TYPES_SLIM__)
    600 
    601 }  // end namespace tensorflow
    602 
    603 #endif  // TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_
    604