Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================
     15 */
     16 
     17 // This header declares the class CudaSolver, which contains wrappers of linear
     18 // algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow
     19 // kernels.
     20 
     21 #ifdef GOOGLE_CUDA
     22 
     23 #include <functional>
     24 #include <vector>
     25 
     26 #include "cuda/include/cublas_v2.h"
     27 #include "cuda/include/cusolverDn.h"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/tensor.h"
     30 #include "tensorflow/core/lib/core/status.h"
     31 #include "tensorflow/core/platform/stream_executor.h"
     32 
     33 namespace tensorflow {
     34 
     35 // Type traits to get CUDA complex types from std::complex<T>.
     36 template <typename T>
     37 struct CUDAComplexT {
     38   typedef T type;
     39 };
     40 template <>
     41 struct CUDAComplexT<std::complex<float>> {
     42   typedef cuComplex type;
     43 };
     44 template <>
     45 struct CUDAComplexT<std::complex<double>> {
     46   typedef cuDoubleComplex type;
     47 };
     48 // Converts pointers of std::complex<> to pointers of
     49 // cuComplex/cuDoubleComplex. No type conversion for non-complex types.
     50 template <typename T>
     51 inline const typename CUDAComplexT<T>::type* CUDAComplex(const T* p) {
     52   return reinterpret_cast<const typename CUDAComplexT<T>::type*>(p);
     53 }
     54 template <typename T>
     55 inline typename CUDAComplexT<T>::type* CUDAComplex(T* p) {
     56   return reinterpret_cast<typename CUDAComplexT<T>::type*>(p);
     57 }
     58 
     59 // Template to give the Cublas adjoint operation for real and complex types.
     60 template <typename T>
     61 cublasOperation_t CublasAdjointOp() {
     62   return Eigen::NumTraits<T>::IsComplex ? CUBLAS_OP_C : CUBLAS_OP_T;
     63 }
     64 
     65 // Container of LAPACK info data (an array of int) generated on-device by
     66 // a CudaSolver call. One or more such objects can be passed to
     67 // CudaSolver::CopyLapackInfoToHostAsync() along with a callback to
     68 // check the LAPACK info data after the corresponding kernels
     69 // finish and LAPACK info has been copied from the device to the host.
     70 class DeviceLapackInfo;
     71 
     72 // Host-side copy of LAPACK info.
     73 class HostLapackInfo;
     74 
     75 // The CudaSolver class provides a simplified templated API for the dense linear
     76 // solvers implemented in cuSolverDN (http://docs.nvidia.com/cuda/cusolver) and
     77 // cuBlas (http://docs.nvidia.com/cuda/cublas/#blas-like-extension/).
     78 // An object of this class wraps static cuSolver and cuBlas instances,
     79 // and will launch Cuda kernels on the stream wrapped by the GPU device
     80 // in the OpKernelContext provided to the constructor.
     81 //
     82 // Notice: All the computational member functions are asynchronous and simply
     83 // launch one or more Cuda kernels on the Cuda stream wrapped by the CudaSolver
     84 // object. To check the final status of the kernels run, call
     85 // CopyLapackInfoToHostAsync() on the CudaSolver object to set a callback that
     86 // will be invoked with the status of the kernels launched thus far as
     87 // arguments.
     88 //
     89 // Example of an asynchronous TensorFlow kernel using CudaSolver:
     90 //
     91 // template <typename Scalar>
     92 // class SymmetricPositiveDefiniteSolveOpGpu : public AsyncOpKernel {
     93 //  public:
     94 //   explicit SymmetricPositiveDefiniteSolveOpGpu(OpKernelConstruction* context)
     95 //       : AsyncOpKernel(context) { }
     96 //   void ComputeAsync(OpKernelContext* context, DoneCallback done) final {
     97 //     // 1. Set up input and output device ptrs. See, e.g.,
     98 //     // matrix_inverse_op.cc for a full example.
     99 //     ...
    100 //
    101 //     // 2. Initialize the solver object.
    102 //     std::unique_ptr<CudaSolver> solver(new CudaSolver(context));
    103 //
    104 //     // 3. Launch the two compute kernels back to back on the stream without
    105 //     // synchronizing.
    106 //     std::vector<DeviceLapackInfo> dev_info;
    107 //     const int batch_size = 1;
    108 //     dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrf");
    109 //     // Compute the Cholesky decomposition of the input matrix.
    110 //     OP_REQUIRES_OK_ASYNC(context,
    111 //                          solver->Potrf(uplo, n, dev_matrix_ptrs, n,
    112 //                                        dev_info.back().mutable_data()),
    113 //                          done);
    114 //     dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "potrs");
    115 //     // Use the Cholesky decomposition of the input matrix to solve A X = RHS.
    116 //     OP_REQUIRES_OK_ASYNC(context,
    117 //                          solver->Potrs(uplo, n, nrhs, dev_matrix_ptrs, n,
    118 //                                        dev_output_ptrs, ldrhs,
    119 //                                        dev_info.back().mutable_data()),
    120 //                          done);
    121 //
    122 //     // 4. Check the status after the computation finishes and call done.
    123 //     solver.CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info,
    124 //                                                std::move(done));
    125 //   }
    126 // };
    127 
    128 template <typename Scalar>
    129 class ScratchSpace;
    130 
    131 class CudaSolver {
    132  public:
    133   // This object stores a pointer to context, which must outlive it.
    134   explicit CudaSolver(OpKernelContext* context);
    135   virtual ~CudaSolver();
    136 
    137   // Launches a memcpy of solver status data specified by dev_lapack_info from
    138   // device to the host, and asynchronously invokes the given callback when the
    139   // copy is complete. The first Status argument to the callback will be
    140   // Status::OK if all lapack infos retrieved are zero, otherwise an error
    141   // status is given. The second argument contains a host-side copy of the
    142   // entire set of infos retrieved, and can be used for generating detailed
    143   // error messages.
    144   // `info_checker_callback` must call the DoneCallback of any asynchronous
    145   // OpKernel within which `solver` is used.
    146   static void CheckLapackInfoAndDeleteSolverAsync(
    147       std::unique_ptr<CudaSolver> solver,
    148       const std::vector<DeviceLapackInfo>& dev_lapack_info,
    149       std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
    150           info_checker_callback);
    151 
    152   // Simpler version to use if no special error checking / messages are needed
    153   // apart from checking that the Status of all calls was Status::OK.
    154   // `done` may be nullptr.
    155   static void CheckLapackInfoAndDeleteSolverAsync(
    156       std::unique_ptr<CudaSolver> solver,
    157       const std::vector<DeviceLapackInfo>& dev_lapack_info,
    158       AsyncOpKernel::DoneCallback done);
    159 
    160   // Returns a ScratchSpace. The CudaSolver object maintains a TensorReference
    161   // to the underlying Tensor to prevent it from being deallocated prematurely.
    162   template <typename Scalar>
    163   ScratchSpace<Scalar> GetScratchSpace(const TensorShape& shape,
    164                                        const string& debug_info, bool on_host);
    165   template <typename Scalar>
    166   ScratchSpace<Scalar> GetScratchSpace(int64 size, const string& debug_info,
    167                                        bool on_host);
    168   // Returns a DeviceLapackInfo that will live for the duration of the
    169   // CudaSolver object.
    170   inline DeviceLapackInfo GetDeviceLapackInfo(int64 size,
    171                                               const string& debug_info);
    172 
    173   // Allocates a temporary tensor that will live for the duration of the
    174   // CudaSolver object.
    175   Status allocate_scoped_tensor(DataType type, const TensorShape& shape,
    176                                 Tensor* scoped_tensor);
    177   Status forward_input_or_allocate_scoped_tensor(
    178       gtl::ArraySlice<int> candidate_input_indices, DataType type,
    179       const TensorShape& shape, Tensor* input_alias_or_new_scoped_tensor);
    180 
    181   OpKernelContext* context() { return context_; }
    182 
    183   // ====================================================================
    184   // Wrappers for cuSolverDN and cuBlas solvers start here.
    185   //
    186   // Apart from capitalization of the first letter, the method names below
    187   // map to those in cuSolverDN and cuBlas, which follow the naming
    188   // convention in LAPACK see, e.g.,
    189   // http://docs.nvidia.com/cuda/cusolver/#naming-convention
    190 
    191   // This function performs the matrix-matrix addition/transposition
    192   //   C = alpha * op(A) + beta * op(B).
    193   // Returns Status::OK() if the kernel was launched successfully.  See:
    194   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-geam
    195   // NOTE(ebrevdo): Does not support in-place transpose of non-square
    196   // matrices.
    197   template <typename Scalar>
    198   Status Geam(cublasOperation_t transa, cublasOperation_t transb, int m, int n,
    199               const Scalar* alpha, /* host or device pointer */
    200               const Scalar* A, int lda,
    201               const Scalar* beta, /* host or device pointer */
    202               const Scalar* B, int ldb, Scalar* C,
    203               int ldc) const TF_MUST_USE_RESULT;
    204 
    205   // Computes the Cholesky factorization A = L * L^T for a single matrix.
    206   // Returns Status::OK() if the kernel was launched successfully. See:
    207   // http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-potrf
    208   template <typename Scalar>
    209   Status Potrf(cublasFillMode_t uplo, int n, Scalar* dev_A, int lda,
    210                int* dev_lapack_info) TF_MUST_USE_RESULT;
    211 
    212   // LU factorization.
    213   // Computes LU factorization with partial pivoting P * A = L * U.
    214   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrf
    215   template <typename Scalar>
    216   Status Getrf(int m, int n, Scalar* dev_A, int lda, int* dev_pivots,
    217                int* dev_lapack_info) TF_MUST_USE_RESULT;
    218 
    219   // Uses LU factorization to solve A * X = B.
    220   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-getrs
    221   template <typename Scalar>
    222   Status Getrs(cublasOperation_t trans, int n, int nrhs, const Scalar* A,
    223                int lda, const int* pivots, Scalar* B, int ldb,
    224                int* dev_lapack_info) const TF_MUST_USE_RESULT;
    225 
    226   // Computes partially pivoted LU factorizations for a batch of small matrices.
    227   // Returns Status::OK() if the kernel was launched successfully.See:
    228   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrfbatched
    229   template <typename Scalar>
    230   Status GetrfBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda,
    231                       int* dev_pivots, DeviceLapackInfo* dev_lapack_info,
    232                       int batch_size) TF_MUST_USE_RESULT;
    233 
    234   // Batched linear solver using LU factorization from getrfBatched.
    235   // See:
    236   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrsbatched
    237   template <typename Scalar>
    238   Status GetrsBatched(cublasOperation_t trans, int n, int nrhs,
    239                       const Scalar* const dev_Aarray[], int lda,
    240                       const int* devIpiv, const Scalar* const dev_Barray[],
    241                       int ldb, DeviceLapackInfo* dev_lapack_info,
    242                       int batch_size) TF_MUST_USE_RESULT;
    243 
    244   // Computes matrix inverses for a batch of small matrices. Uses the outputs
    245   // from GetrfBatched. Returns Status::OK() if the kernel was launched
    246   // successfully. See:
    247   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getribatched
    248   template <typename Scalar>
    249   Status GetriBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda,
    250                       const int* dev_pivots,
    251                       const Scalar* const host_a_inverse_dev_ptrs[], int ldainv,
    252                       DeviceLapackInfo* dev_lapack_info,
    253                       int batch_size) TF_MUST_USE_RESULT;
    254 
    255   // Computes matrix inverses for a batch of small matrices with size n < 32.
    256   // Returns Status::OK() if the kernel was launched successfully. See:
    257   // http://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-matinvbatched
    258   template <typename Scalar>
    259   Status MatInvBatched(int n, const Scalar* const host_a_dev_ptrs[], int lda,
    260                        const Scalar* const host_a_inverse_dev_ptrs[],
    261                        int ldainv, DeviceLapackInfo* dev_lapack_info,
    262                        int batch_size) TF_MUST_USE_RESULT;
    263 
    264   // QR factorization.
    265   // Computes QR factorization A = Q * R.
    266   // Returns Status::OK() if the kernel was launched successfully.
    267   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-geqrf
    268   template <typename Scalar>
    269   Status Geqrf(int m, int n, Scalar* dev_A, int lda, Scalar* dev_tau,
    270                int* dev_lapack_info) TF_MUST_USE_RESULT;
    271 
    272   // Overwrite matrix C by product of C and the unitary Householder matrix Q.
    273   // The Householder matrix Q is represented by the output from Geqrf in dev_a
    274   // and dev_tau.
    275   // Notice: If Scalar is real, only trans=CUBLAS_OP_N or trans=CUBLAS_OP_T is
    276   // supported. If Scalar is complex, trans=CUBLAS_OP_N or trans=CUBLAS_OP_C is
    277   // supported.
    278   // Returns Status::OK() if the kernel was launched successfully.
    279   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-ormqr
    280   template <typename Scalar>
    281   Status Unmqr(cublasSideMode_t side, cublasOperation_t trans, int m, int n,
    282                int k, const Scalar* dev_a, int lda, const Scalar* dev_tau,
    283                Scalar* dev_c, int ldc, int* dev_lapack_info) TF_MUST_USE_RESULT;
    284 
    285   // Overwrites QR factorization produced by Geqrf by the unitary Householder
    286   // matrix Q. On input, the Householder matrix Q is represented by the output
    287   // from Geqrf in dev_a and dev_tau. On output, dev_a is overwritten with the
    288   // first n columns of Q. Requires m >= n >= 0.
    289   // Returns Status::OK() if the kernel was launched successfully.
    290   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-orgqr
    291   template <typename Scalar>
    292   Status Ungqr(int m, int n, int k, Scalar* dev_a, int lda,
    293                const Scalar* dev_tau, int* dev_lapack_info) TF_MUST_USE_RESULT;
    294 
    295   // Hermitian (Symmetric) Eigen decomposition.
    296   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-syevd
    297   template <typename Scalar>
    298   Status Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, int n,
    299                Scalar* dev_A, int lda,
    300                typename Eigen::NumTraits<Scalar>::Real* dev_W,
    301                int* dev_lapack_info) TF_MUST_USE_RESULT;
    302 
    303   // Singular value decomposition.
    304   // Returns Status::OK() if the kernel was launched successfully.
    305   // TODO(rmlarsen, volunteers): Add support for complex types.
    306   // See: http://docs.nvidia.com/cuda/cusolver/#cuds-lt-t-gt-gesvd
    307   template <typename Scalar>
    308   Status Gesvd(signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,
    309                int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,
    310                int ldvt, int* dev_lapack_info) TF_MUST_USE_RESULT;
    311 
    312  private:
    313   OpKernelContext* context_;  // not owned.
    314   cudaStream_t cuda_stream_;
    315   cusolverDnHandle_t cusolver_dn_handle_;
    316   cublasHandle_t cublas_handle_;
    317   std::vector<TensorReference> scratch_tensor_refs_;
    318 
    319   TF_DISALLOW_COPY_AND_ASSIGN(CudaSolver);
    320 };
    321 
    322 // Helper class to allocate scratch memory and keep track of debug info.
    323 // Mostly a thin wrapper around Tensor & allocate_temp.
    324 template <typename Scalar>
    325 class ScratchSpace {
    326  public:
    327   ScratchSpace(OpKernelContext* context, int64 size, bool on_host)
    328       : ScratchSpace(context, TensorShape({size}), "", on_host) {}
    329 
    330   ScratchSpace(OpKernelContext* context, int64 size, const string& debug_info,
    331                bool on_host)
    332       : ScratchSpace(context, TensorShape({size}), debug_info, on_host) {}
    333 
    334   ScratchSpace(OpKernelContext* context, const TensorShape& shape,
    335                const string& debug_info, bool on_host)
    336       : context_(context), debug_info_(debug_info), on_host_(on_host) {
    337     AllocatorAttributes alloc_attr;
    338     if (on_host) {
    339       // Allocate pinned memory on the host to avoid unnecessary
    340       // synchronization.
    341       alloc_attr.set_on_host(true);
    342       alloc_attr.set_gpu_compatible(true);
    343     }
    344     TF_CHECK_OK(context->allocate_temp(DataTypeToEnum<Scalar>::value, shape,
    345                                        &scratch_tensor_, alloc_attr));
    346   }
    347 
    348   virtual ~ScratchSpace() {}
    349 
    350   Scalar* mutable_data() {
    351     return scratch_tensor_.template flat<Scalar>().data();
    352   }
    353   const Scalar* data() const {
    354     return scratch_tensor_.template flat<Scalar>().data();
    355   }
    356   Scalar& operator()(int64 i) {
    357     return scratch_tensor_.template flat<Scalar>()(i);
    358   }
    359   const Scalar& operator()(int64 i) const {
    360     return scratch_tensor_.template flat<Scalar>()(i);
    361   }
    362   int64 bytes() const { return scratch_tensor_.TotalBytes(); }
    363   int64 size() const { return scratch_tensor_.NumElements(); }
    364   const string& debug_info() const { return debug_info_; }
    365 
    366   Tensor& tensor() { return scratch_tensor_; }
    367   const Tensor& tensor() const { return scratch_tensor_; }
    368 
    369   // Returns true if this ScratchSpace is in host memory.
    370   bool on_host() const { return on_host_; }
    371 
    372  protected:
    373   OpKernelContext* context() const { return context_; }
    374 
    375  private:
    376   OpKernelContext* context_;  // not owned
    377   const string debug_info_;
    378   const bool on_host_;
    379   Tensor scratch_tensor_;
    380 };
    381 
    382 class HostLapackInfo : public ScratchSpace<int> {
    383  public:
    384   HostLapackInfo(OpKernelContext* context, int64 size, const string& debug_info)
    385       : ScratchSpace<int>(context, size, debug_info, /* on_host */ true){};
    386 };
    387 
    388 class DeviceLapackInfo : public ScratchSpace<int> {
    389  public:
    390   DeviceLapackInfo(OpKernelContext* context, int64 size,
    391                    const string& debug_info)
    392       : ScratchSpace<int>(context, size, debug_info, /* on_host */ false) {}
    393 
    394   // Allocates a new scratch space on the host and launches a copy of the
    395   // contents of *this to the new scratch space. Sets success to true if
    396   // the copy kernel was launched successfully.
    397   HostLapackInfo CopyToHost(bool* success) const {
    398     CHECK(success != nullptr);
    399     HostLapackInfo copy(context(), size(), debug_info());
    400     auto stream = context()->op_device_context()->stream();
    401     perftools::gputools::DeviceMemoryBase wrapped_src(
    402         static_cast<void*>(const_cast<int*>(this->data())));
    403     *success =
    404         stream->ThenMemcpy(copy.mutable_data(), wrapped_src, this->bytes())
    405             .ok();
    406     return copy;
    407   }
    408 };
    409 
    410 template <typename Scalar>
    411 ScratchSpace<Scalar> CudaSolver::GetScratchSpace(const TensorShape& shape,
    412                                                  const string& debug_info,
    413                                                  bool on_host) {
    414   ScratchSpace<Scalar> new_scratch_space(context_, shape, debug_info, on_host);
    415   scratch_tensor_refs_.emplace_back(new_scratch_space.tensor());
    416   return std::move(new_scratch_space);
    417 }
    418 
    419 template <typename Scalar>
    420 ScratchSpace<Scalar> CudaSolver::GetScratchSpace(int64 size,
    421                                                  const string& debug_info,
    422                                                  bool on_host) {
    423   return GetScratchSpace<Scalar>(TensorShape({size}), debug_info, on_host);
    424 }
    425 
    426 inline DeviceLapackInfo CudaSolver::GetDeviceLapackInfo(
    427     int64 size, const string& debug_info) {
    428   DeviceLapackInfo new_dev_info(context_, size, debug_info);
    429   scratch_tensor_refs_.emplace_back(new_dev_info.tensor());
    430   return new_dev_info;
    431 }
    432 
    433 }  // namespace tensorflow
    434 
    435 #endif  // GOOGLE_CUDA
    436