1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #if GOOGLE_CUDA 17 18 #define EIGEN_USE_GPU 19 20 #include "tensorflow/core/kernels/determinant_op.h" 21 22 #include <complex> 23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 24 #include "tensorflow/core/framework/tensor_types.h" 25 #include "tensorflow/core/kernels/cuda_solvers.h" 26 #include "tensorflow/core/util/cuda_kernel_helper.h" 27 28 namespace tensorflow { 29 namespace functor { 30 31 typedef Eigen::GpuDevice GPUDevice; 32 namespace { 33 __device__ int PermutationOrder(int n, const int* pivots) { 34 // Compute the order of the permutation from the number of transpositions 35 // encoded in the pivot array, see: 36 // http://icl.cs.utk.edu/lapack-forum/viewtopic.php?f=2&t=340 37 int order = 0; 38 for (int i = 0; i < n - 1; ++i) { 39 // Notice: Internally, the cuBlas code uses Fortran convention (1-based) 40 // indexing so we expect pivots[i] == i + 1 for rows that were not moved. 41 order += pivots[i] != (i + 1); 42 } 43 return order; 44 } 45 46 #if defined(__CUDACC__) 47 // Hack around missing support for complex in NVCC. 48 template <typename T> 49 __device__ inline std::complex<T> complex_multiply(const std::complex<T>& a, 50 const std::complex<T>& b) { 51 const T a_real = Eigen::numext::real(a); 52 const T a_imag = Eigen::numext::imag(a); 53 const T b_real = Eigen::numext::real(b); 54 const T b_imag = Eigen::numext::imag(b); 55 return std::complex<T>(a_real * b_real - a_imag * b_imag, 56 a_real * b_imag + a_imag * b_real); 57 } 58 __device__ inline complex64 operator*(const complex64& a, const complex64& b) { 59 return complex_multiply<float>(a, b); 60 } 61 __device__ inline complex64 operator*(const complex64& a, const float& b) { 62 return complex64(Eigen::numext::real(a) * b, Eigen::numext::imag(a) * b); 63 } 64 __device__ inline complex64 operator/(const complex64& a, const float& b) { 65 const float inv_b = 1.0f / b; 66 return a * inv_b; 67 } 68 __device__ inline complex128 operator*(const complex128& a, 69 const complex128& b) { 70 return complex_multiply<double>(a, b); 71 } 72 __device__ inline complex128 operator*(const complex128& a, const double& b) { 73 return complex128(Eigen::numext::real(a) * b, Eigen::numext::imag(a) * b); 74 } 75 __device__ inline complex128 operator/(const complex128& a, const double& b) { 76 const double inv_b = 1.0 / b; 77 return a * inv_b; 78 } 79 #endif 80 } // namespace 81 82 // This kernel computes either determinant or log_abs_determinant, depending 83 // on the value of the template parameter. If compute_log_abs_det is false, 84 // the sign argument is ignored. 85 template <typename Scalar, bool compute_log_abs_det = true> 86 __global__ void DeterminantFromPivotedLUKernel(int nthreads, int n, 87 const Scalar* lu_factor, 88 const int* all_pivots, 89 Scalar* sign, 90 Scalar* log_abs_det) { 91 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; 92 const int matrix_size = n * n; 93 const int stride = n + 1; 94 // We only parallelize over batches here. Performance is not critical, 95 // since this cheap O(n) kernel always follows an O(n^3) LU factorization. 96 // The main purpose is to avoid having to copy the LU decomposition to 97 // host memory. 98 CUDA_1D_KERNEL_LOOP(o_idx, nthreads) { 99 // Initialize sign to (-1)^order. 100 const int order = PermutationOrder(n, all_pivots + o_idx * n); 101 Scalar prod_sign = order % 2 ? Scalar(-1) : Scalar(1); 102 RealScalar sum_log_abs_det = RealScalar(0); 103 int i_idx = matrix_size * o_idx; 104 for (int i = 0; i < n; ++i, i_idx += stride) { 105 const RealScalar abs_i = Eigen::numext::abs(lu_factor[i_idx]); 106 sum_log_abs_det += Eigen::numext::log(abs_i); 107 prod_sign = prod_sign * (lu_factor[i_idx] / abs_i); 108 } 109 if (!Eigen::numext::isfinite(sum_log_abs_det)) { 110 prod_sign = Scalar(0); 111 sum_log_abs_det = sum_log_abs_det > 0 ? -Eigen::numext::log(RealScalar(0)) 112 : Eigen::numext::log(RealScalar(0)); 113 } 114 if (compute_log_abs_det) { 115 sign[o_idx] = prod_sign; 116 log_abs_det[o_idx] = Scalar(sum_log_abs_det); 117 } else { 118 log_abs_det[o_idx] = prod_sign * Eigen::numext::exp(sum_log_abs_det); 119 } 120 } 121 } 122 123 template <typename Scalar> 124 struct DeterminantFromPivotedLUFunctor<GPUDevice, Scalar> { 125 void operator()(const GPUDevice& device, 126 typename TTypes<Scalar, 3>::ConstTensor lu_factor, 127 const int* pivots, typename TTypes<Scalar, 1>::Tensor output, 128 int* info) { 129 const int64 num_matrices = output.size(); 130 const int64 n = lu_factor.dimension(2); 131 CudaLaunchConfig config = GetCudaLaunchConfig(num_matrices, device); 132 DeterminantFromPivotedLUKernel<Scalar, /*compute_log_abs_det=*/false> 133 <<<config.block_count, config.thread_per_block, 0, device.stream()>>>( 134 config.virtual_thread_count, n, lu_factor.data(), pivots, nullptr, 135 output.data()); 136 } 137 }; 138 139 template struct DeterminantFromPivotedLUFunctor<GPUDevice, float>; 140 template struct DeterminantFromPivotedLUFunctor<GPUDevice, double>; 141 template struct DeterminantFromPivotedLUFunctor<GPUDevice, complex64>; 142 template struct DeterminantFromPivotedLUFunctor<GPUDevice, complex128>; 143 144 template <typename Scalar> 145 struct LogDeterminantFromPivotedLUFunctor<GPUDevice, Scalar> { 146 void operator()(const GPUDevice& device, 147 typename TTypes<Scalar, 3>::ConstTensor lu_factor, 148 const int* pivots, typename TTypes<Scalar, 1>::Tensor sign, 149 typename TTypes<Scalar, 1>::Tensor log_abs_det) { 150 const int64 num_matrices = sign.size(); 151 const int64 n = lu_factor.dimension(2); 152 CudaLaunchConfig config = GetCudaLaunchConfig(num_matrices, device); 153 DeterminantFromPivotedLUKernel<Scalar, /*compute_log_abs_det=*/true> 154 <<<config.block_count, config.thread_per_block, 0, device.stream()>>>( 155 config.virtual_thread_count, n, lu_factor.data(), pivots, 156 sign.data(), log_abs_det.data()); 157 } 158 }; 159 160 template struct LogDeterminantFromPivotedLUFunctor<GPUDevice, float>; 161 template struct LogDeterminantFromPivotedLUFunctor<GPUDevice, double>; 162 template struct LogDeterminantFromPivotedLUFunctor<GPUDevice, complex64>; 163 template struct LogDeterminantFromPivotedLUFunctor<GPUDevice, complex128>; 164 165 } // namespace functor 166 } // namespace tensorflow 167 168 #endif // GOOGLE_CUDA 169