1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/linalg_ops.cc. 17 #if GOOGLE_CUDA 18 #define EIGEN_USE_GPU 19 #endif 20 21 #include <numeric> 22 23 #include "third_party/eigen3/Eigen/Core" 24 #include "third_party/eigen3/Eigen/LU" 25 #include "tensorflow/core/framework/kernel_def_builder.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/kernels/linalg_ops_common.h" 29 #include "tensorflow/core/lib/core/errors.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/types.h" 33 34 #if GOOGLE_CUDA 35 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 36 #include "tensorflow/core/kernels/cuda_solvers.h" 37 #include "tensorflow/core/kernels/transpose_functor.h" 38 #endif 39 40 namespace tensorflow { 41 42 static const char kErrMsg[] = "Input matrix is not invertible."; 43 44 template <class Scalar> 45 class MatrixSolveOp : public LinearAlgebraOp<Scalar> { 46 public: 47 INHERIT_LINALG_TYPEDEFS(Scalar); 48 49 explicit MatrixSolveOp(OpKernelConstruction* context) : Base(context) { 50 OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); 51 } 52 53 void ValidateInputMatrixShapes( 54 OpKernelContext* context, 55 const TensorShapes& input_matrix_shapes) const final { 56 Base::ValidateSquareSolver(context, input_matrix_shapes); 57 } 58 59 TensorShapes GetOutputMatrixShapes( 60 const TensorShapes& input_matrix_shapes) const final { 61 return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1), 62 input_matrix_shapes[1].dim_size(1)})}); 63 } 64 65 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { 66 double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0)); 67 double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1)); 68 double cost = rows * rows * (rows + num_rhss); 69 return cost >= static_cast<double>(kint64max) ? kint64max 70 : static_cast<int64>(cost); 71 } 72 73 bool EnableInputForwarding() const final { return false; } 74 75 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, 76 MatrixMaps* outputs) final { 77 const ConstMatrixMap& matrix = inputs[0]; 78 const ConstMatrixMap& rhs = inputs[1]; 79 if (matrix.rows() == 0 || rhs.cols() == 0) { 80 // To be consistent with the MatrixInverse op, we define the solution for 81 // an empty set of equation as the empty matrix. 82 return; 83 } 84 Eigen::PartialPivLU<Matrix> lu_decomposition(matrix.rows()); 85 if (adjoint_) { 86 // TODO(rmlarsen): For Eigen 3.2, this creates a temporary copy. 87 // Make sure to backport: https://bitbucket.org/eigen/eigen/commits/ 88 // bd2219a74c96dfe3f6bc2c23588749e36d2d8173 89 lu_decomposition.compute(matrix.adjoint()); 90 } else { 91 lu_decomposition.compute(matrix); 92 } 93 94 // PartialPivLU cannot give strong guarantees on invertibility, 95 // but we can at least guard against exact zero pivots. This can occur as 96 // a result of basic user mistakes such providing integer valued 97 // matrices that are exactly singular, or due to underflow if this 98 // code is run with denormals being flushed to zero. 99 const RealScalar min_abs_pivot = 100 lu_decomposition.matrixLU().diagonal().cwiseAbs().minCoeff(); 101 OP_REQUIRES(context, min_abs_pivot > RealScalar(0), 102 errors::InvalidArgument(kErrMsg)); 103 104 // TODO(rmlarsen): Add check based on condition number estimation. 105 // The necessary changes to Eigen are in 106 // https://bitbucket.org/eigen/eigen/pull-requests/174/ 107 // add-matrix-condition-number-estimation/diff 108 outputs->at(0) = lu_decomposition.solve(rhs); 109 } 110 111 private: 112 bool adjoint_; 113 114 TF_DISALLOW_COPY_AND_ASSIGN(MatrixSolveOp); 115 }; 116 117 #if GOOGLE_CUDA 118 typedef Eigen::GpuDevice GPUDevice; 119 120 template <class Scalar> 121 class MatrixSolveOpGpu : public AsyncOpKernel { 122 public: 123 explicit MatrixSolveOpGpu(OpKernelConstruction* context) 124 : AsyncOpKernel(context) { 125 OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); 126 } 127 128 void ComputeAsync(OpKernelContext* context, DoneCallback done) final { 129 const Tensor& input = context->input(0); 130 const Tensor& rhs = context->input(1); 131 const int ndims = input.dims(); 132 const int64 n = input.dim_size(ndims - 1); 133 const int64 nrhs = rhs.dim_size(ndims - 1); 134 // Validate inputs. 135 OP_REQUIRES_ASYNC( 136 context, ndims >= 2, 137 errors::InvalidArgument("Input must have rank >= 2, got ", ndims), 138 done); 139 OP_REQUIRES_ASYNC(context, rhs.dims() == ndims, 140 errors::InvalidArgument( 141 "Input and right-hand side must have same rank, got ", 142 ndims, " != ", rhs.dims()), 143 done); 144 OP_REQUIRES_ASYNC( 145 context, input.dim_size(ndims - 2) == n, 146 errors::InvalidArgument("Input matrices must be squares, got", 147 input.dim_size(ndims - 2), " != ", n), 148 done); 149 OP_REQUIRES_ASYNC(context, rhs.dim_size(ndims - 2) == n, 150 errors::InvalidArgument( 151 "Input matrix and right-hand side must have the " 152 "same number of rows, got", 153 n, " != ", rhs.dim_size(ndims - 2)), 154 done); 155 156 // Allocate output. 157 Tensor* output; 158 OP_REQUIRES_OK_ASYNC( 159 context, 160 context->forward_input_or_allocate_output({1}, 0, rhs.shape(), &output), 161 done); 162 163 // To be consistent with the MatrixInverse op, we define the solution for 164 // an empty set of equations as the empty matrix. 165 if (rhs.NumElements() == 0) { 166 done(); 167 return; 168 } 169 170 // TODO(rmlarsen): Convert to std::make_unique when available. 171 std::unique_ptr<CudaSolver> solver(new CudaSolver(context)); 172 173 // Make a copy of the input for the factorization step, or, if adjoint_ is 174 // false, try to reuse the input buffer if this op owns it exclusively. 175 Tensor input_copy; 176 const GPUDevice& device = context->eigen_device<GPUDevice>(); 177 if (adjoint_) { 178 // For the adjoint case, it is simpler to always make a transposed copy up 179 // front. 180 OP_REQUIRES_OK_ASYNC( 181 context, 182 solver->allocate_scoped_tensor(DataTypeToEnum<Scalar>::value, 183 input.shape(), &input_copy), 184 done); 185 OP_REQUIRES_OK_ASYNC(context, 186 DoMatrixTranspose(device, input, &input_copy), done); 187 } else { 188 OP_REQUIRES_OK_ASYNC( 189 context, 190 solver->forward_input_or_allocate_scoped_tensor( 191 {0}, DataTypeToEnum<Scalar>::value, input.shape(), &input_copy), 192 done); 193 if (!input.SharesBufferWith(input_copy)) { 194 device.memcpy(input_copy.flat<Scalar>().data(), 195 input.flat<Scalar>().data(), 196 input.NumElements() * sizeof(Scalar)); 197 } 198 } 199 auto input_copy_reshaped = input_copy.template flat_inner_dims<Scalar, 3>(); 200 const int64 batch_size = input_copy_reshaped.dimension(0); 201 202 // Allocate pivots on the device. 203 Tensor pivots; 204 OP_REQUIRES_OK_ASYNC( 205 context, 206 solver->allocate_scoped_tensor(DataTypeToEnum<int>::value, 207 TensorShape{batch_size, n}, &pivots), 208 done); 209 auto pivots_mat = pivots.template matrix<int>(); 210 211 // 1. Compute the partially pivoted LU factorization(s) of the 212 // matrix/matrices. 213 std::vector<DeviceLapackInfo> dev_info; 214 auto input_copy_ptrs = solver->GetScratchSpace<uint8>( 215 sizeof(Scalar*) * batch_size, "input_copt_ptrs", 216 /* on_host */ true); 217 if (n / batch_size <= 128) { 218 // For small matrices or large batch sizes, we use the batched 219 // interface from cuBlas. 220 const Scalar** input_copy_ptrs_base = 221 reinterpret_cast<const Scalar**>(input_copy_ptrs.mutable_data()); 222 for (int batch = 0; batch < batch_size; ++batch) { 223 input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0); 224 } 225 dev_info.push_back( 226 solver->GetDeviceLapackInfo(batch_size, "getrfBatched")); 227 OP_REQUIRES_OK_ASYNC( 228 context, 229 solver->GetrfBatched(n, input_copy_ptrs_base, n, pivots_mat.data(), 230 &dev_info.back(), batch_size), 231 done); 232 } else { 233 // For small batch sizes we use the non-batched interface from cuSolver, 234 // which is much faster for large matrices. 235 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf")); 236 for (int batch = 0; batch < batch_size; ++batch) { 237 OP_REQUIRES_OK_ASYNC( 238 context, 239 solver->Getrf(n, n, &input_copy_reshaped(batch, 0, 0), n, 240 &pivots_mat(batch, 0), &dev_info.back()(batch)), 241 done); 242 } 243 } 244 245 // 2. Make a transposed copy of the right-hand sides. This is necessary 246 // because cuBLAS assumes column-major storage while TensorFlow TF uses 247 // row-major. 248 TensorShape transposed_rhs_shape(rhs.shape()); 249 transposed_rhs_shape.RemoveLastDims(2); 250 transposed_rhs_shape.AddDim(nrhs); 251 transposed_rhs_shape.AddDim(n); 252 Tensor transposed_rhs; 253 OP_REQUIRES_OK_ASYNC( 254 context, 255 solver->allocate_scoped_tensor(DataTypeToEnum<Scalar>::value, 256 transposed_rhs_shape, &transposed_rhs), 257 done); 258 if (nrhs > 1) { 259 OP_REQUIRES_OK_ASYNC( 260 context, DoMatrixTranspose(device, rhs, &transposed_rhs), done); 261 } else { 262 device.memcpy(transposed_rhs.flat<Scalar>().data(), 263 rhs.flat<Scalar>().data(), 264 rhs.NumElements() * sizeof(Scalar)); 265 } 266 267 // 3. Solve op(A) X = B (in column major form). 268 // We use a trick here: If adjoint_ is true, we converted A to column major 269 // form above. If adjoint is false then I leave A in row-major form and use 270 // trans_a = CUBLAS_OP_T to effectively transform it to column-major on the 271 // fly. (This means that we actually use the LU-factorization of A^T in that 272 // case, but that is equally good for solving AX=B). This way we save an 273 // explicit transpose in the more common case of adjoint_ == false. 274 auto input_copy_ptr_array = solver->GetScratchSpace<uint8>( 275 sizeof(Scalar*) * batch_size, "input_copy_ptr_array", 276 /* on_host */ true); 277 auto transposed_rhs_ptr_array = solver->GetScratchSpace<uint8>( 278 sizeof(Scalar*) * batch_size, "transposed_rhs_ptr_array", 279 /* on_host */ true); 280 auto transposed_rhs_reshaped = 281 transposed_rhs.template flat_inner_dims<Scalar, 3>(); 282 // TODO(rmlarsen): Enable the following branch when I figure 283 // out why it causes a segfault. 284 if (false && n / batch_size <= 128) { 285 dev_info.push_back( 286 solver->GetDeviceLapackInfo(batch_size, "GetrsBatched")); 287 const Scalar** input_copy_ptrs_base = 288 reinterpret_cast<const Scalar**>(input_copy_ptr_array.mutable_data()); 289 const Scalar** transposed_rhs_ptrs_base = 290 reinterpret_cast<const Scalar**>( 291 transposed_rhs_ptr_array.mutable_data()); 292 for (int batch = 0; batch < batch_size; ++batch) { 293 input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0); 294 transposed_rhs_ptrs_base[batch] = &transposed_rhs_reshaped(batch, 0, 0); 295 } 296 OP_REQUIRES_OK_ASYNC( 297 context, 298 solver->GetrsBatched(adjoint_ ? CUBLAS_OP_C : CUBLAS_OP_T, n, nrhs, 299 input_copy_ptrs_base, n, pivots_mat.data(), 300 transposed_rhs_ptrs_base, n, &dev_info.back(), 301 batch_size), 302 done); 303 } else { 304 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrs")); 305 for (int batch = 0; batch < batch_size; ++batch) { 306 OP_REQUIRES_OK_ASYNC( 307 context, 308 solver->Getrs(adjoint_ ? CUBLAS_OP_C : CUBLAS_OP_T, n, nrhs, 309 &input_copy_reshaped(batch, 0, 0), n, 310 &pivots_mat(batch, 0), 311 &transposed_rhs_reshaped(batch, 0, 0), n, 312 &dev_info.back()(batch)), 313 done); 314 } 315 } 316 317 // 4. Transpose X to get the final result in row-major form. 318 if (nrhs > 1) { 319 OP_REQUIRES_OK_ASYNC( 320 context, DoMatrixTranspose(device, transposed_rhs, output), done); 321 } else { 322 device.memcpy(output->flat<Scalar>().data(), 323 transposed_rhs.flat<Scalar>().data(), 324 transposed_rhs.NumElements() * sizeof(Scalar)); 325 } 326 327 // Callback for checking info after kernels finish. Also capture the 328 // temporary Tensors/ScratchSpace so they don't get deallocated before the 329 // kernels run. TODO(rmlarsen): Use move capture once C++14 becomes 330 // available. 331 auto info_checker = [context, done, dev_info]( 332 const Status& status, 333 const std::vector<HostLapackInfo>& host_infos) { 334 if (!status.ok() && errors::IsInvalidArgument(status) && 335 !host_infos.empty()) { 336 for (int i = 0; i < host_infos[0].size(); ++i) { 337 // Match the CPU error message for singular matrices. Otherwise 338 // just print the original error message from the status below. 339 OP_REQUIRES_ASYNC(context, host_infos[0].data()[i] <= 0, 340 errors::InvalidArgument(kErrMsg), done); 341 } 342 } 343 OP_REQUIRES_OK_ASYNC(context, status, done); 344 done(); 345 }; 346 CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info, 347 std::move(info_checker)); 348 } 349 350 private: 351 bool adjoint_; 352 }; 353 354 REGISTER_LINALG_OP_GPU("MatrixSolve", (MatrixSolveOpGpu<float>), float); 355 REGISTER_LINALG_OP_GPU("MatrixSolve", (MatrixSolveOpGpu<double>), double); 356 REGISTER_LINALG_OP_GPU("MatrixSolve", (MatrixSolveOpGpu<complex64>), complex64); 357 REGISTER_LINALG_OP_GPU("MatrixSolve", (MatrixSolveOpGpu<complex128>), 358 complex128); 359 360 #endif // GOOGLE_CUDA 361 362 REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<float>), float); 363 REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<double>), double); 364 REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<complex64>), complex64); 365 REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<complex128>), complex128); 366 REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<float>), float); 367 REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<double>), double); 368 REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<complex64>), complex64); 369 REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<complex128>), complex128); 370 } // namespace tensorflow 371