1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/linalg_ops.cc. 17 18 #if GOOGLE_CUDA 19 #define EIGEN_USE_GPU 20 #endif // GOOGLE_CUDA 21 22 #include "third_party/eigen3/Eigen/Cholesky" 23 #include "third_party/eigen3/Eigen/Core" 24 #include "tensorflow/core/framework/kernel_def_builder.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/register_types.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/kernels/linalg_ops_common.h" 29 #include "tensorflow/core/lib/core/errors.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/types.h" 32 33 #if GOOGLE_CUDA 34 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 35 #include "tensorflow/core/kernels/cuda_solvers.h" 36 #include "tensorflow/core/kernels/matrix_band_part_op.h" 37 #include "tensorflow/core/platform/stream_executor.h" 38 #endif 39 40 namespace tensorflow { 41 42 static const char kErrMsg[] = 43 "Cholesky decomposition was not successful. The input might not be valid."; 44 45 template <class Scalar> 46 class CholeskyOp : public LinearAlgebraOp<Scalar> { 47 public: 48 INHERIT_LINALG_TYPEDEFS(Scalar); 49 50 explicit CholeskyOp(OpKernelConstruction* context) : Base(context) {} 51 52 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, 53 MatrixMaps* outputs) final { 54 const ConstMatrixMap& input = inputs[0]; 55 if (input.rows() == 0) { 56 // If X is an empty matrix (0 rows, 0 col), X * X' == X. 57 // Therefore, we return X. 58 return; 59 } 60 // Perform the actual LL^T Cholesky decomposition. This will only use 61 // the lower triangular part of data_in by default. The upper triangular 62 // part of the matrix will not be read. 63 Eigen::LLT< 64 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> 65 llt_decomposition(input); 66 67 OP_REQUIRES(context, llt_decomposition.info() == Eigen::Success, 68 errors::InvalidArgument(kErrMsg)); 69 70 // Output the lower triangular in a dense form. 71 outputs->at(0) = llt_decomposition.matrixL(); 72 } 73 }; 74 75 #if GOOGLE_CUDA 76 typedef Eigen::GpuDevice GPUDevice; 77 78 namespace functor { 79 #define DECLARE_GPU_SPEC(T) \ 80 template <> \ 81 struct MatrixBandPartFunctor<GPUDevice, T> { \ 82 void operator()(OpKernelContext* context, const GPUDevice& device, \ 83 int num_upper_diags, int num_lower_diags, \ 84 typename TTypes<T, 3>::ConstTensor input, \ 85 typename TTypes<T, 3>::Tensor output); \ 86 }; \ 87 extern template struct MatrixBandPartFunctor<GPUDevice, T>; 88 89 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); 90 TF_CALL_complex64(DECLARE_GPU_SPEC); 91 TF_CALL_complex128(DECLARE_GPU_SPEC); 92 } // namespace functor 93 94 template <class Scalar> 95 class CholeskyOpGpu : public AsyncOpKernel { 96 public: 97 explicit CholeskyOpGpu(OpKernelConstruction* context) 98 : AsyncOpKernel(context) {} 99 100 void ComputeAsync(OpKernelContext* context, DoneCallback done) final { 101 const Tensor& input = context->input(0); 102 const int ndims = input.dims(); 103 const int64 n = input.dim_size(ndims - 1); 104 // Validate inputs. 105 OP_REQUIRES_ASYNC( 106 context, ndims >= 2, 107 errors::InvalidArgument("Input must have rank >= 2, got ", ndims), 108 done); 109 OP_REQUIRES_ASYNC( 110 context, input.dim_size(ndims - 2) == n, 111 errors::InvalidArgument("Input matrices must be squares, got", 112 input.dim_size(ndims - 2), " != ", n), 113 done); 114 115 if (input.NumElements() == 0) { 116 // If X is an empty matrix (0 rows, 0 col), X * X' == X. 117 // Therefore, we return X. 118 context->set_output(0, input); 119 done(); 120 return; 121 } 122 123 // Allocate output. 124 // TODO(rmlarsen): Convert to std::make_unique when available. 125 std::unique_ptr<CudaSolver> solver(new CudaSolver(context)); 126 Tensor* output; 127 OP_REQUIRES_OK_ASYNC(context, 128 context->forward_input_or_allocate_output( 129 {0}, 0, input.shape(), &output), 130 done); 131 132 // Copy the lower triangular part of the input matrices to the output and 133 // set the strictly upper triangular part to zero. We use a pre-existing 134 // kernel MatrixBandPart to do this for all matrices in the batch at once, 135 // before we launch each of the Cholesky factorization kernels in paralle. 136 auto input_reshaped = input.template flat_inner_dims<Scalar, 3>(); 137 auto output_reshaped = output->template flat_inner_dims<Scalar, 3>(); 138 functor::MatrixBandPartFunctor<GPUDevice, Scalar> band_part; 139 band_part(context, context->eigen_device<GPUDevice>(), 140 n /* num_lower_diags */, 0 /* num_upper_diags */, input_reshaped, 141 output_reshaped); 142 143 // Launch a Cholesky kernel for each matrix in the batch. 144 const int64 batch_size = input_reshaped.dimension(0); 145 std::vector<DeviceLapackInfo> dev_info; 146 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrf")); 147 // TODO(rmlarsen): Parallelize over batches if it turns out to be 148 // an important use case. 149 for (int batch = 0; batch < batch_size; ++batch) { 150 OP_REQUIRES_OK_ASYNC(context, 151 solver->Potrf(CUBLAS_FILL_MODE_UPPER, n, 152 &output_reshaped(batch, 0, 0), n, 153 &dev_info.back()(batch)), 154 done); 155 } 156 157 // Register callback to check info after kernels finish. 158 auto info_checker = [context, done]( 159 const Status& status, 160 const std::vector<HostLapackInfo>& /* unused */) { 161 OP_REQUIRES_ASYNC(context, status.ok(), errors::InvalidArgument(kErrMsg), 162 done); 163 done(); 164 }; 165 CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info, 166 std::move(info_checker)); 167 } 168 }; 169 170 REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<float>), float); 171 REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<double>), double); 172 REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<complex64>), complex64); 173 REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<complex128>), complex128); 174 175 #endif // GOOGLE_CUDA 176 177 REGISTER_LINALG_OP("Cholesky", (CholeskyOp<float>), float); 178 REGISTER_LINALG_OP("Cholesky", (CholeskyOp<double>), double); 179 REGISTER_LINALG_OP("Cholesky", (CholeskyOp<complex64>), complex64); 180 REGISTER_LINALG_OP("Cholesky", (CholeskyOp<complex128>), complex128); 181 REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<float>), float); 182 REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<double>), double); 183 184 } // namespace tensorflow 185