Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/linalg_ops.cc.
     17 #if GOOGLE_CUDA
     18 #define EIGEN_USE_GPU
     19 #endif
     20 
     21 #include <numeric>
     22 
     23 #include "third_party/eigen3/Eigen/Core"
     24 #include "third_party/eigen3/Eigen/LU"
     25 #include "tensorflow/core/framework/kernel_def_builder.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/kernels/linalg_ops_common.h"
     29 #include "tensorflow/core/lib/core/errors.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/platform/macros.h"
     32 #include "tensorflow/core/platform/types.h"
     33 
     34 #if GOOGLE_CUDA
     35 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     36 #include "tensorflow/core/kernels/cuda_solvers.h"
     37 #include "tensorflow/core/kernels/transpose_functor.h"
     38 #endif
     39 
     40 namespace tensorflow {
     41 
     42 static const char kErrMsg[] = "Input matrix is not invertible.";
     43 
     44 template <class Scalar>
     45 class MatrixSolveOp : public LinearAlgebraOp<Scalar> {
     46  public:
     47   INHERIT_LINALG_TYPEDEFS(Scalar);
     48 
     49   explicit MatrixSolveOp(OpKernelConstruction* context) : Base(context) {
     50     OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
     51   }
     52 
     53   void ValidateInputMatrixShapes(
     54       OpKernelContext* context,
     55       const TensorShapes& input_matrix_shapes) const final {
     56     Base::ValidateSquareSolver(context, input_matrix_shapes);
     57   }
     58 
     59   TensorShapes GetOutputMatrixShapes(
     60       const TensorShapes& input_matrix_shapes) const final {
     61     return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
     62                                       input_matrix_shapes[1].dim_size(1)})});
     63   }
     64 
     65   int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
     66     double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
     67     double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
     68     double cost = rows * rows * (rows + num_rhss);
     69     return cost >= static_cast<double>(kint64max) ? kint64max
     70                                                   : static_cast<int64>(cost);
     71   }
     72 
     73   bool EnableInputForwarding() const final { return false; }
     74 
     75   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
     76                      MatrixMaps* outputs) final {
     77     const ConstMatrixMap& matrix = inputs[0];
     78     const ConstMatrixMap& rhs = inputs[1];
     79     if (matrix.rows() == 0 || rhs.cols() == 0) {
     80       // To be consistent with the MatrixInverse op, we define the solution for
     81       // an empty set of equation as the empty matrix.
     82       return;
     83     }
     84     Eigen::PartialPivLU<Matrix> lu_decomposition(matrix.rows());
     85     if (adjoint_) {
     86       // TODO(rmlarsen): For Eigen 3.2, this creates a temporary copy.
     87       // Make sure to backport: https://bitbucket.org/eigen/eigen/commits/
     88       // bd2219a74c96dfe3f6bc2c23588749e36d2d8173
     89       lu_decomposition.compute(matrix.adjoint());
     90     } else {
     91       lu_decomposition.compute(matrix);
     92     }
     93 
     94     // PartialPivLU cannot give strong guarantees on invertibility,
     95     // but we can at least guard against exact zero pivots. This can occur as
     96     // a result of basic user mistakes such providing integer valued
     97     // matrices that are exactly singular, or due to underflow if this
     98     // code is run with denormals being flushed to zero.
     99     const RealScalar min_abs_pivot =
    100         lu_decomposition.matrixLU().diagonal().cwiseAbs().minCoeff();
    101     OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
    102                 errors::InvalidArgument(kErrMsg));
    103 
    104     // TODO(rmlarsen): Add check based on condition number estimation.
    105     // The necessary changes to Eigen are in
    106     // https://bitbucket.org/eigen/eigen/pull-requests/174/
    107     // add-matrix-condition-number-estimation/diff
    108     outputs->at(0) = lu_decomposition.solve(rhs);
    109   }
    110 
    111  private:
    112   bool adjoint_;
    113 
    114   TF_DISALLOW_COPY_AND_ASSIGN(MatrixSolveOp);
    115 };
    116 
    117 #if GOOGLE_CUDA
    118 typedef Eigen::GpuDevice GPUDevice;
    119 
    120 template <class Scalar>
    121 class MatrixSolveOpGpu : public AsyncOpKernel {
    122  public:
    123   explicit MatrixSolveOpGpu(OpKernelConstruction* context)
    124       : AsyncOpKernel(context) {
    125     OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
    126   }
    127 
    128   void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
    129     const Tensor& input = context->input(0);
    130     const Tensor& rhs = context->input(1);
    131     const int ndims = input.dims();
    132     const int64 n = input.dim_size(ndims - 1);
    133     const int64 nrhs = rhs.dim_size(ndims - 1);
    134     // Validate inputs.
    135     OP_REQUIRES_ASYNC(
    136         context, ndims >= 2,
    137         errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
    138         done);
    139     OP_REQUIRES_ASYNC(context, rhs.dims() == ndims,
    140                       errors::InvalidArgument(
    141                           "Input and right-hand side must have same rank, got ",
    142                           ndims, " != ", rhs.dims()),
    143                       done);
    144     OP_REQUIRES_ASYNC(
    145         context, input.dim_size(ndims - 2) == n,
    146         errors::InvalidArgument("Input matrices must be squares, got",
    147                                 input.dim_size(ndims - 2), " != ", n),
    148         done);
    149     OP_REQUIRES_ASYNC(context, rhs.dim_size(ndims - 2) == n,
    150                       errors::InvalidArgument(
    151                           "Input matrix and right-hand side must have the "
    152                           "same number of rows, got",
    153                           n, " != ", rhs.dim_size(ndims - 2)),
    154                       done);
    155 
    156     // Allocate output.
    157     Tensor* output;
    158     OP_REQUIRES_OK_ASYNC(
    159         context,
    160         context->forward_input_or_allocate_output({1}, 0, rhs.shape(), &output),
    161         done);
    162 
    163     // To be consistent with the MatrixInverse op, we define the solution for
    164     // an empty set of equations as the empty matrix.
    165     if (rhs.NumElements() == 0) {
    166       done();
    167       return;
    168     }
    169 
    170     // TODO(rmlarsen): Convert to std::make_unique when available.
    171     std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
    172 
    173     // Make a copy of the input for the factorization step, or, if adjoint_ is
    174     // false, try to reuse the input buffer if this op owns it exclusively.
    175     Tensor input_copy;
    176     const GPUDevice& device = context->eigen_device<GPUDevice>();
    177     if (adjoint_) {
    178       // For the adjoint case, it is simpler to always make a transposed copy up
    179       // front.
    180       OP_REQUIRES_OK_ASYNC(
    181           context,
    182           solver->allocate_scoped_tensor(DataTypeToEnum<Scalar>::value,
    183                                          input.shape(), &input_copy),
    184           done);
    185       OP_REQUIRES_OK_ASYNC(context,
    186                            DoMatrixTranspose(device, input, &input_copy), done);
    187     } else {
    188       OP_REQUIRES_OK_ASYNC(
    189           context,
    190           solver->forward_input_or_allocate_scoped_tensor(
    191               {0}, DataTypeToEnum<Scalar>::value, input.shape(), &input_copy),
    192           done);
    193       if (!input.SharesBufferWith(input_copy)) {
    194         device.memcpy(input_copy.flat<Scalar>().data(),
    195                       input.flat<Scalar>().data(),
    196                       input.NumElements() * sizeof(Scalar));
    197       }
    198     }
    199     auto input_copy_reshaped = input_copy.template flat_inner_dims<Scalar, 3>();
    200     const int64 batch_size = input_copy_reshaped.dimension(0);
    201 
    202     // Allocate pivots on the device.
    203     Tensor pivots;
    204     OP_REQUIRES_OK_ASYNC(
    205         context,
    206         solver->allocate_scoped_tensor(DataTypeToEnum<int>::value,
    207                                        TensorShape{batch_size, n}, &pivots),
    208         done);
    209     auto pivots_mat = pivots.template matrix<int>();
    210 
    211     // 1. Compute the partially pivoted LU factorization(s) of the
    212     // matrix/matrices.
    213     std::vector<DeviceLapackInfo> dev_info;
    214     auto input_copy_ptrs = solver->GetScratchSpace<uint8>(
    215         sizeof(Scalar*) * batch_size, "input_copt_ptrs",
    216         /* on_host */ true);
    217     const int kMaxMatrixSizeToBatchSizeRatio = 128;
    218     const bool use_batched_solver =
    219         n <= kMaxMatrixSizeToBatchSizeRatio * batch_size;
    220     if (use_batched_solver) {
    221       // For small matrices or large batch sizes, we use the batched interface
    222       // from cuBlas.
    223       const Scalar** input_copy_ptrs_base =
    224           reinterpret_cast<const Scalar**>(input_copy_ptrs.mutable_data());
    225       for (int batch = 0; batch < batch_size; ++batch) {
    226         input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0);
    227       }
    228       dev_info.push_back(
    229           solver->GetDeviceLapackInfo(batch_size, "getrfBatched"));
    230       OP_REQUIRES_OK_ASYNC(
    231           context,
    232           solver->GetrfBatched(n, input_copy_ptrs_base, n, pivots_mat.data(),
    233                                &dev_info.back(), batch_size),
    234           done);
    235     } else {
    236       // For small batch sizes or large matrices, we use the non-batched
    237       // interface from cuSolver, which is much faster for large matrices.
    238       dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf"));
    239       for (int batch = 0; batch < batch_size; ++batch) {
    240         OP_REQUIRES_OK_ASYNC(
    241             context,
    242             solver->Getrf(n, n, &input_copy_reshaped(batch, 0, 0), n,
    243                           &pivots_mat(batch, 0), &dev_info.back()(batch)),
    244             done);
    245       }
    246     }
    247 
    248     // 2. Make a transposed copy of the right-hand sides. This is necessary
    249     // because cuBLAS assumes column-major storage while TensorFlow TF uses
    250     // row-major.
    251     TensorShape transposed_rhs_shape(rhs.shape());
    252     transposed_rhs_shape.RemoveLastDims(2);
    253     transposed_rhs_shape.AddDim(nrhs);
    254     transposed_rhs_shape.AddDim(n);
    255     Tensor transposed_rhs;
    256     OP_REQUIRES_OK_ASYNC(
    257         context,
    258         solver->allocate_scoped_tensor(DataTypeToEnum<Scalar>::value,
    259                                        transposed_rhs_shape, &transposed_rhs),
    260         done);
    261     if (nrhs > 1) {
    262       OP_REQUIRES_OK_ASYNC(
    263           context, DoMatrixTranspose(device, rhs, &transposed_rhs), done);
    264     } else {
    265       device.memcpy(transposed_rhs.flat<Scalar>().data(),
    266                     rhs.flat<Scalar>().data(),
    267                     rhs.NumElements() * sizeof(Scalar));
    268     }
    269 
    270     // 3. Solve op(A) X = B (in column major form).
    271     // We use a trick here: If adjoint_ is true, we converted A to column major
    272     // form above. If adjoint is false then I leave A in row-major form and use
    273     // trans_a = CUBLAS_OP_T to effectively transform it to column-major on the
    274     // fly. (This means that we actually use the LU-factorization of A^T in that
    275     // case, but that is equally good for solving AX=B). This way we save an
    276     // explicit transpose in the more common case of adjoint_ == false.
    277     auto input_copy_ptr_array = solver->GetScratchSpace<uint8>(
    278         sizeof(Scalar*) * batch_size, "input_copy_ptr_array",
    279         /* on_host */ true);
    280     auto transposed_rhs_ptr_array = solver->GetScratchSpace<uint8>(
    281         sizeof(Scalar*) * batch_size, "transposed_rhs_ptr_array",
    282         /* on_host */ true);
    283     auto transposed_rhs_reshaped =
    284         transposed_rhs.template flat_inner_dims<Scalar, 3>();
    285     if (use_batched_solver) {
    286       const Scalar** input_copy_ptrs_base =
    287           reinterpret_cast<const Scalar**>(input_copy_ptr_array.mutable_data());
    288       const Scalar** transposed_rhs_ptrs_base =
    289           reinterpret_cast<const Scalar**>(
    290               transposed_rhs_ptr_array.mutable_data());
    291       for (int batch = 0; batch < batch_size; ++batch) {
    292         input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0);
    293         transposed_rhs_ptrs_base[batch] = &transposed_rhs_reshaped(batch, 0, 0);
    294       }
    295       int host_info = 0;
    296       OP_REQUIRES_OK_ASYNC(
    297           context,
    298           solver->GetrsBatched(adjoint_ ? CUBLAS_OP_C : CUBLAS_OP_T, n, nrhs,
    299                                input_copy_ptrs_base, n, pivots_mat.data(),
    300                                transposed_rhs_ptrs_base, n, &host_info,
    301                                batch_size),
    302           done);
    303       OP_REQUIRES_ASYNC(
    304           context, host_info == 0,
    305           errors::InvalidArgument("The ", -host_info,
    306                                   "'th argument to cublas*getrsBatched had "
    307                                   "an illegal value."),
    308           done);
    309     } else {
    310       dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrs"));
    311       for (int batch = 0; batch < batch_size; ++batch) {
    312         OP_REQUIRES_OK_ASYNC(
    313             context,
    314             solver->Getrs(adjoint_ ? CUBLAS_OP_C : CUBLAS_OP_T, n, nrhs,
    315                           &input_copy_reshaped(batch, 0, 0), n,
    316                           &pivots_mat(batch, 0),
    317                           &transposed_rhs_reshaped(batch, 0, 0), n,
    318                           &dev_info.back()(batch)),
    319             done);
    320       }
    321     }
    322 
    323     // 4. Transpose X to get the final result in row-major form.
    324     if (nrhs > 1) {
    325       OP_REQUIRES_OK_ASYNC(
    326           context, DoMatrixTranspose(device, transposed_rhs, output), done);
    327     } else {
    328       device.memcpy(output->flat<Scalar>().data(),
    329                     transposed_rhs.flat<Scalar>().data(),
    330                     transposed_rhs.NumElements() * sizeof(Scalar));
    331     }
    332 
    333     // Callback for checking info after kernels finish. Also capture the
    334     // temporary Tensors/ScratchSpace so they don't get deallocated before the
    335     // kernels run. TODO(rmlarsen): Use move capture once C++14 becomes
    336     // available.
    337     auto info_checker = [context, done, dev_info](
    338                             const Status& status,
    339                             const std::vector<HostLapackInfo>& host_infos) {
    340       if (!status.ok() && errors::IsInvalidArgument(status) &&
    341           !host_infos.empty()) {
    342         for (int i = 0; i < host_infos[0].size(); ++i) {
    343           // Match the CPU error message for singular matrices. Otherwise
    344           // just print the original error message from the status below.
    345           OP_REQUIRES_ASYNC(context, host_infos[0].data()[i] <= 0,
    346                             errors::InvalidArgument(kErrMsg), done);
    347         }
    348       }
    349       OP_REQUIRES_OK_ASYNC(context, status, done);
    350       done();
    351     };
    352     CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
    353                                                     std::move(info_checker));
    354   }
    355 
    356  private:
    357   bool adjoint_;
    358 };
    359 
    360 REGISTER_LINALG_OP_GPU("MatrixSolve", (MatrixSolveOpGpu<float>), float);
    361 REGISTER_LINALG_OP_GPU("MatrixSolve", (MatrixSolveOpGpu<double>), double);
    362 REGISTER_LINALG_OP_GPU("MatrixSolve", (MatrixSolveOpGpu<complex64>), complex64);
    363 REGISTER_LINALG_OP_GPU("MatrixSolve", (MatrixSolveOpGpu<complex128>),
    364                        complex128);
    365 
    366 #endif  // GOOGLE_CUDA
    367 
    368 REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<float>), float);
    369 REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<double>), double);
    370 REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<complex64>), complex64);
    371 REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<complex128>), complex128);
    372 REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<float>), float);
    373 REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<double>), double);
    374 REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<complex64>), complex64);
    375 REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<complex128>), complex128);
    376 }  // namespace tensorflow
    377