Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/linalg_ops.cc.
     17 
     18 #include "third_party/eigen3/Eigen/Cholesky"
     19 #include "third_party/eigen3/Eigen/Core"
     20 #include "third_party/eigen3/Eigen/QR"
     21 #include "tensorflow/core/framework/kernel_def_builder.h"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/tensor_shape.h"
     24 #include "tensorflow/core/kernels/linalg_ops_common.h"
     25 #include "tensorflow/core/lib/core/errors.h"
     26 #include "tensorflow/core/platform/logging.h"
     27 #include "tensorflow/core/platform/types.h"
     28 
     29 namespace tensorflow {
     30 
     31 template <class Scalar>
     32 class MatrixSolveLsOp : public LinearAlgebraOp<Scalar> {
     33  public:
     34   typedef LinearAlgebraOp<Scalar> Base;
     35 
     36   explicit MatrixSolveLsOp(OpKernelConstruction* context) : Base(context) {
     37     OP_REQUIRES_OK(context, context->GetAttr("fast", &fast_));
     38   }
     39 
     40   using TensorShapes = typename Base::TensorShapes;
     41   using Matrix = typename Base::Matrix;
     42   using MatrixMaps = typename Base::MatrixMaps;
     43   using ConstMatrixMap = typename Base::ConstMatrixMap;
     44   using ConstMatrixMaps = typename Base::ConstMatrixMaps;
     45 
     46   // Tell the base class to ignore the regularization parameter
     47   // in context->input(2).
     48   int NumMatrixInputs(const OpKernelContext* context) const final { return 2; }
     49 
     50   void ValidateInputMatrixShapes(
     51       OpKernelContext* context,
     52       const TensorShapes& input_matrix_shapes) const final {
     53     Base::ValidateSolver(context, input_matrix_shapes);
     54   }
     55 
     56   TensorShapes GetOutputMatrixShapes(
     57       const TensorShapes& input_matrix_shapes) const final {
     58     return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
     59                                       input_matrix_shapes[1].dim_size(1)})});
     60   }
     61 
     62   int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
     63     double m = static_cast<double>(input_matrix_shapes[0].dim_size(0));
     64     double n = static_cast<double>(input_matrix_shapes[0].dim_size(1));
     65     double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
     66     double cost = std::max(m, n) * std::min(m, n) * (std::min(m, n) + num_rhss);
     67     return cost >= static_cast<double>(kint64max) ? kint64max
     68                                                   : static_cast<int64>(cost);
     69   }
     70 
     71   bool EnableInputForwarding() const final { return false; }
     72 
     73   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
     74                      MatrixMaps* outputs) final {
     75     const ConstMatrixMap& matrix = inputs[0];
     76     const ConstMatrixMap& rhs = inputs[1];
     77     const auto& l2_regularizer_in = context->input(2);
     78     OP_REQUIRES(
     79         context, TensorShapeUtils::IsScalar(l2_regularizer_in.shape()),
     80         errors::InvalidArgument("l2_regularizer must be scalar, got shape ",
     81                                 l2_regularizer_in.shape().DebugString()));
     82     const double l2_regularizer = l2_regularizer_in.scalar<double>()();
     83     OP_REQUIRES(context, l2_regularizer >= 0,
     84                 errors::InvalidArgument("l2_regularizer must be >= 0."));
     85 
     86     const int64 rows = matrix.rows();
     87     const int64 cols = matrix.cols();
     88     if (rows == 0 || cols == 0) {
     89       // The result is the empty matrix.
     90       return;
     91     }
     92     if (fast_) {
     93       // The fast branch assumes that matrix is not rank deficient and
     94       // not too ill-conditioned. Specifically, the reciprocal condition number
     95       // should be greater than the square root of the machine precision, i.e.
     96       //   1 / cond(matrix) > sqrt(std::numeric_limits<Scalar>::epsilon()).
     97       // This branch solves over- or underdetermined least-squares problems
     98       // via the normal equations and Cholesky decomposition.
     99       if (matrix.rows() >= matrix.cols()) {
    100         // Overdetermined case (rows >= cols): Solves the ordinary (possibly
    101         // regularized) least-squares problem
    102         //   min || A * X - RHS ||_F^2 + l2_regularizer ||X||_F^2
    103         // by solving the normal equations
    104         //    (A^T * A + l2_regularizer * I) X = A^T RHS
    105         // using Cholesky decomposition.
    106         Matrix gramian(cols, cols);
    107         gramian.template triangularView<Eigen::Lower>() =
    108             matrix.adjoint() * matrix;
    109         if (l2_regularizer > 0) {
    110           gramian +=
    111               (Scalar(l2_regularizer) * Matrix::Ones(cols, 1)).asDiagonal();
    112         }
    113         const Eigen::LLT<Eigen::Ref<Matrix>, Eigen::Lower> llt(gramian);
    114         OP_REQUIRES(
    115             context, llt.info() == Eigen::Success,
    116             errors::InvalidArgument("Input matrix was rank deficient or "
    117                                     "ill-conditioned. Try setting fast=False "
    118                                     "or provide a larger l2_regularizer > 0."));
    119         outputs->at(0).noalias() = matrix.adjoint() * rhs;
    120         llt.solveInPlace(outputs->at(0));
    121       } else {
    122         // Underdetermined case (rows < cols): Solves the minimum-norm problem
    123         //   min ||X||_F^2 s.t. A*X = RHS
    124         // by solving the normal equations of the second kind
    125         //   (A * A^T + l2_regularizer * I) Z = RHS,  X = A^T * Z
    126         // using Cholesky decomposition.
    127         Matrix gramian(rows, rows);
    128         gramian.template triangularView<Eigen::Lower>() =
    129             matrix * matrix.adjoint();
    130         if (l2_regularizer > 0) {
    131           gramian +=
    132               (Scalar(l2_regularizer) * Matrix::Ones(rows, 1)).asDiagonal();
    133         }
    134         const Eigen::LLT<Eigen::Ref<Matrix>, Eigen::Lower> llt(gramian);
    135         OP_REQUIRES(
    136             context, llt.info() == Eigen::Success,
    137             errors::InvalidArgument("Input matrix was rank deficient or "
    138                                     "ill-conditioned. Try setting fast=False "
    139                                     "or provide an l2_regularizer > 0."));
    140         outputs->at(0).noalias() = matrix.adjoint() * llt.solve(rhs);
    141       }
    142     } else {
    143       // Use complete orthogonal decomposition which is backwards stable and
    144       // will compute the minimum-norm solution for rank-deficient matrices.
    145       // This is 6-7 times slower than the fast path.
    146       //
    147       // TODO(rmlarsen): The implementation of
    148       //   Eigen::CompleteOrthogonalDecomposition is not blocked, so for
    149       //   matrices that do not fit in cache, it is significantly slower than
    150       //   the equivalent blocked LAPACK routine xGELSY (e.g. Eigen is ~3x
    151       //   slower for 4k x 4k matrices).
    152       //   See http://www.netlib.org/lapack/lawnspdf/lawn114.pdf
    153       outputs->at(0) = matrix.completeOrthogonalDecomposition().solve(rhs);
    154     }
    155   }
    156 
    157  private:
    158   bool fast_;
    159 };
    160 
    161 }  // namespace tensorflow
    162