1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/linalg_ops.cc. 17 // 18 // This header file is used by the individual qr_*op*.cc files for registering 19 // individual kernels. A separate file is used for each instantiated kernel to 20 // improve compilation times. 21 #include <algorithm> 22 #include <numeric> 23 24 #if GOOGLE_CUDA 25 #define EIGEN_USE_GPU 26 #endif 27 28 #include "third_party/eigen3/Eigen/QR" 29 #include "tensorflow/core/framework/kernel_def_builder.h" 30 #include "tensorflow/core/framework/op_kernel.h" 31 #include "tensorflow/core/framework/tensor.h" 32 #include "tensorflow/core/framework/tensor_shape.h" 33 #include "tensorflow/core/kernels/linalg_ops_common.h" 34 #include "tensorflow/core/lib/core/errors.h" 35 #include "tensorflow/core/platform/logging.h" 36 #include "tensorflow/core/platform/macros.h" 37 #include "tensorflow/core/platform/types.h" 38 39 #if GOOGLE_CUDA 40 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 41 #include "tensorflow/core/kernels/cuda_solvers.h" 42 #include "tensorflow/core/kernels/cwise_ops.h" 43 #include "tensorflow/core/kernels/eye_functor.h" 44 #include "tensorflow/core/kernels/matrix_band_part_op.h" 45 #include "tensorflow/core/kernels/transpose_functor.h" 46 #endif 47 48 namespace tensorflow { 49 50 template <class Scalar> 51 class QrOp : public LinearAlgebraOp<Scalar> { 52 public: 53 typedef LinearAlgebraOp<Scalar> Base; 54 55 explicit QrOp(OpKernelConstruction* context) : Base(context) { 56 OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_)); 57 } 58 59 using TensorShapes = typename Base::TensorShapes; 60 61 void ValidateInputMatrixShapes( 62 OpKernelContext* context, 63 const TensorShapes& input_matrix_shapes) const final { 64 Base::ValidateSingleMatrix(context, input_matrix_shapes); 65 } 66 67 TensorShapes GetOutputMatrixShapes( 68 const TensorShapes& input_matrix_shapes) const final { 69 int64 m = input_matrix_shapes[0].dim_size(0); 70 int64 n = input_matrix_shapes[0].dim_size(1); 71 int64 min_size = std::min(m, n); 72 if (full_matrices_) { 73 return TensorShapes({TensorShape({m, m}), TensorShape({m, n})}); 74 } else { 75 return TensorShapes( 76 {TensorShape({m, min_size}), TensorShape({min_size, n})}); 77 } 78 } 79 80 int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { 81 double m = static_cast<double>(input_matrix_shapes[0].dim_size(0)); 82 double n = static_cast<double>(input_matrix_shapes[0].dim_size(1)); 83 double max_size = std::max(m, n); 84 double min_size = std::min(m, n); 85 double cost = 2 * max_size * min_size * min_size - 86 2 * min_size * min_size * min_size / 3.; 87 // TODO(jpoulson): Increase the cost if full_matrices is true in a manner 88 // that reflects the algorithm used for the expansion. 89 return cost >= static_cast<double>(kint64max) ? kint64max 90 : static_cast<int64>(cost); 91 } 92 93 using Matrix = typename Base::Matrix; 94 using MatrixMaps = typename Base::MatrixMaps; 95 using ConstMatrixMap = typename Base::ConstMatrixMap; 96 using ConstMatrixMaps = typename Base::ConstMatrixMaps; 97 98 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, 99 MatrixMaps* outputs) final { 100 Eigen::HouseholderQR<Matrix> qr(inputs[0]); 101 const int m = inputs[0].rows(); 102 const int n = inputs[0].cols(); 103 const int min_size = std::min(m, n); 104 105 if (full_matrices_) { 106 outputs->at(0) = qr.householderQ(); 107 outputs->at(1) = qr.matrixQR().template triangularView<Eigen::Upper>(); 108 } else { 109 // TODO(jpoulson): Exploit the fact that Householder transformations can 110 // be expanded faster than they can be applied to an arbitrary matrix 111 // (Cf. LAPACK's DORGQR). 112 Matrix tmp = Matrix::Identity(m, min_size); 113 outputs->at(0) = qr.householderQ() * tmp; 114 auto qr_top = qr.matrixQR().block(0, 0, min_size, n); 115 outputs->at(1) = qr_top.template triangularView<Eigen::Upper>(); 116 } 117 } 118 119 private: 120 bool full_matrices_; 121 122 TF_DISALLOW_COPY_AND_ASSIGN(QrOp); 123 }; 124 125 #if GOOGLE_CUDA 126 127 typedef Eigen::GpuDevice GPUDevice; 128 129 template <class Scalar> 130 class QrOpGpu : public AsyncOpKernel { 131 public: 132 explicit QrOpGpu(OpKernelConstruction* context) : AsyncOpKernel(context) { 133 OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_)); 134 } 135 136 void ComputeAsync(OpKernelContext* context, DoneCallback done) final { 137 const Tensor& input = context->input(0); 138 const int ndims = input.dims(); 139 const int64 m = input.dim_size(ndims - 2); 140 const int64 n = input.dim_size(ndims - 1); 141 const int64 min_size = std::min(m, n); 142 const int64 batch_size = 143 input.template flat_inner_dims<Scalar, 3>().dimension(0); 144 145 // Validate inputs. 146 OP_REQUIRES_ASYNC( 147 context, ndims >= 2, 148 errors::InvalidArgument("Input must have rank >= 2, got ", ndims), 149 done); 150 151 // Allocate output. 152 // If full_matrices_ is true then Q is m x m and R is m x n. 153 // Otherwise, Q is m x min(m, n), and R is min(m, n) x n. 154 Tensor* q; 155 TensorShape q_shape = input.shape(); 156 q_shape.set_dim(ndims - 1, full_matrices_ ? m : min_size); 157 OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, q_shape, &q), 158 done); 159 Tensor* r; 160 TensorShape r_shape = input.shape(); 161 r_shape.set_dim(ndims - 2, full_matrices_ ? m : min_size); 162 OP_REQUIRES_OK_ASYNC(context, context->allocate_output(1, r_shape, &r), 163 done); 164 165 if (input.NumElements() == 0) { 166 done(); 167 return; 168 } 169 170 // TODO(rmlarsen): Convert to std::make_unique when available. 171 std::unique_ptr<CudaSolver> solver(new CudaSolver(context)); 172 173 // Allocate temporaries. 174 Tensor input_transposed; 175 TensorShape transposed_shape = input.shape(); 176 transposed_shape.set_dim(ndims - 2, input.dim_size(ndims - 1)); 177 transposed_shape.set_dim(ndims - 1, input.dim_size(ndims - 2)); 178 179 OP_REQUIRES_OK_ASYNC( 180 context, 181 solver->allocate_scoped_tensor(DataTypeToEnum<Scalar>::value, 182 transposed_shape, &input_transposed), 183 done); 184 185 Tensor tau; 186 OP_REQUIRES_OK_ASYNC(context, 187 solver->allocate_scoped_tensor( 188 DataTypeToEnum<Scalar>::value, 189 TensorShape({batch_size, min_size}), &tau), 190 done); 191 192 // Transpose input, since cuSolver uses column-major, while TensorFlow uses 193 // row-major storage. 194 const GPUDevice& device = context->eigen_device<GPUDevice>(); 195 OP_REQUIRES_OK_ASYNC( 196 context, DoMatrixTranspose(device, input, &input_transposed), done); 197 198 // Compute QR decomposition in-place in input_transposed. 199 std::vector<DeviceLapackInfo> dev_info; 200 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "geqrf")); 201 auto input_transposed_reshaped = 202 input_transposed.flat_inner_dims<Scalar, 3>(); 203 auto tau_matrix = tau.matrix<Scalar>(); 204 auto r_reshaped = r->flat_inner_dims<Scalar, 3>(); 205 for (int batch = 0; batch < batch_size; ++batch) { 206 OP_REQUIRES_OK_ASYNC( 207 context, 208 solver->Geqrf(m, n, &input_transposed_reshaped(batch, 0, 0), m, 209 &tau_matrix(batch, 0), 210 dev_info.back().mutable_data() + batch), 211 done); 212 } 213 214 // Generate R. R is equal to the upper triangle of the decomposition 215 // stored in input_transposed. Crop, transpose (to get back to row-major) 216 // and copy it to the output buffer. 217 if (full_matrices_ || m == n) { 218 OP_REQUIRES_OK_ASYNC( 219 context, DoMatrixTranspose(device, input_transposed, r), done); 220 } else { 221 const Scalar alpha(1); 222 const Scalar beta(0); 223 const Scalar* dummy = nullptr; 224 for (int batch = 0; batch < batch_size; ++batch) { 225 OP_REQUIRES_OK_ASYNC( 226 context, 227 solver->Geam(CUBLAS_OP_T, CUBLAS_OP_N, n, 228 full_matrices_ ? m : min_size, &alpha, 229 &input_transposed_reshaped(batch, 0, 0), m, &beta, 230 dummy, n, &r_reshaped(batch, 0, 0), n), 231 done); 232 } 233 } 234 // Extract the upper triangle of r (i.e. zero out the strictly lower 235 // triangle). 236 functor::MatrixBandPartFunctor<GPUDevice, Scalar> band_part; 237 auto r_reshaped_const = 238 const_cast<const Tensor*>(r)->flat_inner_dims<Scalar, 3>(); 239 band_part(context, device, 0 /* num_lower_diags */, 240 -1 /* num_upper_diags */, r_reshaped_const, r_reshaped); 241 242 // Generate Q from the decomposition in input_transposed. 243 if (m != n && (full_matrices_ || m < n)) { 244 // Generate full m x m matrix Q by computing the product Q^T * I, 245 // where the transpose is to get back to row-major form. 246 // In the complex case we actually form Q^H * I and conjugate it 247 // to get Q in row-major form. 248 functor::EyeFunctor<GPUDevice, Scalar> eye; 249 auto q_reshaped = q->flat_inner_dims<Scalar, 3>(); 250 eye(device, q_reshaped); 251 for (int batch = 0; batch < batch_size; ++batch) { 252 // Notice: It appears that Unmqr does not write a zero into *info upon 253 // success (probably a bug), so we simply re-use the info array already 254 // zeroed by Geqrf above. 255 OP_REQUIRES_OK_ASYNC( 256 context, 257 solver->Unmqr(CUBLAS_SIDE_LEFT, CublasAdjointOp<Scalar>(), m, m, 258 min_size, &input_transposed_reshaped(batch, 0, 0), m, 259 &tau_matrix(batch, 0), &q_reshaped(batch, 0, 0), m, 260 dev_info.back().mutable_data() + batch), 261 done); 262 } 263 if (Eigen::NumTraits<Scalar>::IsComplex) { 264 functor::UnaryFunctor<GPUDevice, functor::conj<Scalar>> conj; 265 conj(device, q->flat<Scalar>() /*out*/, 266 const_cast<const Tensor*>(q)->flat<Scalar>() /*in*/); 267 } 268 } else { 269 // Generate m x n matrix Q. In this case we can use the more efficient 270 // algorithm in Ungqr to generate Q in place. 271 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "orgqr")); 272 for (int batch = 0; batch < batch_size; ++batch) { 273 OP_REQUIRES_OK_ASYNC( 274 context, 275 solver->Ungqr( 276 m, n, min_size, &input_transposed_reshaped(batch, 0, 0), m, 277 &tau_matrix(batch, 0), dev_info.back().mutable_data() + batch), 278 done); 279 } 280 OP_REQUIRES_OK_ASYNC( 281 context, DoMatrixTranspose(device, input_transposed, q), done); 282 } 283 284 // Asynchronously check return status from cuSolver kernels. 285 CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info, 286 std::move(done)); 287 } 288 289 private: 290 bool full_matrices_; 291 292 TF_DISALLOW_COPY_AND_ASSIGN(QrOpGpu); 293 }; 294 295 #endif 296 297 } // namespace tensorflow 298