1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/linalg_ops.cc. 17 18 #if GOOGLE_CUDA 19 20 #include <numeric> 21 #include <type_traits> 22 23 #define EIGEN_USE_GPU 24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 25 #include "tensorflow/core/framework/kernel_def_builder.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/kernels/cast_op.h" 29 #include "tensorflow/core/kernels/cuda_solvers.h" 30 #include "tensorflow/core/kernels/cwise_ops.h" 31 #include "tensorflow/core/kernels/transpose_functor.h" 32 #include "tensorflow/core/lib/core/errors.h" 33 #include "tensorflow/core/platform/logging.h" 34 #include "tensorflow/core/platform/types.h" 35 36 namespace tensorflow { 37 38 typedef Eigen::GpuDevice GPUDevice; 39 40 template <class Scalar> 41 class SelfAdjointEigV2OpGpu : public AsyncOpKernel { 42 public: 43 explicit SelfAdjointEigV2OpGpu(OpKernelConstruction* context) 44 : AsyncOpKernel(context) { 45 OP_REQUIRES_OK(context, context->GetAttr("compute_v", &compute_v_)); 46 } 47 48 void ComputeAsync(OpKernelContext* context, DoneCallback done) final { 49 const Tensor& input = context->input(0); 50 const int ndims = input.dims(); 51 OP_REQUIRES_ASYNC( 52 context, ndims >= 2, 53 errors::InvalidArgument("Input must have rank >= 2, got ", ndims), 54 done); 55 const int64 n = input.dim_size(ndims - 1); 56 OP_REQUIRES_ASYNC( 57 context, input.dim_size(ndims - 2) == n, 58 errors::InvalidArgument("Input matrices must be squares, got", 59 input.dim_size(ndims - 2), " != ", n), 60 done); 61 const int64 batch_size = 62 input.template flat_inner_dims<Scalar, 3>().dimension(0); 63 64 // Allocate outputs. 65 Tensor* eigenvalues; 66 TensorShape eigenvalues_shape = input.shape(); 67 eigenvalues_shape.RemoveLastDims(1); 68 OP_REQUIRES_OK_ASYNC( 69 context, context->allocate_output(0, eigenvalues_shape, &eigenvalues), 70 done); 71 Tensor* eigenvectors; 72 TensorShape eigenvectors_shape = 73 compute_v_ ? input.shape() : TensorShape({}); 74 OP_REQUIRES_OK_ASYNC( 75 context, context->allocate_output(1, eigenvectors_shape, &eigenvectors), 76 done); 77 78 if (input.NumElements() == 0) { 79 done(); 80 return; 81 } 82 83 // Allocate workspace. 84 // TODO(rmlarsen): Convert to std::make_unique when available. 85 std::unique_ptr<CudaSolver> solver(new CudaSolver(context)); 86 Tensor eigenvalues_real; 87 using RealScalar = typename Eigen::NumTraits<Scalar>::Real; 88 if (std::is_same<Scalar, RealScalar>::value) { 89 eigenvalues_real = *eigenvalues; 90 } else { 91 OP_REQUIRES_OK_ASYNC( 92 context, 93 solver->allocate_scoped_tensor(DataTypeToEnum<RealScalar>::value, 94 eigenvalues_shape, &eigenvalues_real), 95 done); 96 } 97 98 Tensor input_copy; 99 OP_REQUIRES_OK_ASYNC( 100 context, 101 solver->forward_input_or_allocate_scoped_tensor( 102 {0}, DataTypeToEnum<Scalar>::value, input.shape(), &input_copy), 103 done); 104 // For real symmetric matrices, row-major and column-major are the same. For 105 // complex Hermitian, row-major and column-major differ by a conjugation, 106 // which is still cheaper than a transpose. 107 const GPUDevice& device = context->eigen_device<GPUDevice>(); 108 if (!input.SharesBufferWith(input_copy)) { 109 if (Eigen::NumTraits<Scalar>::IsComplex) { 110 functor::UnaryFunctor<GPUDevice, functor::conj<Scalar>> conj; 111 conj(device, input_copy.flat<Scalar>() /*out*/, 112 input.flat<Scalar>() /*in*/); 113 } else { 114 device.memcpy(input_copy.flat<Scalar>().data(), 115 input.flat<Scalar>().data(), 116 input.NumElements() * sizeof(Scalar)); 117 } 118 } else if (Eigen::NumTraits<Scalar>::IsComplex) { 119 functor::UnaryFunctor<GPUDevice, functor::conj<Scalar>> conj; 120 conj(device, const_cast<Tensor*>(&input)->flat<Scalar>() /*out*/, 121 input.flat<Scalar>() /*in*/); 122 } 123 124 // Compute eigen decomposition in-place in input_copy. 125 std::vector<DeviceLapackInfo> dev_info; 126 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "heevd")); 127 auto input_copy_reshaped = input_copy.flat_inner_dims<Scalar, 3>(); 128 auto eigenvalues_real_reshaped = 129 eigenvalues_real.flat_inner_dims<RealScalar, 2>(); 130 for (int batch = 0; batch < batch_size; ++batch) { 131 OP_REQUIRES_OK_ASYNC( 132 context, 133 solver->Heevd(compute_v_ ? CUSOLVER_EIG_MODE_VECTOR 134 : CUSOLVER_EIG_MODE_NOVECTOR, 135 CUBLAS_FILL_MODE_UPPER, n, 136 &input_copy_reshaped(batch, 0, 0), n, 137 &eigenvalues_real_reshaped(batch, 0), 138 dev_info.back().mutable_data() + batch), 139 done); 140 } 141 142 if (!std::is_same<Scalar, RealScalar>::value) { 143 functor::CastFunctor<GPUDevice, Scalar, RealScalar> cast; 144 cast(device, eigenvalues->flat<Scalar>(), 145 const_cast<const Tensor*>(&eigenvalues_real)->flat<RealScalar>()); 146 } 147 148 if (compute_v_) { 149 // Transpose eigenvectors now stored in input_copy in column-major form to 150 // output in row-major form. 151 OP_REQUIRES_OK_ASYNC( 152 context, DoMatrixTranspose(device, input_copy, eigenvectors), done); 153 } 154 155 // Asynchronously check return status from cuSolver kernels. 156 CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info, 157 std::move(done)); 158 } 159 160 private: 161 bool compute_v_; 162 163 TF_DISALLOW_COPY_AND_ASSIGN(SelfAdjointEigV2OpGpu); 164 }; 165 166 #define REGISTER(Scalar) \ 167 REGISTER_KERNEL_BUILDER( \ 168 Name("SelfAdjointEigV2").Device(DEVICE_GPU).TypeConstraint<Scalar>("T"), \ 169 (SelfAdjointEigV2OpGpu<Scalar>)) 170 171 REGISTER(float); 172 REGISTER(double); 173 REGISTER(complex64); 174 REGISTER(complex128); 175 176 #undef REGISTER 177 178 } // namespace tensorflow 179 180 #endif // GOOGLE_CUDA 181