Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 #if GOOGLE_CUDA
     18 #define EIGEN_USE_GPU
     19 #endif  // GOOGLE_CUDA
     20 
     21 #include "tensorflow/core/framework/numeric_op.h"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/types.h"
     26 #include "tensorflow/core/kernels/bounds_check.h"
     27 
     28 #include "third_party/eigen3/Eigen/Core"
     29 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     30 
     31 #include "tensorflow/core/kernels/scan_ops.h"
     32 
     33 namespace tensorflow {
     34 
     35 typedef Eigen::ThreadPoolDevice CPUDevice;
     36 typedef Eigen::GpuDevice GPUDevice;
     37 
     38 template <typename Device, class T, typename Reducer, typename Tidx>
     39 class ScanOp : public OpKernel {
     40  public:
     41   explicit ScanOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
     42     OP_REQUIRES_OK(ctx, ctx->GetAttr("reverse", &reverse_));
     43     OP_REQUIRES_OK(ctx, ctx->GetAttr("exclusive", &exclusive_));
     44   }
     45 
     46   void Compute(OpKernelContext* ctx) override {
     47     const Tensor& input = ctx->input(0);
     48     const Tensor& tensor_axis = ctx->input(1);
     49 
     50     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_axis.shape()),
     51                 errors::InvalidArgument("ScanOp: axis must be a scalar, not ",
     52                                         tensor_axis.shape().DebugString()));
     53 
     54     const Tidx axis_arg =
     55         internal::SubtleMustCopy(tensor_axis.scalar<Tidx>()());
     56     const Tidx axis = (axis_arg < 0) ? input.dims() + axis_arg : axis_arg;
     57     OP_REQUIRES(ctx, FastBoundsCheck(axis, input.dims()),
     58                 errors::InvalidArgument(
     59                     "ScanOp: Expected scan axis in the range [", -input.dims(),
     60                     ", ", input.dims(), "), but got ", axis));
     61 
     62     const TensorShape& output_shape = input.shape();
     63     Tensor* output = nullptr;
     64     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output));
     65 
     66     // Exit early if there's nothing to compute
     67     if (output_shape.num_elements() == 0) return;
     68 
     69     const Device& d = ctx->eigen_device<Device>();
     70     Reducer reducer;
     71 
     72     // Dim reduction.
     73     int64 reduced_shape[3] = {1, 1, 1};
     74     for (Tidx i = 0; i < axis; ++i) {
     75       reduced_shape[0] *= input.dim_size(i);
     76     }
     77     reduced_shape[1] = input.dim_size(axis);
     78     for (Tidx i = axis + 1; i < input.dims(); ++i) {
     79       reduced_shape[2] *= input.dim_size(i);
     80     }
     81 
     82     functor::Scan<Device, Reducer, T>()(d, input.shaped<T, 3>(reduced_shape),
     83                                         output->shaped<T, 3>(reduced_shape),
     84                                         reducer, reverse_, exclusive_);
     85   }
     86 
     87  private:
     88   bool reverse_;
     89   bool exclusive_;
     90 };
     91 
     92 #ifdef GOOGLE_CUDA
     93 namespace functor {
     94 
     95 // Forward declarations of GPU functors
     96 #define DECLARE(REDUCER, T)                                                 \
     97   template <>                                                               \
     98   void Scan<GPUDevice, REDUCER, T>::operator()(                             \
     99       const GPUDevice& d, TTypes<T, 3>::ConstTensor in,                     \
    100       TTypes<T, 3>::Tensor out, const REDUCER& reducer, const bool reverse, \
    101       const bool exclusive);                                                \
    102   extern template struct Scan<GPUDevice, REDUCER, T>;
    103 
    104 #define DECLARE_FOR_ALL_REDUCERS(T)           \
    105   DECLARE(Eigen::internal::SumReducer<T>, T); \
    106   DECLARE(Eigen::internal::ProdReducer<T>, T);
    107 
    108 TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_ALL_REDUCERS);
    109 
    110 #undef DECLARE_FOR_ALL_REDUCERS
    111 #undef DECLARE
    112 
    113 }  // namespace functor
    114 #endif  // GOOGLE_CUDA
    115 
    116 // Register Cumsum kernels
    117 #define REGISTER_CPU_KERNELS(type)                                       \
    118   REGISTER_KERNEL_BUILDER(                                               \
    119       Name("Cumsum")                                                     \
    120           .Device(DEVICE_CPU)                                            \
    121           .TypeConstraint<type>("T")                                     \
    122           .TypeConstraint<int32>("Tidx"),                                \
    123       ScanOp<CPUDevice, type, Eigen::internal::SumReducer<type>, int32>) \
    124   REGISTER_KERNEL_BUILDER(                                               \
    125       Name("Cumsum")                                                     \
    126           .Device(DEVICE_CPU)                                            \
    127           .TypeConstraint<type>("T")                                     \
    128           .TypeConstraint<int64>("Tidx"),                                \
    129       ScanOp<CPUDevice, type, Eigen::internal::SumReducer<type>, int64>)
    130 TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
    131 #undef REGISTER_CPU_KERNELS
    132 
    133 #if GOOGLE_CUDA
    134 #define REGISTER_GPU_KERNELS(type)                                       \
    135   REGISTER_KERNEL_BUILDER(                                               \
    136       Name("Cumsum")                                                     \
    137           .Device(DEVICE_GPU)                                            \
    138           .TypeConstraint<type>("T")                                     \
    139           .TypeConstraint<int32>("Tidx")                                 \
    140           .HostMemory("axis"),                                           \
    141       ScanOp<GPUDevice, type, Eigen::internal::SumReducer<type>, int32>) \
    142   REGISTER_KERNEL_BUILDER(                                               \
    143       Name("Cumsum")                                                     \
    144           .Device(DEVICE_GPU)                                            \
    145           .TypeConstraint<type>("T")                                     \
    146           .TypeConstraint<int64>("Tidx")                                 \
    147           .HostMemory("axis"),                                           \
    148       ScanOp<GPUDevice, type, Eigen::internal::SumReducer<type>, int64>)
    149 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS)
    150 #undef REGISTER_GPU_KERNELS
    151 #endif  // GOOGLE_CUDA
    152 
    153 // Register Cumprod kernels
    154 #define REGISTER_CPU_KERNELS(type)                                        \
    155   REGISTER_KERNEL_BUILDER(                                                \
    156       Name("Cumprod")                                                     \
    157           .Device(DEVICE_CPU)                                             \
    158           .TypeConstraint<type>("T")                                      \
    159           .TypeConstraint<int32>("Tidx"),                                 \
    160       ScanOp<CPUDevice, type, Eigen::internal::ProdReducer<type>, int32>) \
    161   REGISTER_KERNEL_BUILDER(                                                \
    162       Name("Cumprod")                                                     \
    163           .Device(DEVICE_CPU)                                             \
    164           .TypeConstraint<type>("T")                                      \
    165           .TypeConstraint<int64>("Tidx"),                                 \
    166       ScanOp<CPUDevice, type, Eigen::internal::ProdReducer<type>, int64>)
    167 TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
    168 #undef REGISTER_CPU_KERNELS
    169 
    170 #if GOOGLE_CUDA
    171 #define REGISTER_GPU_KERNELS(type)                                        \
    172   REGISTER_KERNEL_BUILDER(                                                \
    173       Name("Cumprod")                                                     \
    174           .Device(DEVICE_GPU)                                             \
    175           .TypeConstraint<type>("T")                                      \
    176           .TypeConstraint<int32>("Tidx")                                  \
    177           .HostMemory("axis"),                                            \
    178       ScanOp<GPUDevice, type, Eigen::internal::ProdReducer<type>, int32>) \
    179   REGISTER_KERNEL_BUILDER(                                                \
    180       Name("Cumprod")                                                     \
    181           .Device(DEVICE_GPU)                                             \
    182           .TypeConstraint<type>("T")                                      \
    183           .TypeConstraint<int64>("Tidx")                                  \
    184           .HostMemory("axis"),                                            \
    185       ScanOp<GPUDevice, type, Eigen::internal::ProdReducer<type>, int64>)
    186 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS)
    187 #undef REGISTER_GPU_KERNELS
    188 #endif  // GOOGLE_CUDA
    189 
    190 }  // namespace tensorflow
    191