1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/linalg_ops.cc. 17 18 #include "third_party/eigen3/Eigen/Core" 19 #include "tensorflow/core/framework/kernel_def_builder.h" 20 #include "tensorflow/core/framework/op_kernel.h" 21 #include "tensorflow/core/framework/tensor_shape.h" 22 #include "tensorflow/core/kernels/linalg_ops_common.h" 23 #include "tensorflow/core/lib/core/errors.h" 24 #include "tensorflow/core/platform/logging.h" 25 #include "tensorflow/core/platform/macros.h" 26 #include "tensorflow/core/platform/types.h" 27 28 #if GOOGLE_CUDA 29 #include "tensorflow/core/platform/stream_executor.h" 30 #endif // GOOGLE_CUDA 31 32 namespace tensorflow { 33 34 #if GOOGLE_CUDA 35 namespace { 36 template <typename Scalar> 37 se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* cuda_memory) { 38 se::DeviceMemoryBase wrapped(const_cast<Scalar*>(cuda_memory)); 39 se::DeviceMemory<Scalar> typed(wrapped); 40 return typed; 41 } 42 } // namespace 43 #endif // GOOGLE_CUDA 44 45 template <class Scalar> 46 class MatrixTriangularSolveOp : public LinearAlgebraOp<Scalar> { 47 public: 48 INHERIT_LINALG_TYPEDEFS(Scalar); 49 50 explicit MatrixTriangularSolveOp(OpKernelConstruction* context) 51 : Base(context), lower_(true), adjoint_(false) { 52 OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_)); 53 OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); 54 } 55 56 void ValidateInputMatrixShapes( 57 OpKernelContext* context, 58 const TensorShapes& input_matrix_shapes) const final { 59 Base::ValidateSquareSolver(context, input_matrix_shapes); 60 } 61 62 TensorShapes GetOutputMatrixShapes( 63 const TensorShapes& input_matrix_shapes) const final { 64 return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1), 65 input_matrix_shapes[1].dim_size(1)})}); 66 } 67 68 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { 69 double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0)); 70 double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1)); 71 double cost = rows * rows * num_rhss * 72 (Eigen::TensorOpCost::AddCost<Scalar>() + 73 Eigen::TensorOpCost::MulCost<Scalar>()); 74 return cost >= static_cast<double>(kint64max) ? kint64max 75 : static_cast<int64>(cost); 76 } 77 78 bool EnableInputForwarding() const final { return false; } 79 80 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, 81 MatrixMaps* outputs) final { 82 const ConstMatrixMap& matrix = inputs[0]; 83 const ConstMatrixMap& rhs = inputs[1]; 84 MatrixMap& output = outputs->at(0); 85 86 if (matrix.rows() == 0 || rhs.cols() == 0) { 87 // To be consistent with the MatrixInverse op, we define the solution for 88 // an empty set of equation as the empty matrix. 89 return; 90 } 91 const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff(); 92 OP_REQUIRES(context, min_abs_pivot > RealScalar(0), 93 errors::InvalidArgument("Input matrix is not invertible.")); 94 if (lower_) { 95 auto triangle = matrix.template triangularView<Eigen::Lower>(); 96 if (adjoint_) { 97 output.noalias() = triangle.adjoint().solve(rhs); 98 } else { 99 output.noalias() = triangle.solve(rhs); 100 } 101 } else { 102 auto triangle = matrix.template triangularView<Eigen::Upper>(); 103 if (adjoint_) { 104 output.noalias() = triangle.adjoint().solve(rhs); 105 } else { 106 output.noalias() = triangle.solve(rhs); 107 } 108 } 109 } 110 111 private: 112 bool lower_; 113 bool adjoint_; 114 115 TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOp); 116 }; 117 118 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve", 119 (MatrixTriangularSolveOp<float>), float); 120 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve", 121 (MatrixTriangularSolveOp<double>), double); 122 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve", 123 (MatrixTriangularSolveOp<complex64>), complex64); 124 REGISTER_LINALG_OP_CPU("MatrixTriangularSolve", 125 (MatrixTriangularSolveOp<complex128>), complex128); 126 REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve", 127 (MatrixTriangularSolveOp<float>), float); 128 REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve", 129 (MatrixTriangularSolveOp<double>), double); 130 131 #ifdef GOOGLE_CUDA 132 133 // TODO(rmlarsen): Re-factor to 134 // 1. Enable buffer forwarding from rhs->out. 135 // 2. Save Memcpy when buffer forwarding is used. 136 // 3. Copy entire rhs in a single Memcpy when forwarding is not used. 137 template <class Scalar> 138 class MatrixTriangularSolveOpGPU : public LinearAlgebraOp<Scalar> { 139 public: 140 INHERIT_LINALG_TYPEDEFS(Scalar); 141 142 explicit MatrixTriangularSolveOpGPU(OpKernelConstruction* context) 143 : Base(context), lower_(true), adjoint_(false) { 144 OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_)); 145 OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); 146 } 147 148 void ValidateInputMatrixShapes( 149 OpKernelContext* context, 150 const TensorShapes& input_matrix_shapes) const final { 151 Base::ValidateSquareSolver(context, input_matrix_shapes); 152 } 153 154 TensorShapes GetOutputMatrixShapes( 155 const TensorShapes& input_matrix_shapes) const final { 156 return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1), 157 input_matrix_shapes[1].dim_size(1)})}); 158 } 159 160 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { 161 double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0)); 162 double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1)); 163 double cost = rows * rows * num_rhss * 164 (Eigen::TensorOpCost::AddCost<Scalar>() + 165 Eigen::TensorOpCost::MulCost<Scalar>()); 166 return cost >= static_cast<double>(kint64max) ? kint64max 167 : static_cast<int64>(cost); 168 } 169 170 bool EnableInputForwarding() const final { return false; } 171 172 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, 173 MatrixMaps* outputs) final { 174 const ConstMatrixMap& matrix = inputs[0]; 175 const ConstMatrixMap& rhs = inputs[1]; 176 MatrixMap& output = outputs->at(0); 177 178 if (matrix.rows() == 0 || rhs.cols() == 0) { 179 // To be consistent with the MatrixInverse op, we define the solution for 180 // an empty set of equation as the empty matrix. 181 return; 182 } 183 184 auto matrix_ptr = AsDeviceMemory(matrix.data()); 185 auto rhs_ptr = AsDeviceMemory(rhs.data()); 186 auto out_ptr = AsDeviceMemory(output.data()); 187 188 auto* stream = context->op_device_context()->stream(); 189 uint64 rhs_elems = rhs.rows() * rhs.cols(); 190 bool copy_status = 191 stream->ThenMemcpyD2D(&out_ptr, rhs_ptr, sizeof(Scalar) * rhs_elems) 192 .ok(); 193 if (!copy_status) { 194 context->SetStatus( 195 errors::Internal("Failed to copy rhs into output before solve")); 196 } 197 198 // Cublas does 199 // output = matrix \ rhs 200 // where matrix, rhs and output are assumed to be in column major. 201 // We want the output to be in row-major, so we can compute 202 // output' = rhs' / matrix' (' stands for transpose) 203 // Upper/lower needs to be swapped for this. 204 205 se::blas::UpperLower upper_lower_matrix; 206 se::blas::Transpose transpose_matrix; 207 if (lower_) { 208 upper_lower_matrix = se::blas::UpperLower::kUpper; 209 } else { 210 upper_lower_matrix = se::blas::UpperLower::kLower; 211 } 212 if (adjoint_) { 213 transpose_matrix = se::blas::Transpose::kConjugateTranspose; 214 } else { 215 transpose_matrix = se::blas::Transpose::kNoTranspose; 216 } 217 uint64 leading_dim_matrix = matrix.cols(); 218 uint64 leading_dim_output = output.cols(); 219 uint64 colmajor_rows = output.cols(); 220 uint64 colmajor_cols = output.rows(); 221 bool blas_launch_status = 222 stream 223 ->ThenBlasTrsm( 224 se::blas::Side::kRight /*side*/, upper_lower_matrix /*uplo*/, 225 transpose_matrix /*trans*/, 226 se::blas::Diagonal::kNonUnit /*diag*/, colmajor_rows /*m*/, 227 colmajor_cols /*n*/, Scalar(1.0) /*alpha*/, matrix_ptr, 228 leading_dim_matrix /*lda*/, &out_ptr, 229 leading_dim_output /*ldb*/) 230 .ok(); 231 if (!blas_launch_status) { 232 context->SetStatus(errors::Internal("Blas TRSM launch failed")); 233 } 234 } 235 236 private: 237 bool lower_; 238 bool adjoint_; 239 240 TF_DISALLOW_COPY_AND_ASSIGN(MatrixTriangularSolveOpGPU); 241 }; 242 243 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve", 244 (MatrixTriangularSolveOpGPU<float>), float); 245 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve", 246 (MatrixTriangularSolveOpGPU<double>), double); 247 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve", 248 (MatrixTriangularSolveOpGPU<complex64>), complex64); 249 REGISTER_LINALG_OP_GPU("MatrixTriangularSolve", 250 (MatrixTriangularSolveOpGPU<complex128>), complex128); 251 REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve", 252 (MatrixTriangularSolveOpGPU<float>), float); 253 REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve", 254 (MatrixTriangularSolveOpGPU<double>), double); 255 256 #endif // GOOGLE_CUDA 257 258 } // namespace tensorflow 259