Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #if GOOGLE_CUDA
     21 #define EIGEN_USE_GPU
     22 #endif  // GOOGLE_CUDA
     23 
     24 #include "tensorflow/core/kernels/diag_op.h"
     25 
     26 #include <algorithm>
     27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/register_types.h"
     30 #include "tensorflow/core/framework/tensor.h"
     31 #include "tensorflow/core/framework/tensor_types.h"
     32 #include "tensorflow/core/platform/logging.h"
     33 #include "tensorflow/core/platform/types.h"
     34 #include "tensorflow/core/util/work_sharder.h"
     35 
     36 namespace tensorflow {
     37 
     38 typedef Eigen::ThreadPoolDevice CPUDevice;
     39 typedef Eigen::GpuDevice GPUDevice;
     40 
     41 // Generate the diagonal tensor with the diagonal set to the input tensor.
     42 template <typename Device, typename T>
     43 class DiagOp : public OpKernel {
     44  public:
     45   explicit DiagOp(OpKernelConstruction* context) : OpKernel(context) {}
     46 
     47   void Compute(OpKernelContext* context) override {
     48     const Tensor& diagonal = context->input(0);
     49     const int num_dims = diagonal.dims();
     50     OP_REQUIRES(
     51         context, 0 != num_dims,
     52         errors::InvalidArgument("Input must be at least rank 1, got 0"));
     53     TensorShape out_shape;
     54     for (int i = 0; i < num_dims; ++i) {
     55       out_shape.AddDim(diagonal.dim_size(i));
     56     }
     57     for (int i = 0; i < num_dims; ++i) {
     58       out_shape.AddDim(diagonal.dim_size(i));
     59     }
     60     Tensor* output_tensor = nullptr;
     61     OP_REQUIRES_OK(context,
     62                    context->allocate_output(0, out_shape, &output_tensor));
     63     functor::DiagFunctor<Device, T> diagFunc;
     64     Status s =
     65         diagFunc(context, diagonal.NumElements(), diagonal.flat<T>().data(),
     66                  output_tensor->flat<T>().data());
     67     OP_REQUIRES_OK(context, s);
     68   }
     69 };
     70 
     71 // Extract the diagonal tensor with the diagonal set to the input tensor.
     72 template <typename Device, typename T>
     73 class DiagPartOp : public OpKernel {
     74  public:
     75   explicit DiagPartOp(OpKernelConstruction* context) : OpKernel(context) {}
     76 
     77   void Compute(OpKernelContext* context) override {
     78     const Tensor& tensor = context->input(0);
     79     const int num_dims = tensor.dims();
     80     const int out_dims = num_dims / 2;
     81     OP_REQUIRES(context, 0 == num_dims % 2,
     82                 errors::InvalidArgument("The rank of the tensor should be \
     83                                          even and positive, got shape ",
     84                                         tensor.shape().DebugString()));
     85     for (int i = 0; i < out_dims; i++) {
     86       OP_REQUIRES(
     87           context, tensor.dim_size(i) == tensor.dim_size(i + out_dims),
     88           errors::InvalidArgument("Invalid shape ",
     89                                   tensor.shape().DebugString(), ": dimensions ",
     90                                   i, " and ", i + out_dims, " do not match."));
     91     }
     92 
     93     TensorShape out_shape;
     94     for (int i = 0; i < out_dims; ++i) {
     95       out_shape.AddDim(tensor.dim_size(i));
     96     }
     97 
     98     Tensor* output = nullptr;
     99     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
    100     functor::DiagPartFunctor<Device, T> diagPartFunc;
    101     Status s = diagPartFunc(context, out_shape.num_elements(),
    102                             tensor.flat<T>().data(), output->flat<T>().data());
    103     OP_REQUIRES_OK(context, s);
    104   }
    105 };
    106 
    107 // Implementation of the functor specialization for CPU.
    108 //
    109 // According to the diagonal definition,
    110 // `output[i1,..., ik, i1,..., ik] = input[i1,..., ik]`,
    111 //
    112 // Let the rank of input is [s1,..., sk], then any offset of input's
    113 // pointer can be represent by coordinate [i1,..., ik],
    114 // where `index = i1*(s2*...*sk) + i2*(s3*...*sk) +... + ik`
    115 //
    116 // Let new_index is the offset of output's pointer with coordinate
    117 // [i1,..., ik, i1,..., ik], then we have
    118 // `new_index = i1*(s2*...sk*s1*...*sk) + i2*(s3*...*sk*s1*...*sk) +... + \
    119 //              ik*(s1*...*sk) + i1*(s2*...*sk) + i2*(s3*...*sk) +... + ik
    120 //            = (i1*(s2*...*sk) + i2*(s3*...*sk) +... + ik) * (1 + s1*...*sk)
    121 //            = index * (1 + s1*...*sk)
    122 //
    123 // Let `size = s1*...*sk`, we finally have `new_index = index * (1 + size)`,
    124 // which is the transfer function we use below.
    125 // This trick make our implementations clear and easy to be parallel.
    126 namespace functor {
    127 template <typename T>
    128 struct DiagFunctor<CPUDevice, T> {
    129   EIGEN_ALWAYS_INLINE Status operator()(OpKernelContext* context,
    130                                         const int64 size, const T* in, T* out) {
    131     // This subprocess is responsible for writing values in index range
    132     // [start*size, limit*size)
    133     auto subDiag = [in, out, size](int64 start, int64 limit) {
    134       std::fill(out + size * start, out + size * limit, T());
    135       for (int64 index = start; index < limit; ++index) {
    136         out[(1 + size) * index] = in[index];
    137       }
    138     };
    139 
    140     // Here, 5 is a empirical factor of cost_per_unit.
    141     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
    142     Shard(worker_threads.num_threads, worker_threads.workers, size, 5 * size,
    143           subDiag);
    144     return Status::OK();
    145   }
    146 };
    147 
    148 template <typename T>
    149 struct DiagPartFunctor<CPUDevice, T> {
    150   EIGEN_ALWAYS_INLINE Status operator()(OpKernelContext* context,
    151                                         const int64 size, const T* in, T* out) {
    152     // This subprocess is responsible for extracting values in index range
    153     // [start, limit)
    154     auto subDiagPart = [in, out, size](int64 start, int64 limit) {
    155       for (int64 index = start; index < limit; ++index) {
    156         out[index] = in[(1 + size) * index];
    157       }
    158     };
    159 
    160     // Here, 5 is a empirical factor of cost_per_unit.
    161     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
    162     Shard(worker_threads.num_threads, worker_threads.workers, size, 5,
    163           subDiagPart);
    164     return Status::OK();
    165   }
    166 };
    167 }  // namespace functor
    168 
    169 // Register the CPU kernels.
    170 #define REGISTER_DIAGOP(T)                                    \
    171   REGISTER_KERNEL_BUILDER(                                    \
    172       Name("Diag").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    173       DiagOp<CPUDevice, T>)
    174 
    175 TF_CALL_double(REGISTER_DIAGOP);
    176 TF_CALL_float(REGISTER_DIAGOP);
    177 TF_CALL_int32(REGISTER_DIAGOP);
    178 TF_CALL_int64(REGISTER_DIAGOP);
    179 TF_CALL_complex64(REGISTER_DIAGOP);
    180 TF_CALL_complex128(REGISTER_DIAGOP);
    181 #undef REGISTER_DIAGOP
    182 
    183 #define REGISTER_DIAGPARTOP(T)                                    \
    184   REGISTER_KERNEL_BUILDER(                                        \
    185       Name("DiagPart").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    186       DiagPartOp<CPUDevice, T>)
    187 
    188 TF_CALL_double(REGISTER_DIAGPARTOP);
    189 TF_CALL_float(REGISTER_DIAGPARTOP);
    190 TF_CALL_int32(REGISTER_DIAGPARTOP);
    191 TF_CALL_int64(REGISTER_DIAGPARTOP);
    192 TF_CALL_complex64(REGISTER_DIAGPARTOP);
    193 TF_CALL_complex128(REGISTER_DIAGPARTOP);
    194 #undef REGISTER_DIAGPARTOP
    195 
    196 // Register the GPU kernels.
    197 #ifdef GOOGLE_CUDA
    198 
    199 // Forward declarations of the functor specializations for GPU.
    200 namespace functor {
    201 extern template struct DiagFunctor<GPUDevice, double>;
    202 extern template struct DiagFunctor<GPUDevice, float>;
    203 extern template struct DiagFunctor<GPUDevice, int32>;
    204 extern template struct DiagFunctor<GPUDevice, int64>;
    205 extern template struct DiagFunctor<GPUDevice, complex64>;
    206 extern template struct DiagFunctor<GPUDevice, complex128>;
    207 }  // namespace functor
    208 
    209 #define REGISTER_DIAGOP_GPU(T)                                \
    210   REGISTER_KERNEL_BUILDER(                                    \
    211       Name("Diag").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    212       DiagOp<GPUDevice, T>)
    213 
    214 TF_CALL_double(REGISTER_DIAGOP_GPU);
    215 TF_CALL_float(REGISTER_DIAGOP_GPU);
    216 TF_CALL_int32(REGISTER_DIAGOP_GPU);
    217 TF_CALL_int64(REGISTER_DIAGOP_GPU);
    218 TF_CALL_complex64(REGISTER_DIAGOP_GPU);
    219 TF_CALL_complex128(REGISTER_DIAGOP_GPU);
    220 #undef REGISTER_DIAGOP_GPU
    221 
    222 // Forward declarations of the functor specializations for GPU.
    223 namespace functor {
    224 extern template struct DiagPartFunctor<GPUDevice, double>;
    225 extern template struct DiagPartFunctor<GPUDevice, float>;
    226 extern template struct DiagPartFunctor<GPUDevice, int32>;
    227 extern template struct DiagPartFunctor<GPUDevice, int64>;
    228 extern template struct DiagPartFunctor<GPUDevice, complex64>;
    229 extern template struct DiagPartFunctor<GPUDevice, complex128>;
    230 }  // namespace functor
    231 
    232 #define REGISTER_DIAGPARTOP_GPU(T)                                \
    233   REGISTER_KERNEL_BUILDER(                                        \
    234       Name("DiagPart").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    235       DiagPartOp<GPUDevice, T>)
    236 
    237 TF_CALL_double(REGISTER_DIAGPARTOP_GPU);
    238 TF_CALL_float(REGISTER_DIAGPARTOP_GPU);
    239 TF_CALL_int32(REGISTER_DIAGPARTOP_GPU);
    240 TF_CALL_int64(REGISTER_DIAGPARTOP_GPU);
    241 TF_CALL_complex64(REGISTER_DIAGPARTOP_GPU);
    242 TF_CALL_complex128(REGISTER_DIAGPARTOP_GPU);
    243 #undef REGISTER_DIAGPARTOP_GPU
    244 
    245 #endif  // GOOGLE_CUDA
    246 
    247 }  // namespace tensorflow
    248