Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3    Licensed under the Apache License, Version 2.0 (the "License");
      4    you may not use this file except in compliance with the License.
      5    You may obtain a copy of the License at
      6 
      7    http://www.apache.org/licenses/LICENSE-2.0
      8 
      9    Unless required by applicable law or agreed to in writing, software
     10    distributed under the License is distributed on an "AS IS" BASIS,
     11    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12    See the License for the specific language governing permissions and
     13    limitations under the License.
     14    ==============================================================================
     15 */
     16 #ifdef GOOGLE_CUDA
     17 #include "tensorflow/core/kernels/cuda_solvers.h"
     18 
     19 #include <chrono>
     20 #include <complex>
     21 #include <unordered_map>
     22 #include <vector>
     23 
     24 #include "cuda/include/cublas_v2.h"
     25 #include "cuda/include/cusolverDn.h"
     26 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
     27 #include "tensorflow/core/framework/op_kernel.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/lib/core/blocking_counter.h"
     30 #include "tensorflow/core/lib/core/status.h"
     31 #include "tensorflow/core/lib/core/stringpiece.h"
     32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     33 #include "tensorflow/core/platform/cuda.h"
     34 #include "tensorflow/core/platform/mutex.h"
     35 #include "tensorflow/core/platform/stream_executor.h"
     36 #include "tensorflow/core/platform/types.h"
     37 
     38 using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
     39 
     40 // The CUDA cublas_api.h API contains const-correctness errors. Instead of
     41 // casting away constness on our data, we instead reinterpret the CuBLAS
     42 // functions as what they were clearly meant to be, and thus we can call
     43 // the functions naturally.
     44 //
     45 // (The error is that input-only arrays are bound to parameter types
     46 // "const T**" instead of the correct "const T* const*".)
     47 extern "C" {
     48 using getrs_S = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
     49                                const float* const*, int, const int*, float**,
     50                                int, int*, int);
     51 using getrs_D = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
     52                                const double* const*, int, const int*, double**,
     53                                int, int*, int);
     54 using getrs_C = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
     55                                const float2* const*, int, const int*, float2**,
     56                                int, int*, int);
     57 using getrs_Z = cublasStatus_t(cublasContext*, cublasOperation_t, int, int,
     58                                const double2* const*, int, const int*,
     59                                double2**, int, int*, int);
     60 
     61 using getri_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
     62                                const int*, float**, int, int*, int);
     63 using getri_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
     64                                const int*, double**, int, int*, int);
     65 using getri_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
     66                                const int*, float2**, int, int*, int);
     67 using getri_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
     68                                const int*, double2**, int, int*, int);
     69 
     70 using matinv_S = cublasStatus_t(cublasContext*, int, const float* const*, int,
     71                                 float**, int, int*, int);
     72 using matinv_D = cublasStatus_t(cublasContext*, int, const double* const*, int,
     73                                 double**, int, int*, int);
     74 using matinv_C = cublasStatus_t(cublasContext*, int, const float2* const*, int,
     75                                 float2**, int, int*, int);
     76 using matinv_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
     77                                 double2**, int, int*, int);
     78 }
     79 
     80 namespace tensorflow {
     81 namespace {
     82 
     83 inline bool CopyHostToDevice(OpKernelContext* context, void* dst,
     84                              const void* src, uint64 bytes) {
     85   auto stream = context->op_device_context()->stream();
     86   perftools::gputools::DeviceMemoryBase wrapped_dst(dst);
     87   return stream->ThenMemcpy(&wrapped_dst, src, bytes).ok();
     88 }
     89 
     90 // A set of initialized handles to the underlying Cuda libraries used by
     91 // CudaSolver. We maintain one such set of handles per unique stream.
     92 struct CudaSolverHandles {
     93   explicit CudaSolverHandles(cudaStream_t stream) {
     94     CHECK(cusolverDnCreate(&cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
     95         << "Failed to create cuSolverDN instance.";
     96     CHECK(cusolverDnSetStream(cusolver_dn_handle, stream) ==
     97           CUSOLVER_STATUS_SUCCESS)
     98         << "Failed to set cuSolverDN stream.";
     99     CHECK(cublasCreate(&cublas_handle) == CUBLAS_STATUS_SUCCESS)
    100         << "Failed to create cuBlas instance.";
    101     CHECK(cublasSetStream(cublas_handle, stream) == CUBLAS_STATUS_SUCCESS)
    102         << "Failed to set cuBlas stream.";
    103   }
    104 
    105   ~CudaSolverHandles() {
    106     CHECK(cublasDestroy(cublas_handle) == CUBLAS_STATUS_SUCCESS)
    107         << "Failed to destroy cuBlas instance.";
    108     CHECK(cusolverDnDestroy(cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS)
    109         << "Failed to destroy cuSolverDN instance.";
    110   }
    111   cublasHandle_t cublas_handle;
    112   cusolverDnHandle_t cusolver_dn_handle;
    113 };
    114 
    115 static mutex handle_map_mutex(LINKER_INITIALIZED);
    116 
    117 using HandleMap =
    118     std::unordered_map<cudaStream_t, std::unique_ptr<CudaSolverHandles>>;
    119 
    120 // Returns a singleton map used for storing initialized handles for each unique
    121 // cuda stream.
    122 HandleMap* GetHandleMapSingleton() {
    123   static HandleMap* cm = new HandleMap;
    124   return cm;
    125 }
    126 
    127 }  // namespace
    128 
    129 #define TF_RETURN_IF_CUSOLVER_ERROR(expr)                      \
    130   do {                                                         \
    131     auto status = (expr);                                      \
    132     if (TF_PREDICT_FALSE(status != CUSOLVER_STATUS_SUCCESS)) { \
    133       return errors::Internal(                                 \
    134           __FILE__, ":", __LINE__,                             \
    135           ": cuSolverDN call failed with status =", status);   \
    136     }                                                          \
    137   } while (0)
    138 
    139 #define TF_RETURN_IF_CUBLAS_ERROR(expr)                                  \
    140   do {                                                                   \
    141     auto status = (expr);                                                \
    142     if (TF_PREDICT_FALSE(status != CUBLAS_STATUS_SUCCESS)) {             \
    143       return errors::Internal(__FILE__, ":", __LINE__,                   \
    144                               ": cuBlas call failed status = ", status); \
    145     }                                                                    \
    146   } while (0)
    147 
    148 CudaSolver::CudaSolver(OpKernelContext* context) : context_(context) {
    149   mutex_lock lock(handle_map_mutex);
    150   const cudaStream_t* cu_stream_ptr = CHECK_NOTNULL(
    151       reinterpret_cast<const cudaStream_t*>(context->op_device_context()
    152                                                 ->stream()
    153                                                 ->implementation()
    154                                                 ->CudaStreamMemberHack()));
    155   cuda_stream_ = *cu_stream_ptr;
    156   HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton());
    157   auto it = handle_map->find(cuda_stream_);
    158   if (it == handle_map->end()) {
    159     LOG(INFO) << "Creating CudaSolver handles for stream " << cuda_stream_;
    160     // Previously unseen Cuda stream. Initialize a set of Cuda solver library
    161     // handles for it.
    162     std::unique_ptr<CudaSolverHandles> new_handles(
    163         new CudaSolverHandles(cuda_stream_));
    164     it =
    165         handle_map->insert(std::make_pair(cuda_stream_, std::move(new_handles)))
    166             .first;
    167   }
    168   cusolver_dn_handle_ = it->second->cusolver_dn_handle;
    169   cublas_handle_ = it->second->cublas_handle;
    170 }
    171 
    172 CudaSolver::~CudaSolver() {
    173   for (auto tensor_ref : scratch_tensor_refs_) {
    174     tensor_ref.Unref();
    175   }
    176 }
    177 
    178 // static
    179 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync(
    180     std::unique_ptr<CudaSolver> solver,
    181     const std::vector<DeviceLapackInfo>& dev_lapack_infos,
    182     std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
    183         info_checker_callback) {
    184   CHECK(info_checker_callback != nullptr);
    185   std::vector<HostLapackInfo> host_lapack_infos;
    186   if (dev_lapack_infos.empty()) {
    187     info_checker_callback(Status::OK(), host_lapack_infos);
    188     return;
    189   }
    190 
    191   // Launch memcpys to copy info back from the device to the host.
    192   for (const auto& dev_lapack_info : dev_lapack_infos) {
    193     bool success = true;
    194     auto host_copy = dev_lapack_info.CopyToHost(&success);
    195     OP_REQUIRES(
    196         solver->context(), success,
    197         errors::Internal(
    198             "Failed to launch copy of dev_lapack_info to host, debug_info = ",
    199             dev_lapack_info.debug_info()));
    200     host_lapack_infos.push_back(std::move(host_copy));
    201   }
    202 
    203   // This callback checks that all batch items in all calls were processed
    204   // successfully and passes status to the info_checker_callback accordingly.
    205   auto* stream = solver->context()->op_device_context()->stream();
    206   auto wrapped_info_checker_callback =
    207       [stream](
    208           CudaSolver* solver,
    209           std::function<void(const Status&, const std::vector<HostLapackInfo>&)>
    210               info_checker_callback,
    211           std::vector<HostLapackInfo> host_lapack_infos) {
    212         ScopedActivateExecutorContext scoped_activation{stream->parent()};
    213         Status status;
    214         for (const auto& host_lapack_info : host_lapack_infos) {
    215           for (int i = 0; i < host_lapack_info.size() && status.ok(); ++i) {
    216             const int info_value = host_lapack_info(i);
    217             if (info_value != 0) {
    218               status = errors::InvalidArgument(
    219                   "Got info = ", info_value, " for batch index ", i,
    220                   ", expected info = 0. Debug_info = ",
    221                   host_lapack_info.debug_info());
    222             }
    223           }
    224           if (!status.ok()) {
    225             break;
    226           }
    227         }
    228         // Delete solver to release temp tensor refs.
    229         delete solver;
    230 
    231         // Delegate further error checking to provided functor.
    232         info_checker_callback(status, host_lapack_infos);
    233       };
    234   // Note: An std::function cannot have unique_ptr arguments (it must be copy
    235   // constructible and therefore so must its arguments). Therefore, we release
    236   // solver into a raw pointer to be deleted at the end of
    237   // wrapped_info_checker_callback.
    238   // Release ownership of solver. It will be deleted in the cb callback.
    239   auto solver_raw_ptr = solver.release();
    240   auto cb =
    241       std::bind(wrapped_info_checker_callback, solver_raw_ptr,
    242                 std::move(info_checker_callback), std::move(host_lapack_infos));
    243 
    244   solver_raw_ptr->context()
    245       ->device()
    246       ->tensorflow_gpu_device_info()
    247       ->event_mgr->ThenExecute(stream, std::move(cb));
    248 }
    249 
    250 // static
    251 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync(
    252     std::unique_ptr<CudaSolver> solver,
    253     const std::vector<DeviceLapackInfo>& dev_lapack_info,
    254     AsyncOpKernel::DoneCallback done) {
    255   OpKernelContext* context = solver->context();
    256   auto wrapped_done = [context, done](
    257                           const Status& status,
    258                           const std::vector<HostLapackInfo>& /* unused */) {
    259     if (done != nullptr) {
    260       OP_REQUIRES_OK_ASYNC(context, status, done);
    261       done();
    262     } else {
    263       OP_REQUIRES_OK(context, status);
    264     }
    265   };
    266   CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_lapack_info,
    267                                       wrapped_done);
    268 }
    269 
    270 // Allocates a temporary tensor. The CudaSolver object maintains a
    271 // TensorReference to the underlying Tensor to prevent it from being deallocated
    272 // prematurely.
    273 Status CudaSolver::allocate_scoped_tensor(DataType type,
    274                                           const TensorShape& shape,
    275                                           Tensor* out_temp) {
    276   const Status status = context_->allocate_temp(type, shape, out_temp);
    277   if (status.ok()) {
    278     scratch_tensor_refs_.emplace_back(*out_temp);
    279   }
    280   return status;
    281 }
    282 
    283 Status CudaSolver::forward_input_or_allocate_scoped_tensor(
    284     gtl::ArraySlice<int> candidate_input_indices, DataType type,
    285     const TensorShape& shape, Tensor* out_temp) {
    286   const Status status = context_->forward_input_or_allocate_temp(
    287       candidate_input_indices, type, shape, out_temp);
    288   if (status.ok()) {
    289     scratch_tensor_refs_.emplace_back(*out_temp);
    290   }
    291   return status;
    292 }
    293 
    294 // Macro that specializes a solver method for all 4 standard
    295 // numeric types.
    296 #define TF_CALL_LAPACK_TYPES(m) \
    297   m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z)
    298 #define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, S) m(double, D)
    299 
    300 // Macros to construct cusolverDn method names.
    301 #define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method
    302 #define DN_SOLVER_NAME(method, type_prefix) "cusolverDn" #type_prefix #method
    303 #define DN_BUFSIZE_FN(method, type_prefix) \
    304   cusolverDn##type_prefix##method##_bufferSize
    305 
    306 // Macros to construct cublas method names.
    307 #define BLAS_SOLVER_FN(method, type_prefix) cublas##type_prefix##method
    308 #define BLAS_SOLVER_NAME(method, type_prefix) "cublas" #type_prefix #method
    309 
    310 //=============================================================================
    311 // Wrappers of cuSolverDN computational methods begin here.
    312 //
    313 // WARNING to implementers: The function signatures listed in the online docs
    314 // are sometimes inaccurate, e.g., are missing 'const' on pointers
    315 // to immutable arguments, while the actual headers have them as expected.
    316 // Check the actual declarations in the cusolver_api.h header file.
    317 //
    318 // NOTE: The cuSolver functions called below appear not to be threadsafe.
    319 // so we put a global lock around the calls. Since these functions only put a
    320 // kernel on the shared stream, it is not a big performance hit.
    321 // TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9.
    322 //=============================================================================
    323 
    324 template <typename Scalar, typename SolverFnT>
    325 static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle,
    326                               cublasOperation_t transa,
    327                               cublasOperation_t transb, int m, int n,
    328                               const Scalar* alpha, /* host or device pointer */
    329                               const Scalar* A, int lda,
    330                               const Scalar* beta, /* host or device pointer */
    331                               const Scalar* B, int ldb, Scalar* C, int ldc) {
    332   mutex_lock lock(handle_map_mutex);
    333   using CudaScalar = typename CUDAComplexT<Scalar>::type;
    334   TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n,
    335                                    reinterpret_cast<const CudaScalar*>(alpha),
    336                                    reinterpret_cast<const CudaScalar*>(A), lda,
    337                                    reinterpret_cast<const CudaScalar*>(beta),
    338                                    reinterpret_cast<const CudaScalar*>(B), ldb,
    339                                    reinterpret_cast<CudaScalar*>(C), ldc));
    340   return Status::OK();
    341 }
    342 
    343 #define GEAM_INSTANCE(Scalar, type_prefix)                                     \
    344   template <>                                                                  \
    345   Status CudaSolver::Geam<Scalar>(                                             \
    346       cublasOperation_t transa, cublasOperation_t transb, int m, int n,        \
    347       const Scalar* alpha, /* host or device pointer */                        \
    348       const Scalar* A, int lda,                                                \
    349       const Scalar* beta, /* host or device pointer */                         \
    350       const Scalar* B, int ldb, Scalar* C, int ldc) const {                    \
    351     return GeamImpl(BLAS_SOLVER_FN(geam, type_prefix), cublas_handle_, transa, \
    352                     transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);        \
    353   }
    354 
    355 TF_CALL_LAPACK_TYPES(GEAM_INSTANCE);
    356 
    357 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
    358 static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver,
    359                                CudaSolver* cuda_solver,
    360                                OpKernelContext* context,
    361                                cusolverDnHandle_t cusolver_dn_handle,
    362                                cublasFillMode_t uplo, int n, Scalar* A, int lda,
    363                                int* dev_lapack_info) {
    364   mutex_lock lock(handle_map_mutex);
    365   /* Get amount of workspace memory required. */
    366   int lwork;
    367   TF_RETURN_IF_CUSOLVER_ERROR(
    368       bufsize(cusolver_dn_handle, uplo, n, CUDAComplex(A), lda, &lwork));
    369   /* Allocate device memory for workspace. */
    370   auto dev_workspace =
    371       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
    372   /* Launch the solver kernel. */
    373   TF_RETURN_IF_CUSOLVER_ERROR(solver(
    374       cusolver_dn_handle, uplo, n, CUDAComplex(A), lda,
    375       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
    376   return Status::OK();
    377 }
    378 
    379 #define POTRF_INSTANCE(Scalar, type_prefix)                                  \
    380   template <>                                                                \
    381   Status CudaSolver::Potrf<Scalar>(cublasFillMode_t uplo, int n, Scalar* A,  \
    382                                    int lda, int* dev_lapack_info) {          \
    383     return PotrfImpl(DN_BUFSIZE_FN(potrf, type_prefix),                      \
    384                      DN_SOLVER_FN(potrf, type_prefix), this, context_,       \
    385                      cusolver_dn_handle_, uplo, n, A, lda, dev_lapack_info); \
    386   }
    387 
    388 TF_CALL_LAPACK_TYPES(POTRF_INSTANCE);
    389 
    390 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
    391 static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver,
    392                                CudaSolver* cuda_solver,
    393                                OpKernelContext* context,
    394                                cusolverDnHandle_t cusolver_dn_handle, int m,
    395                                int n, Scalar* A, int lda, int* dev_pivots,
    396                                int* dev_lapack_info) {
    397   mutex_lock lock(handle_map_mutex);
    398   /* Get amount of workspace memory required. */
    399   int lwork;
    400   TF_RETURN_IF_CUSOLVER_ERROR(
    401       bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
    402   /* Allocate device memory for workspace. */
    403   auto dev_workspace =
    404       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
    405   /* Launch the solver kernel. */
    406   TF_RETURN_IF_CUSOLVER_ERROR(solver(
    407       cusolver_dn_handle, m, n, CUDAComplex(A), lda,
    408       CUDAComplex(dev_workspace.mutable_data()), dev_pivots, dev_lapack_info));
    409   return Status::OK();
    410 }
    411 
    412 #define GETRF_INSTANCE(Scalar, type_prefix)                                 \
    413   template <>                                                               \
    414   Status CudaSolver::Getrf<Scalar>(int m, int n, Scalar* A, int lda,        \
    415                                    int* dev_pivots, int* dev_lapack_info) { \
    416     return GetrfImpl(DN_BUFSIZE_FN(getrf, type_prefix),                     \
    417                      DN_SOLVER_FN(getrf, type_prefix), this, context_,      \
    418                      cusolver_dn_handle_, m, n, A, lda, dev_pivots,         \
    419                      dev_lapack_info);                                      \
    420   }
    421 
    422 TF_CALL_LAPACK_TYPES(GETRF_INSTANCE);
    423 
    424 template <typename Scalar, typename SolverFnT>
    425 static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context,
    426                                cusolverDnHandle_t cusolver_dn_handle,
    427                                cublasOperation_t trans, int n, int nrhs,
    428                                const Scalar* A, int lda, const int* pivots,
    429                                Scalar* B, int ldb, int* dev_lapack_info) {
    430   mutex_lock lock(handle_map_mutex);
    431   /* Launch the solver kernel. */
    432   TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs,
    433                                      CUDAComplex(A), lda, pivots,
    434                                      CUDAComplex(B), ldb, dev_lapack_info));
    435   return Status::OK();
    436 }
    437 
    438 #define GETRS_INSTANCE(Scalar, type_prefix)                                  \
    439   template <>                                                                \
    440   Status CudaSolver::Getrs<Scalar>(                                          \
    441       cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda,    \
    442       const int* pivots, Scalar* B, int ldb, int* dev_lapack_info) const {   \
    443     return GetrsImpl(DN_SOLVER_FN(getrs, type_prefix), context_,             \
    444                      cusolver_dn_handle_, trans, n, nrhs, A, lda, pivots, B, \
    445                      ldb, dev_lapack_info);                                  \
    446   }
    447 
    448 TF_CALL_LAPACK_TYPES(GETRS_INSTANCE);
    449 
    450 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
    451 static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver,
    452                                CudaSolver* cuda_solver,
    453                                OpKernelContext* context,
    454                                cusolverDnHandle_t cusolver_dn_handle, int m,
    455                                int n, Scalar* A, int lda, Scalar* tau,
    456                                int* dev_lapack_info) {
    457   mutex_lock lock(handle_map_mutex);
    458   /* Get amount of workspace memory required. */
    459   int lwork;
    460   TF_RETURN_IF_CUSOLVER_ERROR(
    461       bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork));
    462   /* Allocate device memory for workspace. */
    463   auto dev_workspace =
    464       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
    465   /* Launch the solver kernel. */
    466   TF_RETURN_IF_CUSOLVER_ERROR(solver(
    467       cusolver_dn_handle, m, n, CUDAComplex(A), lda, CUDAComplex(tau),
    468       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
    469   return Status::OK();
    470 }
    471 
    472 #define GEQRF_INSTANCE(Scalar, type_prefix)                                    \
    473   template <>                                                                  \
    474   Status CudaSolver::Geqrf<Scalar>(int m, int n, Scalar* A, int lda,           \
    475                                    Scalar* tau, int* dev_lapack_info) {        \
    476     return GeqrfImpl(DN_BUFSIZE_FN(geqrf, type_prefix),                        \
    477                      DN_SOLVER_FN(geqrf, type_prefix), this, context_,         \
    478                      cusolver_dn_handle_, m, n, A, lda, tau, dev_lapack_info); \
    479   }
    480 
    481 TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE);
    482 
    483 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
    484 static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver,
    485                                CudaSolver* cuda_solver,
    486                                OpKernelContext* context,
    487                                cusolverDnHandle_t cusolver_dn_handle,
    488                                cublasSideMode_t side, cublasOperation_t trans,
    489                                int m, int n, int k, const Scalar* dev_a,
    490                                int lda, const Scalar* dev_tau, Scalar* dev_c,
    491                                int ldc, int* dev_lapack_info) {
    492   mutex_lock lock(handle_map_mutex);
    493   /* Get amount of workspace memory required. */
    494   int lwork;
    495   TF_RETURN_IF_CUSOLVER_ERROR(
    496       bufsize(cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
    497               CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, &lwork));
    498   /* Allocate device memory for workspace. */
    499   auto dev_workspace =
    500       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
    501   /* Launch the solver kernel. */
    502   TF_RETURN_IF_CUSOLVER_ERROR(solver(
    503       cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda,
    504       CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc,
    505       CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info));
    506   return Status::OK();
    507 }
    508 
    509 // Unfortunately the LAPACK function name differs for the real and complex case
    510 // (complex ones are prefixed with "UN" for "unitary"), so we instantiate each
    511 // one separately.
    512 #define UNMQR_INSTANCE(Scalar, function_prefix, type_prefix)                  \
    513   template <>                                                                 \
    514   Status CudaSolver::Unmqr(cublasSideMode_t side, cublasOperation_t trans,    \
    515                            int m, int n, int k, const Scalar* dev_a, int lda, \
    516                            const Scalar* dev_tau, Scalar* dev_c, int ldc,     \
    517                            int* dev_lapack_info) {                            \
    518     return UnmqrImpl(DN_BUFSIZE_FN(function_prefix##mqr, type_prefix),        \
    519                      DN_SOLVER_FN(function_prefix##mqr, type_prefix), this,   \
    520                      context_, cusolver_dn_handle_, side, trans, m, n, k,     \
    521                      dev_a, lda, dev_tau, dev_c, ldc, dev_lapack_info);       \
    522   }
    523 
    524 UNMQR_INSTANCE(float, or, S);
    525 UNMQR_INSTANCE(double, or, D);
    526 UNMQR_INSTANCE(complex64, un, C);
    527 UNMQR_INSTANCE(complex128, un, Z);
    528 
    529 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
    530 static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver,
    531                                CudaSolver* cuda_solver,
    532                                OpKernelContext* context,
    533                                cusolverDnHandle_t cusolver_dn_handle, int m,
    534                                int n, int k, Scalar* dev_a, int lda,
    535                                const Scalar* dev_tau, int* dev_lapack_info) {
    536   mutex_lock lock(handle_map_mutex);
    537   /* Get amount of workspace memory required. */
    538   int lwork;
    539   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k,
    540                                       CUDAComplex(dev_a), lda,
    541                                       CUDAComplex(dev_tau), &lwork));
    542   /* Allocate device memory for workspace. */
    543   auto dev_workspace =
    544       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
    545   /* Launch the solver kernel. */
    546   TF_RETURN_IF_CUSOLVER_ERROR(
    547       solver(cusolver_dn_handle, m, n, k, CUDAComplex(dev_a), lda,
    548              CUDAComplex(dev_tau), CUDAComplex(dev_workspace.mutable_data()),
    549              lwork, dev_lapack_info));
    550   return Status::OK();
    551 }
    552 
    553 #define UNGQR_INSTANCE(Scalar, function_prefix, type_prefix)                \
    554   template <>                                                               \
    555   Status CudaSolver::Ungqr(int m, int n, int k, Scalar* dev_a, int lda,     \
    556                            const Scalar* dev_tau, int* dev_lapack_info) {   \
    557     return UngqrImpl(DN_BUFSIZE_FN(function_prefix##gqr, type_prefix),      \
    558                      DN_SOLVER_FN(function_prefix##gqr, type_prefix), this, \
    559                      context_, cusolver_dn_handle_, m, n, k, dev_a, lda,    \
    560                      dev_tau, dev_lapack_info);                             \
    561   }
    562 
    563 UNGQR_INSTANCE(float, or, S);
    564 UNGQR_INSTANCE(double, or, D);
    565 UNGQR_INSTANCE(complex64, un, C);
    566 UNGQR_INSTANCE(complex128, un, Z);
    567 
    568 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
    569 static inline Status HeevdImpl(BufSizeFnT bufsize, SolverFnT solver,
    570                                CudaSolver* cuda_solver,
    571                                OpKernelContext* context,
    572                                cusolverDnHandle_t cusolver_dn_handle,
    573                                cusolverEigMode_t jobz, cublasFillMode_t uplo,
    574                                int n, Scalar* dev_A, int lda,
    575                                typename Eigen::NumTraits<Scalar>::Real* dev_W,
    576                                int* dev_lapack_info) {
    577   mutex_lock lock(handle_map_mutex);
    578   /* Get amount of workspace memory required. */
    579   int lwork;
    580   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, jobz, uplo, n,
    581                                       CUDAComplex(dev_A), lda,
    582                                       CUDAComplex(dev_W), &lwork));
    583   /* Allocate device memory for workspace. */
    584   auto dev_workspace =
    585       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
    586   /* Launch the solver kernel. */
    587   TF_RETURN_IF_CUSOLVER_ERROR(
    588       solver(cusolver_dn_handle, jobz, uplo, n, CUDAComplex(dev_A), lda,
    589              CUDAComplex(dev_W), CUDAComplex(dev_workspace.mutable_data()),
    590              lwork, dev_lapack_info));
    591   return Status::OK();
    592 }
    593 
    594 #define HEEVD_INSTANCE(Scalar, function_prefix, type_prefix)                   \
    595   template <>                                                                  \
    596   Status CudaSolver::Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo,      \
    597                            int n, Scalar* dev_A, int lda,                      \
    598                            typename Eigen::NumTraits<Scalar>::Real* dev_W,     \
    599                            int* dev_lapack_info) {                             \
    600     return HeevdImpl(DN_BUFSIZE_FN(function_prefix##evd, type_prefix),         \
    601                      DN_SOLVER_FN(function_prefix##evd, type_prefix), this,    \
    602                      context_, cusolver_dn_handle_, jobz, uplo, n, dev_A, lda, \
    603                      dev_W, dev_lapack_info);                                  \
    604   }
    605 
    606 HEEVD_INSTANCE(float, sy, S);
    607 HEEVD_INSTANCE(double, sy, D);
    608 HEEVD_INSTANCE(complex64, he, C);
    609 HEEVD_INSTANCE(complex128, he, Z);
    610 
    611 template <typename Scalar, typename BufSizeFnT, typename SolverFnT>
    612 static inline Status GesvdImpl(
    613     BufSizeFnT bufsize, SolverFnT solver, CudaSolver* cuda_solver,
    614     OpKernelContext* context, cusolverDnHandle_t cusolver_dn_handle,
    615     signed char jobu, signed char jobvt, int m, int n, Scalar* A, int lda,
    616     Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, int* dev_lapack_info) {
    617   mutex_lock lock(handle_map_mutex);
    618   /* Get amount of workspace memory required. */
    619   int lwork;
    620   TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork));
    621   /* Allocate device memory for workspace. */
    622   auto dev_workspace =
    623       cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false);
    624   TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n,
    625                                      CUDAComplex(A), lda, S, CUDAComplex(U),
    626                                      ldu, CUDAComplex(VT), ldvt,
    627                                      CUDAComplex(dev_workspace.mutable_data()),
    628                                      lwork, nullptr, dev_lapack_info));
    629   return Status::OK();
    630 }
    631 
    632 #define GESVD_INSTANCE(Scalar, type_prefix)                              \
    633   template <>                                                            \
    634   Status CudaSolver::Gesvd<Scalar>(                                      \
    635       signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A,  \
    636       int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT,    \
    637       int ldvt, int* dev_lapack_info) {                                  \
    638     return GesvdImpl(DN_BUFSIZE_FN(gesvd, type_prefix),                  \
    639                      DN_SOLVER_FN(gesvd, type_prefix), this, context_,   \
    640                      cusolver_dn_handle_, jobu, jobvt, m, n, dev_A, lda, \
    641                      dev_S, dev_U, ldu, dev_VT, ldvt, dev_lapack_info);  \
    642   }
    643 
    644 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVD_INSTANCE);
    645 
    646 //=============================================================================
    647 // Wrappers of cuBlas computational methods begin here.
    648 //
    649 // WARNING to implementers: The function signatures listed in the online docs
    650 // are sometimes inaccurate, e.g., are missing 'const' on pointers
    651 // to immutable arguments, while the actual headers have them as expected.
    652 // Check the actual declarations in the cublas_api.h header file.
    653 //=============================================================================
    654 template <typename Scalar, typename SolverFnT>
    655 static inline Status GetrfBatchedImpl(SolverFnT solver, CudaSolver* cuda_solver,
    656                                       OpKernelContext* context,
    657                                       cublasHandle_t cublas_handle, int n,
    658                                       const Scalar* const host_a_dev_ptrs[],
    659                                       int lda, int* dev_pivots,
    660                                       DeviceLapackInfo* dev_lapack_info,
    661                                       int batch_size) {
    662   mutex_lock lock(handle_map_mutex);
    663   using CudaScalar = typename CUDAComplexT<Scalar>::type;
    664   ScratchSpace<uint8> dev_a_dev_ptrs =
    665       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
    666                                           /* on_host */ false);
    667   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
    668                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
    669     return errors::Internal("GetrfBatched: failed to copy pointers to device");
    670   }
    671   TF_RETURN_IF_CUBLAS_ERROR(
    672       solver(cublas_handle, n,
    673              reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda,
    674              dev_pivots, dev_lapack_info->mutable_data(), batch_size));
    675   return Status::OK();
    676 }
    677 
    678 #define GETRF_BATCHED_INSTANCE(Scalar, type_prefix)                            \
    679   template <>                                                                  \
    680   Status CudaSolver::GetrfBatched(                                             \
    681       int n, const Scalar* const host_a_dev_ptrs[], int lda, int* dev_pivots,  \
    682       DeviceLapackInfo* dev_lapack_info, int batch_size) {                     \
    683     return GetrfBatchedImpl(BLAS_SOLVER_FN(getrfBatched, type_prefix), this,   \
    684                             context_, cublas_handle_, n, host_a_dev_ptrs, lda, \
    685                             dev_pivots, dev_lapack_info, batch_size);          \
    686   }
    687 
    688 TF_CALL_LAPACK_TYPES(GETRF_BATCHED_INSTANCE);
    689 
    690 template <typename Scalar, typename SolverFnT>
    691 static inline Status GetrsBatchedImpl(
    692     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
    693     cublasHandle_t cublas_handle, cublasOperation_t trans, int n, int nrhs,
    694     const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,
    695     const Scalar* const host_b_dev_ptrs[], int ldb,
    696     DeviceLapackInfo* dev_lapack_info, int batch_size) {
    697   mutex_lock lock(handle_map_mutex);
    698   using CudaScalar = typename CUDAComplexT<Scalar>::type;
    699   ScratchSpace<uint8> dev_a_dev_ptrs =
    700       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
    701                                           /* on_host */ false);
    702   ScratchSpace<uint8> dev_b_dev_ptrs =
    703       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
    704                                           /* on_host */ false);
    705   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
    706                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) {
    707     return errors::Internal("GetrsBatched: failed to copy pointers to device");
    708   }
    709   if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */,
    710                         host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) {
    711     return errors::Internal("GetrsBatched: failed to copy pointers to device");
    712   }
    713   TF_RETURN_IF_CUBLAS_ERROR(solver(
    714       cublas_handle, trans, n, nrhs,
    715       reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
    716       dev_pivots, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()),
    717       ldb, dev_lapack_info->mutable_data(), batch_size));
    718   return Status::OK();
    719 }
    720 
    721 #define GETRS_BATCHED_INSTANCE(Scalar, type_prefix)                            \
    722   template <>                                                                  \
    723   Status CudaSolver::GetrsBatched(                                             \
    724       cublasOperation_t trans, int n, int nrhs,                                \
    725       const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots,   \
    726       const Scalar* const host_b_dev_ptrs[], int ldb,                          \
    727       DeviceLapackInfo* dev_lapack_info, int batch_size) {                     \
    728     return GetrsBatchedImpl(reinterpret_cast<getrs_##type_prefix*>(            \
    729                                 BLAS_SOLVER_FN(getrsBatched, type_prefix)),    \
    730                             this, context_, cublas_handle_, trans, n, nrhs,    \
    731                             host_a_dev_ptrs, lda, dev_pivots, host_b_dev_ptrs, \
    732                             ldb, dev_lapack_info, batch_size);                 \
    733   }
    734 
    735 TF_CALL_LAPACK_TYPES(GETRS_BATCHED_INSTANCE);
    736 
    737 template <typename Scalar, typename SolverFnT>
    738 static inline Status GetriBatchedImpl(
    739     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
    740     cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
    741     int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],
    742     int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {
    743   mutex_lock lock(handle_map_mutex);
    744   using CudaScalar = typename CUDAComplexT<Scalar>::type;
    745   ScratchSpace<uint8> dev_a_dev_ptrs =
    746       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
    747                                           /* on_host */ false);
    748   ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
    749       sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
    750   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
    751                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
    752       !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
    753                         host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
    754     return errors::Internal("GetriBatched: failed to copy pointers to device");
    755   }
    756   TF_RETURN_IF_CUBLAS_ERROR(
    757       solver(cublas_handle, n,
    758              reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()),
    759              lda, dev_pivots,
    760              reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()),
    761              ldainv, dev_lapack_info->mutable_data(), batch_size));
    762   return Status::OK();
    763 }
    764 
    765 #define GETRI_BATCHED_INSTANCE(Scalar, type_prefix)                          \
    766   template <>                                                                \
    767   Status CudaSolver::GetriBatched(                                           \
    768       int n, const Scalar* const host_a_dev_ptrs[], int lda,                 \
    769       const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[],      \
    770       int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) {       \
    771     return GetriBatchedImpl(                                                 \
    772         reinterpret_cast<getri_##type_prefix*>(                              \
    773             BLAS_SOLVER_FN(getriBatched, type_prefix)),                      \
    774         this, context_, cublas_handle_, n, host_a_dev_ptrs, lda, dev_pivots, \
    775         host_a_inv_dev_ptrs, ldainv, dev_lapack_info, batch_size);           \
    776   }
    777 
    778 TF_CALL_LAPACK_TYPES(GETRI_BATCHED_INSTANCE);
    779 
    780 template <typename Scalar, typename SolverFnT>
    781 static inline Status MatInvBatchedImpl(
    782     SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context,
    783     cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[],
    784     int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv,
    785     DeviceLapackInfo* dev_lapack_info, int batch_size) {
    786   mutex_lock lock(handle_map_mutex);
    787   using CudaScalar = typename CUDAComplexT<Scalar>::type;
    788   ScratchSpace<uint8> dev_a_dev_ptrs =
    789       cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "",
    790                                           /* on_host */ false);
    791   ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>(
    792       sizeof(CudaScalar*) * batch_size, "", /* on_host */ false);
    793   if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */,
    794                         host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) ||
    795       !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(),
    796                         host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) {
    797     return errors::Internal("MatInvBatched: failed to copy pointers to device");
    798   }
    799   TF_RETURN_IF_CUBLAS_ERROR(solver(
    800       cublas_handle, n,
    801       reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda,
    802       reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()), ldainv,
    803       dev_lapack_info->mutable_data(), batch_size));
    804   return Status::OK();
    805 }
    806 
    807 #define MATINV_BATCHED_INSTANCE(Scalar, type_prefix)                          \
    808   template <>                                                                 \
    809   Status CudaSolver::MatInvBatched(                                           \
    810       int n, const Scalar* const host_a_dev_ptrs[], int lda,                  \
    811       const Scalar* const host_a_inv_dev_ptrs[], int ldainv,                  \
    812       DeviceLapackInfo* dev_lapack_info, int batch_size) {                    \
    813     return MatInvBatchedImpl(reinterpret_cast<matinv_##type_prefix*>(         \
    814                                  BLAS_SOLVER_FN(matinvBatched, type_prefix)), \
    815                              this, context_, cublas_handle_, n,               \
    816                              host_a_dev_ptrs, lda, host_a_inv_dev_ptrs,       \
    817                              ldainv, dev_lapack_info, batch_size);            \
    818   }
    819 
    820 TF_CALL_LAPACK_TYPES(MATINV_BATCHED_INSTANCE);
    821 
    822 }  // namespace tensorflow
    823 
    824 #endif  // GOOGLE_CUDA
    825