1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/linalg_ops.cc. 17 #if GOOGLE_CUDA 18 #define EIGEN_USE_GPU 19 #endif 20 21 #include <numeric> 22 23 #include "third_party/eigen3/Eigen/Core" 24 #include "third_party/eigen3/Eigen/LU" 25 #include "tensorflow/core/framework/kernel_def_builder.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/kernels/linalg_ops_common.h" 29 #include "tensorflow/core/lib/core/errors.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/types.h" 33 34 #if GOOGLE_CUDA 35 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 36 #include "tensorflow/core/kernels/cuda_solvers.h" 37 #include "tensorflow/core/kernels/transpose_functor.h" 38 #endif 39 40 namespace tensorflow { 41 42 static const char kErrMsg[] = "Input matrix is not invertible."; 43 44 template <class Scalar> 45 class MatrixSolveOp : public LinearAlgebraOp<Scalar> { 46 public: 47 INHERIT_LINALG_TYPEDEFS(Scalar); 48 49 explicit MatrixSolveOp(OpKernelConstruction* context) : Base(context) { 50 OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); 51 } 52 53 void ValidateInputMatrixShapes( 54 OpKernelContext* context, 55 const TensorShapes& input_matrix_shapes) const final { 56 Base::ValidateSquareSolver(context, input_matrix_shapes); 57 } 58 59 TensorShapes GetOutputMatrixShapes( 60 const TensorShapes& input_matrix_shapes) const final { 61 return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1), 62 input_matrix_shapes[1].dim_size(1)})}); 63 } 64 65 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { 66 double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0)); 67 double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1)); 68 double cost = rows * rows * (rows + num_rhss); 69 return cost >= static_cast<double>(kint64max) ? kint64max 70 : static_cast<int64>(cost); 71 } 72 73 bool EnableInputForwarding() const final { return false; } 74 75 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, 76 MatrixMaps* outputs) final { 77 const ConstMatrixMap& matrix = inputs[0]; 78 const ConstMatrixMap& rhs = inputs[1]; 79 if (matrix.rows() == 0 || rhs.cols() == 0) { 80 // To be consistent with the MatrixInverse op, we define the solution for 81 // an empty set of equation as the empty matrix. 82 return; 83 } 84 Eigen::PartialPivLU<Matrix> lu_decomposition(matrix.rows()); 85 if (adjoint_) { 86 // TODO(rmlarsen): For Eigen 3.2, this creates a temporary copy. 87 // Make sure to backport: https://bitbucket.org/eigen/eigen/commits/ 88 // bd2219a74c96dfe3f6bc2c23588749e36d2d8173 89 lu_decomposition.compute(matrix.adjoint()); 90 } else { 91 lu_decomposition.compute(matrix); 92 } 93 94 // PartialPivLU cannot give strong guarantees on invertibility, 95 // but we can at least guard against exact zero pivots. This can occur as 96 // a result of basic user mistakes such providing integer valued 97 // matrices that are exactly singular, or due to underflow if this 98 // code is run with denormals being flushed to zero. 99 const RealScalar min_abs_pivot = 100 lu_decomposition.matrixLU().diagonal().cwiseAbs().minCoeff(); 101 OP_REQUIRES(context, min_abs_pivot > RealScalar(0), 102 errors::InvalidArgument(kErrMsg)); 103 104 // TODO(rmlarsen): Add check based on condition number estimation. 105 // The necessary changes to Eigen are in 106 // https://bitbucket.org/eigen/eigen/pull-requests/174/ 107 // add-matrix-condition-number-estimation/diff 108 outputs->at(0) = lu_decomposition.solve(rhs); 109 } 110 111 private: 112 bool adjoint_; 113 114 TF_DISALLOW_COPY_AND_ASSIGN(MatrixSolveOp); 115 }; 116 117 #if GOOGLE_CUDA 118 typedef Eigen::GpuDevice GPUDevice; 119 120 template <class Scalar> 121 class MatrixSolveOpGpu : public AsyncOpKernel { 122 public: 123 explicit MatrixSolveOpGpu(OpKernelConstruction* context) 124 : AsyncOpKernel(context) { 125 OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); 126 } 127 128 void ComputeAsync(OpKernelContext* context, DoneCallback done) final { 129 const Tensor& input = context->input(0); 130 const Tensor& rhs = context->input(1); 131 const int ndims = input.dims(); 132 const int64 n = input.dim_size(ndims - 1); 133 const int64 nrhs = rhs.dim_size(ndims - 1); 134 // Validate inputs. 135 OP_REQUIRES_ASYNC( 136 context, ndims >= 2, 137 errors::InvalidArgument("Input must have rank >= 2, got ", ndims), 138 done); 139 OP_REQUIRES_ASYNC(context, rhs.dims() == ndims, 140 errors::InvalidArgument( 141 "Input and right-hand side must have same rank, got ", 142 ndims, " != ", rhs.dims()), 143 done); 144 OP_REQUIRES_ASYNC( 145 context, input.dim_size(ndims - 2) == n, 146 errors::InvalidArgument("Input matrices must be squares, got", 147 input.dim_size(ndims - 2), " != ", n), 148 done); 149 OP_REQUIRES_ASYNC(context, rhs.dim_size(ndims - 2) == n, 150 errors::InvalidArgument( 151 "Input matrix and right-hand side must have the " 152 "same number of rows, got", 153 n, " != ", rhs.dim_size(ndims - 2)), 154 done); 155 156 // Allocate output. 157 Tensor* output; 158 OP_REQUIRES_OK_ASYNC( 159 context, 160 context->forward_input_or_allocate_output({1}, 0, rhs.shape(), &output), 161 done); 162 163 // To be consistent with the MatrixInverse op, we define the solution for 164 // an empty set of equations as the empty matrix. 165 if (rhs.NumElements() == 0) { 166 done(); 167 return; 168 } 169 170 // TODO(rmlarsen): Convert to std::make_unique when available. 171 std::unique_ptr<CudaSolver> solver(new CudaSolver(context)); 172 173 // Make a copy of the input for the factorization step, or, if adjoint_ is 174 // false, try to reuse the input buffer if this op owns it exclusively. 175 Tensor input_copy; 176 const GPUDevice& device = context->eigen_device<GPUDevice>(); 177 if (adjoint_) { 178 // For the adjoint case, it is simpler to always make a transposed copy up 179 // front. 180 OP_REQUIRES_OK_ASYNC( 181 context, 182 solver->allocate_scoped_tensor(DataTypeToEnum<Scalar>::value, 183 input.shape(), &input_copy), 184 done); 185 OP_REQUIRES_OK_ASYNC(context, 186 DoMatrixTranspose(device, input, &input_copy), done); 187 } else { 188 OP_REQUIRES_OK_ASYNC( 189 context, 190 solver->forward_input_or_allocate_scoped_tensor( 191 {0}, DataTypeToEnum<Scalar>::value, input.shape(), &input_copy), 192 done); 193 if (!input.SharesBufferWith(input_copy)) { 194 device.memcpy(input_copy.flat<Scalar>().data(), 195 input.flat<Scalar>().data(), 196 input.NumElements() * sizeof(Scalar)); 197 } 198 } 199 auto input_copy_reshaped = input_copy.template flat_inner_dims<Scalar, 3>(); 200 const int64 batch_size = input_copy_reshaped.dimension(0); 201 202 // Allocate pivots on the device. 203 Tensor pivots; 204 OP_REQUIRES_OK_ASYNC( 205 context, 206 solver->allocate_scoped_tensor(DataTypeToEnum<int>::value, 207 TensorShape{batch_size, n}, &pivots), 208 done); 209 auto pivots_mat = pivots.template matrix<int>(); 210 211 // 1. Compute the partially pivoted LU factorization(s) of the 212 // matrix/matrices. 213 std::vector<DeviceLapackInfo> dev_info; 214 auto input_copy_ptrs = solver->GetScratchSpace<uint8>( 215 sizeof(Scalar*) * batch_size, "input_copt_ptrs", 216 /* on_host */ true); 217 const int kMaxMatrixSizeToBatchSizeRatio = 128; 218 const bool use_batched_solver = 219 n <= kMaxMatrixSizeToBatchSizeRatio * batch_size; 220 if (use_batched_solver) { 221 // For small matrices or large batch sizes, we use the batched interface 222 // from cuBlas. 223 const Scalar** input_copy_ptrs_base = 224 reinterpret_cast<const Scalar**>(input_copy_ptrs.mutable_data()); 225 for (int batch = 0; batch < batch_size; ++batch) { 226 input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0); 227 } 228 dev_info.push_back( 229 solver->GetDeviceLapackInfo(batch_size, "getrfBatched")); 230 OP_REQUIRES_OK_ASYNC( 231 context, 232 solver->GetrfBatched(n, input_copy_ptrs_base, n, pivots_mat.data(), 233 &dev_info.back(), batch_size), 234 done); 235 } else { 236 // For small batch sizes or large matrices, we use the non-batched 237 // interface from cuSolver, which is much faster for large matrices. 238 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf")); 239 for (int batch = 0; batch < batch_size; ++batch) { 240 OP_REQUIRES_OK_ASYNC( 241 context, 242 solver->Getrf(n, n, &input_copy_reshaped(batch, 0, 0), n, 243 &pivots_mat(batch, 0), &dev_info.back()(batch)), 244 done); 245 } 246 } 247 248 // 2. Make a transposed copy of the right-hand sides. This is necessary 249 // because cuBLAS assumes column-major storage while TensorFlow TF uses 250 // row-major. 251 TensorShape transposed_rhs_shape(rhs.shape()); 252 transposed_rhs_shape.RemoveLastDims(2); 253 transposed_rhs_shape.AddDim(nrhs); 254 transposed_rhs_shape.AddDim(n); 255 Tensor transposed_rhs; 256 OP_REQUIRES_OK_ASYNC( 257 context, 258 solver->allocate_scoped_tensor(DataTypeToEnum<Scalar>::value, 259 transposed_rhs_shape, &transposed_rhs), 260 done); 261 if (nrhs > 1) { 262 OP_REQUIRES_OK_ASYNC( 263 context, DoMatrixTranspose(device, rhs, &transposed_rhs), done); 264 } else { 265 device.memcpy(transposed_rhs.flat<Scalar>().data(), 266 rhs.flat<Scalar>().data(), 267 rhs.NumElements() * sizeof(Scalar)); 268 } 269 270 // 3. Solve op(A) X = B (in column major form). 271 // We use a trick here: If adjoint_ is true, we converted A to column major 272 // form above. If adjoint is false then I leave A in row-major form and use 273 // trans_a = CUBLAS_OP_T to effectively transform it to column-major on the 274 // fly. (This means that we actually use the LU-factorization of A^T in that 275 // case, but that is equally good for solving AX=B). This way we save an 276 // explicit transpose in the more common case of adjoint_ == false. 277 auto input_copy_ptr_array = solver->GetScratchSpace<uint8>( 278 sizeof(Scalar*) * batch_size, "input_copy_ptr_array", 279 /* on_host */ true); 280 auto transposed_rhs_ptr_array = solver->GetScratchSpace<uint8>( 281 sizeof(Scalar*) * batch_size, "transposed_rhs_ptr_array", 282 /* on_host */ true); 283 auto transposed_rhs_reshaped = 284 transposed_rhs.template flat_inner_dims<Scalar, 3>(); 285 if (use_batched_solver) { 286 const Scalar** input_copy_ptrs_base = 287 reinterpret_cast<const Scalar**>(input_copy_ptr_array.mutable_data()); 288 const Scalar** transposed_rhs_ptrs_base = 289 reinterpret_cast<const Scalar**>( 290 transposed_rhs_ptr_array.mutable_data()); 291 for (int batch = 0; batch < batch_size; ++batch) { 292 input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0); 293 transposed_rhs_ptrs_base[batch] = &transposed_rhs_reshaped(batch, 0, 0); 294 } 295 int host_info = 0; 296 OP_REQUIRES_OK_ASYNC( 297 context, 298 solver->GetrsBatched(adjoint_ ? CUBLAS_OP_C : CUBLAS_OP_T, n, nrhs, 299 input_copy_ptrs_base, n, pivots_mat.data(), 300 transposed_rhs_ptrs_base, n, &host_info, 301 batch_size), 302 done); 303 OP_REQUIRES_ASYNC( 304 context, host_info == 0, 305 errors::InvalidArgument("The ", -host_info, 306 "'th argument to cublas*getrsBatched had " 307 "an illegal value."), 308 done); 309 } else { 310 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrs")); 311 for (int batch = 0; batch < batch_size; ++batch) { 312 OP_REQUIRES_OK_ASYNC( 313 context, 314 solver->Getrs(adjoint_ ? CUBLAS_OP_C : CUBLAS_OP_T, n, nrhs, 315 &input_copy_reshaped(batch, 0, 0), n, 316 &pivots_mat(batch, 0), 317 &transposed_rhs_reshaped(batch, 0, 0), n, 318 &dev_info.back()(batch)), 319 done); 320 } 321 } 322 323 // 4. Transpose X to get the final result in row-major form. 324 if (nrhs > 1) { 325 OP_REQUIRES_OK_ASYNC( 326 context, DoMatrixTranspose(device, transposed_rhs, output), done); 327 } else { 328 device.memcpy(output->flat<Scalar>().data(), 329 transposed_rhs.flat<Scalar>().data(), 330 transposed_rhs.NumElements() * sizeof(Scalar)); 331 } 332 333 // Callback for checking info after kernels finish. Also capture the 334 // temporary Tensors/ScratchSpace so they don't get deallocated before the 335 // kernels run. TODO(rmlarsen): Use move capture once C++14 becomes 336 // available. 337 auto info_checker = [context, done, dev_info]( 338 const Status& status, 339 const std::vector<HostLapackInfo>& host_infos) { 340 if (!status.ok() && errors::IsInvalidArgument(status) && 341 !host_infos.empty()) { 342 for (int i = 0; i < host_infos[0].size(); ++i) { 343 // Match the CPU error message for singular matrices. Otherwise 344 // just print the original error message from the status below. 345 OP_REQUIRES_ASYNC(context, host_infos[0].data()[i] <= 0, 346 errors::InvalidArgument(kErrMsg), done); 347 } 348 } 349 OP_REQUIRES_OK_ASYNC(context, status, done); 350 done(); 351 }; 352 CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info, 353 std::move(info_checker)); 354 } 355 356 private: 357 bool adjoint_; 358 }; 359 360 REGISTER_LINALG_OP_GPU("MatrixSolve", (MatrixSolveOpGpu<float>), float); 361 REGISTER_LINALG_OP_GPU("MatrixSolve", (MatrixSolveOpGpu<double>), double); 362 REGISTER_LINALG_OP_GPU("MatrixSolve", (MatrixSolveOpGpu<complex64>), complex64); 363 REGISTER_LINALG_OP_GPU("MatrixSolve", (MatrixSolveOpGpu<complex128>), 364 complex128); 365 366 #endif // GOOGLE_CUDA 367 368 REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<float>), float); 369 REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<double>), double); 370 REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<complex64>), complex64); 371 REGISTER_LINALG_OP("MatrixSolve", (MatrixSolveOp<complex128>), complex128); 372 REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<float>), float); 373 REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<double>), double); 374 REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<complex64>), complex64); 375 REGISTER_LINALG_OP("BatchMatrixSolve", (MatrixSolveOp<complex128>), complex128); 376 } // namespace tensorflow 377