Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/linalg_ops.cc.
     17 
     18 #include <cmath>
     19 
     20 #if GOOGLE_CUDA
     21 #define EIGEN_USE_GPU
     22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     23 #include "tensorflow/core/kernels/determinant_op.h"
     24 #endif
     25 
     26 #include "third_party/eigen3/Eigen/LU"
     27 #include "tensorflow/core/framework/kernel_def_builder.h"
     28 #include "tensorflow/core/framework/numeric_types.h"
     29 #include "tensorflow/core/framework/op_kernel.h"
     30 #include "tensorflow/core/framework/tensor_shape.h"
     31 #include "tensorflow/core/kernels/linalg_ops_common.h"
     32 #include "tensorflow/core/lib/core/errors.h"
     33 #include "tensorflow/core/platform/logging.h"
     34 #include "tensorflow/core/platform/types.h"
     35 
     36 #if GOOGLE_CUDA
     37 #include "tensorflow/core/kernels/cuda_solvers.h"
     38 #include "tensorflow/core/kernels/fill_functor.h"
     39 #endif
     40 
     41 namespace tensorflow {
     42 
     43 // A helper function to compute the sign and absolute value of the log of the
     44 // determinant of inputs via a partially pivoted LU
     45 // factorization.
     46 //
     47 // Returns the log of the absolute value of the determinant, and its sign in
     48 // 'sign'.
     49 template <class Scalar>
     50 static typename Eigen::NumTraits<Scalar>::Real SLogDet(
     51     const Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>& inputs,
     52     Scalar* sign) {
     53   using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
     54   RealScalar log_abs_det = 0;
     55   *sign = 1;
     56   // An empty matrix' determinant is defined to be 1.
     57   // (https://en.wikipedia.org/wiki/Determinant)
     58   if (inputs.size() > 0) {
     59     // Compute the log determinant through a Partially Pivoted LU decomposition
     60     using Eigen::Dynamic;
     61     Eigen::PartialPivLU<Eigen::Matrix<Scalar, Dynamic, Dynamic>> lu(inputs);
     62     Eigen::Matrix<Scalar, Dynamic, Dynamic> LU = lu.matrixLU();
     63     *sign = lu.permutationP().determinant();
     64     auto diag = LU.diagonal().array().eval();
     65     auto abs_diag = diag.cwiseAbs().eval();
     66     log_abs_det += abs_diag.log().sum();
     67     *sign *= (diag / abs_diag).prod();
     68   }
     69   if (!Eigen::numext::isfinite(log_abs_det)) {
     70     *sign = 0;
     71     log_abs_det =
     72         log_abs_det > 0 ? -std::log(RealScalar(0)) : std::log(RealScalar(0));
     73   }
     74   return log_abs_det;
     75 }
     76 
     77 template <class Scalar>
     78 class LogDeterminantOp : public LinearAlgebraOp<Scalar> {
     79  public:
     80   INHERIT_LINALG_TYPEDEFS(Scalar);
     81 
     82   explicit LogDeterminantOp(OpKernelConstruction* context) : Base(context) {}
     83 
     84   TensorShapes GetOutputMatrixShapes(
     85       const TensorShapes& input_matrix_shapes) const final {
     86     return TensorShapes({TensorShape({}), TensorShape({})});
     87   }
     88 
     89   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
     90                      MatrixMaps* outputs) final {
     91     Scalar sign;
     92     const RealScalar log_abs_det = SLogDet(
     93         Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>(inputs[0]),
     94         &sign);
     95 
     96     outputs->at(0)(0, 0) = sign;
     97     outputs->at(1)(0, 0) = log_abs_det;
     98   }
     99 };
    100 
    101 template <class Scalar>
    102 class DeterminantOp : public LinearAlgebraOp<Scalar> {
    103  public:
    104   INHERIT_LINALG_TYPEDEFS(Scalar);
    105 
    106   explicit DeterminantOp(OpKernelConstruction* context) : Base(context) {}
    107 
    108   TensorShapes GetOutputMatrixShapes(
    109       const TensorShapes& input_matrix_shape) const final {
    110     return TensorShapes({TensorShape({})});
    111   }
    112 
    113   void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
    114                      MatrixMaps* outputs) final {
    115     Scalar sign;
    116     const RealScalar log_abs_det = SLogDet(
    117         Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>(inputs[0]),
    118         &sign);
    119     outputs->at(0)(0, 0) = sign * std::exp(log_abs_det);
    120   }
    121 };
    122 
    123 #if GOOGLE_CUDA
    124 
    125 typedef Eigen::GpuDevice GPUDevice;
    126 
    127 template <class Scalar>
    128 class DeterminantOpGpu : public AsyncOpKernel {
    129  public:
    130   explicit DeterminantOpGpu(OpKernelConstruction* context)
    131       : AsyncOpKernel(context) {}
    132 
    133   void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
    134     const Tensor& input = context->input(0);
    135     const int ndims = input.dims();
    136     const int64 n = input.dim_size(ndims - 1);
    137     // Validate inputs.
    138     OP_REQUIRES_ASYNC(
    139         context, ndims >= 2,
    140         errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
    141         done);
    142     OP_REQUIRES_ASYNC(
    143         context, input.dim_size(ndims - 2) == n,
    144         errors::InvalidArgument("Input matrices must be square, got",
    145                                 input.dim_size(ndims - 2), " != ", n),
    146         done);
    147 
    148     // Allocate output.
    149     TensorShape out_shape;
    150     for (int dim = 0; dim < ndims - 2; ++dim) {
    151       out_shape.AddDim(input.dim_size(dim));
    152     }
    153     out_shape.AppendShape(TensorShape({}));
    154     Tensor* out;
    155     OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, out_shape, &out),
    156                          done);
    157 
    158     // By definition, the determinant of an empty matrix is equal to one.
    159     const GPUDevice& d = context->eigen_device<GPUDevice>();
    160     if (input.NumElements() == 0) {
    161       functor::SetOneFunctor<GPUDevice, Scalar> f;
    162       f(d, out->template flat<Scalar>());
    163       done();
    164       return;
    165     }
    166 
    167     // TODO(rmlarsen): Convert to absl::make_unique when available.
    168     std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
    169 
    170     // Reuse the input buffer or make a copy for the factorization step,
    171     // depending on whether this ops owns it exclusively.
    172     Tensor input_copy;
    173     OP_REQUIRES_OK_ASYNC(
    174         context,
    175         solver->forward_input_or_allocate_scoped_tensor(
    176             {0}, DataTypeToEnum<Scalar>::value, input.shape(), &input_copy),
    177         done);
    178     if (!input.SharesBufferWith(input_copy)) {
    179       d.memcpy(input_copy.flat<Scalar>().data(), input.flat<Scalar>().data(),
    180                input.NumElements() * sizeof(Scalar));
    181     }
    182     auto input_copy_reshaped = input_copy.template flat_inner_dims<Scalar, 3>();
    183     const int64 batch_size = input_copy_reshaped.dimension(0);
    184 
    185     // Allocate pivots on the device.
    186     Tensor pivots;
    187     OP_REQUIRES_OK_ASYNC(
    188         context,
    189         solver->allocate_scoped_tensor(DataTypeToEnum<int>::value,
    190                                        TensorShape{batch_size, n}, &pivots),
    191         done);
    192     auto pivots_mat = pivots.template matrix<int>();
    193 
    194     // Prepare pointer arrays for cuBlas' batch interface.
    195     // TODO(rmlarsen): Find a way to encode pointer arrays in pinned host memory
    196     // without the ugly casting.
    197     auto input_copy_ptrs = solver->GetScratchSpace<uint8>(
    198         sizeof(Scalar*) * batch_size, "input_copy_ptrs",
    199         /* on_host */ true);
    200     auto output_reshaped = out->template flat_inner_dims<Scalar, 1>();
    201 
    202     // Compute the partially pivoted LU factorization(s) of the matrix/matrices.
    203     std::vector<DeviceLapackInfo> dev_info;
    204     if (n / batch_size <= 128) {
    205       // For small matrices or large batch sizes, we use the batched interface
    206       // from cuBlas.
    207       const Scalar** input_copy_ptrs_base =
    208           reinterpret_cast<const Scalar**>(input_copy_ptrs.mutable_data());
    209       for (int batch = 0; batch < batch_size; ++batch) {
    210         input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0);
    211       }
    212       dev_info.push_back(
    213           solver->GetDeviceLapackInfo(batch_size, "getrfBatched"));
    214       OP_REQUIRES_OK_ASYNC(
    215           context,
    216           solver->GetrfBatched(n, input_copy_ptrs_base, n, pivots_mat.data(),
    217                                &dev_info.back(), batch_size),
    218           done);
    219     } else {
    220       // For small batch sizes we use the non-batched interface from cuSolver,
    221       // which is much faster for large matrices.
    222       dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf"));
    223       for (int batch = 0; batch < batch_size; ++batch) {
    224         OP_REQUIRES_OK_ASYNC(
    225             context,
    226             solver->Getrf(n, n, &input_copy_reshaped(batch, 0, 0), n,
    227                           &pivots_mat(batch, 0), &dev_info.back()(batch)),
    228             done);
    229       }
    230     }
    231 
    232     // Compute the determinant for each batch as (-1)^s * prod(diag(U)),
    233     // where s is the order of the permutation encoded in pivots and U is the
    234     // upper triangular factor of the LU factorization, which is written to
    235     // input_copy by the Getrf{Batched} kernel.
    236     functor::DeterminantFromPivotedLUFunctor<GPUDevice, Scalar> functor;
    237     functor(d,
    238             const_cast<const Tensor*>(&input_copy)
    239                 ->template flat_inner_dims<Scalar, 3>(),
    240             pivots_mat.data(), output_reshaped, dev_info.back().mutable_data());
    241 
    242     // Register callback to check info after kernels finish.
    243     auto info_checker = [context, done](
    244                             const Status& status,
    245                             const std::vector<HostLapackInfo>& host_infos) {
    246       if (!status.ok() && errors::IsInvalidArgument(status) &&
    247           !host_infos.empty()) {
    248         for (int i = 0; i < host_infos[0].size(); ++i) {
    249           // It is OK for a matrix to be singular (signaled by info > 0),
    250           // corresponding to determinant of zero, but we do want to catch
    251           // invalid arguments to Getrf{Batched}.
    252           OP_REQUIRES_ASYNC(
    253               context, host_infos[0](i) >= 0,
    254               errors::InvalidArgument("Invalid input argument no. ",
    255                                       host_infos[0].data()[i],
    256                                       " for batch index ", i, "."),
    257               done);
    258         }
    259       }
    260       done();
    261     };
    262     CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
    263                                                     std::move(info_checker));
    264   }
    265 };
    266 
    267 template <class Scalar>
    268 class LogDeterminantOpGpu : public AsyncOpKernel {
    269  public:
    270   explicit LogDeterminantOpGpu(OpKernelConstruction* context)
    271       : AsyncOpKernel(context) {}
    272 
    273   void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
    274     const Tensor& input = context->input(0);
    275     const int ndims = input.dims();
    276     const int64 n = input.dim_size(ndims - 1);
    277     // Validate inputs.
    278     OP_REQUIRES_ASYNC(
    279         context, ndims >= 2,
    280         errors::InvalidArgument("Input must have rank >= 2, got ", ndims),
    281         done);
    282     OP_REQUIRES_ASYNC(
    283         context, input.dim_size(ndims - 2) == n,
    284         errors::InvalidArgument("Input matrices must be square, got",
    285                                 input.dim_size(ndims - 2), " != ", n),
    286         done);
    287 
    288     // Allocate output.
    289     TensorShape out_shape;
    290     for (int dim = 0; dim < ndims - 2; ++dim) {
    291       out_shape.AddDim(input.dim_size(dim));
    292     }
    293     out_shape.AppendShape(TensorShape({}));
    294     Tensor* sign;
    295     OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, out_shape, &sign),
    296                          done);
    297     Tensor* log_abs_det;
    298     OP_REQUIRES_OK_ASYNC(
    299         context, context->allocate_output(1, out_shape, &log_abs_det), done);
    300 
    301     // By definition, the determinant of an empty matrix is equal to one.
    302     const GPUDevice& d = context->eigen_device<GPUDevice>();
    303     if (input.NumElements() == 0) {
    304       functor::SetOneFunctor<GPUDevice, Scalar> one_func;
    305       one_func(d, sign->template flat<Scalar>());
    306       functor::SetZeroFunctor<GPUDevice, Scalar> zero_func;
    307       zero_func(d, log_abs_det->template flat<Scalar>());
    308       done();
    309       return;
    310     }
    311 
    312     // TODO(rmlarsen): Convert to absl::make_unique when available.
    313     std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
    314 
    315     // Reuse the input buffer or make a copy for the factorization step,
    316     // depending on whether this ops owns it exclusively.
    317     Tensor input_copy;
    318     OP_REQUIRES_OK_ASYNC(
    319         context,
    320         solver->forward_input_or_allocate_scoped_tensor(
    321             {0}, DataTypeToEnum<Scalar>::value, input.shape(), &input_copy),
    322         done);
    323     if (!input.SharesBufferWith(input_copy)) {
    324       d.memcpy(input_copy.flat<Scalar>().data(), input.flat<Scalar>().data(),
    325                input.NumElements() * sizeof(Scalar));
    326     }
    327     auto input_copy_reshaped = input_copy.template flat_inner_dims<Scalar, 3>();
    328     const int64 batch_size = input_copy_reshaped.dimension(0);
    329 
    330     // Allocate pivots on the device.
    331     Tensor pivots;
    332     OP_REQUIRES_OK_ASYNC(
    333         context,
    334         solver->allocate_scoped_tensor(DataTypeToEnum<int>::value,
    335                                        TensorShape{batch_size, n}, &pivots),
    336         done);
    337     auto pivots_mat = pivots.template matrix<int>();
    338 
    339     // Prepare pointer arrays for cuBlas' batch interface.
    340     // TODO(rmlarsen): Find a way to encode pointer arrays in pinned host memory
    341     // without the ugly casting.
    342     auto input_copy_ptrs = solver->GetScratchSpace<uint8>(
    343         sizeof(Scalar*) * batch_size, "input_copy_ptrs",
    344         /* on_host */ true);
    345 
    346     // Compute the partially pivoted LU factorization(s) of the matrix/matrices.
    347     std::vector<DeviceLapackInfo> dev_info;
    348     if (n / batch_size <= 128) {
    349       // For small matrices or large batch sizes, we use the batched interface
    350       // from cuBlas.
    351       const Scalar** input_copy_ptrs_base =
    352           reinterpret_cast<const Scalar**>(input_copy_ptrs.mutable_data());
    353       for (int batch = 0; batch < batch_size; ++batch) {
    354         input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0);
    355       }
    356       dev_info.push_back(
    357           solver->GetDeviceLapackInfo(batch_size, "getrfBatched"));
    358       OP_REQUIRES_OK_ASYNC(
    359           context,
    360           solver->GetrfBatched(n, input_copy_ptrs_base, n, pivots_mat.data(),
    361                                &dev_info.back(), batch_size),
    362           done);
    363     } else {
    364       // For large matrices or small batch sizes we use the non-batched
    365       // interface from cuSolver, which is much faster for large matrices.
    366       dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf"));
    367       for (int batch = 0; batch < batch_size; ++batch) {
    368         OP_REQUIRES_OK_ASYNC(
    369             context,
    370             solver->Getrf(n, n, &input_copy_reshaped(batch, 0, 0), n,
    371                           &pivots_mat(batch, 0), &dev_info.back()(batch)),
    372             done);
    373       }
    374     }
    375 
    376     auto input_copy_reshaped_const =
    377         const_cast<const Tensor*>(&input_copy)
    378             ->template flat_inner_dims<Scalar, 3>();
    379     auto sign_reshaped = sign->flat<Scalar>();
    380     auto log_abs_det_reshaped = log_abs_det->flat<Scalar>();
    381     // Compute the determinant for each batch as (-1)^s * prod(diag(U)),
    382     // where s is the order of the permutation encoded in pivots and U is the
    383     // upper triangular factor of the LU factorization, which is written to
    384     // input_copy by the Getrf{Batched} kernel.
    385     functor::LogDeterminantFromPivotedLUFunctor<GPUDevice, Scalar> functor;
    386     functor(d, input_copy_reshaped_const, pivots_mat.data(), sign_reshaped,
    387             log_abs_det_reshaped);
    388 
    389     // Register callback to check info after kernels finish.
    390     auto info_checker = [context, done](
    391                             const Status& status,
    392                             const std::vector<HostLapackInfo>& host_infos) {
    393       if (!status.ok() && errors::IsInvalidArgument(status) &&
    394           !host_infos.empty()) {
    395         for (int i = 0; i < host_infos[0].size(); ++i) {
    396           // It is OK for a matrix to be singular (signaled by info > 0),
    397           // corresponding to determinant of zero, but we do want to catch
    398           // invalid arguments to Getrf{Batched}.
    399           OP_REQUIRES_ASYNC(
    400               context, host_infos[0](i) >= 0,
    401               errors::InvalidArgument("Invalid input argument no. ",
    402                                       host_infos[0].data()[i],
    403                                       " for batch index ", i, "."),
    404               done);
    405         }
    406       }
    407       done();
    408     };
    409     CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
    410                                                     std::move(info_checker));
    411   }
    412 };
    413 
    414 REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu<float>), float);
    415 REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu<double>), double);
    416 REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu<complex64>),
    417                        complex64);
    418 REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu<complex128>),
    419                        complex128);
    420 
    421 REGISTER_LINALG_OP_GPU("LogMatrixDeterminant", (LogDeterminantOpGpu<float>),
    422                        float);
    423 REGISTER_LINALG_OP_GPU("LogMatrixDeterminant", (LogDeterminantOpGpu<double>),
    424                        double);
    425 REGISTER_LINALG_OP_GPU("LogMatrixDeterminant", (LogDeterminantOpGpu<complex64>),
    426                        complex64);
    427 REGISTER_LINALG_OP_GPU("LogMatrixDeterminant",
    428                        (LogDeterminantOpGpu<complex128>), complex128);
    429 #endif  // GOOGLE_CUDA
    430 
    431 REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<float>), float);
    432 REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<double>), double);
    433 REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<complex64>), complex64);
    434 REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<complex128>),
    435                    complex128);
    436 REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<float>), float);
    437 REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<double>), double);
    438 REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<complex64>),
    439                    complex64);
    440 REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<complex128>),
    441                    complex128);
    442 
    443 REGISTER_LINALG_OP("LogMatrixDeterminant", (LogDeterminantOp<float>), float);
    444 REGISTER_LINALG_OP("LogMatrixDeterminant", (LogDeterminantOp<double>), double);
    445 REGISTER_LINALG_OP("LogMatrixDeterminant", (LogDeterminantOp<complex64>),
    446                    complex64);
    447 REGISTER_LINALG_OP("LogMatrixDeterminant", (LogDeterminantOp<complex128>),
    448                    complex128);
    449 }  // namespace tensorflow
    450