Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if GOOGLE_CUDA
     17 
     18 #define EIGEN_USE_GPU
     19 
     20 #include "tensorflow/core/kernels/determinant_op.h"
     21 
     22 #include <complex>
     23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     24 #include "tensorflow/core/framework/tensor_types.h"
     25 #include "tensorflow/core/kernels/cuda_solvers.h"
     26 #include "tensorflow/core/util/cuda_kernel_helper.h"
     27 
     28 namespace tensorflow {
     29 namespace functor {
     30 
     31 typedef Eigen::GpuDevice GPUDevice;
     32 namespace {
     33 __device__ int PermutationOrder(int n, const int* pivots) {
     34   // Compute the order of the permutation from the number of transpositions
     35   // encoded in the pivot array, see:
     36   // http://icl.cs.utk.edu/lapack-forum/viewtopic.php?f=2&t=340
     37   int order = 0;
     38   for (int i = 0; i < n - 1; ++i) {
     39     // Notice: Internally, the cuBlas code uses Fortran convention (1-based)
     40     // indexing so we expect pivots[i] == i + 1 for rows that were not moved.
     41     order += pivots[i] != (i + 1);
     42   }
     43   return order;
     44 }
     45 
     46 #if defined(__CUDACC__)
     47 // Hack around missing support for complex in NVCC.
     48 template <typename T>
     49 __device__ inline std::complex<T> complex_multiply(const std::complex<T>& a,
     50                                                    const std::complex<T>& b) {
     51   const T a_real = Eigen::numext::real(a);
     52   const T a_imag = Eigen::numext::imag(a);
     53   const T b_real = Eigen::numext::real(b);
     54   const T b_imag = Eigen::numext::imag(b);
     55   return std::complex<T>(a_real * b_real - a_imag * b_imag,
     56                          a_real * b_imag + a_imag * b_real);
     57 }
     58 __device__ inline complex64 operator*(const complex64& a, const complex64& b) {
     59   return complex_multiply<float>(a, b);
     60 }
     61 __device__ inline complex64 operator*(const complex64& a, const float& b) {
     62   return complex64(Eigen::numext::real(a) * b, Eigen::numext::imag(a) * b);
     63 }
     64 __device__ inline complex64 operator/(const complex64& a, const float& b) {
     65   const float inv_b = 1.0f / b;
     66   return a * inv_b;
     67 }
     68 __device__ inline complex128 operator*(const complex128& a,
     69                                        const complex128& b) {
     70   return complex_multiply<double>(a, b);
     71 }
     72 __device__ inline complex128 operator*(const complex128& a, const double& b) {
     73   return complex128(Eigen::numext::real(a) * b, Eigen::numext::imag(a) * b);
     74 }
     75 __device__ inline complex128 operator/(const complex128& a, const double& b) {
     76   const double inv_b = 1.0 / b;
     77   return a * inv_b;
     78 }
     79 #endif
     80 }  // namespace
     81 
     82 // This kernel computes either determinant or log_abs_determinant, depending
     83 // on the value of the template parameter. If compute_log_abs_det is false,
     84 // the sign argument is ignored.
     85 template <typename Scalar, bool compute_log_abs_det = true>
     86 __global__ void DeterminantFromPivotedLUKernel(int nthreads, int n,
     87                                                const Scalar* lu_factor,
     88                                                const int* all_pivots,
     89                                                Scalar* sign,
     90                                                Scalar* log_abs_det) {
     91   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
     92   const int matrix_size = n * n;
     93   const int stride = n + 1;
     94   // We only parallelize over batches here. Performance is not critical,
     95   // since this cheap O(n) kernel always follows an O(n^3) LU factorization.
     96   // The main purpose is to avoid having to copy the LU decomposition to
     97   // host memory.
     98   CUDA_1D_KERNEL_LOOP(o_idx, nthreads) {
     99     // Initialize sign to (-1)^order.
    100     const int order = PermutationOrder(n, all_pivots + o_idx * n);
    101     Scalar prod_sign = order % 2 ? Scalar(-1) : Scalar(1);
    102     RealScalar sum_log_abs_det = RealScalar(0);
    103     int i_idx = matrix_size * o_idx;
    104     for (int i = 0; i < n; ++i, i_idx += stride) {
    105       const RealScalar abs_i = Eigen::numext::abs(lu_factor[i_idx]);
    106       sum_log_abs_det += Eigen::numext::log(abs_i);
    107       prod_sign = prod_sign * (lu_factor[i_idx] / abs_i);
    108     }
    109     if (!Eigen::numext::isfinite(sum_log_abs_det)) {
    110       prod_sign = Scalar(0);
    111       sum_log_abs_det = sum_log_abs_det > 0 ? -Eigen::numext::log(RealScalar(0))
    112                                             : Eigen::numext::log(RealScalar(0));
    113     }
    114     if (compute_log_abs_det) {
    115       sign[o_idx] = prod_sign;
    116       log_abs_det[o_idx] = Scalar(sum_log_abs_det);
    117     } else {
    118       log_abs_det[o_idx] = prod_sign * Eigen::numext::exp(sum_log_abs_det);
    119     }
    120   }
    121 }
    122 
    123 template <typename Scalar>
    124 struct DeterminantFromPivotedLUFunctor<GPUDevice, Scalar> {
    125   void operator()(const GPUDevice& device,
    126                   typename TTypes<Scalar, 3>::ConstTensor lu_factor,
    127                   const int* pivots, typename TTypes<Scalar, 1>::Tensor output,
    128                   int* info) {
    129     const int64 num_matrices = output.size();
    130     const int64 n = lu_factor.dimension(2);
    131     CudaLaunchConfig config = GetCudaLaunchConfig(num_matrices, device);
    132     DeterminantFromPivotedLUKernel<Scalar, /*compute_log_abs_det=*/false>
    133         <<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
    134             config.virtual_thread_count, n, lu_factor.data(), pivots, nullptr,
    135             output.data());
    136   }
    137 };
    138 
    139 template struct DeterminantFromPivotedLUFunctor<GPUDevice, float>;
    140 template struct DeterminantFromPivotedLUFunctor<GPUDevice, double>;
    141 template struct DeterminantFromPivotedLUFunctor<GPUDevice, complex64>;
    142 template struct DeterminantFromPivotedLUFunctor<GPUDevice, complex128>;
    143 
    144 template <typename Scalar>
    145 struct LogDeterminantFromPivotedLUFunctor<GPUDevice, Scalar> {
    146   void operator()(const GPUDevice& device,
    147                   typename TTypes<Scalar, 3>::ConstTensor lu_factor,
    148                   const int* pivots, typename TTypes<Scalar, 1>::Tensor sign,
    149                   typename TTypes<Scalar, 1>::Tensor log_abs_det) {
    150     const int64 num_matrices = sign.size();
    151     const int64 n = lu_factor.dimension(2);
    152     CudaLaunchConfig config = GetCudaLaunchConfig(num_matrices, device);
    153     DeterminantFromPivotedLUKernel<Scalar, /*compute_log_abs_det=*/true>
    154         <<<config.block_count, config.thread_per_block, 0, device.stream()>>>(
    155             config.virtual_thread_count, n, lu_factor.data(), pivots,
    156             sign.data(), log_abs_det.data());
    157   }
    158 };
    159 
    160 template struct LogDeterminantFromPivotedLUFunctor<GPUDevice, float>;
    161 template struct LogDeterminantFromPivotedLUFunctor<GPUDevice, double>;
    162 template struct LogDeterminantFromPivotedLUFunctor<GPUDevice, complex64>;
    163 template struct LogDeterminantFromPivotedLUFunctor<GPUDevice, complex128>;
    164 
    165 }  // namespace functor
    166 }  // namespace tensorflow
    167 
    168 #endif  // GOOGLE_CUDA
    169