1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/linalg_ops.cc. 17 18 #include "third_party/eigen3/Eigen/Core" 19 #include "tensorflow/core/framework/kernel_def_builder.h" 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/tensor_shape.h" 22 #include "tensorflow/core/kernels/linalg_ops_common.h" 23 #include "tensorflow/core/lib/core/errors.h" 24 #include "tensorflow/core/platform/logging.h" 25 #include "tensorflow/core/platform/macros.h" 26 #include "tensorflow/core/platform/types.h" 27 28 #if GOOGLE_CUDA 29 #include "tensorflow/core/platform/stream_executor.h" 30 #endif // GOOGLE_CUDA 31 32 namespace tensorflow { 33 34 #if GOOGLE_CUDA 35 namespace { 36 template <typename Scalar> 37 perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory( 38 const Scalar* cuda_memory) { 39 perftools::gputools::DeviceMemoryBase wrapped( 40 const_cast<Scalar*>(cuda_memory)); 41 perftools::gputools::DeviceMemory<Scalar> typed(wrapped); 42 return typed; 43 } 44 } // namespace 45 #endif // GOOGLE_CUDA 46 47 template <class Scalar> 48 class MatrixTriangularSolveOp : public LinearAlgebraOp<Scalar> { 49 public: 50 INHERIT_LINALG_TYPEDEFS(Scalar); 51 52 explicit MatrixTriangularSolveOp(OpKernelConstruction* context) 53 : Base(context), lower_(true), adjoint_(false) { 54 OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_)); 55 OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); 56 } 57 58 void ValidateInputMatrixShapes( 59 OpKernelContext* context, 60 const TensorShapes& input_matrix_shapes) const final { 61 Base::ValidateSquareSolver(context, input_matrix_shapes); 62 } 63 64 TensorShapes GetOutputMatrixShapes( 65 const TensorShapes& input_matrix_shapes) const final { 66 return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1), 67 input_matrix_shapes[1].dim_size(1)})}); 68 } 69 70 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { 71 double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0)); 72 double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1)); 73 double cost = rows * rows * num_rhss * 74 (Eigen::TensorOpCost::AddCost<Scalar>() + 75 Eigen::TensorOpCost::MulCost<Scalar>()); 76 return cost >= static_cast<double>(kint64max) ? kint64max 77 : static_cast<int64>(cost); 78 } 79 80 bool EnableInputForwarding() const final { return false; } 81 82 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, 83 MatrixMaps* outputs) final { 84 const ConstMatrixMap& matrix = inputs[0]; 85 const ConstMatrixMap& rhs = inputs[1]; 86 MatrixMap& output = outputs->at(0); 87 88 if (matrix.rows() == 0 || rhs.cols() == 0) { 89 // To be consistent with the MatrixInverse op, we define the solution for 90 // an empty set of equation as the empty matrix. 91 return; 92 } 93 const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff(); 94 OP_REQUIRES(context, min_abs_pivot > RealScalar(0), 95 errors::InvalidArgument("Input matrix is not invertible.")); 96 if (lower_) { 97 auto triangle = matrix.template triangularView<Eigen::Lower>(); 98 if (adjoint_) { 99 output.noalias() = triangle.adjoint().solve(rhs); 100 } else { 101 output.noalias() = triangle.solve(rhs); 102 } 103 } else { 104 auto triangle = matrix.template triangularView<Eigen::Upper>(); 105 if (adjoint_) { 106 output.noalias() = triangle.adjoint().solve(rhs); 107 } else { 108 output.noalias() = triangle.solve(rhs); 109 } 110 } 111 } 112 113 private: 114 bool lower_; 115 bool adjoint_; 116 117 TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOp); 118 }; 119 120 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve", 121 (MatrixTriangularSolveOp<float>), float); 122 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve", 123 (MatrixTriangularSolveOp<double>), double); 124 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve", 125 (MatrixTriangularSolveOp<complex64>), complex64); 126 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve", 127 (MatrixTriangularSolveOp<complex128>), complex128); 128 REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve", 129 (MatrixTriangularSolveOp<float>), float); 130 REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve", 131 (MatrixTriangularSolveOp<double>), double); 132 133 #ifdef GOOGLE_CUDA 134 135 // TODO(rmlarsen): Re-factor to 136 // 1. Enable buffer forwarding from rhs->out. 137 // 2. Save Memcpy when buffer forwarding is used. 138 // 3. Copy entire rhs in a single Memcpy when forwarding is not used. 139 template <class Scalar> 140 class MatrixTriangularSolveOpGPU : public LinearAlgebraOp<Scalar> { 141 public: 142 INHERIT_LINALG_TYPEDEFS(Scalar); 143 144 explicit MatrixTriangularSolveOpGPU(OpKernelConstruction* context) 145 : Base(context), lower_(true), adjoint_(false) { 146 OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_)); 147 OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); 148 } 149 150 void ValidateInputMatrixShapes( 151 OpKernelContext* context, 152 const TensorShapes& input_matrix_shapes) const final { 153 Base::ValidateSquareSolver(context, input_matrix_shapes); 154 } 155 156 TensorShapes GetOutputMatrixShapes( 157 const TensorShapes& input_matrix_shapes) const final { 158 return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1), 159 input_matrix_shapes[1].dim_size(1)})}); 160 } 161 162 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { 163 double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0)); 164 double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1)); 165 double cost = rows * rows * num_rhss * 166 (Eigen::TensorOpCost::AddCost<Scalar>() + 167 Eigen::TensorOpCost::MulCost<Scalar>()); 168 return cost >= static_cast<double>(kint64max) ? kint64max 169 : static_cast<int64>(cost); 170 } 171 172 bool EnableInputForwarding() const final { return false; } 173 174 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, 175 MatrixMaps* outputs) final { 176 const ConstMatrixMap& matrix = inputs[0]; 177 const ConstMatrixMap& rhs = inputs[1]; 178 MatrixMap& output = outputs->at(0); 179 180 if (matrix.rows() == 0 || rhs.cols() == 0) { 181 // To be consistent with the MatrixInverse op, we define the solution for 182 // an empty set of equation as the empty matrix. 183 return; 184 } 185 186 auto matrix_ptr = AsDeviceMemory(matrix.data()); 187 auto rhs_ptr = AsDeviceMemory(rhs.data()); 188 auto out_ptr = AsDeviceMemory(output.data()); 189 190 auto* stream = context->op_device_context()->stream(); 191 uint64 rhs_elems = rhs.rows() * rhs.cols(); 192 bool copy_status = 193 stream->ThenMemcpyD2D(&out_ptr, rhs_ptr, sizeof(Scalar) * rhs_elems) 194 .ok(); 195 if (!copy_status) { 196 context->SetStatus( 197 errors::Internal("Failed to copy rhs into output before solve")); 198 } 199 200 // Cublas does 201 // output = matrix \ rhs 202 // where matrix, rhs and output are assumed to be in column major. 203 // We want the output to be in row-major, so we can compute 204 // output' = rhs' / matrix' (' stands for transpose) 205 // Upper/lower needs to be swapped for this. 206 207 perftools::gputools::blas::UpperLower upper_lower_matrix; 208 perftools::gputools::blas::Transpose transpose_matrix; 209 if (lower_) { 210 upper_lower_matrix = perftools::gputools::blas::UpperLower::kUpper; 211 } else { 212 upper_lower_matrix = perftools::gputools::blas::UpperLower::kLower; 213 } 214 if (adjoint_) { 215 transpose_matrix = 216 perftools::gputools::blas::Transpose::kConjugateTranspose; 217 } else { 218 transpose_matrix = perftools::gputools::blas::Transpose::kNoTranspose; 219 } 220 uint64 leading_dim_matrix = matrix.cols(); 221 uint64 leading_dim_output = output.cols(); 222 uint64 colmajor_rows = output.cols(); 223 uint64 colmajor_cols = output.rows(); 224 bool blas_launch_status = 225 stream 226 ->ThenBlasTrsm( 227 perftools::gputools::blas::Side::kRight /*side*/, 228 upper_lower_matrix /*uplo*/, transpose_matrix /*trans*/, 229 perftools::gputools::blas::Diagonal::kNonUnit /*diag*/, 230 colmajor_rows /*m*/, colmajor_cols /*n*/, Scalar(1.0) /*alpha*/, 231 matrix_ptr, leading_dim_matrix /*lda*/, &out_ptr, 232 leading_dim_output /*ldb*/) 233 .ok(); 234 if (!blas_launch_status) { 235 context->SetStatus(errors::Internal("Blas TRSM launch failed")); 236 } 237 } 238 239 private: 240 bool lower_; 241 bool adjoint_; 242 243 TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOpGPU); 244 }; 245 246 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve", 247 (MatrixTriangularSolveOpGPU<float>), float); 248 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve", 249 (MatrixTriangularSolveOpGPU<double>), double); 250 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve", 251 (MatrixTriangularSolveOpGPU<complex64>), complex64); 252 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve", 253 (MatrixTriangularSolveOpGPU<complex128>), complex128); 254 REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve", 255 (MatrixTriangularSolveOpGPU<float>), float); 256 REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve", 257 (MatrixTriangularSolveOpGPU<double>), double); 258 259 #endif // GOOGLE_CUDA 260 261 } // namespace tensorflow 262