Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/linalg_ops.cc.
     17 #if GOOGLE_CUDA
     18 #define EIGEN_USE_GPU
     19 #endif
     20 
     21 #include <numeric>
     22 
     23 #include "third_party/eigen3/Eigen/Core"
     24 #include "third_party/eigen3/Eigen/LU"
     25 #include "tensorflow/core/framework/kernel_def_builder.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/kernels/linalg_ops_common.h"
     29 #include "tensorflow/core/lib/core/errors.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/platform/macros.h"
     32 #include "tensorflow/core/platform/types.h"
     33 
     34 #if GOOGLE_CUDA
     35 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     36 #include "tensorflow/core/kernels/cuda_solvers.h"
     37 #include "tensorflow/core/kernels/transpose_functor.h"
     38 #endif
     39 
     40 namespace tensorflow {
     41 
     42 static const char kErrMsg[] = "Input matrix is not invertible.";
     43 
     44 template <class Scalar>
     45 class MatrixSolveOp : public LinearAlgebraOp<Scalar> {
     46  public:
     47   INHERIT_LINALG_TYPEDEFS(Scalar);
     48 
     49   explicit MatrixSolveOp(OpKernelConstruction* context) : Base(context) {
     50     OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
     51   }
     52 
     53   void ValidateInputMatrixShapes(
     54       OpKernelContext* context,
     55       const TensorShapes& input_matrix_shapes) const final {
     56     Base::ValidateSquareSolver(context, input_matrix_shapes);
     57   }
     58 
     59   TensorShapes GetOutputMatrixShapes(
     60       const TensorShapes& input_matrix_shapes) const final {
     61     return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
     62                                       input_matrix_shapes[1].dim_size(1)})});
     63   }
     64 
     65   int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
     66     double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
     67     double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
     68     double cost = rows * rows * (rows + num_rhss);
     69     return cost >= static_cast<double>(kint64max) ? kint64max
     70                                                   : static_cast<int64>(cost);
     71   }
     72 
     73   bool EnableInputForwarding() const final { return false; }
     74 
     75   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
     76                      MatrixMaps* outputs) final {
     77     const ConstMatrixMap& matrix = inputs[0];
     78     const ConstMatrixMap& rhs = inputs[1];
     79     if (matrix.rows() == 0 || rhs.cols() == 0) {
     80       // To be consistent with the MatrixInverse op, we define the solution for
     81       // an empty set of equation as the empty matrix.
     82       return;
     83     }
     84     Eigen::PartialPivLU<Matrix> lu_decomposition(matrix.rows());
     85     if (adjoint_) {
     86       // TODO(rmlarsen): For Eigen 3.2, this creates a temporary copy.
     87       // Make sure to backport: https://bitbucket.org/eigen/eigen/commits/
     88       // bd2219a74c96dfe3f6bc2c23588749e36d2d8173
     89       lu_decomposition.compute(matrix.adjoint());
     90     } else {
     91       lu_decomposition.compute(matrix);
     92     }
     93 
     94     // PartialPivLU cannot give strong guarantees on invertibility,
     95     // but we can at least guard against exact zero pivots. This can occur as
     96     // a result of basic user mistakes such providing integer valued
     97     // matrices that are exactly singular, or due to underflow if this
     98     // code is run with denormals being flushed to zero.
     99     const RealScalar min_abs_pivot =
    100         lu_decomposition.matrixLU().diagonal().cwiseAbs().minCoeff();
    101     OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
    102                 errors::InvalidArgument(kErrMsg));
    103 
    104     // TODO(rmlarsen): Add check based on condition number estimation.
    105     // The necessary changes to Eigen are in
    106     // https://bitbucket.org/eigen/eigen/pull-requests/174/
    107     // add-matrix-condition-number-estimation/diff
    108     outputs->at(0) = lu_decomposition.solve(rhs);
    109   }
    110 
    111  private:
    112   bool adjoint_;
    113 
    114   TF_DISALLOW_COPY_AND_ASSIGN(MatrixSolveOp);
    115 };
    116 
    117 #if GOOGLE_CUDA
    118 typedef Eigen::GpuDevice GPUDevice;
    119 
    120 template <class Scalar>
    121 class MatrixSolveOpGpu : public AsyncOpKernel {
    122  public:
    123   explicit MatrixSolveOpGpu(OpKernelConstruction* context)
    124       : AsyncOpKernel(context) {
    125     OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
    126   }
    127 
    128   void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
    129     const Tensor& input = context->input(0);
    130     const Tensor& rhs = context->input(1);
    131     const int ndims = input.dims();
    132     const int64 n = input.dim_size(ndims - 1);
    133     const int64 nrhs = rhs.dim_size(ndims - 1);
    134     // Validate inputs.
    135     OP_REQUIRES_ASYNC(
    136         context, ndims >= 2,
    137         errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
    138         done);
    139     OP_REQUIRES_ASYNC(context, rhs.dims() == ndims,
    140                       errors::InvalidArgument(
    141                           "Input and right-hand side must have same rank, got ",
    142                           ndims, " != ", rhs.dims()),
    143                       done);
    144     OP_REQUIRES_ASYNC(
    145         context, input.dim_size(ndims - 2) == n,
    146         errors::InvalidArgument("Input matrices must be squares, got",
    147                                 input.dim_size(ndims - 2), " != ", n),
    148         done);
    149     OP_REQUIRES_ASYNC(context, rhs.dim_size(ndims - 2) == n,
    150                       errors::InvalidArgument(
    151                           "Input matrix and right-hand side must have the "
    152                           "same number of rows, got",
    153                           n, " != ", rhs.dim_size(ndims - 2)),
    154                       done);
    155 
    156     // Allocate output.
    157     Tensor* output;
    158     OP_REQUIRES_OK_ASYNC(
    159         context,
    160         context->forward_input_or_allocate_output({1}, 0, rhs.shape(), &output),
    161         done);
    162 
    163     // To be consistent with the MatrixInverse op, we define the solution for
    164     // an empty set of equations as the empty matrix.
    165     if (rhs.NumElements() == 0) {
    166       done();
    167       return;
    168     }
    169 
    170     // TODO(rmlarsen): Convert to std::make_unique when available.
    171     std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
    172 
    173     // Make a copy of the input for the factorization step, or, if adjoint_ is
    174     // false, try to reuse the input buffer if this op owns it exclusively.
    175     Tensor input_copy;
    176     const GPUDevice& device = context->eigen_device<GPUDevice>();
    177     if (adjoint_) {
    178       // For the adjoint case, it is simpler to always make a transposed copy up
    179       // front.
    180       OP_REQUIRES_OK_ASYNC(
    181           context,
    182           solver->allocate_scoped_tensor(DataTypeToEnum<Scalar>::value,
    183                                          input.shape(), &input_copy),
    184           done);
    185       OP_REQUIRES_OK_ASYNC(context,
    186                            DoMatrixTranspose(device, input, &input_copy), done);
    187     } else {
    188       OP_REQUIRES_OK_ASYNC(
    189           context,
    190           solver->forward_input_or_allocate_scoped_tensor(
    191               {0}, DataTypeToEnum<Scalar>::value, input.shape(), &input_copy),
    192           done);
    193       if (!input.SharesBufferWith(input_copy)) {
    194         device.memcpy(input_copy.flat<Scalar>().data(),
    195                       input.flat<Scalar>().data(),
    196                       input.NumElements() * sizeof(Scalar));
    197       }
    198     }
    199     auto input_copy_reshaped = input_copy.template flat_inner_dims<Scalar, 3>();
    200     const int64 batch_size = input_copy_reshaped.dimension(0);
    201 
    202     // Allocate pivots on the device.
    203     Tensor pivots;
    204     OP_REQUIRES_OK_ASYNC(
    205         context,
    206         solver->allocate_scoped_tensor(DataTypeToEnum<int>::value,
    207                                        TensorShape{batch_size, n}, &pivots),
    208         done);
    209     auto pivots_mat = pivots.template matrix<int>();
    210 
    211     // 1. Compute the partially pivoted LU factorization(s) of the
    212     // matrix/matrices.
    213     std::vector<DeviceLapackInfo> dev_info;
    214     auto input_copy_ptrs = solver->GetScratchSpace<uint8>(
    215         sizeof(Scalar*) * batch_size, "input_copt_ptrs",
    216         /* on_host */ true);
    217     if (n / batch_size <= 128) {
    218       // For small matrices or large batch sizes, we use the batched
    219       // interface from cuBlas.
    220       const Scalar** input_copy_ptrs_base =
    221           reinterpret_cast<const Scalar**>(input_copy_ptrs.mutable_data());
    222       for (int batch = 0; batch < batch_size; ++batch) {
    223         input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0);
    224       }
    225       dev_info.push_back(
    226           solver->GetDeviceLapackInfo(batch_size, "getrfBatched"));
    227       OP_REQUIRES_OK_ASYNC(
    228           context,
    229           solver->GetrfBatched(n, input_copy_ptrs_base, n, pivots_mat.data(),
    230                                &dev_info.back(), batch_size),
    231           done);
    232     } else {
    233       // For small batch sizes we use the non-batched interface from cuSolver,
    234       // which is much faster for large matrices.
    235       dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf"));
    236       for (int batch = 0; batch < batch_size; ++batch) {
    237         OP_REQUIRES_OK_ASYNC(
    238             context,
    239             solver->Getrf(n, n, &input_copy_reshaped(batch, 0, 0), n,
    240                           &pivots_mat(batch, 0), &dev_info.back()(batch)),
    241             done);
    242       }
    243     }
    244 
    245     // 2. Make a transposed copy of the right-hand sides. This is necessary
    246     // because cuBLAS assumes column-major storage while TensorFlow TF uses
    247     // row-major.
    248     TensorShape transposed_rhs_shape(rhs.shape());
    249     transposed_rhs_shape.RemoveLastDims(2);
    250     transposed_rhs_shape.AddDim(nrhs);
    251     transposed_rhs_shape.AddDim(n);
    252     Tensor transposed_rhs;
    253     OP_REQUIRES_OK_ASYNC(
    254         context,
    255         solver->allocate_scoped_tensor(DataTypeToEnum<Scalar>::value,
    256                                        transposed_rhs_shape, &transposed_rhs),
    257         done);
    258     if (nrhs > 1) {
    259       OP_REQUIRES_OK_ASYNC(
    260           context, DoMatrixTranspose(device, rhs, &transposed_rhs), done);
    261     } else {
    262       device.memcpy(transposed_rhs.flat<Scalar>().data(),
    263                     rhs.flat<Scalar>().data(),
    264                     rhs.NumElements() * sizeof(Scalar));
    265     }
    266 
    267     // 3. Solve op(A) X = B (in column major form).
    268     // We use a trick here: If adjoint_ is true, we converted A to column major
    269     // form above. If adjoint is false then I leave A in row-major form and use
    270     // trans_a = CUBLAS_OP_T to effectively transform it to column-major on the
    271     // fly. (This means that we actually use the LU-factorization of A^T in that
    272     // case, but that is equally good for solving AX=B). This way we save an
    273     // explicit transpose in the more common case of adjoint_ == false.
    274     auto input_copy_ptr_array = solver->GetScratchSpace<uint8>(
    275         sizeof(Scalar*) * batch_size, "input_copy_ptr_array",
    276         /* on_host */ true);
    277     auto transposed_rhs_ptr_array = solver->GetScratchSpace<uint8>(
    278         sizeof(Scalar*) * batch_size, "transposed_rhs_ptr_array",
    279         /* on_host */ true);
    280     auto transposed_rhs_reshaped =
    281         transposed_rhs.template flat_inner_dims<Scalar, 3>();
    282     // TODO(rmlarsen): Enable the following branch when I figure
    283     // out why it causes a segfault.
    284     if (false && n / batch_size <= 128) {
    285       dev_info.push_back(
    286           solver->GetDeviceLapackInfo(batch_size, "GetrsBatched"));
    287       const Scalar** input_copy_ptrs_base =
    288           reinterpret_cast<const Scalar**>(input_copy_ptr_array.mutable_data());
    289       const Scalar** transposed_rhs_ptrs_base =
    290           reinterpret_cast<const Scalar**>(
    291               transposed_rhs_ptr_array.mutable_data());
    292       for (int batch = 0; batch < batch_size; ++batch) {
    293         input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0);
    294         transposed_rhs_ptrs_base[batch] = &transposed_rhs_reshaped(batch, 0, 0);
    295       }
    296       OP_REQUIRES_OK_ASYNC(
    297           context,
    298           solver->GetrsBatched(adjoint_ ? CUBLAS_OP_C : CUBLAS_OP_T, n, nrhs,
    299                                input_copy_ptrs_base, n, pivots_mat.data(),
    300                                transposed_rhs_ptrs_base, n, &dev_info.back(),
    301                                batch_size),
    302           done);
    303     } else {
    304       dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrs"));
    305       for (int batch = 0; batch < batch_size; ++batch) {
    306         OP_REQUIRES_OK_ASYNC(
    307             context,
    308             solver->Getrs(adjoint_ ? CUBLAS_OP_C : CUBLAS_OP_T, n, nrhs,
    309                           &input_copy_reshaped(batch, 0, 0), n,
    310                           &pivots_mat(batch, 0),
    311                           &transposed_rhs_reshaped(batch, 0, 0), n,
    312                           &dev_info.back()(batch)),
    313             done);
    314       }
    315     }
    316 
    317     // 4. Transpose X to get the final result in row-major form.
    318     if (nrhs > 1) {
    319       OP_REQUIRES_OK_ASYNC(
    320           context, DoMatrixTranspose(device, transposed_rhs, output), done);
    321     } else {
    322       device.memcpy(output->flat<Scalar>().data(),
    323                     transposed_rhs.flat<Scalar>().data(),
    324                     transposed_rhs.NumElements() * sizeof(Scalar));
    325     }
    326 
    327     // Callback for checking info after kernels finish. Also capture the
    328     // temporary Tensors/ScratchSpace so they don't get deallocated before the
    329     // kernels run. TODO(rmlarsen): Use move capture once C++14 becomes
    330     // available.
    331     auto info_checker = [context, done, dev_info](
    332                             const Status& status,
    333                             const std::vector<HostLapackInfo>& host_infos) {
    334       if (!status.ok() && errors::IsInvalidArgument(status) &&
    335           !host_infos.empty()) {
    336         for (int i = 0; i < host_infos[0].size(); ++i) {
    337           // Match the CPU error message for singular matrices. Otherwise
    338           // just print the original error message from the status below.
    339           OP_REQUIRES_ASYNC(context, host_infos[0].data()[i] <= 0,
    340                             errors::InvalidArgument(kErrMsg), done);
    341         }
    342       }
    343       OP_REQUIRES_OK_ASYNC(context, status, done);
    344       done();
    345     };
    346     CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
    347                                                     std::move(info_checker));
    348   }
    349 
    350  private:
    351   bool adjoint_;
    352 };
    353 
    354 REGISTER_LINALG_OP_GPU("MatrixSolve", (MatrixSolveOpGpu<float>), float);
    355 REGISTER_LINALG_OP_GPU("MatrixSolve", (MatrixSolveOpGpu<double>), double);
    356 REGISTER_LINALG_OP_GPU("MatrixSolve", (MatrixSolveOpGpu<complex64>), complex64);
    357 REGISTER_LINALG_OP_GPU("MatrixSolve", (MatrixSolveOpGpu<complex128>),
    358                        complex128);
    359 
    360 #endif  // GOOGLE_CUDA
    361 
    362 REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<float>), float);
    363 REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<double>), double);
    364 REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<complex64>), complex64);
    365 REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<complex128>), complex128);
    366 REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<float>), float);
    367 REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<double>), double);
    368 REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<complex64>), complex64);
    369 REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<complex128>), complex128);
    370 }  // namespace tensorflow
    371