Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/math_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/kernels/cwise_ops.h"
     23 #include "tensorflow/core/kernels/cwise_ops_common.h"
     24 #include "tensorflow/core/kernels/relu_op_functor.h"
     25 
     26 namespace tensorflow {
     27 
     28 template <typename T>
     29 class UnaryOpsComposition;  // forward declare kernel
     30 
     31 template <typename T>
     32 struct UnaryOpsCompositionSupport;
     33 
     34 template <typename T>
     35 struct UnaryOpsCompositionBase {
     36   using InputBuffer = typename TTypes<T>::ConstFlat;
     37   using OutputBuffer = typename TTypes<T>::Flat;
     38 
     39   using ComputeFn = void (*)(const InputBuffer&, OutputBuffer*);
     40 
     41   struct ComputeFnRegistration {
     42     ComputeFn compute_fn;
     43     int cost;
     44   };
     45 
     46   bool HasComputeFn(const string& name) {
     47     return compute_fns.find(name) != compute_fns.end();
     48   }
     49 
     50  protected:
     51   void RegisterComputeFn(const string& name, ComputeFn compute_fn, int cost) {
     52     VLOG(5) << "Register compute fn: name=" << name << " cost=" << cost;
     53     compute_fns[name] = {compute_fn, cost};
     54   }
     55 
     56  private:
     57   friend class UnaryOpsComposition<T>;
     58 
     59   Status ExportComputeFns(const std::vector<string>& op_names,
     60                           std::vector<ComputeFn>* fns, int* cost) {
     61     for (const string& op_name : op_names) {
     62       auto it = compute_fns.find(op_name);
     63       if (it == compute_fns.end())
     64         return errors::InvalidArgument(
     65             "Do not have a compute function registered for op: ", op_name);
     66 
     67       const ComputeFnRegistration& reg = it->second;
     68       fns->push_back(reg.compute_fn);
     69       *cost += reg.cost;
     70     }
     71 
     72     return Status::OK();
     73   }
     74 
     75   std::unordered_map<string, ComputeFnRegistration> compute_fns;
     76 };
     77 
     78 template <typename T>
     79 class UnaryOpsComposition : public OpKernel {
     80  public:
     81   using Kernel = UnaryOpsComposition<T>;
     82 
     83   using Scalar = T;
     84   using Packet = typename Eigen::internal::packet_traits<T>::type;
     85 
     86   using Support = UnaryOpsCompositionSupport<T>;
     87 
     88   using InputBuffer = typename Support::InputBuffer;
     89   using OutputBuffer = typename Support::OutputBuffer;
     90   using ComputeFn = typename Support::ComputeFn;
     91 
     92   explicit UnaryOpsComposition(OpKernelConstruction* context)
     93       : OpKernel(context) {
     94     OP_REQUIRES_OK(context, context->GetAttr("op_names", &op_names_));
     95 
     96     OP_REQUIRES(context, !op_names_.empty(),
     97                 errors::InvalidArgument(
     98                     "Unary op composition must have at least one op"));
     99 
    100     OP_REQUIRES_OK(context,
    101                    support_.ExportComputeFns(op_names_, &fns_, &cost_));
    102 
    103     VLOG(2) << "Composed unary op: [" << str_util::Join(op_names_, ", ")
    104             << "]; cost=" << cost_;
    105   }
    106 
    107   void Compute(OpKernelContext* ctx) override {
    108     const Tensor& in = ctx->input(0);
    109     Tensor* out = nullptr;
    110     OP_REQUIRES_OK(
    111         ctx, ctx->forward_input_or_allocate_output({0}, 0, in.shape(), &out));
    112 
    113     InputBuffer in_flat = in.flat<T>();
    114     OutputBuffer out_flat = out->flat<T>();
    115 
    116     const std::size_t num_fns = fns_.size();
    117     auto compute_fn = [this, &in_flat, &out_flat, &num_fns](int64 begin,
    118                                                             int64 end) {
    119       int64 len = end - begin;
    120       const InputBuffer in_slice(in_flat.data() + begin, len);
    121       const InputBuffer scratch_slice(out_flat.data() + begin, len);
    122       OutputBuffer out_slice(out_flat.data() + begin, len);
    123 
    124       fns_[0](in_slice, &out_slice);
    125       for (int i = 1; i < num_fns; ++i) {
    126         fns_[i](scratch_slice, &out_slice);
    127       }
    128     };
    129 
    130     const CPUDevice& device = ctx->eigen_device<CPUDevice>();
    131     const int kOverheadCycles = static_cast<int>(num_fns) * 10;
    132     Eigen::TensorOpCost cost(/*bytes_loaded=*/sizeof(T) * num_fns,
    133                              /*bytes_stored=*/sizeof(T) * num_fns,
    134                              kOverheadCycles + cost_);
    135     device.parallelFor(in.NumElements(), cost, AlignBlockSize,
    136                        std::move(compute_fn));
    137   }
    138 
    139  private:
    140   static const int kPacketSize = Eigen::internal::unpacket_traits<Packet>::size;
    141 
    142   static inline int64 AlignBlockSize(int64 block_size) {
    143     // Align block size to packet size and account for unrolling in run above.
    144     if (block_size >= 16 * kPacketSize) {
    145       return (block_size + 4 * kPacketSize - 1) & ~(4 * kPacketSize - 1);
    146     }
    147     // Aligning to 4 * PacketSize would increase block size by more than 25%.
    148     return (block_size + kPacketSize - 1) & ~(kPacketSize - 1);
    149   }
    150 
    151   Support support_;
    152 
    153   std::vector<string> op_names_;
    154   std::vector<ComputeFn> fns_;
    155   int cost_ = 0;
    156 };
    157 
    158 // Register compute functions for UnaryOp functors.
    159 #define REGISTER_COMPUTE_FN_HELPER(name, functor)                              \
    160   static_assert(std::is_same<functor::in_type, functor::out_type>::value,      \
    161                 "Functor must have same input and output types");              \
    162                                                                                \
    163   static inline void Compute##name(const InputBuffer& in, OutputBuffer* out) { \
    164     *out = in.unaryExpr(functor::func());                                      \
    165   }                                                                            \
    166   static inline int Cost##name() {                                             \
    167     return Eigen::internal::functor_traits<functor::func>::Cost;               \
    168   }
    169 
    170 // Register compute function for the Relu/Relu6/Elu/Selu.
    171 #define REGISTER_RELU_HELPER()                                                \
    172   template <typename T>                                                       \
    173   using functor_traits = Eigen::internal::functor_traits<T>;                  \
    174                                                                               \
    175   static inline void ComputeRelu(const InputBuffer& in, OutputBuffer* out) {  \
    176     auto relu = functor::Relu<Eigen::DefaultDevice, T>();                     \
    177     relu(Eigen::DefaultDevice(), in, *out);                                   \
    178   }                                                                           \
    179                                                                               \
    180   static inline int CostRelu() {                                              \
    181     return functor_traits<Eigen::internal::scalar_max_op<T>>::Cost;           \
    182   }                                                                           \
    183                                                                               \
    184   static inline void ComputeRelu6(const InputBuffer& in, OutputBuffer* out) { \
    185     auto relu6 = functor::Relu6<Eigen::DefaultDevice, T>();                   \
    186     relu6(Eigen::DefaultDevice(), in, *out);                                  \
    187   }                                                                           \
    188                                                                               \
    189   static inline int CostRelu6() {                                             \
    190     return functor_traits<Eigen::internal::scalar_max_op<T>>::Cost +          \
    191            functor_traits<Eigen::internal::scalar_min_op<T>>::Cost;           \
    192   }                                                                           \
    193   static inline void ComputeElu(const InputBuffer& in, OutputBuffer* out) {   \
    194     auto elu = functor::Elu<Eigen::DefaultDevice, T>();                       \
    195     elu(Eigen::DefaultDevice(), in, *out);                                    \
    196   }                                                                           \
    197                                                                               \
    198   static inline int CostElu() {                                               \
    199     return functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost +          \
    200            Eigen::NumTraits<T>::MulCost;                                      \
    201   }                                                                           \
    202   static inline void ComputeSelu(const InputBuffer& in, OutputBuffer* out) {  \
    203     auto selu = functor::Selu<Eigen::DefaultDevice, T>();                     \
    204     selu(Eigen::DefaultDevice(), in, *out);                                   \
    205   }                                                                           \
    206                                                                               \
    207   static inline int CostSelu() {                                              \
    208     return 2 * (functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost +     \
    209                 Eigen::NumTraits<T>::MulCost);                                \
    210   }
    211 
    212 #define REGISTER_COMPUTE_FN(func) \
    213   RegisterComputeFn(#func, Compute##func, Cost##func());
    214 
    215 template <>
    216 struct UnaryOpsCompositionSupport<float> : UnaryOpsCompositionBase<float> {
    217   using T = float;
    218 
    219   UnaryOpsCompositionSupport() {
    220     // UnaryOp functors.
    221     REGISTER_COMPUTE_FN(Abs);
    222     REGISTER_COMPUTE_FN(Acos);
    223     REGISTER_COMPUTE_FN(Acosh);
    224     REGISTER_COMPUTE_FN(Asin);
    225     REGISTER_COMPUTE_FN(Asinh);
    226     REGISTER_COMPUTE_FN(Atan);
    227     REGISTER_COMPUTE_FN(Atanh);
    228     REGISTER_COMPUTE_FN(Ceil);
    229     REGISTER_COMPUTE_FN(Cos);
    230     REGISTER_COMPUTE_FN(Cosh);
    231     REGISTER_COMPUTE_FN(Expm1);
    232     REGISTER_COMPUTE_FN(Exp);
    233     REGISTER_COMPUTE_FN(Floor);
    234     REGISTER_COMPUTE_FN(Inv);
    235     REGISTER_COMPUTE_FN(Log);
    236     REGISTER_COMPUTE_FN(Log1p);
    237     REGISTER_COMPUTE_FN(Neg);
    238     REGISTER_COMPUTE_FN(Reciprocal);
    239     REGISTER_COMPUTE_FN(Rint);
    240     REGISTER_COMPUTE_FN(Round);
    241     REGISTER_COMPUTE_FN(Rsqrt);
    242     REGISTER_COMPUTE_FN(Sigmoid);
    243     REGISTER_COMPUTE_FN(Sin);
    244     REGISTER_COMPUTE_FN(Sinh);
    245     REGISTER_COMPUTE_FN(Sqrt);
    246     REGISTER_COMPUTE_FN(Square);
    247     REGISTER_COMPUTE_FN(Tan);
    248     REGISTER_COMPUTE_FN(Tanh);
    249 
    250     // Additional compute functions not defined via UnaryOp functors.
    251     REGISTER_COMPUTE_FN(Elu);
    252     REGISTER_COMPUTE_FN(Relu);
    253     REGISTER_COMPUTE_FN(Relu6);
    254     REGISTER_COMPUTE_FN(Selu);
    255   }
    256 
    257   REGISTER_RELU_HELPER();
    258 
    259   // clang-format off
    260   REGISTER_COMPUTE_FN_HELPER(Abs,        functor::abs<T>);
    261   REGISTER_COMPUTE_FN_HELPER(Acos,       functor::acos<T>);
    262   REGISTER_COMPUTE_FN_HELPER(Acosh,      functor::acosh<T>);
    263   REGISTER_COMPUTE_FN_HELPER(Asin,       functor::asin<T>);
    264   REGISTER_COMPUTE_FN_HELPER(Asinh,      functor::asinh<T>);
    265   REGISTER_COMPUTE_FN_HELPER(Atan,       functor::atan<T>);
    266   REGISTER_COMPUTE_FN_HELPER(Atanh,      functor::atanh<T>);
    267   REGISTER_COMPUTE_FN_HELPER(Ceil,       functor::ceil<T>);
    268   REGISTER_COMPUTE_FN_HELPER(Cos,        functor::cos<T>);
    269   REGISTER_COMPUTE_FN_HELPER(Cosh,       functor::cosh<T>);
    270   REGISTER_COMPUTE_FN_HELPER(Expm1,      functor::expm1<T>);
    271   REGISTER_COMPUTE_FN_HELPER(Exp,        functor::exp<T>);
    272   REGISTER_COMPUTE_FN_HELPER(Floor,      functor::floor<T>);
    273   REGISTER_COMPUTE_FN_HELPER(Inv,        functor::inverse<T>);
    274   REGISTER_COMPUTE_FN_HELPER(Log,        functor::log<T>);
    275   REGISTER_COMPUTE_FN_HELPER(Log1p,      functor::log1p<T>);
    276   REGISTER_COMPUTE_FN_HELPER(Neg,        functor::neg<T>);
    277   REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>);
    278   REGISTER_COMPUTE_FN_HELPER(Rint,       functor::rint<T>);
    279   REGISTER_COMPUTE_FN_HELPER(Round,      functor::round<T>);
    280   REGISTER_COMPUTE_FN_HELPER(Rsqrt,      functor::rsqrt<T>);
    281   REGISTER_COMPUTE_FN_HELPER(Sigmoid,    functor::sigmoid<T>);
    282   REGISTER_COMPUTE_FN_HELPER(Sin,        functor::sin<T>);
    283   REGISTER_COMPUTE_FN_HELPER(Sinh,       functor::sinh<T>);
    284   REGISTER_COMPUTE_FN_HELPER(Sqrt,       functor::sqrt<T>);
    285   REGISTER_COMPUTE_FN_HELPER(Square,     functor::square<T>);
    286   REGISTER_COMPUTE_FN_HELPER(Tan,        functor::tan<T>);
    287   REGISTER_COMPUTE_FN_HELPER(Tanh,       functor::tanh<T>);
    288   // clang-format on
    289 };
    290 
    291 template <>
    292 struct UnaryOpsCompositionSupport<Eigen::half>
    293     : UnaryOpsCompositionBase<Eigen::half> {
    294   using T = Eigen::half;
    295 
    296   UnaryOpsCompositionSupport() {
    297     REGISTER_COMPUTE_FN(Abs);
    298     REGISTER_COMPUTE_FN(Ceil);
    299     REGISTER_COMPUTE_FN(Cos);
    300     REGISTER_COMPUTE_FN(Expm1);
    301     REGISTER_COMPUTE_FN(Exp);
    302     REGISTER_COMPUTE_FN(Floor);
    303     REGISTER_COMPUTE_FN(Inv);
    304     REGISTER_COMPUTE_FN(Log);
    305     REGISTER_COMPUTE_FN(Log1p);
    306     REGISTER_COMPUTE_FN(Neg);
    307     REGISTER_COMPUTE_FN(Reciprocal);
    308     REGISTER_COMPUTE_FN(Round);
    309     REGISTER_COMPUTE_FN(Rsqrt);
    310     REGISTER_COMPUTE_FN(Sigmoid);
    311     REGISTER_COMPUTE_FN(Sin);
    312     REGISTER_COMPUTE_FN(Sqrt);
    313     REGISTER_COMPUTE_FN(Square);
    314     REGISTER_COMPUTE_FN(Tanh);
    315     // Additional compute functions not defined via UnaryOp functors.
    316     REGISTER_COMPUTE_FN(Elu);
    317     REGISTER_COMPUTE_FN(Relu);
    318     REGISTER_COMPUTE_FN(Relu6);
    319     REGISTER_COMPUTE_FN(Selu);
    320   }
    321 
    322   REGISTER_RELU_HELPER();
    323 
    324   // clang-format off
    325   REGISTER_COMPUTE_FN_HELPER(Abs,        functor::abs<T>);
    326   REGISTER_COMPUTE_FN_HELPER(Ceil,       functor::ceil<T>);
    327   REGISTER_COMPUTE_FN_HELPER(Cos,        functor::cos<T>);
    328   REGISTER_COMPUTE_FN_HELPER(Expm1,      functor::expm1<T>);
    329   REGISTER_COMPUTE_FN_HELPER(Exp,        functor::exp<T>);
    330   REGISTER_COMPUTE_FN_HELPER(Floor,      functor::floor<T>);
    331   REGISTER_COMPUTE_FN_HELPER(Inv,        functor::inverse<T>);
    332   REGISTER_COMPUTE_FN_HELPER(Log,        functor::log<T>);
    333   REGISTER_COMPUTE_FN_HELPER(Log1p,      functor::log1p<T>);
    334   REGISTER_COMPUTE_FN_HELPER(Neg,        functor::neg<T>);
    335   REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>);
    336   REGISTER_COMPUTE_FN_HELPER(Round,      functor::round<T>);
    337   REGISTER_COMPUTE_FN_HELPER(Rsqrt,      functor::rsqrt<T>);
    338   REGISTER_COMPUTE_FN_HELPER(Sigmoid,    functor::sigmoid<T>);
    339   REGISTER_COMPUTE_FN_HELPER(Sin,        functor::sin<T>);
    340   REGISTER_COMPUTE_FN_HELPER(Sqrt,       functor::sqrt<T>);
    341   REGISTER_COMPUTE_FN_HELPER(Square,     functor::square<T>);
    342   REGISTER_COMPUTE_FN_HELPER(Tanh,       functor::tanh<T>);
    343   // clang-format on
    344 };
    345 
    346 template <>
    347 struct UnaryOpsCompositionSupport<double> : UnaryOpsCompositionBase<double> {
    348   using T = double;
    349 
    350   UnaryOpsCompositionSupport() {
    351     REGISTER_COMPUTE_FN(Abs);
    352     REGISTER_COMPUTE_FN(Acos);
    353     REGISTER_COMPUTE_FN(Acosh);
    354     REGISTER_COMPUTE_FN(Asin);
    355     REGISTER_COMPUTE_FN(Asinh);
    356     REGISTER_COMPUTE_FN(Atan);
    357     REGISTER_COMPUTE_FN(Atanh);
    358     REGISTER_COMPUTE_FN(Ceil);
    359     REGISTER_COMPUTE_FN(Cos);
    360     REGISTER_COMPUTE_FN(Cosh);
    361     REGISTER_COMPUTE_FN(Expm1);
    362     REGISTER_COMPUTE_FN(Exp);
    363     REGISTER_COMPUTE_FN(Floor);
    364     REGISTER_COMPUTE_FN(Inv);
    365     REGISTER_COMPUTE_FN(Log);
    366     REGISTER_COMPUTE_FN(Log1p);
    367     REGISTER_COMPUTE_FN(Neg);
    368     REGISTER_COMPUTE_FN(Reciprocal);
    369     REGISTER_COMPUTE_FN(Rint);
    370     REGISTER_COMPUTE_FN(Round);
    371     REGISTER_COMPUTE_FN(Rsqrt);
    372     REGISTER_COMPUTE_FN(Sigmoid);
    373     REGISTER_COMPUTE_FN(Sin);
    374     REGISTER_COMPUTE_FN(Sinh);
    375     REGISTER_COMPUTE_FN(Sqrt);
    376     REGISTER_COMPUTE_FN(Square);
    377     REGISTER_COMPUTE_FN(Tan);
    378     REGISTER_COMPUTE_FN(Tanh);
    379     // Additional compute functions not defined via UnaryOp functors.
    380     REGISTER_COMPUTE_FN(Elu);
    381     REGISTER_COMPUTE_FN(Relu);
    382     REGISTER_COMPUTE_FN(Relu6);
    383     REGISTER_COMPUTE_FN(Selu);
    384   }
    385 
    386   REGISTER_RELU_HELPER();
    387 
    388   // clang-format off
    389   REGISTER_COMPUTE_FN_HELPER(Abs,        functor::abs<T>);
    390   REGISTER_COMPUTE_FN_HELPER(Acos,       functor::acos<T>);
    391   REGISTER_COMPUTE_FN_HELPER(Acosh,      functor::acosh<T>);
    392   REGISTER_COMPUTE_FN_HELPER(Asin,       functor::asin<T>);
    393   REGISTER_COMPUTE_FN_HELPER(Asinh,      functor::asinh<T>);
    394   REGISTER_COMPUTE_FN_HELPER(Atan,       functor::atan<T>);
    395   REGISTER_COMPUTE_FN_HELPER(Atanh,      functor::atanh<T>);
    396   REGISTER_COMPUTE_FN_HELPER(Ceil,       functor::ceil<T>);
    397   REGISTER_COMPUTE_FN_HELPER(Cos,        functor::cos<T>);
    398   REGISTER_COMPUTE_FN_HELPER(Cosh,       functor::cosh<T>);
    399   REGISTER_COMPUTE_FN_HELPER(Expm1,      functor::expm1<T>);
    400   REGISTER_COMPUTE_FN_HELPER(Exp,        functor::exp<T>);
    401   REGISTER_COMPUTE_FN_HELPER(Floor,      functor::floor<T>);
    402   REGISTER_COMPUTE_FN_HELPER(Inv,        functor::inverse<T>);
    403   REGISTER_COMPUTE_FN_HELPER(Log,        functor::log<T>);
    404   REGISTER_COMPUTE_FN_HELPER(Log1p,      functor::log1p<T>);
    405   REGISTER_COMPUTE_FN_HELPER(Neg,        functor::neg<T>);
    406   REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>);
    407   REGISTER_COMPUTE_FN_HELPER(Rint,       functor::rint<T>);
    408   REGISTER_COMPUTE_FN_HELPER(Round,      functor::round<T>);
    409   REGISTER_COMPUTE_FN_HELPER(Rsqrt,      functor::rsqrt<T>);
    410   REGISTER_COMPUTE_FN_HELPER(Sigmoid,    functor::sigmoid<T>);
    411   REGISTER_COMPUTE_FN_HELPER(Sin,        functor::sin<T>);
    412   REGISTER_COMPUTE_FN_HELPER(Sinh,       functor::sinh<T>);
    413   REGISTER_COMPUTE_FN_HELPER(Sqrt,       functor::sqrt<T>);
    414   REGISTER_COMPUTE_FN_HELPER(Square,     functor::square<T>);
    415   REGISTER_COMPUTE_FN_HELPER(Tan,        functor::tan<T>);
    416   REGISTER_COMPUTE_FN_HELPER(Tanh,       functor::tanh<T>);
    417   // clang-format on
    418 };
    419 
    420 // Register the CPU kernels.
    421 #define REGISTER_CPU(T)                                                       \
    422   REGISTER_KERNEL_BUILDER(                                                    \
    423       Name("_UnaryOpsComposition").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    424       UnaryOpsComposition<T>);
    425 
    426 REGISTER_CPU(float);
    427 REGISTER_CPU(Eigen::half);
    428 REGISTER_CPU(double);
    429 
    430 #undef REGISTER_CPU
    431 
    432 }  // namespace tensorflow
    433