Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/linalg_ops.cc.
     17 //
     18 // This header file is used by the individual qr_*op*.cc files for registering
     19 // individual kernels. A separate file is used for each instantiated kernel to
     20 // improve compilation times.
     21 #include <algorithm>
     22 #include <numeric>
     23 
     24 #if GOOGLE_CUDA
     25 #define EIGEN_USE_GPU
     26 #endif
     27 
     28 #include "third_party/eigen3/Eigen/QR"
     29 #include "tensorflow/core/framework/kernel_def_builder.h"
     30 #include "tensorflow/core/framework/op_kernel.h"
     31 #include "tensorflow/core/framework/tensor.h"
     32 #include "tensorflow/core/framework/tensor_shape.h"
     33 #include "tensorflow/core/kernels/linalg_ops_common.h"
     34 #include "tensorflow/core/lib/core/errors.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 #include "tensorflow/core/platform/macros.h"
     37 #include "tensorflow/core/platform/types.h"
     38 
     39 #if GOOGLE_CUDA
     40 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     41 #include "tensorflow/core/kernels/cuda_solvers.h"
     42 #include "tensorflow/core/kernels/cwise_ops.h"
     43 #include "tensorflow/core/kernels/eye_functor.h"
     44 #include "tensorflow/core/kernels/matrix_band_part_op.h"
     45 #include "tensorflow/core/kernels/transpose_functor.h"
     46 #endif
     47 
     48 namespace tensorflow {
     49 
     50 template <class Scalar>
     51 class QrOp : public LinearAlgebraOp<Scalar> {
     52  public:
     53   typedef LinearAlgebraOp<Scalar> Base;
     54 
     55   explicit QrOp(OpKernelConstruction* context) : Base(context) {
     56     OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_));
     57   }
     58 
     59   using TensorShapes = typename Base::TensorShapes;
     60 
     61   void ValidateInputMatrixShapes(
     62       OpKernelContext* context,
     63       const TensorShapes& input_matrix_shapes) const final {
     64     Base::ValidateSingleMatrix(context, input_matrix_shapes);
     65   }
     66 
     67   TensorShapes GetOutputMatrixShapes(
     68       const TensorShapes& input_matrix_shapes) const final {
     69     int64 m = input_matrix_shapes[0].dim_size(0);
     70     int64 n = input_matrix_shapes[0].dim_size(1);
     71     int64 min_size = std::min(m, n);
     72     if (full_matrices_) {
     73       return TensorShapes({TensorShape({m, m}), TensorShape({m, n})});
     74     } else {
     75       return TensorShapes(
     76           {TensorShape({m, min_size}), TensorShape({min_size, n})});
     77     }
     78   }
     79 
     80   int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
     81     double m = static_cast<double>(input_matrix_shapes[0].dim_size(0));
     82     double n = static_cast<double>(input_matrix_shapes[0].dim_size(1));
     83     double max_size = std::max(m, n);
     84     double min_size = std::min(m, n);
     85     double cost = 2 * max_size * min_size * min_size -
     86                   2 * min_size * min_size * min_size / 3.;
     87     // TODO(jpoulson): Increase the cost if full_matrices is true in a manner
     88     // that reflects the algorithm used for the expansion.
     89     return cost >= static_cast<double>(kint64max) ? kint64max
     90                                                   : static_cast<int64>(cost);
     91   }
     92 
     93   using Matrix = typename Base::Matrix;
     94   using MatrixMaps = typename Base::MatrixMaps;
     95   using ConstMatrixMap = typename Base::ConstMatrixMap;
     96   using ConstMatrixMaps = typename Base::ConstMatrixMaps;
     97 
     98   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
     99                      MatrixMaps* outputs) final {
    100     Eigen::HouseholderQR<Matrix> qr(inputs[0]);
    101     const int m = inputs[0].rows();
    102     const int n = inputs[0].cols();
    103     const int min_size = std::min(m, n);
    104 
    105     if (full_matrices_) {
    106       outputs->at(0) = qr.householderQ();
    107       outputs->at(1) = qr.matrixQR().template triangularView<Eigen::Upper>();
    108     } else {
    109       // TODO(jpoulson): Exploit the fact that Householder transformations can
    110       // be expanded faster than they can be applied to an arbitrary matrix
    111       // (Cf. LAPACK's DORGQR).
    112       Matrix tmp = Matrix::Identity(m, min_size);
    113       outputs->at(0) = qr.householderQ() * tmp;
    114       auto qr_top = qr.matrixQR().block(0, 0, min_size, n);
    115       outputs->at(1) = qr_top.template triangularView<Eigen::Upper>();
    116     }
    117   }
    118 
    119  private:
    120   bool full_matrices_;
    121 
    122   TF_DISALLOW_COPY_AND_ASSIGN(QrOp);
    123 };
    124 
    125 #if GOOGLE_CUDA
    126 
    127 typedef Eigen::GpuDevice GPUDevice;
    128 
    129 template <class Scalar>
    130 class QrOpGpu : public AsyncOpKernel {
    131  public:
    132   explicit QrOpGpu(OpKernelConstruction* context) : AsyncOpKernel(context) {
    133     OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_));
    134   }
    135 
    136   void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
    137     const Tensor& input = context->input(0);
    138     const int ndims = input.dims();
    139     const int64 m = input.dim_size(ndims - 2);
    140     const int64 n = input.dim_size(ndims - 1);
    141     const int64 min_size = std::min(m, n);
    142     const int64 batch_size =
    143         input.template flat_inner_dims<Scalar, 3>().dimension(0);
    144 
    145     // Validate inputs.
    146     OP_REQUIRES_ASYNC(
    147         context, ndims >= 2,
    148         errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
    149         done);
    150 
    151     // Allocate output.
    152     // If full_matrices_ is true then Q is m x m and R is m x n.
    153     // Otherwise, Q is m x min(m, n), and R is min(m, n) x n.
    154     Tensor* q;
    155     TensorShape q_shape = input.shape();
    156     q_shape.set_dim(ndims - 1, full_matrices_ ? m : min_size);
    157     OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, q_shape, &q),
    158                          done);
    159     Tensor* r;
    160     TensorShape r_shape = input.shape();
    161     r_shape.set_dim(ndims - 2, full_matrices_ ? m : min_size);
    162     OP_REQUIRES_OK_ASYNC(context, context->allocate_output(1, r_shape, &r),
    163                          done);
    164 
    165     if (input.NumElements() == 0) {
    166       done();
    167       return;
    168     }
    169 
    170     // TODO(rmlarsen): Convert to std::make_unique when available.
    171     std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
    172 
    173     // Allocate temporaries.
    174     Tensor input_transposed;
    175     TensorShape transposed_shape = input.shape();
    176     transposed_shape.set_dim(ndims - 2, input.dim_size(ndims - 1));
    177     transposed_shape.set_dim(ndims - 1, input.dim_size(ndims - 2));
    178 
    179     OP_REQUIRES_OK_ASYNC(
    180         context,
    181         solver->allocate_scoped_tensor(DataTypeToEnum<Scalar>::value,
    182                                        transposed_shape, &input_transposed),
    183         done);
    184 
    185     Tensor tau;
    186     OP_REQUIRES_OK_ASYNC(context,
    187                          solver->allocate_scoped_tensor(
    188                              DataTypeToEnum<Scalar>::value,
    189                              TensorShape({batch_size, min_size}), &tau),
    190                          done);
    191 
    192     // Transpose input, since cuSolver uses column-major, while TensorFlow uses
    193     // row-major storage.
    194     const GPUDevice& device = context->eigen_device<GPUDevice>();
    195     OP_REQUIRES_OK_ASYNC(
    196         context, DoMatrixTranspose(device, input, &input_transposed), done);
    197 
    198     // Compute QR decomposition in-place in input_transposed.
    199     std::vector<DeviceLapackInfo> dev_info;
    200     dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "geqrf"));
    201     auto input_transposed_reshaped =
    202         input_transposed.flat_inner_dims<Scalar, 3>();
    203     auto tau_matrix = tau.matrix<Scalar>();
    204     auto r_reshaped = r->flat_inner_dims<Scalar, 3>();
    205     for (int batch = 0; batch < batch_size; ++batch) {
    206       OP_REQUIRES_OK_ASYNC(
    207           context,
    208           solver->Geqrf(m, n, &input_transposed_reshaped(batch, 0, 0), m,
    209                         &tau_matrix(batch, 0),
    210                         dev_info.back().mutable_data() + batch),
    211           done);
    212     }
    213 
    214     // Generate R. R is equal to the upper triangle of the decomposition
    215     // stored in input_transposed. Crop, transpose (to get back to row-major)
    216     // and copy it to the output buffer.
    217     if (full_matrices_ || m == n) {
    218       OP_REQUIRES_OK_ASYNC(
    219           context, DoMatrixTranspose(device, input_transposed, r), done);
    220     } else {
    221       const Scalar alpha(1);
    222       const Scalar beta(0);
    223       const Scalar* dummy = nullptr;
    224       for (int batch = 0; batch < batch_size; ++batch) {
    225         OP_REQUIRES_OK_ASYNC(
    226             context,
    227             solver->Geam(CUBLAS_OP_T, CUBLAS_OP_N, n,
    228                          full_matrices_ ? m : min_size, &alpha,
    229                          &input_transposed_reshaped(batch, 0, 0), m, &beta,
    230                          dummy, n, &r_reshaped(batch, 0, 0), n),
    231             done);
    232       }
    233     }
    234     // Extract the upper triangle of r (i.e. zero out the strictly lower
    235     // triangle).
    236     functor::MatrixBandPartFunctor<GPUDevice, Scalar> band_part;
    237     auto r_reshaped_const =
    238         const_cast<const Tensor*>(r)->flat_inner_dims<Scalar, 3>();
    239     band_part(context, device, 0 /* num_lower_diags */,
    240               -1 /* num_upper_diags */, r_reshaped_const, r_reshaped);
    241 
    242     // Generate Q from the decomposition in input_transposed.
    243     if (m != n && (full_matrices_ || m < n)) {
    244       // Generate full m x m matrix Q by computing the product Q^T * I,
    245       // where the transpose is to get back to row-major form.
    246       // In the complex case we actually form Q^H * I and conjugate it
    247       // to get Q in row-major form.
    248       functor::EyeFunctor<GPUDevice, Scalar> eye;
    249       auto q_reshaped = q->flat_inner_dims<Scalar, 3>();
    250       eye(device, q_reshaped);
    251       for (int batch = 0; batch < batch_size; ++batch) {
    252         // Notice: It appears that Unmqr does not write a zero into *info upon
    253         // success (probably a bug), so we simply re-use the info array already
    254         // zeroed by Geqrf above.
    255         OP_REQUIRES_OK_ASYNC(
    256             context,
    257             solver->Unmqr(CUBLAS_SIDE_LEFT, CublasAdjointOp<Scalar>(), m, m,
    258                           min_size, &input_transposed_reshaped(batch, 0, 0), m,
    259                           &tau_matrix(batch, 0), &q_reshaped(batch, 0, 0), m,
    260                           dev_info.back().mutable_data() + batch),
    261             done);
    262       }
    263       if (Eigen::NumTraits<Scalar>::IsComplex) {
    264         functor::UnaryFunctor<GPUDevice, functor::conj<Scalar>> conj;
    265         conj(device, q->flat<Scalar>() /*out*/,
    266              const_cast<const Tensor*>(q)->flat<Scalar>() /*in*/);
    267       }
    268     } else {
    269       // Generate m x n matrix Q. In this case we can use the more efficient
    270       // algorithm in Ungqr to generate Q in place.
    271       dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "orgqr"));
    272       for (int batch = 0; batch < batch_size; ++batch) {
    273         OP_REQUIRES_OK_ASYNC(
    274             context,
    275             solver->Ungqr(
    276                 m, n, min_size, &input_transposed_reshaped(batch, 0, 0), m,
    277                 &tau_matrix(batch, 0), dev_info.back().mutable_data() + batch),
    278             done);
    279       }
    280       OP_REQUIRES_OK_ASYNC(
    281           context, DoMatrixTranspose(device, input_transposed, q), done);
    282     }
    283 
    284     // Asynchronously check return status from cuSolver kernels.
    285     CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
    286                                                     std::move(done));
    287   }
    288 
    289  private:
    290   bool full_matrices_;
    291 
    292   TF_DISALLOW_COPY_AND_ASSIGN(QrOpGpu);
    293 };
    294 
    295 #endif
    296 
    297 }  // namespace tensorflow
    298