Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #if GOOGLE_CUDA
     21 #define EIGEN_USE_GPU
     22 #endif  // GOOGLE_CUDA
     23 
     24 #include "tensorflow/core/kernels/matrix_set_diag_op.h"
     25 
     26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     27 #include "tensorflow/core/framework/op_kernel.h"
     28 #include "tensorflow/core/framework/register_types.h"
     29 #include "tensorflow/core/framework/tensor.h"
     30 #include "tensorflow/core/framework/tensor_shape.h"
     31 #include "tensorflow/core/framework/tensor_types.h"
     32 #include "tensorflow/core/framework/types.h"
     33 #include "tensorflow/core/lib/core/threadpool.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/macros.h"
     36 
     37 namespace tensorflow {
     38 
     39 typedef Eigen::ThreadPoolDevice CPUDevice;
     40 typedef Eigen::GpuDevice GPUDevice;
     41 
     42 template <typename Device, typename T>
     43 class MatrixSetDiagOp : public OpKernel {
     44  public:
     45   explicit MatrixSetDiagOp(OpKernelConstruction* context) : OpKernel(context) {}
     46 
     47   void Compute(OpKernelContext* context) override {
     48     const Tensor& input = context->input(0);
     49     const Tensor& diag = context->input(1);
     50 
     51     const TensorShape& input_shape = input.shape();
     52     const TensorShape& diag_shape = diag.shape();
     53     const int rank = input_shape.dims();
     54 
     55     // Preliminary validation of sizes.
     56     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
     57                 errors::InvalidArgument(
     58                     "input must be at least 2-dim, received shape: ",
     59                     input.shape().DebugString()));
     60 
     61     // Check to make sure the last dimension of diag is equal to the smaller of
     62     // the last two dimensions of input.
     63     const int64 min_dim = std::min(input_shape.dim_size(rank - 1),
     64                                    input_shape.dim_size(rank - 2));
     65     TensorShape expected_diag_shape = input_shape;
     66     expected_diag_shape.RemoveLastDims(2);
     67     expected_diag_shape.AddDim(min_dim);
     68     OP_REQUIRES(context, expected_diag_shape == diag_shape,
     69                 errors::InvalidArgument(
     70                     "must have diagonal.shape == input.shape[:-2] + "
     71                     "min(input.shape[-2:]), but received input shape: ",
     72                     input_shape.DebugString(),
     73                     " and diagonal shape: ", diag_shape.DebugString()));
     74 
     75     if (input.NumElements() == 0) {
     76       // This is a no-op.
     77       context->set_output(0, input);
     78       return;
     79     }
     80 
     81     auto input_reshaped = input.flat_inner_dims<T, 3>();
     82     auto diag_reshaped = diag.flat_inner_dims<T, 2>();
     83     Tensor* output = nullptr;
     84     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
     85                                 {0}, 0, input_shape, &output));
     86     auto output_reshaped = output->flat_inner_dims<T, 3>();
     87     functor::MatrixSetDiag<Device, T>::Compute(
     88         context, context->eigen_device<Device>(), input_reshaped, diag_reshaped,
     89         output_reshaped);
     90   }
     91 
     92  private:
     93   TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
     94 };
     95 
     96 #define REGISTER_MATRIX_SET_DIAG(type)                                    \
     97   REGISTER_KERNEL_BUILDER(                                                \
     98       Name("MatrixSetDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
     99       MatrixSetDiagOp<CPUDevice, type>);
    100 TF_CALL_POD_TYPES(REGISTER_MATRIX_SET_DIAG);
    101 #undef REGISTER_MATRIX_SET_DIAG
    102 
    103 // Registration of the deprecated kernel.
    104 // Delete after 10mar2017.
    105 #define REGISTER_BATCH_MATRIX_SET_DIAG(type)                                   \
    106   REGISTER_KERNEL_BUILDER(                                                     \
    107       Name("BatchMatrixSetDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    108       MatrixSetDiagOp<CPUDevice, type>);
    109 TF_CALL_POD_TYPES(REGISTER_BATCH_MATRIX_SET_DIAG);
    110 #undef REGISTER_BATCH_MATRIX_SET_DIAG
    111 
    112 namespace functor {
    113 
    114 // Implementation of the functor specialization for CPU.
    115 template <typename T>
    116 struct MatrixSetDiag<CPUDevice, T> {
    117   static void Compute(OpKernelContext* context, const CPUDevice& device,
    118                       typename TTypes<T, 3>::ConstTensor input,
    119                       typename TTypes<T, 2>::ConstTensor diag,
    120                       typename TTypes<T, 3>::Tensor output) {
    121     if (input.data() != output.data()) {
    122       output.device(device) = input;
    123     }
    124     auto compute_shard = [&output, &diag](int64 begin, int64 end) {
    125       for (int64 batch = begin; batch < end; ++batch) {
    126         for (int64 col = 0; col < diag.dimension(1); ++col) {
    127           output(batch, col, col) = diag(batch, col);
    128         }
    129       }
    130     };
    131     auto thread_pool =
    132         context->device()->tensorflow_cpu_worker_threads()->workers;
    133     int64 cost_per_batch = 10 * output.dimension(1);  // Heuristic.
    134     thread_pool->ParallelFor(output.dimension(0), cost_per_batch,
    135                              std::move(compute_shard));
    136   }
    137 };
    138 
    139 }  // namespace functor
    140 
    141 #if GOOGLE_CUDA
    142 
    143 // Forward declarations of the functor specializations for GPU.
    144 namespace functor {
    145 #define DECLARE_GPU_SPEC(T)                         \
    146   template <>                                       \
    147   void MatrixSetDiag<GPUDevice, T>::Compute(        \
    148       OpKernelContext* context, const GPUDevice& d, \
    149       typename TTypes<T, 3>::ConstTensor input,     \
    150       typename TTypes<T, 2>::ConstTensor diag,      \
    151       typename TTypes<T, 3>::Tensor output);        \
    152   extern template struct MatrixSetDiag<GPUDevice, T>;
    153 
    154 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
    155 TF_CALL_bool(DECLARE_GPU_SPEC);
    156 TF_CALL_complex64(DECLARE_GPU_SPEC);
    157 TF_CALL_complex128(DECLARE_GPU_SPEC);
    158 
    159 }  // namespace functor
    160 
    161 // Registration of the GPU implementations.
    162 #define REGISTER_MATRIX_SET_DIAG_GPU(type)                                \
    163   REGISTER_KERNEL_BUILDER(                                                \
    164       Name("MatrixSetDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    165       MatrixSetDiagOp<GPUDevice, type>);
    166 TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_SET_DIAG_GPU);
    167 TF_CALL_bool(REGISTER_MATRIX_SET_DIAG_GPU);
    168 TF_CALL_complex64(REGISTER_MATRIX_SET_DIAG_GPU);
    169 TF_CALL_complex128(REGISTER_MATRIX_SET_DIAG_GPU);
    170 #undef REGISTER_MATRIX_SET_DIAG_GPU
    171 
    172 // Registration of the deprecated kernel.
    173 // Delete after 10mar2017.
    174 #define REGISTER_BATCH_MATRIX_SET_DIAG_GPU(type)                               \
    175   REGISTER_KERNEL_BUILDER(                                                     \
    176       Name("BatchMatrixSetDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    177       MatrixSetDiagOp<GPUDevice, type>);
    178 TF_CALL_GPU_NUMBER_TYPES(REGISTER_BATCH_MATRIX_SET_DIAG_GPU);
    179 #undef REGISTER_BATCH_MATRIX_SET_DIAG_GPU
    180 
    181 #endif  // GOOGLE_CUDA
    182 
    183 }  // namespace tensorflow
    184