Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/linalg_ops.cc.
     17 
     18 #if GOOGLE_CUDA
     19 
     20 #include <numeric>
     21 #include <type_traits>
     22 
     23 #define EIGEN_USE_GPU
     24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     25 #include "tensorflow/core/framework/kernel_def_builder.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/kernels/cast_op.h"
     29 #include "tensorflow/core/kernels/cuda_solvers.h"
     30 #include "tensorflow/core/kernels/cwise_ops.h"
     31 #include "tensorflow/core/kernels/transpose_functor.h"
     32 #include "tensorflow/core/lib/core/errors.h"
     33 #include "tensorflow/core/platform/logging.h"
     34 #include "tensorflow/core/platform/types.h"
     35 
     36 namespace tensorflow {
     37 
     38 typedef Eigen::GpuDevice GPUDevice;
     39 
     40 template <class Scalar>
     41 class SelfAdjointEigV2OpGpu : public AsyncOpKernel {
     42  public:
     43   explicit SelfAdjointEigV2OpGpu(OpKernelConstruction* context)
     44       : AsyncOpKernel(context) {
     45     OP_REQUIRES_OK(context, context->GetAttr("compute_v", &compute_v_));
     46   }
     47 
     48   void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
     49     const Tensor& input = context->input(0);
     50     const int ndims = input.dims();
     51     OP_REQUIRES_ASYNC(
     52         context, ndims >= 2,
     53         errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
     54         done);
     55     const int64 n = input.dim_size(ndims - 1);
     56     OP_REQUIRES_ASYNC(
     57         context, input.dim_size(ndims - 2) == n,
     58         errors::InvalidArgument("Input matrices must be squares, got",
     59                                 input.dim_size(ndims - 2), " != ", n),
     60         done);
     61     const int64 batch_size =
     62         input.template flat_inner_dims<Scalar, 3>().dimension(0);
     63 
     64     // Allocate outputs.
     65     Tensor* eigenvalues;
     66     TensorShape eigenvalues_shape = input.shape();
     67     eigenvalues_shape.RemoveLastDims(1);
     68     OP_REQUIRES_OK_ASYNC(
     69         context, context->allocate_output(0, eigenvalues_shape, &eigenvalues),
     70         done);
     71     Tensor* eigenvectors;
     72     TensorShape eigenvectors_shape =
     73         compute_v_ ? input.shape() : TensorShape({});
     74     OP_REQUIRES_OK_ASYNC(
     75         context, context->allocate_output(1, eigenvectors_shape, &eigenvectors),
     76         done);
     77 
     78     if (input.NumElements() == 0) {
     79       done();
     80       return;
     81     }
     82 
     83     // Allocate workspace.
     84     // TODO(rmlarsen): Convert to std::make_unique when available.
     85     std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
     86     Tensor eigenvalues_real;
     87     using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
     88     if (std::is_same<Scalar, RealScalar>::value) {
     89       eigenvalues_real = *eigenvalues;
     90     } else {
     91       OP_REQUIRES_OK_ASYNC(
     92           context,
     93           solver->allocate_scoped_tensor(DataTypeToEnum<RealScalar>::value,
     94                                          eigenvalues_shape, &eigenvalues_real),
     95           done);
     96     }
     97 
     98     Tensor input_copy;
     99     OP_REQUIRES_OK_ASYNC(
    100         context,
    101         solver->forward_input_or_allocate_scoped_tensor(
    102             {0}, DataTypeToEnum<Scalar>::value, input.shape(), &input_copy),
    103         done);
    104     // For real symmetric matrices, row-major and column-major are the same. For
    105     // complex Hermitian, row-major and column-major differ by a conjugation,
    106     // which is still cheaper than a transpose.
    107     const GPUDevice& device = context->eigen_device<GPUDevice>();
    108     if (!input.SharesBufferWith(input_copy)) {
    109       if (Eigen::NumTraits<Scalar>::IsComplex) {
    110         functor::UnaryFunctor<GPUDevice, functor::conj<Scalar>> conj;
    111         conj(device, input_copy.flat<Scalar>() /*out*/,
    112              input.flat<Scalar>() /*in*/);
    113       } else {
    114         device.memcpy(input_copy.flat<Scalar>().data(),
    115                       input.flat<Scalar>().data(),
    116                       input.NumElements() * sizeof(Scalar));
    117       }
    118     } else if (Eigen::NumTraits<Scalar>::IsComplex) {
    119       functor::UnaryFunctor<GPUDevice, functor::conj<Scalar>> conj;
    120       conj(device, const_cast<Tensor*>(&input)->flat<Scalar>() /*out*/,
    121            input.flat<Scalar>() /*in*/);
    122     }
    123 
    124     // Compute eigen decomposition in-place in input_copy.
    125     std::vector<DeviceLapackInfo> dev_info;
    126     dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "heevd"));
    127     auto input_copy_reshaped = input_copy.flat_inner_dims<Scalar, 3>();
    128     auto eigenvalues_real_reshaped =
    129         eigenvalues_real.flat_inner_dims<RealScalar, 2>();
    130     for (int batch = 0; batch < batch_size; ++batch) {
    131       OP_REQUIRES_OK_ASYNC(
    132           context,
    133           solver->Heevd(compute_v_ ? CUSOLVER_EIG_MODE_VECTOR
    134                                    : CUSOLVER_EIG_MODE_NOVECTOR,
    135                         CUBLAS_FILL_MODE_UPPER, n,
    136                         &input_copy_reshaped(batch, 0, 0), n,
    137                         &eigenvalues_real_reshaped(batch, 0),
    138                         dev_info.back().mutable_data() + batch),
    139           done);
    140     }
    141 
    142     if (!std::is_same<Scalar, RealScalar>::value) {
    143       functor::CastFunctor<GPUDevice, Scalar, RealScalar> cast;
    144       cast(device, eigenvalues->flat<Scalar>(),
    145            const_cast<const Tensor*>(&eigenvalues_real)->flat<RealScalar>());
    146     }
    147 
    148     if (compute_v_) {
    149       // Transpose eigenvectors now stored in input_copy in column-major form to
    150       // output in row-major form.
    151       OP_REQUIRES_OK_ASYNC(
    152           context, DoMatrixTranspose(device, input_copy, eigenvectors), done);
    153     }
    154 
    155     // Asynchronously check return status from cuSolver kernels.
    156     CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
    157                                                     std::move(done));
    158   }
    159 
    160  private:
    161   bool compute_v_;
    162 
    163   TF_DISALLOW_COPY_AND_ASSIGN(SelfAdjointEigV2OpGpu);
    164 };
    165 
    166 #define REGISTER(Scalar)                                                       \
    167   REGISTER_KERNEL_BUILDER(                                                     \
    168       Name("SelfAdjointEigV2").Device(DEVICE_GPU).TypeConstraint<Scalar>("T"), \
    169       (SelfAdjointEigV2OpGpu<Scalar>))
    170 
    171 REGISTER(float);
    172 REGISTER(double);
    173 REGISTER(complex64);
    174 REGISTER(complex128);
    175 
    176 #undef REGISTER
    177 
    178 }  // namespace tensorflow
    179 
    180 #endif  // GOOGLE_CUDA
    181