Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #if GOOGLE_CUDA
     21 #define EIGEN_USE_GPU
     22 #endif  // GOOGLE_CUDA
     23 
     24 #include "tensorflow/core/kernels/matrix_band_part_op.h"
     25 
     26 #include <algorithm>
     27 #include <memory>
     28 #include <vector>
     29 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     30 #include "tensorflow/core/framework/op_kernel.h"
     31 #include "tensorflow/core/framework/register_types.h"
     32 #include "tensorflow/core/framework/tensor.h"
     33 #include "tensorflow/core/framework/tensor_shape.h"
     34 #include "tensorflow/core/framework/tensor_types.h"
     35 #include "tensorflow/core/framework/types.h"
     36 #include "tensorflow/core/lib/core/threadpool.h"
     37 #include "tensorflow/core/platform/logging.h"
     38 #include "tensorflow/core/platform/macros.h"
     39 
     40 namespace tensorflow {
     41 
     42 typedef Eigen::ThreadPoolDevice CPUDevice;
     43 typedef Eigen::GpuDevice GPUDevice;
     44 
     45 template <typename Device, typename T>
     46 class MatrixBandPartOp : public OpKernel {
     47  public:
     48   explicit MatrixBandPartOp(OpKernelConstruction* context)
     49       : OpKernel(context) {}
     50 
     51   void Compute(OpKernelContext* context) override {
     52     const Tensor& input = context->input(0);
     53     const TensorShape& input_shape = input.shape();
     54     // Preliminary validation of sizes.
     55     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
     56                 errors::InvalidArgument(
     57                     "input must be at least 2-dim, received shape: ",
     58                     input.shape().DebugString()));
     59     auto input_reshaped = input.flat_inner_dims<T, 3>();
     60 
     61     const Tensor& num_lower_in = context->input(1);
     62     OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in.shape()),
     63                 errors::InvalidArgument("num_lower must be scalar, got shape ",
     64                                         num_lower_in.shape().DebugString()));
     65 
     66     auto as_int64_scalar = [](const Tensor& tensor) -> int64 {
     67       if (tensor.dtype() == DT_INT32) {
     68         return tensor.scalar<int32>()();
     69       } else {
     70         return tensor.scalar<int64>()();
     71       }
     72     };
     73     const int64 num_lower = as_int64_scalar(num_lower_in);
     74     OP_REQUIRES(
     75         context, num_lower <= input_reshaped.dimension(1),
     76         errors::InvalidArgument(
     77             "num_lower must be negative or less or equal to number of rows (",
     78             input_reshaped.dimension(1), ") got: ", num_lower));
     79 
     80     const Tensor& num_upper_in = context->input(2);
     81     OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in.shape()),
     82                 errors::InvalidArgument("num_upper must be scalar, got shape ",
     83                                         num_upper_in.shape().DebugString()));
     84     const int64 num_upper = as_int64_scalar(num_upper_in);
     85     OP_REQUIRES(context, num_upper <= input_reshaped.dimension(2),
     86                 errors::InvalidArgument("num_upper must be negative or less or "
     87                                         "equal to number of columns (",
     88                                         input_reshaped.dimension(2),
     89                                         ") got: ", num_upper));
     90 
     91     if (input.NumElements() == 0 ||
     92         ((num_lower < 0 || num_lower == input_reshaped.dimension(1)) &&
     93          (num_upper < 0 || num_upper == input_reshaped.dimension(2)))) {
     94       // This is a no-op.
     95       context->set_output(0, input);
     96       return;
     97     }
     98 
     99     Tensor* output = nullptr;
    100     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
    101                                 {0}, 0, input_shape, &output));
    102     auto output_reshaped = output->flat_inner_dims<T, 3>();
    103     functor::MatrixBandPartFunctor<Device, T> fn;
    104     fn(context, context->eigen_device<Device>(), num_lower, num_upper,
    105        input_reshaped, output_reshaped);
    106   }
    107 
    108  private:
    109   TF_DISALLOW_COPY_AND_ASSIGN(MatrixBandPartOp);
    110 };
    111 
    112 #define REGISTER_MATRIX_BAND_PART(type)                                    \
    113   REGISTER_KERNEL_BUILDER(                                                 \
    114       Name("MatrixBandPart").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    115       MatrixBandPartOp<CPUDevice, type>);
    116 TF_CALL_POD_TYPES(REGISTER_MATRIX_BAND_PART);
    117 #undef REGISTER_MATRIX_BAND_PART
    118 
    119 // Registration of the deprecated kernel.
    120 // Delete after 10mar2017.
    121 #define REGISTER_BATCH_MATRIX_BAND_PART(type)             \
    122   REGISTER_KERNEL_BUILDER(Name("BatchMatrixBandPart")     \
    123                               .Device(DEVICE_CPU)         \
    124                               .TypeConstraint<type>("T"), \
    125                           MatrixBandPartOp<CPUDevice, type>);
    126 TF_CALL_NUMBER_TYPES(REGISTER_BATCH_MATRIX_BAND_PART);
    127 #undef REGISTER_BATCH_MATRIX_BAND_PART
    128 
    129 // Implementation of the functor specialization for CPU.
    130 namespace functor {
    131 
    132 // CPU implementation of BandPartFunctor.
    133 typedef Eigen::ThreadPoolDevice CPUDevice;
    134 
    135 template <typename Scalar>
    136 struct MatrixBandPartFunctor<CPUDevice, Scalar> {
    137   void operator()(OpKernelContext* context, const CPUDevice& device,
    138                   int num_lower_diags, int num_upper_diags,
    139                   typename TTypes<Scalar, 3>::ConstTensor input,
    140                   typename TTypes<Scalar, 3>::Tensor output) {
    141     const int64 b = input.dimension(0);
    142     const int64 m = input.dimension(1);
    143     const int64 n = input.dimension(2);
    144     auto thread_pool =
    145         context->device()->tensorflow_cpu_worker_threads()->workers;
    146     const int64 total_rows = b * m;
    147     const int64 row_cost = 10 * n;
    148     const bool in_place = input.data() == output.data();
    149     auto compute_shard = [=, &input, &output](int64 begin, int64 end) {
    150       if (!in_place) {
    151         std::fill(output.data() + begin * n, output.data() + end * n, Scalar());
    152       }
    153       const int64 batch_begin = begin / m;
    154       const int64 batch_end = (end + m - 1) / m;
    155       for (int64 batch = batch_begin; batch < batch_end; ++batch) {
    156         const int64 row_begin = begin > batch * m ? begin % m : 0;
    157         const int64 row_end = end < (batch + 1) * m ? end % m : m;
    158         for (int64 row = row_begin; row < row_end; ++row) {
    159           const int64 band_start =
    160               num_lower_diags < 0
    161                   ? 0
    162                   : std::min(n, std::max(0ll, row - num_lower_diags));
    163           const int64 band_end =
    164               num_upper_diags < 0
    165                   ? n
    166                   : std::min(static_cast<int64>(n), row + num_upper_diags + 1);
    167           if (in_place) {
    168             if (band_start > 0) {
    169               std::fill(&output(batch, row, 0), &output(batch, row, band_start),
    170                         Scalar());
    171             }
    172             if (band_end < n) {
    173               std::fill(&output(batch, row, band_end), &output(batch, row, n),
    174                         Scalar());
    175             }
    176           } else {
    177             if (band_start < band_end) {
    178               const Eigen::DSizes<Eigen::DenseIndex, 3> indices(batch, row,
    179                                                                 band_start);
    180               const Eigen::DSizes<Eigen::DenseIndex, 3> sizes(
    181                   1, 1, band_end - band_start);
    182               output.slice(indices, sizes) = input.slice(indices, sizes);
    183             }
    184           }
    185         }
    186       }
    187     };
    188     thread_pool->ParallelFor(total_rows, row_cost, std::move(compute_shard));
    189   }
    190 };
    191 
    192 #define DEFINE_CPU_SPEC(T) template struct MatrixBandPartFunctor<CPUDevice, T>;
    193 TF_CALL_POD_TYPES(DEFINE_CPU_SPEC);
    194 #undef DEFINE_CPU_SPEC
    195 
    196 }  // namespace functor
    197 
    198 #if GOOGLE_CUDA
    199 
    200 // Forward declarations of the functor specializations for GPU.
    201 namespace functor {
    202 #define DECLARE_GPU_SPEC(T)                                            \
    203   template <>                                                          \
    204   struct MatrixBandPartFunctor<GPUDevice, T> {                         \
    205     void operator()(OpKernelContext* context, const GPUDevice& device, \
    206                     int num_upper_diags, int num_lower_diags,          \
    207                     typename TTypes<T, 3>::ConstTensor input,          \
    208                     typename TTypes<T, 3>::Tensor output);             \
    209   };                                                                   \
    210   extern template struct MatrixBandPartFunctor<GPUDevice, T>;
    211 
    212 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
    213 TF_CALL_bool(DECLARE_GPU_SPEC);
    214 TF_CALL_complex64(DECLARE_GPU_SPEC);
    215 TF_CALL_complex128(DECLARE_GPU_SPEC);
    216 #undef DECLARE_GPU_SPEC
    217 }  // namespace functor
    218 
    219 // Registration of the GPU implementations.
    220 #define REGISTER_MATRIX_BAND_PART_GPU(type)              \
    221   REGISTER_KERNEL_BUILDER(Name("MatrixBandPart")         \
    222                               .Device(DEVICE_GPU)        \
    223                               .TypeConstraint<type>("T") \
    224                               .HostMemory("num_lower")   \
    225                               .HostMemory("num_upper"),  \
    226                           MatrixBandPartOp<GPUDevice, type>);
    227 TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_BAND_PART_GPU);
    228 TF_CALL_bool(REGISTER_MATRIX_BAND_PART_GPU);
    229 TF_CALL_complex64(REGISTER_MATRIX_BAND_PART_GPU);
    230 TF_CALL_complex128(REGISTER_MATRIX_BAND_PART_GPU);
    231 #undef REGISTER_MATRIX_BAND_PART_GPU
    232 
    233 // Registration of the deprecated kernel.
    234 // Delete after 10mar2017.
    235 #define REGISTER_BATCH_MATRIX_BAND_PART_GPU(type)        \
    236   REGISTER_KERNEL_BUILDER(Name("BatchMatrixBandPart")    \
    237                               .Device(DEVICE_GPU)        \
    238                               .TypeConstraint<type>("T") \
    239                               .HostMemory("num_lower")   \
    240                               .HostMemory("num_upper"),  \
    241                           MatrixBandPartOp<GPUDevice, type>);
    242 TF_CALL_GPU_NUMBER_TYPES(REGISTER_BATCH_MATRIX_BAND_PART_GPU);
    243 #undef REGISTER_BATCH_MATRIX_BAND_PART_GPU
    244 
    245 #endif  // GOOGLE_CUDA
    246 
    247 }  // namespace tensorflow
    248