Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/linalg_ops.cc.
     17 
     18 #if GOOGLE_CUDA
     19 #define EIGEN_USE_GPU
     20 #endif  // GOOGLE_CUDA
     21 
     22 #include "third_party/eigen3/Eigen/Cholesky"
     23 #include "third_party/eigen3/Eigen/Core"
     24 #include "tensorflow/core/framework/kernel_def_builder.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/kernels/linalg_ops_common.h"
     29 #include "tensorflow/core/lib/core/errors.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/platform/types.h"
     32 
     33 #if GOOGLE_CUDA
     34 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     35 #include "tensorflow/core/kernels/cuda_solvers.h"
     36 #include "tensorflow/core/kernels/matrix_band_part_op.h"
     37 #include "tensorflow/core/platform/stream_executor.h"
     38 #endif
     39 
     40 namespace tensorflow {
     41 
     42 static const char kErrMsg[] =
     43     "Cholesky decomposition was not successful. The input might not be valid.";
     44 
     45 template <class Scalar>
     46 class CholeskyOp : public LinearAlgebraOp<Scalar> {
     47  public:
     48   INHERIT_LINALG_TYPEDEFS(Scalar);
     49 
     50   explicit CholeskyOp(OpKernelConstruction* context) : Base(context) {}
     51 
     52   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
     53                      MatrixMaps* outputs) final {
     54     const ConstMatrixMap& input = inputs[0];
     55     if (input.rows() == 0) {
     56       // If X is an empty matrix (0 rows, 0 col), X * X' == X.
     57       // Therefore, we return X.
     58       return;
     59     }
     60     // Perform the actual LL^T Cholesky decomposition. This will only use
     61     // the lower triangular part of data_in by default. The upper triangular
     62     // part of the matrix will not be read.
     63     Eigen::LLT<
     64         Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
     65         llt_decomposition(input);
     66 
     67     OP_REQUIRES(context, llt_decomposition.info() == Eigen::Success,
     68                 errors::InvalidArgument(kErrMsg));
     69 
     70     // Output the lower triangular in a dense form.
     71     outputs->at(0) = llt_decomposition.matrixL();
     72   }
     73 };
     74 
     75 #if GOOGLE_CUDA
     76 typedef Eigen::GpuDevice GPUDevice;
     77 
     78 namespace functor {
     79 #define DECLARE_GPU_SPEC(T)                                            \
     80   template <>                                                          \
     81   struct MatrixBandPartFunctor<GPUDevice, T> {                         \
     82     void operator()(OpKernelContext* context, const GPUDevice& device, \
     83                     int num_upper_diags, int num_lower_diags,          \
     84                     typename TTypes<T, 3>::ConstTensor input,          \
     85                     typename TTypes<T, 3>::Tensor output);             \
     86   };                                                                   \
     87   extern template struct MatrixBandPartFunctor<GPUDevice, T>;
     88 
     89 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
     90 TF_CALL_complex64(DECLARE_GPU_SPEC);
     91 TF_CALL_complex128(DECLARE_GPU_SPEC);
     92 }  // namespace functor
     93 
     94 template <class Scalar>
     95 class CholeskyOpGpu : public AsyncOpKernel {
     96  public:
     97   explicit CholeskyOpGpu(OpKernelConstruction* context)
     98       : AsyncOpKernel(context) {}
     99 
    100   void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
    101     const Tensor& input = context->input(0);
    102     const int ndims = input.dims();
    103     const int64 n = input.dim_size(ndims - 1);
    104     // Validate inputs.
    105     OP_REQUIRES_ASYNC(
    106         context, ndims >= 2,
    107         errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
    108         done);
    109     OP_REQUIRES_ASYNC(
    110         context, input.dim_size(ndims - 2) == n,
    111         errors::InvalidArgument("Input matrices must be squares, got",
    112                                 input.dim_size(ndims - 2), " != ", n),
    113         done);
    114 
    115     if (input.NumElements() == 0) {
    116       // If X is an empty matrix (0 rows, 0 col), X * X' == X.
    117       // Therefore, we return X.
    118       context->set_output(0, input);
    119       done();
    120       return;
    121     }
    122 
    123     // Allocate output.
    124     // TODO(rmlarsen): Convert to std::make_unique when available.
    125     std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
    126     Tensor* output;
    127     OP_REQUIRES_OK_ASYNC(context,
    128                          context->forward_input_or_allocate_output(
    129                              {0}, 0, input.shape(), &output),
    130                          done);
    131 
    132     // Copy the lower triangular part of the input matrices to the output and
    133     // set the strictly upper triangular part to zero. We use a pre-existing
    134     // kernel MatrixBandPart to do this for all matrices in the batch at once,
    135     // before we launch each of the Cholesky factorization kernels in paralle.
    136     auto input_reshaped = input.template flat_inner_dims<Scalar, 3>();
    137     auto output_reshaped = output->template flat_inner_dims<Scalar, 3>();
    138     functor::MatrixBandPartFunctor<GPUDevice, Scalar> band_part;
    139     band_part(context, context->eigen_device<GPUDevice>(),
    140               n /* num_lower_diags */, 0 /* num_upper_diags */, input_reshaped,
    141               output_reshaped);
    142 
    143     // Launch a Cholesky kernel for each matrix in the batch.
    144     const int64 batch_size = input_reshaped.dimension(0);
    145     std::vector<DeviceLapackInfo> dev_info;
    146     dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrf"));
    147     // TODO(rmlarsen): Parallelize over batches if it turns out to be
    148     // an important use case.
    149     for (int batch = 0; batch < batch_size; ++batch) {
    150       OP_REQUIRES_OK_ASYNC(context,
    151                            solver->Potrf(CUBLAS_FILL_MODE_UPPER, n,
    152                                          &output_reshaped(batch, 0, 0), n,
    153                                          &dev_info.back()(batch)),
    154                            done);
    155     }
    156 
    157     // Register callback to check info after kernels finish.
    158     auto info_checker = [context, done](
    159                             const Status& status,
    160                             const std::vector<HostLapackInfo>& /* unused */) {
    161       OP_REQUIRES_ASYNC(context, status.ok(), errors::InvalidArgument(kErrMsg),
    162                         done);
    163       done();
    164     };
    165     CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
    166                                                     std::move(info_checker));
    167   }
    168 };
    169 
    170 REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<float>), float);
    171 REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<double>), double);
    172 REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<complex64>), complex64);
    173 REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<complex128>), complex128);
    174 
    175 #endif  // GOOGLE_CUDA
    176 
    177 REGISTER_LINALG_OP("Cholesky", (CholeskyOp<float>), float);
    178 REGISTER_LINALG_OP("Cholesky", (CholeskyOp<double>), double);
    179 REGISTER_LINALG_OP("Cholesky", (CholeskyOp<complex64>), complex64);
    180 REGISTER_LINALG_OP("Cholesky", (CholeskyOp<complex128>), complex128);
    181 REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<float>), float);
    182 REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<double>), double);
    183 
    184 }  // namespace tensorflow
    185