Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/linalg_ops.cc.
     17 
     18 #include "third_party/eigen3/Eigen/Core"
     19 #include "tensorflow/core/framework/kernel_def_builder.h"
     20 #include "tensorflow/core/framework/op_kernel.h"
     21 #include "tensorflow/core/framework/tensor_shape.h"
     22 #include "tensorflow/core/kernels/linalg_ops_common.h"
     23 #include "tensorflow/core/lib/core/errors.h"
     24 #include "tensorflow/core/platform/logging.h"
     25 #include "tensorflow/core/platform/macros.h"
     26 #include "tensorflow/core/platform/types.h"
     27 
     28 #if GOOGLE_CUDA
     29 #include "tensorflow/core/platform/stream_executor.h"
     30 #endif  // GOOGLE_CUDA
     31 
     32 namespace tensorflow {
     33 
     34 #if GOOGLE_CUDA
     35 namespace {
     36 template <typename Scalar>
     37 se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* cuda_memory) {
     38   se::DeviceMemoryBase wrapped(const_cast<Scalar*>(cuda_memory));
     39   se::DeviceMemory<Scalar> typed(wrapped);
     40   return typed;
     41 }
     42 }  // namespace
     43 #endif  // GOOGLE_CUDA
     44 
     45 template <class Scalar>
     46 class MatrixTriangularSolveOp : public LinearAlgebraOp<Scalar> {
     47  public:
     48   INHERIT_LINALG_TYPEDEFS(Scalar);
     49 
     50   explicit MatrixTriangularSolveOp(OpKernelConstruction* context)
     51       : Base(context), lower_(true), adjoint_(false) {
     52     OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
     53     OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
     54   }
     55 
     56   void ValidateInputMatrixShapes(
     57       OpKernelContext* context,
     58       const TensorShapes& input_matrix_shapes) const final {
     59     Base::ValidateSquareSolver(context, input_matrix_shapes);
     60   }
     61 
     62   TensorShapes GetOutputMatrixShapes(
     63       const TensorShapes& input_matrix_shapes) const final {
     64     return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
     65                                       input_matrix_shapes[1].dim_size(1)})});
     66   }
     67 
     68   int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
     69     double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
     70     double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
     71     double cost = rows * rows * num_rhss *
     72                   (Eigen::TensorOpCost::AddCost<Scalar>() +
     73                    Eigen::TensorOpCost::MulCost<Scalar>());
     74     return cost >= static_cast<double>(kint64max) ? kint64max
     75                                                   : static_cast<int64>(cost);
     76   }
     77 
     78   bool EnableInputForwarding() const final { return false; }
     79 
     80   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
     81                      MatrixMaps* outputs) final {
     82     const ConstMatrixMap& matrix = inputs[0];
     83     const ConstMatrixMap& rhs = inputs[1];
     84     MatrixMap& output = outputs->at(0);
     85 
     86     if (matrix.rows() == 0 || rhs.cols() == 0) {
     87       // To be consistent with the MatrixInverse op, we define the solution for
     88       // an empty set of equation as the empty matrix.
     89       return;
     90     }
     91     const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
     92     OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
     93                 errors::InvalidArgument("Input matrix is not invertible."));
     94     if (lower_) {
     95       auto triangle = matrix.template triangularView<Eigen::Lower>();
     96       if (adjoint_) {
     97         output.noalias() = triangle.adjoint().solve(rhs);
     98       } else {
     99         output.noalias() = triangle.solve(rhs);
    100       }
    101     } else {
    102       auto triangle = matrix.template triangularView<Eigen::Upper>();
    103       if (adjoint_) {
    104         output.noalias() = triangle.adjoint().solve(rhs);
    105       } else {
    106         output.noalias() = triangle.solve(rhs);
    107       }
    108     }
    109   }
    110 
    111  private:
    112   bool lower_;
    113   bool adjoint_;
    114 
    115   TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOp);
    116 };
    117 
    118 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
    119                        (MatrixTriangularSolveOp<float>), float);
    120 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
    121                        (MatrixTriangularSolveOp<double>), double);
    122 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
    123                        (MatrixTriangularSolveOp<complex64>), complex64);
    124 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
    125                        (MatrixTriangularSolveOp<complex128>), complex128);
    126 REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
    127                        (MatrixTriangularSolveOp<float>), float);
    128 REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
    129                        (MatrixTriangularSolveOp<double>), double);
    130 
    131 #ifdef GOOGLE_CUDA
    132 
    133 // TODO(rmlarsen): Re-factor to
    134 // 1. Enable buffer forwarding from rhs->out.
    135 // 2. Save Memcpy when buffer forwarding is used.
    136 // 3. Copy entire rhs in a single Memcpy when forwarding is not used.
    137 template <class Scalar>
    138 class MatrixTriangularSolveOpGPU : public LinearAlgebraOp<Scalar> {
    139  public:
    140   INHERIT_LINALG_TYPEDEFS(Scalar);
    141 
    142   explicit MatrixTriangularSolveOpGPU(OpKernelConstruction* context)
    143       : Base(context), lower_(true), adjoint_(false) {
    144     OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_));
    145     OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_));
    146   }
    147 
    148   void ValidateInputMatrixShapes(
    149       OpKernelContext* context,
    150       const TensorShapes& input_matrix_shapes) const final {
    151     Base::ValidateSquareSolver(context, input_matrix_shapes);
    152   }
    153 
    154   TensorShapes GetOutputMatrixShapes(
    155       const TensorShapes& input_matrix_shapes) const final {
    156     return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1),
    157                                       input_matrix_shapes[1].dim_size(1)})});
    158   }
    159 
    160   int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final {
    161     double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
    162     double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
    163     double cost = rows * rows * num_rhss *
    164                   (Eigen::TensorOpCost::AddCost<Scalar>() +
    165                    Eigen::TensorOpCost::MulCost<Scalar>());
    166     return cost >= static_cast<double>(kint64max) ? kint64max
    167                                                   : static_cast<int64>(cost);
    168   }
    169 
    170   bool EnableInputForwarding() const final { return false; }
    171 
    172   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
    173                      MatrixMaps* outputs) final {
    174     const ConstMatrixMap& matrix = inputs[0];
    175     const ConstMatrixMap& rhs = inputs[1];
    176     MatrixMap& output = outputs->at(0);
    177 
    178     if (matrix.rows() == 0 || rhs.cols() == 0) {
    179       // To be consistent with the MatrixInverse op, we define the solution for
    180       // an empty set of equation as the empty matrix.
    181       return;
    182     }
    183 
    184     auto matrix_ptr = AsDeviceMemory(matrix.data());
    185     auto rhs_ptr = AsDeviceMemory(rhs.data());
    186     auto out_ptr = AsDeviceMemory(output.data());
    187 
    188     auto* stream = context->op_device_context()->stream();
    189     uint64 rhs_elems = rhs.rows() * rhs.cols();
    190     bool copy_status =
    191         stream->ThenMemcpyD2D(&out_ptr, rhs_ptr, sizeof(Scalar) * rhs_elems)
    192             .ok();
    193     if (!copy_status) {
    194       context->SetStatus(
    195           errors::Internal("Failed to copy rhs into output before solve"));
    196     }
    197 
    198     // Cublas does
    199     // output = matrix \ rhs
    200     // where matrix, rhs and output are assumed to be in column major.
    201     // We want the output to be in row-major, so we can compute
    202     // output' = rhs' / matrix' (' stands for transpose)
    203     // Upper/lower needs to be swapped for this.
    204 
    205     se::blas::UpperLower upper_lower_matrix;
    206     se::blas::Transpose transpose_matrix;
    207     if (lower_) {
    208       upper_lower_matrix = se::blas::UpperLower::kUpper;
    209     } else {
    210       upper_lower_matrix = se::blas::UpperLower::kLower;
    211     }
    212     if (adjoint_) {
    213       transpose_matrix = se::blas::Transpose::kConjugateTranspose;
    214     } else {
    215       transpose_matrix = se::blas::Transpose::kNoTranspose;
    216     }
    217     uint64 leading_dim_matrix = matrix.cols();
    218     uint64 leading_dim_output = output.cols();
    219     uint64 colmajor_rows = output.cols();
    220     uint64 colmajor_cols = output.rows();
    221     bool blas_launch_status =
    222         stream
    223             ->ThenBlasTrsm(
    224                 se::blas::Side::kRight /*side*/, upper_lower_matrix /*uplo*/,
    225                 transpose_matrix /*trans*/,
    226                 se::blas::Diagonal::kNonUnit /*diag*/, colmajor_rows /*m*/,
    227                 colmajor_cols /*n*/, Scalar(1.0) /*alpha*/, matrix_ptr,
    228                 leading_dim_matrix /*lda*/, &out_ptr,
    229                 leading_dim_output /*ldb*/)
    230             .ok();
    231     if (!blas_launch_status) {
    232       context->SetStatus(errors::Internal("Blas TRSM launch failed"));
    233     }
    234   }
    235 
    236  private:
    237   bool lower_;
    238   bool adjoint_;
    239 
    240   TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOpGPU);
    241 };
    242 
    243 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
    244                        (MatrixTriangularSolveOpGPU<float>), float);
    245 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
    246                        (MatrixTriangularSolveOpGPU<double>), double);
    247 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
    248                        (MatrixTriangularSolveOpGPU<complex64>), complex64);
    249 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
    250                        (MatrixTriangularSolveOpGPU<complex128>), complex128);
    251 REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
    252                        (MatrixTriangularSolveOpGPU<float>), float);
    253 REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
    254                        (MatrixTriangularSolveOpGPU<double>), double);
    255 
    256 #endif  // GOOGLE_CUDA
    257 
    258 }  // namespace tensorflow
    259