Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #if GOOGLE_CUDA
     21 #define EIGEN_USE_GPU
     22 #endif  // GOOGLE_CUDA
     23 
     24 #include "tensorflow/core/kernels/matrix_diag_op.h"
     25 
     26 #include <memory>
     27 #include <vector>
     28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     29 #include "tensorflow/core/framework/op_kernel.h"
     30 #include "tensorflow/core/framework/register_types.h"
     31 #include "tensorflow/core/framework/tensor.h"
     32 #include "tensorflow/core/framework/tensor_shape.h"
     33 #include "tensorflow/core/framework/tensor_types.h"
     34 #include "tensorflow/core/framework/types.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 #include "tensorflow/core/platform/macros.h"
     37 
     38 namespace tensorflow {
     39 
     40 typedef Eigen::ThreadPoolDevice CPUDevice;
     41 typedef Eigen::GpuDevice GPUDevice;
     42 
     43 template <typename Device, typename T>
     44 class MatrixDiagPartOp : public OpKernel {
     45  public:
     46   explicit MatrixDiagPartOp(OpKernelConstruction* context)
     47       : OpKernel(context) {}
     48 
     49   void Compute(OpKernelContext* context) override {
     50     const Tensor& input = context->input(0);
     51 
     52     const TensorShape& input_shape = input.shape();
     53     const int rank = input_shape.dims();
     54 
     55     // Preliminary validation of sizes.
     56     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
     57                 errors::InvalidArgument(
     58                     "input must be at least 2-dim, received shape: ",
     59                     input.shape().DebugString()));
     60 
     61     TensorShape output_shape;
     62     for (int i = 0; i < rank - 2; ++i) {
     63       output_shape.AddDim(input_shape.dim_size(i));
     64     }
     65     const int64 min_dim = std::min(input_shape.dim_size(rank - 2),
     66                                    input_shape.dim_size(rank - 1));
     67     output_shape.AddDim(min_dim);
     68 
     69     Tensor* output = nullptr;
     70     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
     71 
     72     auto output_reshaped = output->flat_inner_dims<T, 2>();
     73     auto input_reshaped = input.flat_inner_dims<T, 3>();
     74 
     75     functor::MatrixDiagPart<Device, T>::Compute(
     76         context->eigen_device<Device>(), input_reshaped, output_reshaped);
     77   }
     78 
     79  private:
     80   TF_DISALLOW_COPY_AND_ASSIGN(MatrixDiagPartOp);
     81 };
     82 
     83 template <typename Device, typename T>
     84 class MatrixDiagOp : public OpKernel {
     85  public:
     86   explicit MatrixDiagOp(OpKernelConstruction* context) : OpKernel(context) {}
     87 
     88   void Compute(OpKernelContext* context) override {
     89     const Tensor& input = context->input(0);
     90 
     91     const TensorShape& input_shape = input.shape();
     92     const int rank = input_shape.dims();
     93 
     94     // Preliminary validation of sizes.
     95     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input_shape),
     96                 errors::InvalidArgument(
     97                     "input must be at least 1-dim, received shape: ",
     98                     input.shape().DebugString()));
     99 
    100     const int64 k = input_shape.dim_size(rank - 1);
    101     auto input_reshaped = input.flat_inner_dims<T, 2>();
    102 
    103     TensorShape output_shape = input_shape;
    104     output_shape.AddDim(k);
    105 
    106     Tensor* output = nullptr;
    107     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    108 
    109     auto output_reshaped = output->flat_inner_dims<T, 3>();
    110 
    111     functor::MatrixDiag<Device, T>::Compute(context->eigen_device<Device>(),
    112                                             input_reshaped, output_reshaped);
    113   }
    114 
    115  private:
    116   TF_DISALLOW_COPY_AND_ASSIGN(MatrixDiagOp);
    117 };
    118 
    119 #define REGISTER_MATRIX_DIAG(type)                                         \
    120   REGISTER_KERNEL_BUILDER(                                                 \
    121       Name("MatrixDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"),     \
    122       MatrixDiagOp<CPUDevice, type>);                                      \
    123   REGISTER_KERNEL_BUILDER(                                                 \
    124       Name("MatrixDiagPart").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    125       MatrixDiagPartOp<CPUDevice, type>);
    126 TF_CALL_POD_TYPES(REGISTER_MATRIX_DIAG);
    127 #undef REGISTER_MATRIX_DIAG
    128 
    129 // Registration of the deprecated kernel.
    130 // Delete after 10mar2017.
    131 #define REGISTER_BATCH_MATRIX_DIAG(type)                                    \
    132   REGISTER_KERNEL_BUILDER(                                                  \
    133       Name("BatchMatrixDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    134       MatrixDiagOp<CPUDevice, type>);                                       \
    135   REGISTER_KERNEL_BUILDER(Name("BatchMatrixDiagPart")                       \
    136                               .Device(DEVICE_CPU)                           \
    137                               .TypeConstraint<type>("T"),                   \
    138                           MatrixDiagPartOp<CPUDevice, type>);
    139 TF_CALL_POD_TYPES(REGISTER_BATCH_MATRIX_DIAG);
    140 #undef REGISTER_BATCH_MATRIX_DIAG
    141 
    142 // Implementation of the functor specialization for CPU.
    143 namespace functor {
    144 template <typename T>
    145 struct MatrixDiag<CPUDevice, T> {
    146   static void Compute(const CPUDevice& d,
    147                       typename TTypes<T, 2>::ConstTensor input,
    148                       typename TTypes<T, 3>::Tensor output) {
    149     output.device(d) = output.constant(T());
    150     for (int64 r = 0; r < output.dimension(0); ++r) {
    151       for (int64 d = 0; d < output.dimension(1); ++d) {
    152         output(r, d, d) = input(r, d);
    153       }
    154     }
    155   }
    156 };
    157 
    158 template <typename T>
    159 struct MatrixDiagPart<CPUDevice, T> {
    160   static void Compute(const CPUDevice& d,
    161                       typename TTypes<T, 3>::ConstTensor input,
    162                       typename TTypes<T, 2>::Tensor output) {
    163     for (int64 r = 0; r < output.dimension(0); ++r) {
    164       for (int64 d = 0; d < output.dimension(1); ++d) {
    165         output(r, d) = input(r, d, d);
    166       }
    167     }
    168   }
    169 };
    170 
    171 }  // namespace functor
    172 
    173 #if GOOGLE_CUDA
    174 
    175 // Forward declarations of the functor specializations for GPU.
    176 namespace functor {
    177 #define DECLARE_GPU_SPEC(T)                                         \
    178   template <>                                                       \
    179   void MatrixDiag<GPUDevice, T>::Compute(                           \
    180       const GPUDevice& d, typename TTypes<T, 2>::ConstTensor input, \
    181       typename TTypes<T, 3>::Tensor output);                        \
    182   extern template struct MatrixDiag<GPUDevice, T>;                  \
    183   template <>                                                       \
    184   void MatrixDiagPart<GPUDevice, T>::Compute(                       \
    185       const GPUDevice& d, typename TTypes<T, 3>::ConstTensor input, \
    186       typename TTypes<T, 2>::Tensor output);                        \
    187   extern template struct MatrixDiagPart<GPUDevice, T>;
    188 
    189 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
    190 TF_CALL_bool(DECLARE_GPU_SPEC);
    191 TF_CALL_complex64(DECLARE_GPU_SPEC);
    192 TF_CALL_complex128(DECLARE_GPU_SPEC);
    193 
    194 }  // namespace functor
    195 
    196 // Registration of the GPU implementations.
    197 #define REGISTER_MATRIX_DIAG_GPU(type)                                     \
    198   REGISTER_KERNEL_BUILDER(                                                 \
    199       Name("MatrixDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"),     \
    200       MatrixDiagOp<GPUDevice, type>);                                      \
    201   REGISTER_KERNEL_BUILDER(                                                 \
    202       Name("MatrixDiagPart").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    203       MatrixDiagPartOp<GPUDevice, type>);
    204 TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_DIAG_GPU);
    205 TF_CALL_bool(REGISTER_MATRIX_DIAG_GPU);
    206 TF_CALL_complex64(REGISTER_MATRIX_DIAG_GPU);
    207 TF_CALL_complex128(REGISTER_MATRIX_DIAG_GPU);
    208 #undef REGISTER_MATRIX_DIAG_GPU
    209 
    210 // Registration of the deprecated kernel.
    211 // Delete after 10mar2017.
    212 #define REGISTER_BATCH_MATRIX_DIAG_GPU(type)                                \
    213   REGISTER_KERNEL_BUILDER(                                                  \
    214       Name("BatchMatrixDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    215       MatrixDiagOp<GPUDevice, type>);                                       \
    216   REGISTER_KERNEL_BUILDER(Name("BatchMatrixDiagPart")                       \
    217                               .Device(DEVICE_GPU)                           \
    218                               .TypeConstraint<type>("T"),                   \
    219                           MatrixDiagPartOp<GPUDevice, type>);
    220 TF_CALL_GPU_NUMBER_TYPES(REGISTER_BATCH_MATRIX_DIAG_GPU);
    221 #undef REGISTER_BATCH_MATRIX_DIAG_GPU
    222 
    223 #endif  // GOOGLE_CUDA
    224 
    225 }  // namespace tensorflow
    226