Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/linalg_ops.cc.
     17 
     18 #include "third_party/eigen3/Eigen/Core"
     19 #include "tensorflow/core/framework/kernel_def_builder.h"
     20 #include "tensorflow/core/framework/op_kernel.h"
     21 #include "tensorflow/core/framework/tensor_shape.h"
     22 #include "tensorflow/core/kernels/linalg_ops_common.h"
     23 #include "tensorflow/core/lib/core/errors.h"
     24 #include "tensorflow/core/platform/logging.h"
     25 #include "tensorflow/core/platform/macros.h"
     26 #include "tensorflow/core/platform/types.h"
     27 
     28 #if GOOGLE_CUDA
     29 #include "tensorflow/core/platform/stream_executor.h"
     30 #endif  // GOOGLE_CUDA
     31 
     32 namespace tensorflow {
     33 
     34 #if GOOGLE_CUDA
     35 namespace {
     36 template <typename Scalar>
     37 perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory(
     38     const Scalar* cuda_memory) {
     39   perftools::gputools::DeviceMemoryBase wrapped(
     40       const_cast<Scalar*>(cuda_memory));
     41   perftools::gputools::DeviceMemory<Scalar> typed(wrapped);
     42   return typed;
     43 }
     44 }  // namespace
     45 #endif  // GOOGLE_CUDA
     46 
     47 template <class Scalar>
     48 class MatrixTriangularSolveOp : public LinearAlgebraOp<Scalar> {
     49  public:
     50   INHERIT_LINALG_TYPEDEFS(Scalar);
     51 
     52   explicit MatrixTriangularSolveOp(OpKernelConstruction* context)
     53       : Base(context), lower_(true), adjoint_(false) {
     54     OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
     55     OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
     56   }
     57 
     58   void ValidateInputMatrixShapes(
     59       OpKernelContext* context,
     60       const TensorShapes& input_matrix_shapes) const final {
     61     Base::ValidateSquareSolver(context, input_matrix_shapes);
     62   }
     63 
     64   TensorShapes GetOutputMatrixShapes(
     65       const TensorShapes& input_matrix_shapes) const final {
     66     return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
     67                                       input_matrix_shapes[1].dim_size(1)})});
     68   }
     69 
     70   int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
     71     double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
     72     double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
     73     double cost = rows * rows * num_rhss *
     74                   (Eigen::TensorOpCost::AddCost<Scalar>() +
     75                    Eigen::TensorOpCost::MulCost<Scalar>());
     76     return cost >= static_cast<double>(kint64max) ? kint64max
     77                                                   : static_cast<int64>(cost);
     78   }
     79 
     80   bool EnableInputForwarding() const final { return false; }
     81 
     82   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
     83                      MatrixMaps* outputs) final {
     84     const ConstMatrixMap& matrix = inputs[0];
     85     const ConstMatrixMap& rhs = inputs[1];
     86     MatrixMap& output = outputs->at(0);
     87 
     88     if (matrix.rows() == 0 || rhs.cols() == 0) {
     89       // To be consistent with the MatrixInverse op, we define the solution for
     90       // an empty set of equation as the empty matrix.
     91       return;
     92     }
     93     const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
     94     OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
     95                 errors::InvalidArgument("Input matrix is not invertible."));
     96     if (lower_) {
     97       auto triangle = matrix.template triangularView<Eigen::Lower>();
     98       if (adjoint_) {
     99         output.noalias() = triangle.adjoint().solve(rhs);
    100       } else {
    101         output.noalias() = triangle.solve(rhs);
    102       }
    103     } else {
    104       auto triangle = matrix.template triangularView<Eigen::Upper>();
    105       if (adjoint_) {
    106         output.noalias() = triangle.adjoint().solve(rhs);
    107       } else {
    108         output.noalias() = triangle.solve(rhs);
    109       }
    110     }
    111   }
    112 
    113  private:
    114   bool lower_;
    115   bool adjoint_;
    116 
    117   TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOp);
    118 };
    119 
    120 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
    121                        (MatrixTriangularSolveOp<float>), float);
    122 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
    123                        (MatrixTriangularSolveOp<double>), double);
    124 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
    125                        (MatrixTriangularSolveOp<complex64>), complex64);
    126 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
    127                        (MatrixTriangularSolveOp<complex128>), complex128);
    128 REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
    129                        (MatrixTriangularSolveOp<float>), float);
    130 REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
    131                        (MatrixTriangularSolveOp<double>), double);
    132 
    133 #ifdef GOOGLE_CUDA
    134 
    135 // TODO(rmlarsen): Re-factor to
    136 // 1. Enable buffer forwarding from rhs->out.
    137 // 2. Save Memcpy when buffer forwarding is used.
    138 // 3. Copy entire rhs in a single Memcpy when forwarding is not used.
    139 template <class Scalar>
    140 class MatrixTriangularSolveOpGPU : public LinearAlgebraOp<Scalar> {
    141  public:
    142   INHERIT_LINALG_TYPEDEFS(Scalar);
    143 
    144   explicit MatrixTriangularSolveOpGPU(OpKernelConstruction* context)
    145       : Base(context), lower_(true), adjoint_(false) {
    146     OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
    147     OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
    148   }
    149 
    150   void ValidateInputMatrixShapes(
    151       OpKernelContext* context,
    152       const TensorShapes& input_matrix_shapes) const final {
    153     Base::ValidateSquareSolver(context, input_matrix_shapes);
    154   }
    155 
    156   TensorShapes GetOutputMatrixShapes(
    157       const TensorShapes& input_matrix_shapes) const final {
    158     return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
    159                                       input_matrix_shapes[1].dim_size(1)})});
    160   }
    161 
    162   int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
    163     double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
    164     double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
    165     double cost = rows * rows * num_rhss *
    166                   (Eigen::TensorOpCost::AddCost<Scalar>() +
    167                    Eigen::TensorOpCost::MulCost<Scalar>());
    168     return cost >= static_cast<double>(kint64max) ? kint64max
    169                                                   : static_cast<int64>(cost);
    170   }
    171 
    172   bool EnableInputForwarding() const final { return false; }
    173 
    174   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
    175                      MatrixMaps* outputs) final {
    176     const ConstMatrixMap& matrix = inputs[0];
    177     const ConstMatrixMap& rhs = inputs[1];
    178     MatrixMap& output = outputs->at(0);
    179 
    180     if (matrix.rows() == 0 || rhs.cols() == 0) {
    181       // To be consistent with the MatrixInverse op, we define the solution for
    182       // an empty set of equation as the empty matrix.
    183       return;
    184     }
    185 
    186     auto matrix_ptr = AsDeviceMemory(matrix.data());
    187     auto rhs_ptr = AsDeviceMemory(rhs.data());
    188     auto out_ptr = AsDeviceMemory(output.data());
    189 
    190     auto* stream = context->op_device_context()->stream();
    191     uint64 rhs_elems = rhs.rows() * rhs.cols();
    192     bool copy_status =
    193         stream->ThenMemcpyD2D(&out_ptr, rhs_ptr, sizeof(Scalar) * rhs_elems)
    194             .ok();
    195     if (!copy_status) {
    196       context->SetStatus(
    197           errors::Internal("Failed to copy rhs into output before solve"));
    198     }
    199 
    200     // Cublas does
    201     // output = matrix \ rhs
    202     // where matrix, rhs and output are assumed to be in column major.
    203     // We want the output to be in row-major, so we can compute
    204     // output' = rhs' / matrix' (' stands for transpose)
    205     // Upper/lower needs to be swapped for this.
    206 
    207     perftools::gputools::blas::UpperLower upper_lower_matrix;
    208     perftools::gputools::blas::Transpose transpose_matrix;
    209     if (lower_) {
    210       upper_lower_matrix = perftools::gputools::blas::UpperLower::kUpper;
    211     } else {
    212       upper_lower_matrix = perftools::gputools::blas::UpperLower::kLower;
    213     }
    214     if (adjoint_) {
    215       transpose_matrix =
    216           perftools::gputools::blas::Transpose::kConjugateTranspose;
    217     } else {
    218       transpose_matrix = perftools::gputools::blas::Transpose::kNoTranspose;
    219     }
    220     uint64 leading_dim_matrix = matrix.cols();
    221     uint64 leading_dim_output = output.cols();
    222     uint64 colmajor_rows = output.cols();
    223     uint64 colmajor_cols = output.rows();
    224     bool blas_launch_status =
    225         stream
    226             ->ThenBlasTrsm(
    227                 perftools::gputools::blas::Side::kRight /*side*/,
    228                 upper_lower_matrix /*uplo*/, transpose_matrix /*trans*/,
    229                 perftools::gputools::blas::Diagonal::kNonUnit /*diag*/,
    230                 colmajor_rows /*m*/, colmajor_cols /*n*/, Scalar(1.0) /*alpha*/,
    231                 matrix_ptr, leading_dim_matrix /*lda*/, &out_ptr,
    232                 leading_dim_output /*ldb*/)
    233             .ok();
    234     if (!blas_launch_status) {
    235       context->SetStatus(errors::Internal("Blas TRSM launch failed"));
    236     }
    237   }
    238 
    239  private:
    240   bool lower_;
    241   bool adjoint_;
    242 
    243   TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOpGPU);
    244 };
    245 
    246 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
    247                        (MatrixTriangularSolveOpGPU<float>), float);
    248 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
    249                        (MatrixTriangularSolveOpGPU<double>), double);
    250 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
    251                        (MatrixTriangularSolveOpGPU<complex64>), complex64);
    252 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
    253                        (MatrixTriangularSolveOpGPU<complex128>), complex128);
    254 REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
    255                        (MatrixTriangularSolveOpGPU<float>), float);
    256 REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
    257                        (MatrixTriangularSolveOpGPU<double>), double);
    258 
    259 #endif  // GOOGLE_CUDA
    260 
    261 }  // namespace tensorflow
    262