1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ============================================================================== 15 */ 16 #ifdef GOOGLE_CUDA 17 #include "tensorflow/core/kernels/cuda_solvers.h" 18 19 #include <chrono> 20 #include <complex> 21 #include <unordered_map> 22 #include <vector> 23 24 #include "cuda/include/cublas_v2.h" 25 #include "cuda/include/cusolverDn.h" 26 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/lib/core/blocking_counter.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/lib/core/stringpiece.h" 32 #include "tensorflow/core/lib/gtl/inlined_vector.h" 33 #include "tensorflow/core/platform/cuda.h" 34 #include "tensorflow/core/platform/mutex.h" 35 #include "tensorflow/core/platform/stream_executor.h" 36 #include "tensorflow/core/platform/types.h" 37 38 using ::perftools::gputools::cuda::ScopedActivateExecutorContext; 39 40 // The CUDA cublas_api.h API contains const-correctness errors. Instead of 41 // casting away constness on our data, we instead reinterpret the CuBLAS 42 // functions as what they were clearly meant to be, and thus we can call 43 // the functions naturally. 44 // 45 // (The error is that input-only arrays are bound to parameter types 46 // "const T**" instead of the correct "const T* const*".) 47 extern "C" { 48 using getrs_S = cublasStatus_t(cublasContext*, cublasOperation_t, int, int, 49 const float* const*, int, const int*, float**, 50 int, int*, int); 51 using getrs_D = cublasStatus_t(cublasContext*, cublasOperation_t, int, int, 52 const double* const*, int, const int*, double**, 53 int, int*, int); 54 using getrs_C = cublasStatus_t(cublasContext*, cublasOperation_t, int, int, 55 const float2* const*, int, const int*, float2**, 56 int, int*, int); 57 using getrs_Z = cublasStatus_t(cublasContext*, cublasOperation_t, int, int, 58 const double2* const*, int, const int*, 59 double2**, int, int*, int); 60 61 using getri_S = cublasStatus_t(cublasContext*, int, const float* const*, int, 62 const int*, float**, int, int*, int); 63 using getri_D = cublasStatus_t(cublasContext*, int, const double* const*, int, 64 const int*, double**, int, int*, int); 65 using getri_C = cublasStatus_t(cublasContext*, int, const float2* const*, int, 66 const int*, float2**, int, int*, int); 67 using getri_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int, 68 const int*, double2**, int, int*, int); 69 70 using matinv_S = cublasStatus_t(cublasContext*, int, const float* const*, int, 71 float**, int, int*, int); 72 using matinv_D = cublasStatus_t(cublasContext*, int, const double* const*, int, 73 double**, int, int*, int); 74 using matinv_C = cublasStatus_t(cublasContext*, int, const float2* const*, int, 75 float2**, int, int*, int); 76 using matinv_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int, 77 double2**, int, int*, int); 78 } 79 80 namespace tensorflow { 81 namespace { 82 83 inline bool CopyHostToDevice(OpKernelContext* context, void* dst, 84 const void* src, uint64 bytes) { 85 auto stream = context->op_device_context()->stream(); 86 perftools::gputools::DeviceMemoryBase wrapped_dst(dst); 87 return stream->ThenMemcpy(&wrapped_dst, src, bytes).ok(); 88 } 89 90 // A set of initialized handles to the underlying Cuda libraries used by 91 // CudaSolver. We maintain one such set of handles per unique stream. 92 struct CudaSolverHandles { 93 explicit CudaSolverHandles(cudaStream_t stream) { 94 CHECK(cusolverDnCreate(&cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS) 95 << "Failed to create cuSolverDN instance."; 96 CHECK(cusolverDnSetStream(cusolver_dn_handle, stream) == 97 CUSOLVER_STATUS_SUCCESS) 98 << "Failed to set cuSolverDN stream."; 99 CHECK(cublasCreate(&cublas_handle) == CUBLAS_STATUS_SUCCESS) 100 << "Failed to create cuBlas instance."; 101 CHECK(cublasSetStream(cublas_handle, stream) == CUBLAS_STATUS_SUCCESS) 102 << "Failed to set cuBlas stream."; 103 } 104 105 ~CudaSolverHandles() { 106 CHECK(cublasDestroy(cublas_handle) == CUBLAS_STATUS_SUCCESS) 107 << "Failed to destroy cuBlas instance."; 108 CHECK(cusolverDnDestroy(cusolver_dn_handle) == CUSOLVER_STATUS_SUCCESS) 109 << "Failed to destroy cuSolverDN instance."; 110 } 111 cublasHandle_t cublas_handle; 112 cusolverDnHandle_t cusolver_dn_handle; 113 }; 114 115 static mutex handle_map_mutex(LINKER_INITIALIZED); 116 117 using HandleMap = 118 std::unordered_map<cudaStream_t, std::unique_ptr<CudaSolverHandles>>; 119 120 // Returns a singleton map used for storing initialized handles for each unique 121 // cuda stream. 122 HandleMap* GetHandleMapSingleton() { 123 static HandleMap* cm = new HandleMap; 124 return cm; 125 } 126 127 } // namespace 128 129 #define TF_RETURN_IF_CUSOLVER_ERROR(expr) \ 130 do { \ 131 auto status = (expr); \ 132 if (TF_PREDICT_FALSE(status != CUSOLVER_STATUS_SUCCESS)) { \ 133 return errors::Internal( \ 134 __FILE__, ":", __LINE__, \ 135 ": cuSolverDN call failed with status =", status); \ 136 } \ 137 } while (0) 138 139 #define TF_RETURN_IF_CUBLAS_ERROR(expr) \ 140 do { \ 141 auto status = (expr); \ 142 if (TF_PREDICT_FALSE(status != CUBLAS_STATUS_SUCCESS)) { \ 143 return errors::Internal(__FILE__, ":", __LINE__, \ 144 ": cuBlas call failed status = ", status); \ 145 } \ 146 } while (0) 147 148 CudaSolver::CudaSolver(OpKernelContext* context) : context_(context) { 149 mutex_lock lock(handle_map_mutex); 150 const cudaStream_t* cu_stream_ptr = CHECK_NOTNULL( 151 reinterpret_cast<const cudaStream_t*>(context->op_device_context() 152 ->stream() 153 ->implementation() 154 ->CudaStreamMemberHack())); 155 cuda_stream_ = *cu_stream_ptr; 156 HandleMap* handle_map = CHECK_NOTNULL(GetHandleMapSingleton()); 157 auto it = handle_map->find(cuda_stream_); 158 if (it == handle_map->end()) { 159 LOG(INFO) << "Creating CudaSolver handles for stream " << cuda_stream_; 160 // Previously unseen Cuda stream. Initialize a set of Cuda solver library 161 // handles for it. 162 std::unique_ptr<CudaSolverHandles> new_handles( 163 new CudaSolverHandles(cuda_stream_)); 164 it = 165 handle_map->insert(std::make_pair(cuda_stream_, std::move(new_handles))) 166 .first; 167 } 168 cusolver_dn_handle_ = it->second->cusolver_dn_handle; 169 cublas_handle_ = it->second->cublas_handle; 170 } 171 172 CudaSolver::~CudaSolver() { 173 for (auto tensor_ref : scratch_tensor_refs_) { 174 tensor_ref.Unref(); 175 } 176 } 177 178 // static 179 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync( 180 std::unique_ptr<CudaSolver> solver, 181 const std::vector<DeviceLapackInfo>& dev_lapack_infos, 182 std::function<void(const Status&, const std::vector<HostLapackInfo>&)> 183 info_checker_callback) { 184 CHECK(info_checker_callback != nullptr); 185 std::vector<HostLapackInfo> host_lapack_infos; 186 if (dev_lapack_infos.empty()) { 187 info_checker_callback(Status::OK(), host_lapack_infos); 188 return; 189 } 190 191 // Launch memcpys to copy info back from the device to the host. 192 for (const auto& dev_lapack_info : dev_lapack_infos) { 193 bool success = true; 194 auto host_copy = dev_lapack_info.CopyToHost(&success); 195 OP_REQUIRES( 196 solver->context(), success, 197 errors::Internal( 198 "Failed to launch copy of dev_lapack_info to host, debug_info = ", 199 dev_lapack_info.debug_info())); 200 host_lapack_infos.push_back(std::move(host_copy)); 201 } 202 203 // This callback checks that all batch items in all calls were processed 204 // successfully and passes status to the info_checker_callback accordingly. 205 auto* stream = solver->context()->op_device_context()->stream(); 206 auto wrapped_info_checker_callback = 207 [stream]( 208 CudaSolver* solver, 209 std::function<void(const Status&, const std::vector<HostLapackInfo>&)> 210 info_checker_callback, 211 std::vector<HostLapackInfo> host_lapack_infos) { 212 ScopedActivateExecutorContext scoped_activation{stream->parent()}; 213 Status status; 214 for (const auto& host_lapack_info : host_lapack_infos) { 215 for (int i = 0; i < host_lapack_info.size() && status.ok(); ++i) { 216 const int info_value = host_lapack_info(i); 217 if (info_value != 0) { 218 status = errors::InvalidArgument( 219 "Got info = ", info_value, " for batch index ", i, 220 ", expected info = 0. Debug_info = ", 221 host_lapack_info.debug_info()); 222 } 223 } 224 if (!status.ok()) { 225 break; 226 } 227 } 228 // Delete solver to release temp tensor refs. 229 delete solver; 230 231 // Delegate further error checking to provided functor. 232 info_checker_callback(status, host_lapack_infos); 233 }; 234 // Note: An std::function cannot have unique_ptr arguments (it must be copy 235 // constructible and therefore so must its arguments). Therefore, we release 236 // solver into a raw pointer to be deleted at the end of 237 // wrapped_info_checker_callback. 238 // Release ownership of solver. It will be deleted in the cb callback. 239 auto solver_raw_ptr = solver.release(); 240 auto cb = 241 std::bind(wrapped_info_checker_callback, solver_raw_ptr, 242 std::move(info_checker_callback), std::move(host_lapack_infos)); 243 244 solver_raw_ptr->context() 245 ->device() 246 ->tensorflow_gpu_device_info() 247 ->event_mgr->ThenExecute(stream, std::move(cb)); 248 } 249 250 // static 251 void CudaSolver::CheckLapackInfoAndDeleteSolverAsync( 252 std::unique_ptr<CudaSolver> solver, 253 const std::vector<DeviceLapackInfo>& dev_lapack_info, 254 AsyncOpKernel::DoneCallback done) { 255 OpKernelContext* context = solver->context(); 256 auto wrapped_done = [context, done]( 257 const Status& status, 258 const std::vector<HostLapackInfo>& /* unused */) { 259 if (done != nullptr) { 260 OP_REQUIRES_OK_ASYNC(context, status, done); 261 done(); 262 } else { 263 OP_REQUIRES_OK(context, status); 264 } 265 }; 266 CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_lapack_info, 267 wrapped_done); 268 } 269 270 // Allocates a temporary tensor. The CudaSolver object maintains a 271 // TensorReference to the underlying Tensor to prevent it from being deallocated 272 // prematurely. 273 Status CudaSolver::allocate_scoped_tensor(DataType type, 274 const TensorShape& shape, 275 Tensor* out_temp) { 276 const Status status = context_->allocate_temp(type, shape, out_temp); 277 if (status.ok()) { 278 scratch_tensor_refs_.emplace_back(*out_temp); 279 } 280 return status; 281 } 282 283 Status CudaSolver::forward_input_or_allocate_scoped_tensor( 284 gtl::ArraySlice<int> candidate_input_indices, DataType type, 285 const TensorShape& shape, Tensor* out_temp) { 286 const Status status = context_->forward_input_or_allocate_temp( 287 candidate_input_indices, type, shape, out_temp); 288 if (status.ok()) { 289 scratch_tensor_refs_.emplace_back(*out_temp); 290 } 291 return status; 292 } 293 294 // Macro that specializes a solver method for all 4 standard 295 // numeric types. 296 #define TF_CALL_LAPACK_TYPES(m) \ 297 m(float, S) m(double, D) m(std::complex<float>, C) m(std::complex<double>, Z) 298 #define TF_CALL_LAPACK_TYPES_NO_COMPLEX(m) m(float, S) m(double, D) 299 300 // Macros to construct cusolverDn method names. 301 #define DN_SOLVER_FN(method, type_prefix) cusolverDn##type_prefix##method 302 #define DN_SOLVER_NAME(method, type_prefix) "cusolverDn" #type_prefix #method 303 #define DN_BUFSIZE_FN(method, type_prefix) \ 304 cusolverDn##type_prefix##method##_bufferSize 305 306 // Macros to construct cublas method names. 307 #define BLAS_SOLVER_FN(method, type_prefix) cublas##type_prefix##method 308 #define BLAS_SOLVER_NAME(method, type_prefix) "cublas" #type_prefix #method 309 310 //============================================================================= 311 // Wrappers of cuSolverDN computational methods begin here. 312 // 313 // WARNING to implementers: The function signatures listed in the online docs 314 // are sometimes inaccurate, e.g., are missing 'const' on pointers 315 // to immutable arguments, while the actual headers have them as expected. 316 // Check the actual declarations in the cusolver_api.h header file. 317 // 318 // NOTE: The cuSolver functions called below appear not to be threadsafe. 319 // so we put a global lock around the calls. Since these functions only put a 320 // kernel on the shared stream, it is not a big performance hit. 321 // TODO(rmlarsen): Investigate if the locking is still needed in Cuda 9. 322 //============================================================================= 323 324 template <typename Scalar, typename SolverFnT> 325 static inline Status GeamImpl(SolverFnT solver, cublasHandle_t cublas_handle, 326 cublasOperation_t transa, 327 cublasOperation_t transb, int m, int n, 328 const Scalar* alpha, /* host or device pointer */ 329 const Scalar* A, int lda, 330 const Scalar* beta, /* host or device pointer */ 331 const Scalar* B, int ldb, Scalar* C, int ldc) { 332 mutex_lock lock(handle_map_mutex); 333 using CudaScalar = typename CUDAComplexT<Scalar>::type; 334 TF_RETURN_IF_CUBLAS_ERROR(solver(cublas_handle, transa, transb, m, n, 335 reinterpret_cast<const CudaScalar*>(alpha), 336 reinterpret_cast<const CudaScalar*>(A), lda, 337 reinterpret_cast<const CudaScalar*>(beta), 338 reinterpret_cast<const CudaScalar*>(B), ldb, 339 reinterpret_cast<CudaScalar*>(C), ldc)); 340 return Status::OK(); 341 } 342 343 #define GEAM_INSTANCE(Scalar, type_prefix) \ 344 template <> \ 345 Status CudaSolver::Geam<Scalar>( \ 346 cublasOperation_t transa, cublasOperation_t transb, int m, int n, \ 347 const Scalar* alpha, /* host or device pointer */ \ 348 const Scalar* A, int lda, \ 349 const Scalar* beta, /* host or device pointer */ \ 350 const Scalar* B, int ldb, Scalar* C, int ldc) const { \ 351 return GeamImpl(BLAS_SOLVER_FN(geam, type_prefix), cublas_handle_, transa, \ 352 transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); \ 353 } 354 355 TF_CALL_LAPACK_TYPES(GEAM_INSTANCE); 356 357 template <typename Scalar, typename BufSizeFnT, typename SolverFnT> 358 static inline Status PotrfImpl(BufSizeFnT bufsize, SolverFnT solver, 359 CudaSolver* cuda_solver, 360 OpKernelContext* context, 361 cusolverDnHandle_t cusolver_dn_handle, 362 cublasFillMode_t uplo, int n, Scalar* A, int lda, 363 int* dev_lapack_info) { 364 mutex_lock lock(handle_map_mutex); 365 /* Get amount of workspace memory required. */ 366 int lwork; 367 TF_RETURN_IF_CUSOLVER_ERROR( 368 bufsize(cusolver_dn_handle, uplo, n, CUDAComplex(A), lda, &lwork)); 369 /* Allocate device memory for workspace. */ 370 auto dev_workspace = 371 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false); 372 /* Launch the solver kernel. */ 373 TF_RETURN_IF_CUSOLVER_ERROR(solver( 374 cusolver_dn_handle, uplo, n, CUDAComplex(A), lda, 375 CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info)); 376 return Status::OK(); 377 } 378 379 #define POTRF_INSTANCE(Scalar, type_prefix) \ 380 template <> \ 381 Status CudaSolver::Potrf<Scalar>(cublasFillMode_t uplo, int n, Scalar* A, \ 382 int lda, int* dev_lapack_info) { \ 383 return PotrfImpl(DN_BUFSIZE_FN(potrf, type_prefix), \ 384 DN_SOLVER_FN(potrf, type_prefix), this, context_, \ 385 cusolver_dn_handle_, uplo, n, A, lda, dev_lapack_info); \ 386 } 387 388 TF_CALL_LAPACK_TYPES(POTRF_INSTANCE); 389 390 template <typename Scalar, typename BufSizeFnT, typename SolverFnT> 391 static inline Status GetrfImpl(BufSizeFnT bufsize, SolverFnT solver, 392 CudaSolver* cuda_solver, 393 OpKernelContext* context, 394 cusolverDnHandle_t cusolver_dn_handle, int m, 395 int n, Scalar* A, int lda, int* dev_pivots, 396 int* dev_lapack_info) { 397 mutex_lock lock(handle_map_mutex); 398 /* Get amount of workspace memory required. */ 399 int lwork; 400 TF_RETURN_IF_CUSOLVER_ERROR( 401 bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork)); 402 /* Allocate device memory for workspace. */ 403 auto dev_workspace = 404 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false); 405 /* Launch the solver kernel. */ 406 TF_RETURN_IF_CUSOLVER_ERROR(solver( 407 cusolver_dn_handle, m, n, CUDAComplex(A), lda, 408 CUDAComplex(dev_workspace.mutable_data()), dev_pivots, dev_lapack_info)); 409 return Status::OK(); 410 } 411 412 #define GETRF_INSTANCE(Scalar, type_prefix) \ 413 template <> \ 414 Status CudaSolver::Getrf<Scalar>(int m, int n, Scalar* A, int lda, \ 415 int* dev_pivots, int* dev_lapack_info) { \ 416 return GetrfImpl(DN_BUFSIZE_FN(getrf, type_prefix), \ 417 DN_SOLVER_FN(getrf, type_prefix), this, context_, \ 418 cusolver_dn_handle_, m, n, A, lda, dev_pivots, \ 419 dev_lapack_info); \ 420 } 421 422 TF_CALL_LAPACK_TYPES(GETRF_INSTANCE); 423 424 template <typename Scalar, typename SolverFnT> 425 static inline Status GetrsImpl(SolverFnT solver, OpKernelContext* context, 426 cusolverDnHandle_t cusolver_dn_handle, 427 cublasOperation_t trans, int n, int nrhs, 428 const Scalar* A, int lda, const int* pivots, 429 Scalar* B, int ldb, int* dev_lapack_info) { 430 mutex_lock lock(handle_map_mutex); 431 /* Launch the solver kernel. */ 432 TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, trans, n, nrhs, 433 CUDAComplex(A), lda, pivots, 434 CUDAComplex(B), ldb, dev_lapack_info)); 435 return Status::OK(); 436 } 437 438 #define GETRS_INSTANCE(Scalar, type_prefix) \ 439 template <> \ 440 Status CudaSolver::Getrs<Scalar>( \ 441 cublasOperation_t trans, int n, int nrhs, const Scalar* A, int lda, \ 442 const int* pivots, Scalar* B, int ldb, int* dev_lapack_info) const { \ 443 return GetrsImpl(DN_SOLVER_FN(getrs, type_prefix), context_, \ 444 cusolver_dn_handle_, trans, n, nrhs, A, lda, pivots, B, \ 445 ldb, dev_lapack_info); \ 446 } 447 448 TF_CALL_LAPACK_TYPES(GETRS_INSTANCE); 449 450 template <typename Scalar, typename BufSizeFnT, typename SolverFnT> 451 static inline Status GeqrfImpl(BufSizeFnT bufsize, SolverFnT solver, 452 CudaSolver* cuda_solver, 453 OpKernelContext* context, 454 cusolverDnHandle_t cusolver_dn_handle, int m, 455 int n, Scalar* A, int lda, Scalar* tau, 456 int* dev_lapack_info) { 457 mutex_lock lock(handle_map_mutex); 458 /* Get amount of workspace memory required. */ 459 int lwork; 460 TF_RETURN_IF_CUSOLVER_ERROR( 461 bufsize(cusolver_dn_handle, m, n, CUDAComplex(A), lda, &lwork)); 462 /* Allocate device memory for workspace. */ 463 auto dev_workspace = 464 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false); 465 /* Launch the solver kernel. */ 466 TF_RETURN_IF_CUSOLVER_ERROR(solver( 467 cusolver_dn_handle, m, n, CUDAComplex(A), lda, CUDAComplex(tau), 468 CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info)); 469 return Status::OK(); 470 } 471 472 #define GEQRF_INSTANCE(Scalar, type_prefix) \ 473 template <> \ 474 Status CudaSolver::Geqrf<Scalar>(int m, int n, Scalar* A, int lda, \ 475 Scalar* tau, int* dev_lapack_info) { \ 476 return GeqrfImpl(DN_BUFSIZE_FN(geqrf, type_prefix), \ 477 DN_SOLVER_FN(geqrf, type_prefix), this, context_, \ 478 cusolver_dn_handle_, m, n, A, lda, tau, dev_lapack_info); \ 479 } 480 481 TF_CALL_LAPACK_TYPES(GEQRF_INSTANCE); 482 483 template <typename Scalar, typename BufSizeFnT, typename SolverFnT> 484 static inline Status UnmqrImpl(BufSizeFnT bufsize, SolverFnT solver, 485 CudaSolver* cuda_solver, 486 OpKernelContext* context, 487 cusolverDnHandle_t cusolver_dn_handle, 488 cublasSideMode_t side, cublasOperation_t trans, 489 int m, int n, int k, const Scalar* dev_a, 490 int lda, const Scalar* dev_tau, Scalar* dev_c, 491 int ldc, int* dev_lapack_info) { 492 mutex_lock lock(handle_map_mutex); 493 /* Get amount of workspace memory required. */ 494 int lwork; 495 TF_RETURN_IF_CUSOLVER_ERROR( 496 bufsize(cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda, 497 CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, &lwork)); 498 /* Allocate device memory for workspace. */ 499 auto dev_workspace = 500 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false); 501 /* Launch the solver kernel. */ 502 TF_RETURN_IF_CUSOLVER_ERROR(solver( 503 cusolver_dn_handle, side, trans, m, n, k, CUDAComplex(dev_a), lda, 504 CUDAComplex(dev_tau), CUDAComplex(dev_c), ldc, 505 CUDAComplex(dev_workspace.mutable_data()), lwork, dev_lapack_info)); 506 return Status::OK(); 507 } 508 509 // Unfortunately the LAPACK function name differs for the real and complex case 510 // (complex ones are prefixed with "UN" for "unitary"), so we instantiate each 511 // one separately. 512 #define UNMQR_INSTANCE(Scalar, function_prefix, type_prefix) \ 513 template <> \ 514 Status CudaSolver::Unmqr(cublasSideMode_t side, cublasOperation_t trans, \ 515 int m, int n, int k, const Scalar* dev_a, int lda, \ 516 const Scalar* dev_tau, Scalar* dev_c, int ldc, \ 517 int* dev_lapack_info) { \ 518 return UnmqrImpl(DN_BUFSIZE_FN(function_prefix##mqr, type_prefix), \ 519 DN_SOLVER_FN(function_prefix##mqr, type_prefix), this, \ 520 context_, cusolver_dn_handle_, side, trans, m, n, k, \ 521 dev_a, lda, dev_tau, dev_c, ldc, dev_lapack_info); \ 522 } 523 524 UNMQR_INSTANCE(float, or, S); 525 UNMQR_INSTANCE(double, or, D); 526 UNMQR_INSTANCE(complex64, un, C); 527 UNMQR_INSTANCE(complex128, un, Z); 528 529 template <typename Scalar, typename BufSizeFnT, typename SolverFnT> 530 static inline Status UngqrImpl(BufSizeFnT bufsize, SolverFnT solver, 531 CudaSolver* cuda_solver, 532 OpKernelContext* context, 533 cusolverDnHandle_t cusolver_dn_handle, int m, 534 int n, int k, Scalar* dev_a, int lda, 535 const Scalar* dev_tau, int* dev_lapack_info) { 536 mutex_lock lock(handle_map_mutex); 537 /* Get amount of workspace memory required. */ 538 int lwork; 539 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, k, 540 CUDAComplex(dev_a), lda, 541 CUDAComplex(dev_tau), &lwork)); 542 /* Allocate device memory for workspace. */ 543 auto dev_workspace = 544 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false); 545 /* Launch the solver kernel. */ 546 TF_RETURN_IF_CUSOLVER_ERROR( 547 solver(cusolver_dn_handle, m, n, k, CUDAComplex(dev_a), lda, 548 CUDAComplex(dev_tau), CUDAComplex(dev_workspace.mutable_data()), 549 lwork, dev_lapack_info)); 550 return Status::OK(); 551 } 552 553 #define UNGQR_INSTANCE(Scalar, function_prefix, type_prefix) \ 554 template <> \ 555 Status CudaSolver::Ungqr(int m, int n, int k, Scalar* dev_a, int lda, \ 556 const Scalar* dev_tau, int* dev_lapack_info) { \ 557 return UngqrImpl(DN_BUFSIZE_FN(function_prefix##gqr, type_prefix), \ 558 DN_SOLVER_FN(function_prefix##gqr, type_prefix), this, \ 559 context_, cusolver_dn_handle_, m, n, k, dev_a, lda, \ 560 dev_tau, dev_lapack_info); \ 561 } 562 563 UNGQR_INSTANCE(float, or, S); 564 UNGQR_INSTANCE(double, or, D); 565 UNGQR_INSTANCE(complex64, un, C); 566 UNGQR_INSTANCE(complex128, un, Z); 567 568 template <typename Scalar, typename BufSizeFnT, typename SolverFnT> 569 static inline Status HeevdImpl(BufSizeFnT bufsize, SolverFnT solver, 570 CudaSolver* cuda_solver, 571 OpKernelContext* context, 572 cusolverDnHandle_t cusolver_dn_handle, 573 cusolverEigMode_t jobz, cublasFillMode_t uplo, 574 int n, Scalar* dev_A, int lda, 575 typename Eigen::NumTraits<Scalar>::Real* dev_W, 576 int* dev_lapack_info) { 577 mutex_lock lock(handle_map_mutex); 578 /* Get amount of workspace memory required. */ 579 int lwork; 580 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, jobz, uplo, n, 581 CUDAComplex(dev_A), lda, 582 CUDAComplex(dev_W), &lwork)); 583 /* Allocate device memory for workspace. */ 584 auto dev_workspace = 585 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false); 586 /* Launch the solver kernel. */ 587 TF_RETURN_IF_CUSOLVER_ERROR( 588 solver(cusolver_dn_handle, jobz, uplo, n, CUDAComplex(dev_A), lda, 589 CUDAComplex(dev_W), CUDAComplex(dev_workspace.mutable_data()), 590 lwork, dev_lapack_info)); 591 return Status::OK(); 592 } 593 594 #define HEEVD_INSTANCE(Scalar, function_prefix, type_prefix) \ 595 template <> \ 596 Status CudaSolver::Heevd(cusolverEigMode_t jobz, cublasFillMode_t uplo, \ 597 int n, Scalar* dev_A, int lda, \ 598 typename Eigen::NumTraits<Scalar>::Real* dev_W, \ 599 int* dev_lapack_info) { \ 600 return HeevdImpl(DN_BUFSIZE_FN(function_prefix##evd, type_prefix), \ 601 DN_SOLVER_FN(function_prefix##evd, type_prefix), this, \ 602 context_, cusolver_dn_handle_, jobz, uplo, n, dev_A, lda, \ 603 dev_W, dev_lapack_info); \ 604 } 605 606 HEEVD_INSTANCE(float, sy, S); 607 HEEVD_INSTANCE(double, sy, D); 608 HEEVD_INSTANCE(complex64, he, C); 609 HEEVD_INSTANCE(complex128, he, Z); 610 611 template <typename Scalar, typename BufSizeFnT, typename SolverFnT> 612 static inline Status GesvdImpl( 613 BufSizeFnT bufsize, SolverFnT solver, CudaSolver* cuda_solver, 614 OpKernelContext* context, cusolverDnHandle_t cusolver_dn_handle, 615 signed char jobu, signed char jobvt, int m, int n, Scalar* A, int lda, 616 Scalar* S, Scalar* U, int ldu, Scalar* VT, int ldvt, int* dev_lapack_info) { 617 mutex_lock lock(handle_map_mutex); 618 /* Get amount of workspace memory required. */ 619 int lwork; 620 TF_RETURN_IF_CUSOLVER_ERROR(bufsize(cusolver_dn_handle, m, n, &lwork)); 621 /* Allocate device memory for workspace. */ 622 auto dev_workspace = 623 cuda_solver->GetScratchSpace<Scalar>(lwork, "", /* on_host */ false); 624 TF_RETURN_IF_CUSOLVER_ERROR(solver(cusolver_dn_handle, jobu, jobvt, m, n, 625 CUDAComplex(A), lda, S, CUDAComplex(U), 626 ldu, CUDAComplex(VT), ldvt, 627 CUDAComplex(dev_workspace.mutable_data()), 628 lwork, nullptr, dev_lapack_info)); 629 return Status::OK(); 630 } 631 632 #define GESVD_INSTANCE(Scalar, type_prefix) \ 633 template <> \ 634 Status CudaSolver::Gesvd<Scalar>( \ 635 signed char jobu, signed char jobvt, int m, int n, Scalar* dev_A, \ 636 int lda, Scalar* dev_S, Scalar* dev_U, int ldu, Scalar* dev_VT, \ 637 int ldvt, int* dev_lapack_info) { \ 638 return GesvdImpl(DN_BUFSIZE_FN(gesvd, type_prefix), \ 639 DN_SOLVER_FN(gesvd, type_prefix), this, context_, \ 640 cusolver_dn_handle_, jobu, jobvt, m, n, dev_A, lda, \ 641 dev_S, dev_U, ldu, dev_VT, ldvt, dev_lapack_info); \ 642 } 643 644 TF_CALL_LAPACK_TYPES_NO_COMPLEX(GESVD_INSTANCE); 645 646 //============================================================================= 647 // Wrappers of cuBlas computational methods begin here. 648 // 649 // WARNING to implementers: The function signatures listed in the online docs 650 // are sometimes inaccurate, e.g., are missing 'const' on pointers 651 // to immutable arguments, while the actual headers have them as expected. 652 // Check the actual declarations in the cublas_api.h header file. 653 //============================================================================= 654 template <typename Scalar, typename SolverFnT> 655 static inline Status GetrfBatchedImpl(SolverFnT solver, CudaSolver* cuda_solver, 656 OpKernelContext* context, 657 cublasHandle_t cublas_handle, int n, 658 const Scalar* const host_a_dev_ptrs[], 659 int lda, int* dev_pivots, 660 DeviceLapackInfo* dev_lapack_info, 661 int batch_size) { 662 mutex_lock lock(handle_map_mutex); 663 using CudaScalar = typename CUDAComplexT<Scalar>::type; 664 ScratchSpace<uint8> dev_a_dev_ptrs = 665 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", 666 /* on_host */ false); 667 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */, 668 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) { 669 return errors::Internal("GetrfBatched: failed to copy pointers to device"); 670 } 671 TF_RETURN_IF_CUBLAS_ERROR( 672 solver(cublas_handle, n, 673 reinterpret_cast<CudaScalar**>(dev_a_dev_ptrs.mutable_data()), lda, 674 dev_pivots, dev_lapack_info->mutable_data(), batch_size)); 675 return Status::OK(); 676 } 677 678 #define GETRF_BATCHED_INSTANCE(Scalar, type_prefix) \ 679 template <> \ 680 Status CudaSolver::GetrfBatched( \ 681 int n, const Scalar* const host_a_dev_ptrs[], int lda, int* dev_pivots, \ 682 DeviceLapackInfo* dev_lapack_info, int batch_size) { \ 683 return GetrfBatchedImpl(BLAS_SOLVER_FN(getrfBatched, type_prefix), this, \ 684 context_, cublas_handle_, n, host_a_dev_ptrs, lda, \ 685 dev_pivots, dev_lapack_info, batch_size); \ 686 } 687 688 TF_CALL_LAPACK_TYPES(GETRF_BATCHED_INSTANCE); 689 690 template <typename Scalar, typename SolverFnT> 691 static inline Status GetrsBatchedImpl( 692 SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context, 693 cublasHandle_t cublas_handle, cublasOperation_t trans, int n, int nrhs, 694 const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, 695 const Scalar* const host_b_dev_ptrs[], int ldb, 696 DeviceLapackInfo* dev_lapack_info, int batch_size) { 697 mutex_lock lock(handle_map_mutex); 698 using CudaScalar = typename CUDAComplexT<Scalar>::type; 699 ScratchSpace<uint8> dev_a_dev_ptrs = 700 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", 701 /* on_host */ false); 702 ScratchSpace<uint8> dev_b_dev_ptrs = 703 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", 704 /* on_host */ false); 705 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */, 706 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes())) { 707 return errors::Internal("GetrsBatched: failed to copy pointers to device"); 708 } 709 if (!CopyHostToDevice(context, dev_b_dev_ptrs.mutable_data() /* dest */, 710 host_b_dev_ptrs /* source */, dev_b_dev_ptrs.bytes())) { 711 return errors::Internal("GetrsBatched: failed to copy pointers to device"); 712 } 713 TF_RETURN_IF_CUBLAS_ERROR(solver( 714 cublas_handle, trans, n, nrhs, 715 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda, 716 dev_pivots, reinterpret_cast<CudaScalar**>(dev_b_dev_ptrs.mutable_data()), 717 ldb, dev_lapack_info->mutable_data(), batch_size)); 718 return Status::OK(); 719 } 720 721 #define GETRS_BATCHED_INSTANCE(Scalar, type_prefix) \ 722 template <> \ 723 Status CudaSolver::GetrsBatched( \ 724 cublasOperation_t trans, int n, int nrhs, \ 725 const Scalar* const host_a_dev_ptrs[], int lda, const int* dev_pivots, \ 726 const Scalar* const host_b_dev_ptrs[], int ldb, \ 727 DeviceLapackInfo* dev_lapack_info, int batch_size) { \ 728 return GetrsBatchedImpl(reinterpret_cast<getrs_##type_prefix*>( \ 729 BLAS_SOLVER_FN(getrsBatched, type_prefix)), \ 730 this, context_, cublas_handle_, trans, n, nrhs, \ 731 host_a_dev_ptrs, lda, dev_pivots, host_b_dev_ptrs, \ 732 ldb, dev_lapack_info, batch_size); \ 733 } 734 735 TF_CALL_LAPACK_TYPES(GETRS_BATCHED_INSTANCE); 736 737 template <typename Scalar, typename SolverFnT> 738 static inline Status GetriBatchedImpl( 739 SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context, 740 cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[], 741 int lda, const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[], 742 int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) { 743 mutex_lock lock(handle_map_mutex); 744 using CudaScalar = typename CUDAComplexT<Scalar>::type; 745 ScratchSpace<uint8> dev_a_dev_ptrs = 746 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", 747 /* on_host */ false); 748 ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>( 749 sizeof(CudaScalar*) * batch_size, "", /* on_host */ false); 750 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */, 751 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) || 752 !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(), 753 host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) { 754 return errors::Internal("GetriBatched: failed to copy pointers to device"); 755 } 756 TF_RETURN_IF_CUBLAS_ERROR( 757 solver(cublas_handle, n, 758 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), 759 lda, dev_pivots, 760 reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()), 761 ldainv, dev_lapack_info->mutable_data(), batch_size)); 762 return Status::OK(); 763 } 764 765 #define GETRI_BATCHED_INSTANCE(Scalar, type_prefix) \ 766 template <> \ 767 Status CudaSolver::GetriBatched( \ 768 int n, const Scalar* const host_a_dev_ptrs[], int lda, \ 769 const int* dev_pivots, const Scalar* const host_a_inv_dev_ptrs[], \ 770 int ldainv, DeviceLapackInfo* dev_lapack_info, int batch_size) { \ 771 return GetriBatchedImpl( \ 772 reinterpret_cast<getri_##type_prefix*>( \ 773 BLAS_SOLVER_FN(getriBatched, type_prefix)), \ 774 this, context_, cublas_handle_, n, host_a_dev_ptrs, lda, dev_pivots, \ 775 host_a_inv_dev_ptrs, ldainv, dev_lapack_info, batch_size); \ 776 } 777 778 TF_CALL_LAPACK_TYPES(GETRI_BATCHED_INSTANCE); 779 780 template <typename Scalar, typename SolverFnT> 781 static inline Status MatInvBatchedImpl( 782 SolverFnT solver, CudaSolver* cuda_solver, OpKernelContext* context, 783 cublasHandle_t cublas_handle, int n, const Scalar* const host_a_dev_ptrs[], 784 int lda, const Scalar* const host_a_inv_dev_ptrs[], int ldainv, 785 DeviceLapackInfo* dev_lapack_info, int batch_size) { 786 mutex_lock lock(handle_map_mutex); 787 using CudaScalar = typename CUDAComplexT<Scalar>::type; 788 ScratchSpace<uint8> dev_a_dev_ptrs = 789 cuda_solver->GetScratchSpace<uint8>(sizeof(CudaScalar*) * batch_size, "", 790 /* on_host */ false); 791 ScratchSpace<uint8> dev_a_inv_dev_ptrs = cuda_solver->GetScratchSpace<uint8>( 792 sizeof(CudaScalar*) * batch_size, "", /* on_host */ false); 793 if (!CopyHostToDevice(context, dev_a_dev_ptrs.mutable_data() /* dest */, 794 host_a_dev_ptrs /* source */, dev_a_dev_ptrs.bytes()) || 795 !CopyHostToDevice(context, dev_a_inv_dev_ptrs.mutable_data(), 796 host_a_inv_dev_ptrs, dev_a_inv_dev_ptrs.bytes())) { 797 return errors::Internal("MatInvBatched: failed to copy pointers to device"); 798 } 799 TF_RETURN_IF_CUBLAS_ERROR(solver( 800 cublas_handle, n, 801 reinterpret_cast<const CudaScalar* const*>(dev_a_dev_ptrs.data()), lda, 802 reinterpret_cast<CudaScalar**>(dev_a_inv_dev_ptrs.mutable_data()), ldainv, 803 dev_lapack_info->mutable_data(), batch_size)); 804 return Status::OK(); 805 } 806 807 #define MATINV_BATCHED_INSTANCE(Scalar, type_prefix) \ 808 template <> \ 809 Status CudaSolver::MatInvBatched( \ 810 int n, const Scalar* const host_a_dev_ptrs[], int lda, \ 811 const Scalar* const host_a_inv_dev_ptrs[], int ldainv, \ 812 DeviceLapackInfo* dev_lapack_info, int batch_size) { \ 813 return MatInvBatchedImpl(reinterpret_cast<matinv_##type_prefix*>( \ 814 BLAS_SOLVER_FN(matinvBatched, type_prefix)), \ 815 this, context_, cublas_handle_, n, \ 816 host_a_dev_ptrs, lda, host_a_inv_dev_ptrs, \ 817 ldainv, dev_lapack_info, batch_size); \ 818 } 819 820 TF_CALL_LAPACK_TYPES(MATINV_BATCHED_INSTANCE); 821 822 } // namespace tensorflow 823 824 #endif // GOOGLE_CUDA 825