Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/linalg_ops.cc.
     17 
     18 #if GOOGLE_CUDA
     19 #define EIGEN_USE_GPU
     20 #endif
     21 
     22 #include "third_party/eigen3/Eigen/Core"
     23 #include "third_party/eigen3/Eigen/LU"
     24 #include "tensorflow/core/framework/kernel_def_builder.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/tensor_shape.h"
     27 #include "tensorflow/core/kernels/linalg_ops_common.h"
     28 #include "tensorflow/core/lib/core/errors.h"
     29 #include "tensorflow/core/platform/logging.h"
     30 #include "tensorflow/core/platform/macros.h"
     31 #include "tensorflow/core/platform/types.h"
     32 
     33 #if GOOGLE_CUDA
     34 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     35 #include "tensorflow/core/kernels/cuda_solvers.h"
     36 #include "tensorflow/core/kernels/eye_functor.h"
     37 #include "tensorflow/core/kernels/transpose_functor.h"
     38 #endif
     39 
     40 namespace tensorflow {
     41 
     42 template <class Scalar>
     43 class MatrixInverseOp : public LinearAlgebraOp<Scalar> {
     44  public:
     45   INHERIT_LINALG_TYPEDEFS(Scalar);
     46 
     47   explicit MatrixInverseOp(OpKernelConstruction* context) : Base(context) {
     48     OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
     49   }
     50 
     51   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
     52                      MatrixMaps* outputs) final {
     53     const ConstMatrixMap& input = inputs[0];
     54     if (input.rows() == 0) {
     55       // By definition, an empty matrix's inverse is an empty matrix.
     56       return;
     57     }
     58     Eigen::PartialPivLU<Matrix> lu_decomposition;
     59     if (adjoint_) {
     60       // TODO(rmlarsen): For Eigen 3.2, this creates a temporary copy.
     61       // Make sure to backport: https://bitbucket.org/eigen/eigen/commits/
     62       // bd2219a74c96dfe3f6bc2c23588749e36d2d8173
     63       lu_decomposition.compute(input.adjoint());
     64     } else {
     65       lu_decomposition.compute(input);
     66     }
     67     // TODO(rmlarsen): Add check based on condition number estimation.
     68     // PartialPivLU cannot give strong guarantees on invertibility, but
     69     // we can at least guard against exact zero pivots. This can occur as
     70     // a result of basic user mistakes, such as providing integer valued
     71     // matrices that are exactly singular, or due to underflow if this
     72     // code is run with denormals being flushed to zero.
     73     const RealScalar min_abs_pivot =
     74         lu_decomposition.matrixLU().diagonal().cwiseAbs().minCoeff();
     75     OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
     76                 errors::InvalidArgument("Input is not invertible."));
     77     outputs->at(0).noalias() = lu_decomposition.inverse();
     78   }
     79 
     80  private:
     81   bool adjoint_;
     82 
     83   TF_DISALLOW_COPY_AND_ASSIGN(MatrixInverseOp);
     84 };
     85 
     86 #if GOOGLE_CUDA
     87 
     88 typedef Eigen::GpuDevice GPUDevice;
     89 
     90 template <class Scalar>
     91 class MatrixInverseOpGpu : public AsyncOpKernel {
     92  public:
     93   explicit MatrixInverseOpGpu(OpKernelConstruction* context)
     94       : AsyncOpKernel(context) {
     95     OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
     96   }
     97 
     98   void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
     99     const Tensor& input = context->input(0);
    100     const int ndims = input.dims();
    101     const int64 n = input.dim_size(ndims - 1);
    102     // Validate inputs.
    103     OP_REQUIRES_ASYNC(
    104         context, ndims >= 2,
    105         errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
    106         done);
    107     OP_REQUIRES_ASYNC(
    108         context, input.dim_size(ndims - 2) == n,
    109         errors::InvalidArgument("Input matrices must be squares, got",
    110                                 input.dim_size(ndims - 2), " != ", n),
    111         done);
    112 
    113     // By definition, an empty matrix's inverse is an empty matrix.
    114     if (input.NumElements() == 0) {
    115       context->set_output(0, input);
    116       done();
    117       return;
    118     }
    119 
    120     // Allocate output.
    121     Tensor* output;
    122     OP_REQUIRES_OK_ASYNC(context,
    123                          context->forward_input_or_allocate_output(
    124                              {0}, 0, input.shape(), &output),
    125                          done);
    126 
    127     // TODO(rmlarsen): Convert to std::make_unique when available.
    128     std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
    129 
    130     // Make a copy of the (possible adjointed) input that we will use for the
    131     // factorization step.
    132     Tensor input_copy;
    133     OP_REQUIRES_OK_ASYNC(
    134         context,
    135         solver->allocate_scoped_tensor(DataTypeToEnum<Scalar>::value,
    136                                        input.shape(), &input_copy),
    137         done);
    138     auto input_copy_reshaped = input_copy.template flat_inner_dims<Scalar, 3>();
    139     const GPUDevice& device = context->eigen_device<GPUDevice>();
    140     if (!adjoint_) {
    141       device.memcpy(input_copy.flat<Scalar>().data(),
    142                     input.flat<Scalar>().data(),
    143                     input.NumElements() * sizeof(Scalar));
    144     } else {
    145       OP_REQUIRES_OK_ASYNC(
    146           context, DoConjugateMatrixTranspose(device, input, &input_copy),
    147           done);
    148     }
    149     const int64 batch_size = input_copy_reshaped.dimension(0);
    150 
    151     Tensor pivots;
    152     OP_REQUIRES_OK_ASYNC(
    153         context,
    154         solver->allocate_scoped_tensor(DataTypeToEnum<int>::value,
    155                                        TensorShape{batch_size, n}, &pivots),
    156         done);
    157     auto pivots_mat = pivots.template matrix<int>();
    158     auto input_copy_ptr_array = solver->GetScratchSpace<uint8>(
    159         sizeof(Scalar*) * batch_size, "input_copy_ptr_array",
    160         /* on_host */ true);
    161     auto output_ptr_array = solver->GetScratchSpace<uint8>(
    162         sizeof(Scalar*) * batch_size, "output_copy_ptr_array",
    163         /* on_host */ true);
    164     auto output_reshaped = output->template flat_inner_dims<Scalar, 3>();
    165     std::vector<DeviceLapackInfo> dev_info;
    166     if (n < 32 || batch_size > n) {
    167       // For small matrices or very large batch sizes, we use the batched
    168       // interfaces in cuBlas to avoid being dominated by kernel launch
    169       // overhead.
    170       // TODO(rmlarsen): Come up with a better heuristic based on a simple
    171       // cost model.
    172       const Scalar** input_copy_ptr_array_base =
    173           reinterpret_cast<const Scalar**>(input_copy_ptr_array.mutable_data());
    174       const Scalar** output_ptr_array_base =
    175           reinterpret_cast<const Scalar**>(output_ptr_array.mutable_data());
    176       for (int batch = 0; batch < batch_size; ++batch) {
    177         input_copy_ptr_array_base[batch] = &input_copy_reshaped(batch, 0, 0);
    178         output_ptr_array_base[batch] = &output_reshaped(batch, 0, 0);
    179       }
    180 
    181       if (n < 32) {
    182         // MatInvBatched only supports n < 32.
    183         dev_info.push_back(
    184             solver->GetDeviceLapackInfo(batch_size, "MatInvBatched"));
    185         OP_REQUIRES_OK_ASYNC(
    186             context,
    187             solver->MatInvBatched(n, input_copy_ptr_array_base, n,
    188                                   output_ptr_array_base, n, &dev_info.back(),
    189                                   batch_size),
    190 
    191             done);
    192       } else {
    193         // For larger matrices and large batch size, we used the batched
    194         // GETRF/GETRI kernels.
    195         dev_info.push_back(
    196             solver->GetDeviceLapackInfo(batch_size, "GetrfBatched"));
    197         OP_REQUIRES_OK_ASYNC(context,
    198                              solver->GetrfBatched(n, input_copy_ptr_array_base,
    199                                                   n, pivots_mat.data(),
    200                                                   &dev_info.back(), batch_size),
    201                              done);
    202         // 2. Compute the inverse(s).
    203         dev_info.push_back(
    204             solver->GetDeviceLapackInfo(batch_size, "GetriBatched"));
    205         OP_REQUIRES_OK_ASYNC(
    206             context,
    207             solver->GetriBatched(n, input_copy_ptr_array_base, n,
    208                                  pivots_mat.data(), output_ptr_array_base, n,
    209                                  &dev_info.back(), batch_size),
    210             done);
    211       }
    212     } else {
    213       // For large matrices, we compute the inverse of each matrix in the batch
    214       // sequentially. Here we use the cuSolver methods GETRF/GETRS because they
    215       // are MUCH faster than their batched cuBlas equivalents for large
    216       // matrices.
    217       dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf"));
    218       for (int batch = 0; batch < batch_size; ++batch) {
    219         OP_REQUIRES_OK_ASYNC(
    220             context,
    221             solver->Getrf(n, n, &input_copy_reshaped(batch, 0, 0), n,
    222                           &pivots_mat(batch, 0), &dev_info.back()(batch)),
    223             done);
    224       }
    225 
    226       // Set all right-hand sides to the identity.
    227       functor::EyeFunctor<GPUDevice, Scalar> eye;
    228       eye(device, output_reshaped);
    229 
    230       // Solve A X = I.
    231       dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrs"));
    232       for (int batch = 0; batch < batch_size; ++batch) {
    233         OP_REQUIRES_OK_ASYNC(
    234             context,
    235             solver->Getrs(CUBLAS_OP_N, n, n, &input_copy_reshaped(batch, 0, 0),
    236                           n, &pivots_mat(batch, 0),
    237                           &output_reshaped(batch, 0, 0), n,
    238                           &dev_info.back()(batch)),
    239             done);
    240       }
    241     }
    242     // Callback for checking info after kernels finish.
    243     auto info_checker = [context, done](
    244                             const Status& status,
    245                             const std::vector<HostLapackInfo>& host_infos) {
    246       if (!status.ok() && errors::IsInvalidArgument(status)) {
    247         for (const auto& host_info : host_infos) {
    248           for (int i = 0; i < host_info.size(); ++i) {
    249             // Match the CPU error message for singular matrices. Otherwise
    250             // just print the original error message from the call itself
    251             // below.
    252             OP_REQUIRES_ASYNC(
    253                 context, host_info(i) <= 0,
    254                 errors::InvalidArgument("Input is not invertible."), done);
    255           }
    256         }
    257       }
    258       OP_REQUIRES_OK_ASYNC(context, status, done);
    259       done();
    260     };
    261     CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
    262                                                     std::move(info_checker));
    263   }
    264 
    265  private:
    266   bool adjoint_;
    267 };
    268 
    269 REGISTER_LINALG_OP_GPU("MatrixInverse", (MatrixInverseOpGpu<float>), float);
    270 REGISTER_LINALG_OP_GPU("MatrixInverse", (MatrixInverseOpGpu<double>), double);
    271 REGISTER_LINALG_OP_GPU("MatrixInverse", (MatrixInverseOpGpu<complex64>),
    272                        complex64);
    273 REGISTER_LINALG_OP_GPU("MatrixInverse", (MatrixInverseOpGpu<complex128>),
    274                        complex128);
    275 
    276 #endif  // GOOGLE_CUDA
    277 
    278 REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp<float>), float);
    279 REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp<double>), double);
    280 REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp<complex64>), complex64);
    281 REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp<complex128>), complex128);
    282 REGISTER_LINALG_OP("BatchMatrixInverse", (MatrixInverseOp<float>), float);
    283 REGISTER_LINALG_OP("BatchMatrixInverse", (MatrixInverseOp<double>), double);
    284 
    285 }  // namespace tensorflow
    286