1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ============================================================================== 15 */ 16 17 // This header declares the class CudaSolver, which contains wrappers of linear 18 // algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow 19 // kernels. 20 21 #ifdef GOOGLE_CUDA 22 23 #include <functional> 24 #include <vector> 25 26 #include "cuda/include/cublas_v2.h" 27 #include "cuda/include/cusolverDn.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/tensor.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/stream_executor.h" 32 33 namespace tensorflow { 34 35 // Type traits to get CUDA complex types from std::complex<T>. 36 template <typename T> 37 struct CUDAComplexT { 38 typedef T type; 39 }; 40 template <> 41 struct CUDAComplexT<std::complex<float>> { 42 typedef cuComplex type; 43 }; 44 template <> 45 struct CUDAComplexT<std::complex<double>> { 46 typedef cuDoubleComplex type; 47 }; 48 // Converts pointers of std::complex<> to pointers of 49 // cuComplex/cuDoubleComplex. No type conversion for non-complex types. 50 template <typename T> 51 inline const typename CUDAComplexT<T>::type* CUDAComplex(const T* p) { 52 return reinterpret_cast<const typename CUDAComplexT<T>::type*>(p); 53 } 54 template <typename T> 55 inline typename CUDAComplexT<T>::type* CUDAComplex(T* p) { 56 return reinterpret_cast<typename CUDAComplexT<T>::type*>(p); 57 } 58 59 // Template to give the Cublas adjoint operation for real and complex types. 60 template <typename T> 61 cublasOperation_t CublasAdjointOp() { 62 return Eigen::NumTraits<T>::IsComplex ? CUBLAS_OP_C : CUBLAS_OP_T; 63 } 64 65 // Container of LAPACK info data (an array of int) generated on-device by 66 // a CudaSolver call. One or more such objects can be passed to 67 // CudaSolver::CopyLapackInfoToHostAsync() along with a callback to 68 // check the LAPACK info data after the corresponding kernels 69 // finish and LAPACK info has been copied from the device to the host. 70 class DeviceLapackInfo; 71 72 // Host-side copy of LAPACK info. 73 class HostLapackInfo; 74 75 // The CudaSolver class provides a simplified templated API for the dense linear 76 // solvers implemented in cuSolverDN (http://docs.nvidia.com/cuda/cusolver) and 77 // cuBlas (http://docs.nvidia.com/cuda/cublas/#blas-like-extension/). 78 // An object of this class wraps static cuSolver and cuBlas instances, 79 // and will launch Cuda kernels on the stream wrapped by the GPU device 80 // in the OpKernelContext provided to the constructor. 81 // 82 // Notice: All the computational member functions are asynchronous and simply 83 // launch one or more Cuda kernels on the Cuda stream wrapped by the CudaSolver 84 // object. To check the final status of the kernels run, call 85 // CopyLapackInfoToHostAsync() on the CudaSolver object to set a callback that 86 // will be invoked with the status of the kernels launched thus far as 87 // arguments. 88 // 89 // Example of an asynchronous TensorFlow kernel using CudaSolver: 90 // 91 // template <typename Scalar> 92 // class SymmetricPositiveDefiniteSolveOpGpu : public AsyncOpKernel { 93 // public: 94 // explicit SymmetricPositiveDefiniteSolveOpGpu(OpKernelConstruction* context) 95 // : AsyncOpKernel(context) { } 96 // void ComputeAsync(OpKernelContext* context, DoneCallback done) final { 97 // // 1. Set up input and output device ptrs. See, e.g., 98 // // matrix_inverse_op.cc for a full example. 99 // ... 100 // 101 // // 2. Initialize the solver object. 102 // std::unique_ptr<CudaSolver> solver(new CudaSolver(context)); 103 // 104 // // 3. Launch the two compute kernels back to back on the stream without 105 // // synchronizing. 106 // std::vector<DeviceLapackInfo> dev_info; 107 // const int batch_size = 1; 108 // dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrf"); 109 // // Compute the Cholesky decomposition of the input matrix. 110 // OP_REQUIRES_OK_ASYNC(context, 111 // solver->Potrf(uplo, n, dev_matrix_ptrs, n, 112 // dev_info.back().mutable_data()), 113 // done); 114 // dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrs"); 115 // // Use the Cholesky decomposition of the input matrix to solve A X = RHS. 116 // OP_REQUIRES_OK_ASYNC(context, 117 // solver->Potrs(uplo, n, nrhs, dev_matrix_ptrs, n, 118 // dev_output_ptrs, ldrhs, 119 // dev_info.back().mutable_data()), 120 // done); 121 // 122 // // 4. Check the status after the computation finishes and call done. 123 // solver.CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info, 124 // std::move(done)); 125 // } 126 // }; 127 128 template <typename Scalar> 129 class ScratchSpace; 130 131 class CudaSolver { 132 public: 133 // This object stores a pointer to context, which must outlive it. 134 explicit CudaSolver(OpKernelContext* context); 135 virtual ~CudaSolver(); 136 137 // Launches a memcpy of solver status data specified by dev_lapack_info from 138 // device to the host, and asynchronously invokes the given callback when the 139 // copy is complete. The first Status argument to the callback will be 140 // Status::OK if all lapack infos retrieved are zero, otherwise an error 141 // status is given. The second argument contains a host-side copy of the 142 // entire set of infos retrieved, and can be used for generating detailed 143 // error messages. 144 // `info_checker_callback` must call the DoneCallback of any asynchronous 145 // OpKernel within which `solver` is used. 146 static void CheckLapackInfoAndDeleteSolverAsync( 147 std::unique_ptr<CudaSolver> solver, 148 const std::vector<DeviceLapackInfo>& dev_lapack_info, 149 std::function<void(const Status&, const std::vector<HostLapackInfo>&)> 150 info_checker_callback); 151 152 // Simpler version to use if no special error checking / messages are needed 153 // apart from checking that the Status of all calls was Status::OK. 154 // `done` may be nullptr. 155 static void CheckLapackInfoAndDeleteSolverAsync( 156 std::unique_ptr<CudaSolver> solver, 157 const std::vector<DeviceLapackInfo>& dev_lapack_info, 158 AsyncOpKernel::DoneCallback done); 159 160 // Returns a ScratchSpace. The CudaSolver object maintains a TensorReference 161 // to the underlying Tensor to prevent it from being deallocated prematurely. 162 template <typename Scalar> 163 ScratchSpace<Scalar> GetScratchSpace(const TensorShape& shape, 164 const string& debug_info, bool on_host); 165 template <typename Scalar> 166 ScratchSpace<Scalar> GetScratchSpace(int64 size, const string& debug_info, 167 bool on_host); 168 // Returns a DeviceLapackInfo that will live for the duration of the 169 // CudaSolver object. 170 inline DeviceLapackInfo GetDeviceLapackInfo(int64 size, 171 const string& debug_info); 172 173 // Allocates a temporary tensor that will live for the duration of the 174 // CudaSolver object. 175 Status allocate_scoped_tensor(DataType type, const TensorShape& shape, 176 Tensor* scoped_tensor); 177 Status forward_input_or_allocate_scoped_tensor( 178 gtl::ArraySlice<int> candidate_input_indices, DataType type, 179 const TensorShape& shape, Tensor* input_alias_or_new_scoped_tensor); 180 181 OpKernelContext* context() { return context_; } 182 183 // ==================================================================== 184 // Wrappers for cuSolverDN and cuBlas solvers start here. 185 // 186 // Apart from capitalization of the first letter, the method names below 187 // map to those in cuSolverDN and cuBlas, which follow the naming 188 // convention in LAPACK see, e.g., 189 // http://docs.nvidia.com/cuda/cusolver/#naming-convention 190 191 // This function performs the matrix-matrix addition/transposition 192 // C = alpha * op(A) + beta * op(B). 193 // Returns Status::OK() if the kernel was launched successfully. See: 194 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-geam 195 // NOTE(ebrevdo): Does not support in-place transpose of non-square 196 // matrices. 197 template <typename Scalar> 198 Status Geam(cublasOperation_t transa, cublasOperation_t transb, int m, int n, 199 const Scalar* alpha, /* host or device pointer */ 200 const Scalar* A, int lda, 201 const Scalar* beta, /* host or device pointer */ 202 const Scalar* B, int ldb, Scalar* C, 203 int ldc) const TF_MUST_USE_RESULT; 204 205 // Computes the Cholesky factorization A = L * L^T for a single matrix. 206 // Returns Status::OK() if the kernel was launched successfully. See: 207 // http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf 208 template <typename Scalar> 209 Status Potrf(cublasFillMode_t uplo, int n, Scalar* dev_A, int lda, 210 int* dev_lapack_info) TF_MUST_USE_RESULT; 211 212 // LU factorization. 213 // Computes LU factorization with partial pivoting P * A = L * U. 214 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrf 215 template <typename Scalar> 216 Status Getrf(int m, int n, Scalar* dev_A, int lda, int* dev_pivots, 217 int* dev_lapack_info) TF_MUST_USE_RESULT; 218 219 // Uses LU factorization to solve A * X = B. 220 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrs 221 template <typename Scalar> 222 Status Getrs(cublasOperation_t trans, int n, int nrhs, const Scalar* A, 223 int lda, const int* pivots, Scalar* B, int ldb, 224 int* dev_lapack_info) const TF_MUST_USE_RESULT; 225 226 // Computes partially pivoted LU factorizations for a batch of small matrices. 227 // Returns Status::OK() if the kernel was launched successfully.See: 228 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrfbatched 229 template <typename Scalar> 230 Status GetrfBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda, 231 int* dev_pivots, DeviceLapackInfo* dev_lapack_info, 232 int batch_size) TF_MUST_USE_RESULT; 233 234 // Batched linear solver using LU factorization from getrfBatched. 235 // See: 236 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched 237 template <typename Scalar> 238 Status GetrsBatched(cublasOperation_t trans, int n, int nrhs, 239 const Scalar* const dev_Aarray[], int lda, 240 const int* devIpiv, const Scalar* const dev_Barray[], 241 int ldb, DeviceLapackInfo* dev_lapack_info, 242 int batch_size) TF_MUST_USE_RESULT; 243 244 // Computes matrix inverses for a batch of small matrices. Uses the outputs 245 // from GetrfBatched. Returns Status::OK() if the kernel was launched 246 // successfully. See: 247 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getribatched 248 template <typename Scalar> 249 Status GetriBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda, 250 const int* dev_pivots, 251 const Scalar* const host_a_inverse_dev_ptrs[], int ldainv, 252 DeviceLapackInfo* dev_lapack_info, 253 int batch_size) TF_MUST_USE_RESULT; 254 255 // Computes matrix inverses for a batch of small matrices with size n < 32. 256 // Returns Status::OK() if the kernel was launched successfully. See: 257 // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-matinvbatched 258 template <typename Scalar> 259 Status MatInvBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda, 260 const Scalar* const host_a_inverse_dev_ptrs[], 261 int ldainv, DeviceLapackInfo* dev_lapack_info, 262 int batch_size) TF_MUST_USE_RESULT; 263 264 // QR factorization. 265 // Computes QR factorization A = Q * R. 266 // Returns Status::OK() if the kernel was launched successfully. 267 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-geqrf 268 template <typename Scalar> 269 Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_tau, 270 int* dev_lapack_info) TF_MUST_USE_RESULT; 271 272 // Overwrite matrix C by product of C and the unitary Householder matrix Q. 273 // The Householder matrix Q is represented by the output from Geqrf in dev_a 274 // and dev_tau. 275 // Notice: If Scalar is real, only trans=CUBLAS_OP_N or trans=CUBLAS_OP_T is 276 // supported. If Scalar is complex, trans=CUBLAS_OP_N or trans=CUBLAS_OP_C is 277 // supported. 278 // Returns Status::OK() if the kernel was launched successfully. 279 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-ormqr 280 template <typename Scalar> 281 Status Unmqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n, 282 int k, const Scalar* dev_a, int lda, const Scalar* dev_tau, 283 Scalar* dev_c, int ldc, int* dev_lapack_info) TF_MUST_USE_RESULT; 284 285 // Overwrites QR factorization produced by Geqrf by the unitary Householder 286 // matrix Q. On input, the Householder matrix Q is represented by the output 287 // from Geqrf in dev_a and dev_tau. On output, dev_a is overwritten with the 288 // first n columns of Q. Requires m >= n >= 0. 289 // Returns Status::OK() if the kernel was launched successfully. 290 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-orgqr 291 template <typename Scalar> 292 Status Ungqr(int m, int n, int k, Scalar* dev_a, int lda, 293 const Scalar* dev_tau, int* dev_lapack_info) TF_MUST_USE_RESULT; 294 295 // Hermitian (Symmetric) Eigen decomposition. 296 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-syevd 297 template <typename Scalar> 298 Status Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n, 299 Scalar* dev_A, int lda, 300 typename Eigen::NumTraits<Scalar>::Real* dev_W, 301 int* dev_lapack_info) TF_MUST_USE_RESULT; 302 303 // Singular value decomposition. 304 // Returns Status::OK() if the kernel was launched successfully. 305 // TODO(rmlarsen, volunteers): Add support for complex types. 306 // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd 307 template <typename Scalar> 308 Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A, 309 int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, 310 int ldvt, int* dev_lapack_info) TF_MUST_USE_RESULT; 311 312 private: 313 OpKernelContext* context_; // not owned. 314 cudaStream_t cuda_stream_; 315 cusolverDnHandle_t cusolver_dn_handle_; 316 cublasHandle_t cublas_handle_; 317 std::vector<TensorReference> scratch_tensor_refs_; 318 319 TF_DISALLOW_COPY_AND_ASSIGN(CudaSolver); 320 }; 321 322 // Helper class to allocate scratch memory and keep track of debug info. 323 // Mostly a thin wrapper around Tensor & allocate_temp. 324 template <typename Scalar> 325 class ScratchSpace { 326 public: 327 ScratchSpace(OpKernelContext* context, int64 size, bool on_host) 328 : ScratchSpace(context, TensorShape({size}), "", on_host) {} 329 330 ScratchSpace(OpKernelContext* context, int64 size, const string& debug_info, 331 bool on_host) 332 : ScratchSpace(context, TensorShape({size}), debug_info, on_host) {} 333 334 ScratchSpace(OpKernelContext* context, const TensorShape& shape, 335 const string& debug_info, bool on_host) 336 : context_(context), debug_info_(debug_info), on_host_(on_host) { 337 AllocatorAttributes alloc_attr; 338 if (on_host) { 339 // Allocate pinned memory on the host to avoid unnecessary 340 // synchronization. 341 alloc_attr.set_on_host(true); 342 alloc_attr.set_gpu_compatible(true); 343 } 344 TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<Scalar>::value, shape, 345 &scratch_tensor_, alloc_attr)); 346 } 347 348 virtual ~ScratchSpace() {} 349 350 Scalar* mutable_data() { 351 return scratch_tensor_.template flat<Scalar>().data(); 352 } 353 const Scalar* data() const { 354 return scratch_tensor_.template flat<Scalar>().data(); 355 } 356 Scalar& operator()(int64 i) { 357 return scratch_tensor_.template flat<Scalar>()(i); 358 } 359 const Scalar& operator()(int64 i) const { 360 return scratch_tensor_.template flat<Scalar>()(i); 361 } 362 int64 bytes() const { return scratch_tensor_.TotalBytes(); } 363 int64 size() const { return scratch_tensor_.NumElements(); } 364 const string& debug_info() const { return debug_info_; } 365 366 Tensor& tensor() { return scratch_tensor_; } 367 const Tensor& tensor() const { return scratch_tensor_; } 368 369 // Returns true if this ScratchSpace is in host memory. 370 bool on_host() const { return on_host_; } 371 372 protected: 373 OpKernelContext* context() const { return context_; } 374 375 private: 376 OpKernelContext* context_; // not owned 377 const string debug_info_; 378 const bool on_host_; 379 Tensor scratch_tensor_; 380 }; 381 382 class HostLapackInfo : public ScratchSpace<int> { 383 public: 384 HostLapackInfo(OpKernelContext* context, int64 size, const string& debug_info) 385 : ScratchSpace<int>(context, size, debug_info, /* on_host */ true){}; 386 }; 387 388 class DeviceLapackInfo : public ScratchSpace<int> { 389 public: 390 DeviceLapackInfo(OpKernelContext* context, int64 size, 391 const string& debug_info) 392 : ScratchSpace<int>(context, size, debug_info, /* on_host */ false) {} 393 394 // Allocates a new scratch space on the host and launches a copy of the 395 // contents of *this to the new scratch space. Sets success to true if 396 // the copy kernel was launched successfully. 397 HostLapackInfo CopyToHost(bool* success) const { 398 CHECK(success != nullptr); 399 HostLapackInfo copy(context(), size(), debug_info()); 400 auto stream = context()->op_device_context()->stream(); 401 perftools::gputools::DeviceMemoryBase wrapped_src( 402 static_cast<void*>(const_cast<int*>(this->data()))); 403 *success = 404 stream->ThenMemcpy(copy.mutable_data(), wrapped_src, this->bytes()) 405 .ok(); 406 return copy; 407 } 408 }; 409 410 template <typename Scalar> 411 ScratchSpace<Scalar> CudaSolver::GetScratchSpace(const TensorShape& shape, 412 const string& debug_info, 413 bool on_host) { 414 ScratchSpace<Scalar> new_scratch_space(context_, shape, debug_info, on_host); 415 scratch_tensor_refs_.emplace_back(new_scratch_space.tensor()); 416 return std::move(new_scratch_space); 417 } 418 419 template <typename Scalar> 420 ScratchSpace<Scalar> CudaSolver::GetScratchSpace(int64 size, 421 const string& debug_info, 422 bool on_host) { 423 return GetScratchSpace<Scalar>(TensorShape({size}), debug_info, on_host); 424 } 425 426 inline DeviceLapackInfo CudaSolver::GetDeviceLapackInfo( 427 int64 size, const string& debug_info) { 428 DeviceLapackInfo new_dev_info(context_, size, debug_info); 429 scratch_tensor_refs_.emplace_back(new_dev_info.tensor()); 430 return new_dev_info; 431 } 432 433 } // namespace tensorflow 434 435 #endif // GOOGLE_CUDA 436