1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Include cuBLAS headers early, and then set EIGEN_HAS_CUDA_FP16 17 // if we have new enough CUDA (which we will only know after including 18 // cuda.h). This ensures that Eigen's Half.h does not attempt to make its own 19 // __half typedef if CUDA has already defined one (and conversely, that we do 20 // not include <cuda_fp16.h> after Half.h has made its typedef). 21 #include "cuda/include/cuda.h" 22 #include "cuda/include/cublas_v2.h" 23 24 #if CUDA_VERSION >= 7050 25 #define EIGEN_HAS_CUDA_FP16 26 #endif 27 28 #if CUDA_VERSION >= 8000 29 #define SE_CUDA_DATA_HALF CUDA_R_16F 30 #else 31 #define SE_CUDA_DATA_HALF CUBLAS_DATA_HALF 32 #endif 33 34 #include "tensorflow/stream_executor/cuda/cuda_blas.h" 35 36 #include <assert.h> 37 #include <complex> 38 39 #include "tensorflow/core/util/env_var.h" 40 #include "tensorflow/stream_executor/cuda/cuda_activation.h" 41 #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" 42 #include "tensorflow/stream_executor/cuda/cuda_helpers.h" 43 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h" 44 #include "tensorflow/stream_executor/cuda/cuda_stream.h" 45 #include "tensorflow/stream_executor/cuda/cuda_timer.h" 46 #include "tensorflow/stream_executor/device_memory.h" 47 #include "tensorflow/stream_executor/lib/env.h" 48 #include "tensorflow/stream_executor/lib/initialize.h" 49 #include "tensorflow/stream_executor/lib/status.h" 50 #include "tensorflow/stream_executor/lib/status_macros.h" 51 #include "tensorflow/stream_executor/lib/strcat.h" 52 #include "tensorflow/stream_executor/lib/stringprintf.h" 53 #include "tensorflow/stream_executor/platform/logging.h" 54 #include "tensorflow/stream_executor/platform/port.h" 55 #include "tensorflow/stream_executor/plugin_registry.h" 56 #include "tensorflow/stream_executor/scratch_allocator.h" 57 #include "tensorflow/stream_executor/stream_executor.h" 58 59 namespace perftools { 60 namespace gputools { 61 namespace cuda { 62 63 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuBlasPlugin); 64 65 namespace wrap { 66 67 #define PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(__name) \ 68 struct WrapperShim__##__name { \ 69 static const char *kName; \ 70 template <typename... Args> \ 71 cublasStatus_t operator()(CUDAExecutor *parent, Args... args) { \ 72 cuda::ScopedActivateExecutorContext sac{parent}; \ 73 return ::__name(args...); \ 74 } \ 75 } __name; \ 76 const char *WrapperShim__##__name::kName = #__name; 77 78 #define PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(__name) \ 79 PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(__name) 80 81 #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ 82 __macro(cublasSnrm2) \ 83 __macro(cublasDnrm2) \ 84 __macro(cublasScnrm2) \ 85 __macro(cublasDznrm2) \ 86 __macro(cublasSdot) \ 87 __macro(cublasDdot) \ 88 __macro(cublasCdotu) \ 89 __macro(cublasCdotc) \ 90 __macro(cublasZdotu) \ 91 __macro(cublasZdotc) \ 92 __macro(cublasSscal) \ 93 __macro(cublasDscal) \ 94 __macro(cublasCscal) \ 95 __macro(cublasCsscal) \ 96 __macro(cublasZscal) \ 97 __macro(cublasZdscal) \ 98 __macro(cublasSaxpy) \ 99 __macro(cublasDaxpy) \ 100 __macro(cublasCaxpy) \ 101 __macro(cublasZaxpy) \ 102 __macro(cublasScopy) \ 103 __macro(cublasDcopy) \ 104 __macro(cublasCcopy) \ 105 __macro(cublasZcopy) \ 106 __macro(cublasSswap) \ 107 __macro(cublasDswap) \ 108 __macro(cublasCswap) \ 109 __macro(cublasZswap) \ 110 __macro(cublasIsamax) \ 111 __macro(cublasIdamax) \ 112 __macro(cublasIcamax) \ 113 __macro(cublasIzamax) \ 114 __macro(cublasIsamin) \ 115 __macro(cublasIdamin) \ 116 __macro(cublasIcamin) \ 117 __macro(cublasIzamin) \ 118 __macro(cublasSasum) \ 119 __macro(cublasDasum) \ 120 __macro(cublasScasum) \ 121 __macro(cublasDzasum) \ 122 __macro(cublasSrot) \ 123 __macro(cublasDrot) \ 124 __macro(cublasCrot) \ 125 __macro(cublasCsrot) \ 126 __macro(cublasZrot) \ 127 __macro(cublasZdrot) \ 128 __macro(cublasSrotg) \ 129 __macro(cublasDrotg) \ 130 __macro(cublasCrotg) \ 131 __macro(cublasZrotg) \ 132 __macro(cublasSrotm) \ 133 __macro(cublasDrotm) \ 134 __macro(cublasSrotmg) \ 135 __macro(cublasDrotmg) \ 136 __macro(cublasSgemv) \ 137 __macro(cublasDgemv) \ 138 __macro(cublasCgemv) \ 139 __macro(cublasZgemv) \ 140 __macro(cublasSgbmv) \ 141 __macro(cublasDgbmv) \ 142 __macro(cublasCgbmv) \ 143 __macro(cublasZgbmv) \ 144 __macro(cublasStrmv) \ 145 __macro(cublasDtrmv) \ 146 __macro(cublasCtrmv) \ 147 __macro(cublasZtrmv) \ 148 __macro(cublasStbmv) \ 149 __macro(cublasDtbmv) \ 150 __macro(cublasCtbmv) \ 151 __macro(cublasZtbmv) \ 152 __macro(cublasStpmv) \ 153 __macro(cublasDtpmv) \ 154 __macro(cublasCtpmv) \ 155 __macro(cublasZtpmv) \ 156 __macro(cublasStrsv) \ 157 __macro(cublasDtrsv) \ 158 __macro(cublasCtrsv) \ 159 __macro(cublasZtrsv) \ 160 __macro(cublasStpsv) \ 161 __macro(cublasDtpsv) \ 162 __macro(cublasCtpsv) \ 163 __macro(cublasZtpsv) \ 164 __macro(cublasStbsv) \ 165 __macro(cublasDtbsv) \ 166 __macro(cublasCtbsv) \ 167 __macro(cublasZtbsv) \ 168 __macro(cublasSsymv) \ 169 __macro(cublasDsymv) \ 170 __macro(cublasCsymv) \ 171 __macro(cublasZsymv) \ 172 __macro(cublasChemv) \ 173 __macro(cublasZhemv) \ 174 __macro(cublasSsbmv) \ 175 __macro(cublasDsbmv) \ 176 __macro(cublasChbmv) \ 177 __macro(cublasZhbmv) \ 178 __macro(cublasSspmv) \ 179 __macro(cublasDspmv) \ 180 __macro(cublasChpmv) \ 181 __macro(cublasZhpmv) \ 182 __macro(cublasSger) \ 183 __macro(cublasDger) \ 184 __macro(cublasCgeru) \ 185 __macro(cublasCgerc) \ 186 __macro(cublasZgeru) \ 187 __macro(cublasZgerc) \ 188 __macro(cublasSsyr) \ 189 __macro(cublasDsyr) \ 190 __macro(cublasCsyr) \ 191 __macro(cublasZsyr) \ 192 __macro(cublasCher) \ 193 __macro(cublasZher) \ 194 __macro(cublasSspr) \ 195 __macro(cublasDspr) \ 196 __macro(cublasChpr) \ 197 __macro(cublasZhpr) \ 198 __macro(cublasSsyr2) \ 199 __macro(cublasDsyr2) \ 200 __macro(cublasCsyr2) \ 201 __macro(cublasZsyr2) \ 202 __macro(cublasCher2) \ 203 __macro(cublasZher2) \ 204 __macro(cublasSspr2) \ 205 __macro(cublasDspr2) \ 206 __macro(cublasChpr2) \ 207 __macro(cublasZhpr2) \ 208 __macro(cublasSgemm) \ 209 __macro(cublasDgemm) \ 210 __macro(cublasCgemm) \ 211 __macro(cublasZgemm) \ 212 __macro(cublasSsyrk) \ 213 __macro(cublasDsyrk) \ 214 __macro(cublasCsyrk) \ 215 __macro(cublasZsyrk) \ 216 __macro(cublasCherk) \ 217 __macro(cublasZherk) \ 218 __macro(cublasSsyr2k) \ 219 __macro(cublasDsyr2k) \ 220 __macro(cublasCsyr2k) \ 221 __macro(cublasZsyr2k) \ 222 __macro(cublasCher2k) \ 223 __macro(cublasZher2k) \ 224 __macro(cublasSsyrkx) \ 225 __macro(cublasDsyrkx) \ 226 __macro(cublasCsyrkx) \ 227 __macro(cublasZsyrkx) \ 228 __macro(cublasCherkx) \ 229 __macro(cublasZherkx) \ 230 __macro(cublasSsymm) \ 231 __macro(cublasDsymm) \ 232 __macro(cublasCsymm) \ 233 __macro(cublasZsymm) \ 234 __macro(cublasChemm) \ 235 __macro(cublasZhemm) \ 236 __macro(cublasStrsm) \ 237 __macro(cublasDtrsm) \ 238 __macro(cublasCtrsm) \ 239 __macro(cublasZtrsm) \ 240 __macro(cublasStrmm) \ 241 __macro(cublasDtrmm) \ 242 __macro(cublasCtrmm) \ 243 __macro(cublasZtrmm) \ 244 __macro(cublasSgeam) \ 245 __macro(cublasDgeam) \ 246 __macro(cublasCgeam) \ 247 __macro(cublasZgeam) \ 248 __macro(cublasSdgmm) \ 249 __macro(cublasDdgmm) \ 250 __macro(cublasCdgmm) \ 251 __macro(cublasZdgmm) 252 253 PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasCreate) 254 PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasDestroy) 255 PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasSetStream) 256 PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasSetPointerMode) 257 PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasGetPointerMode) 258 PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSgemmBatched) 259 PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasDgemmBatched) 260 PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasCgemmBatched) 261 PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasZgemmBatched) 262 CUBLAS_BLAS_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP) 263 264 #if CUDA_VERSION >= 7050 265 PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSgemmEx) 266 #endif 267 268 #if CUDA_VERSION >= 8000 269 PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGemmEx) 270 #endif 271 272 #if CUDA_VERSION >= 9000 273 PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGetMathMode) 274 PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSetMathMode) 275 #endif 276 277 } // namespace wrap 278 279 static string ToString(cublasStatus_t status) { 280 switch (status) { 281 case CUBLAS_STATUS_SUCCESS: 282 return "CUBLAS_STATUS_SUCCESS"; 283 case CUBLAS_STATUS_NOT_INITIALIZED: 284 return "CUBLAS_STATUS_NOT_INITIALIZED"; 285 case CUBLAS_STATUS_ALLOC_FAILED: 286 return "CUBLAS_STATUS_ALLOC_FAILED"; 287 case CUBLAS_STATUS_INVALID_VALUE: 288 return "CUBLAS_STATUS_INVALID_VALUE"; 289 case CUBLAS_STATUS_ARCH_MISMATCH: 290 return "CUBLAS_STATUS_ARCH_MISMATCH"; 291 case CUBLAS_STATUS_MAPPING_ERROR: 292 return "CUBLAS_STATUS_MAPPING_ERROR"; 293 case CUBLAS_STATUS_EXECUTION_FAILED: 294 return "CUBLAS_STATUS_EXECUTION_FAILED"; 295 case CUBLAS_STATUS_INTERNAL_ERROR: 296 return "CUBLAS_STATUS_INTERNAL_ERROR"; 297 #if CUDA_VERSION >= 8000 298 case CUBLAS_STATUS_NOT_SUPPORTED: 299 return "CUBLAS_STATUS_NOT_SUPPORTED"; 300 case CUBLAS_STATUS_LICENSE_ERROR: 301 return "CUBLAS_STATUS_LICENSE_ERROR"; 302 #endif 303 default: 304 return port::StrCat("<invalid cublas status: ", status, ">"); 305 } 306 } 307 308 // Decide whether to enable TENSOR_OP_MATH 309 static bool TensorOpMathEnabled() { 310 static bool is_enabled = [] { 311 bool is_disabled; 312 TF_CHECK_OK( 313 tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUBLAS_TENSOR_OP_MATH", 314 /*default_val=*/false, &is_disabled)); 315 return !is_disabled; 316 }(); 317 return is_enabled; 318 } 319 320 // cuBLAS has interfaces that permit pointers to be passed from either the host 321 // memory space or the device memory space; however, you must instruct it as to 322 // which address space those pointers are in with cublasSetPointerMode. 323 // 324 // This helper sets the cuBLAS pointer mode to a desired value for a cuBLAS call 325 // you are about to perform in a given scope. 326 // 327 // The prior cuBLAS pointer mode is retained and restored when this object goes 328 // out of scope. 329 class ScopedCublasPointerMode { 330 public: 331 // Note that, because the setting of the cublas pointer mode is fallible, 332 // construction of this scoped datatype must be paired with a call to 333 // Init(). 334 // 335 // Parameters: 336 // handle: The cublas library handle to act upon in setting the pointer mode. 337 explicit ScopedCublasPointerMode(CUDAExecutor *parent, cublasHandle_t handle) 338 : parent_(parent), handle_(handle), ok_(false) {} 339 340 // Attempts the switch to the requested scoped pointer mode, new_mode. 341 // 342 // Note that when false is returned, an appropriate error has already been 343 // logged. 344 bool Init(cublasPointerMode_t new_mode) { 345 cublasStatus_t ret = 346 wrap::cublasGetPointerMode(parent_, handle_, &old_mode_); 347 if (ret != CUBLAS_STATUS_SUCCESS) { 348 LOG(ERROR) << "failed to get old cublas pointer mode: " << ToString(ret); 349 return ok_ = false; 350 } 351 352 ret = wrap::cublasSetPointerMode(parent_, handle_, new_mode); 353 if (ret != CUBLAS_STATUS_SUCCESS) { 354 LOG(ERROR) << "failed to set new cublas pointer mode: " << ToString(ret); 355 return ok_ = false; 356 } 357 358 return ok_ = true; 359 } 360 361 // Switches back to the prior pointer mode, if the switch operation was 362 // successful in the first place. 363 ~ScopedCublasPointerMode() { 364 if (ok_) { 365 cublasStatus_t ret = 366 wrap::cublasSetPointerMode(parent_, handle_, old_mode_); 367 if (ret != CUBLAS_STATUS_SUCCESS) { 368 LOG(ERROR) << "failed to set former cublas pointer mode: " 369 << ToString(ret); 370 } 371 } 372 } 373 374 private: 375 CUDAExecutor *parent_; // Executor establishing this pointer mode for. 376 cublasHandle_t handle_; // Handle to the cuBLAS instance of interest. 377 cublasPointerMode_t old_mode_; // Prior cuBLAS pointer mode, to be restored. 378 bool ok_; // Whether the change was successful. 379 }; 380 381 #if CUDA_VERSION >= 9000 382 // cuBLAS has interfaces that permit computations to use the Volta hardware. 383 // This must be enabled via the cublasGet/SetMathMode APIs. 384 // 385 // This helper sets the cuBLAS math mode to a desired value for a cuBLAS call 386 // you are about to perform in a given scope. 387 // 388 // The prior cuBLAS math mode is retained and restored when this object goes 389 // out of scope. 390 class ScopedCublasMathMode { 391 public: 392 // Note that, because the setting of the cublas math mode is fallible, 393 // construction of this scoped datatype must be paired with a call to 394 // Init(). 395 // 396 // Parameters: 397 // handle: The cublas library handle to act upon in setting the math mode. 398 explicit ScopedCublasMathMode(CUDAExecutor *parent, cublasHandle_t handle) 399 : parent_(parent), handle_(handle), ok_(false) {} 400 401 // Attempts the switch to the requested scoped math mode, new_mode. 402 // 403 // Note that when false is returned, an appropriate error has already been 404 // logged. 405 bool Init(cublasMath_t new_mode) { 406 cublasStatus_t ret = wrap::cublasGetMathMode(parent_, handle_, &old_mode_); 407 if (ret != CUBLAS_STATUS_SUCCESS) { 408 LOG(ERROR) << "failed to get old cublas math mode: " << ToString(ret); 409 return ok_ = false; 410 } 411 412 ret = wrap::cublasSetMathMode(parent_, handle_, new_mode); 413 if (ret != CUBLAS_STATUS_SUCCESS) { 414 LOG(ERROR) << "failed to set new cublas math mode: " << ToString(ret); 415 return ok_ = false; 416 } 417 return ok_ = true; 418 } 419 420 // Switches back to the prior math mode, if the switch operation was 421 // successful in the first place. 422 ~ScopedCublasMathMode() { 423 if (ok_) { 424 cublasStatus_t ret = wrap::cublasSetMathMode(parent_, handle_, old_mode_); 425 if (ret != CUBLAS_STATUS_SUCCESS) { 426 LOG(ERROR) << "failed to set former cublas math mode: " 427 << ToString(ret); 428 } 429 } 430 } 431 432 private: 433 CUDAExecutor *parent_; // Executor establishing this math mode for. 434 cublasHandle_t handle_; // Handle to the cuBLAS instance of interest. 435 cublasMath_t old_mode_; // Prior cuBLAS math mode, to be restored. 436 bool ok_; // Whether the change was successful. 437 }; 438 #endif // CUDA_VERSION >= 9000 439 440 bool CUDABlas::Init() { 441 cublasStatus_t ret = wrap::cublasCreate(parent_, &blas_); 442 if (ret != CUBLAS_STATUS_SUCCESS) { 443 LOG(ERROR) << "failed to create cublas handle: " << ToString(ret); 444 return false; 445 } 446 447 return true; 448 } 449 450 CUDABlas::CUDABlas(cuda::CUDAExecutor *parent) 451 : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {} 452 453 CUDABlas::~CUDABlas() { 454 if (blas_ != nullptr) { 455 wrap::cublasDestroy(parent_, blas_); 456 } 457 } 458 459 bool CUDABlas::SetStream(Stream *stream) { 460 CHECK(stream != nullptr); 461 CHECK(AsCUDAStreamValue(stream) != nullptr); 462 CHECK(blas_ != nullptr); 463 cublasStatus_t ret = 464 wrap::cublasSetStream(parent_, blas_, AsCUDAStreamValue(stream)); 465 if (ret != CUBLAS_STATUS_SUCCESS) { 466 LOG(ERROR) << "failed to set stream for cuBLAS calls: " << ToString(ret); 467 return false; 468 } 469 470 return true; 471 } 472 473 namespace { 474 475 // Helper functions transforming blas arguments into cuBLAS arguments. 476 477 cublasOperation_t CUDABlasTranspose(blas::Transpose trans) { 478 switch (trans) { 479 case blas::Transpose::kNoTranspose: 480 return CUBLAS_OP_N; 481 case blas::Transpose::kTranspose: 482 return CUBLAS_OP_T; 483 case blas::Transpose::kConjugateTranspose: 484 return CUBLAS_OP_C; 485 default: 486 LOG(FATAL) << "Invalid value of blas::Transpose."; 487 } 488 } 489 490 cublasFillMode_t CUDABlasUpperLower(blas::UpperLower uplo) { 491 switch (uplo) { 492 case blas::UpperLower::kUpper: 493 return CUBLAS_FILL_MODE_UPPER; 494 case blas::UpperLower::kLower: 495 return CUBLAS_FILL_MODE_LOWER; 496 default: 497 LOG(FATAL) << "Invalid value of blas::UpperLower."; 498 } 499 } 500 501 cublasDiagType_t CUDABlasDiagonal(blas::Diagonal diag) { 502 switch (diag) { 503 case blas::Diagonal::kUnit: 504 return CUBLAS_DIAG_UNIT; 505 case blas::Diagonal::kNonUnit: 506 return CUBLAS_DIAG_NON_UNIT; 507 default: 508 LOG(FATAL) << "Invalid value of blas::Diagonal."; 509 } 510 } 511 512 cublasSideMode_t CUDABlasSide(blas::Side side) { 513 switch (side) { 514 case blas::Side::kLeft: 515 return CUBLAS_SIDE_LEFT; 516 case blas::Side::kRight: 517 return CUBLAS_SIDE_RIGHT; 518 default: 519 LOG(FATAL) << "Invalid value of blas::Side."; 520 } 521 } 522 523 // CUDADataType<T>::type translates from a C++ type (e.g. float) to a 524 // cudaDataType_t (e.g. CUDA_R_32F). CUDAComputationType(ty) translates from a 525 // blas::ComputationType to a cudaDataType_t. 526 // 527 // These are used to build the argument type and computation type args to 528 // cublasGemmEx. cublasGemmEx and cudaDataType_t are available only on 529 // CUDA >= 8.0. 530 #if CUDA_VERSION >= 8000 531 template <typename T> 532 struct CUDADataType; 533 534 template <> 535 struct CUDADataType<Eigen::half> { 536 static constexpr cudaDataType_t type = SE_CUDA_DATA_HALF; 537 }; 538 539 template <> 540 struct CUDADataType<std::complex<Eigen::half>> { 541 static constexpr cudaDataType_t type = CUDA_C_16F; 542 }; 543 544 template <> 545 struct CUDADataType<float> { 546 static constexpr cudaDataType_t type = CUDA_R_32F; 547 }; 548 549 template <> 550 struct CUDADataType<std::complex<float>> { 551 static constexpr cudaDataType_t type = CUDA_C_32F; 552 }; 553 554 template <> 555 struct CUDADataType<double> { 556 static constexpr cudaDataType_t type = CUDA_R_64F; 557 }; 558 559 template <> 560 struct CUDADataType<std::complex<double>> { 561 static constexpr cudaDataType_t type = CUDA_C_64F; 562 }; 563 564 template <> 565 struct CUDADataType<int> { 566 static constexpr cudaDataType_t type = CUDA_R_32I; 567 }; 568 569 template <> 570 struct CUDADataType<int8> { 571 static constexpr cudaDataType_t type = CUDA_R_8I; 572 }; 573 574 template <> 575 struct CUDADataType<std::complex<int8>> { 576 static constexpr cudaDataType_t type = CUDA_C_8I; 577 }; 578 579 template <> 580 struct CUDADataType<uint8> { 581 static constexpr cudaDataType_t type = CUDA_R_8U; 582 }; 583 584 template <> 585 struct CUDADataType<std::complex<uint8>> { 586 static constexpr cudaDataType_t type = CUDA_C_8U; 587 }; 588 589 cudaDataType_t CUDAComputationType(blas::ComputationType ty) { 590 switch (ty) { 591 case blas::ComputationType::kF16: 592 return CUDA_R_16F; 593 case blas::ComputationType::kF32: 594 return CUDA_R_32F; 595 case blas::ComputationType::kF64: 596 return CUDA_R_64F; 597 case blas::ComputationType::kI32: 598 return CUDA_R_32I; 599 case blas::ComputationType::kComplexF32: 600 return CUDA_C_32F; 601 case blas::ComputationType::kComplexF64: 602 return CUDA_C_64F; 603 } 604 } 605 #endif 606 607 } // namespace 608 609 template <typename FuncT, typename... Args> 610 bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream, 611 bool pointer_mode_host, bool err_on_failure, 612 bool use_tensor_op_math, Args... args) { 613 mutex_lock lock{mu_}; 614 615 CHECK(blas_ != nullptr); 616 if (!SetStream(stream)) { 617 return false; 618 } 619 620 ScopedCublasPointerMode pointer_mode{parent_, blas_}; 621 if (!pointer_mode.Init(pointer_mode_host ? CUBLAS_POINTER_MODE_HOST 622 : CUBLAS_POINTER_MODE_DEVICE)) { 623 return false; 624 } 625 #if CUDA_VERSION >= 9000 626 ScopedCublasMathMode math_mode{parent_, blas_}; 627 if (use_tensor_op_math) { 628 if (!math_mode.Init(CUBLAS_TENSOR_OP_MATH)) { 629 return false; 630 } 631 } 632 #endif 633 cublasStatus_t ret = cublas_func(parent_, blas_, args...); 634 if (err_on_failure && ret != CUBLAS_STATUS_SUCCESS) { 635 LOG(ERROR) << "failed to run cuBLAS routine " << cublas_func.kName << ": " 636 << ToString(ret); 637 } 638 return ret == CUBLAS_STATUS_SUCCESS; 639 } 640 641 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count, 642 const DeviceMemory<float> &x, int incx, 643 DeviceMemory<float> *result) { 644 return DoBlasInternal(wrap::cublasSasum, stream, 645 false /* = pointer_mode_host */, elem_count, 646 CUDAMemory(x), incx, CUDAMemoryMutable(result)); 647 } 648 649 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count, 650 const DeviceMemory<double> &x, int incx, 651 DeviceMemory<double> *result) { 652 return DoBlasInternal(wrap::cublasDasum, stream, 653 false /* = pointer_mode_host */, elem_count, 654 CUDAMemory(x), incx, CUDAMemoryMutable(result)); 655 } 656 657 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count, 658 const DeviceMemory<std::complex<float>> &x, int incx, 659 DeviceMemory<float> *result) { 660 return DoBlasInternal( 661 wrap::cublasScasum, stream, false /* = pointer_mode_host */, elem_count, 662 CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result)); 663 } 664 665 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count, 666 const DeviceMemory<std::complex<double>> &x, int incx, 667 DeviceMemory<double> *result) { 668 return DoBlasInternal( 669 wrap::cublasDzasum, stream, false /* = pointer_mode_host */, elem_count, 670 CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result)); 671 } 672 673 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, 674 const DeviceMemory<float> &x, int incx, 675 DeviceMemory<float> *y, int incy) { 676 return DoBlasInternal(wrap::cublasSaxpy, stream, 677 true /* = pointer_mode_host */, elem_count, &alpha, 678 CUDAMemory(x), incx, CUDAMemoryMutable(y), incy); 679 } 680 681 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, 682 const DeviceMemory<double> &x, int incx, 683 DeviceMemory<double> *y, int incy) { 684 return DoBlasInternal(wrap::cublasDaxpy, stream, 685 true /* = pointer_mode_host */, elem_count, &alpha, 686 CUDAMemory(x), incx, CUDAMemoryMutable(y), incy); 687 } 688 689 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count, 690 std::complex<float> alpha, 691 const DeviceMemory<std::complex<float>> &x, int incx, 692 DeviceMemory<std::complex<float>> *y, int incy) { 693 return DoBlasInternal(wrap::cublasCaxpy, stream, 694 true /* = pointer_mode_host */, elem_count, 695 CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx, 696 CUDAComplex(CUDAMemoryMutable(y)), incy); 697 } 698 699 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count, 700 std::complex<double> alpha, 701 const DeviceMemory<std::complex<double>> &x, int incx, 702 DeviceMemory<std::complex<double>> *y, int incy) { 703 return DoBlasInternal(wrap::cublasZaxpy, stream, 704 true /* = pointer_mode_host */, elem_count, 705 CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx, 706 CUDAComplex(CUDAMemoryMutable(y)), incy); 707 } 708 709 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count, 710 const DeviceMemory<float> &x, int incx, 711 DeviceMemory<float> *y, int incy) { 712 return DoBlasInternal(wrap::cublasScopy, stream, 713 true /* = pointer_mode_host */, elem_count, 714 CUDAMemory(x), incx, CUDAMemoryMutable(y), incy); 715 } 716 717 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count, 718 const DeviceMemory<double> &x, int incx, 719 DeviceMemory<double> *y, int incy) { 720 return DoBlasInternal(wrap::cublasDcopy, stream, 721 true /* = pointer_mode_host */, elem_count, 722 CUDAMemory(x), incx, CUDAMemoryMutable(y), incy); 723 } 724 725 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count, 726 const DeviceMemory<std::complex<float>> &x, int incx, 727 DeviceMemory<std::complex<float>> *y, int incy) { 728 return DoBlasInternal(wrap::cublasCcopy, stream, 729 true /* = pointer_mode_host */, elem_count, 730 CUDAComplex(CUDAMemory(x)), incx, 731 CUDAComplex(CUDAMemoryMutable(y)), incy); 732 } 733 734 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count, 735 const DeviceMemory<std::complex<double>> &x, int incx, 736 DeviceMemory<std::complex<double>> *y, int incy) { 737 return DoBlasInternal(wrap::cublasZcopy, stream, 738 true /* = pointer_mode_host */, elem_count, 739 CUDAComplex(CUDAMemory(x)), incx, 740 CUDAComplex(CUDAMemoryMutable(y)), incy); 741 } 742 743 bool CUDABlas::DoBlasDot(Stream *stream, uint64 elem_count, 744 const DeviceMemory<float> &x, int incx, 745 const DeviceMemory<float> &y, int incy, 746 DeviceMemory<float> *result) { 747 return DoBlasInternal( 748 wrap::cublasSdot, stream, false /* = pointer_mode_host */, elem_count, 749 CUDAMemory(x), incx, CUDAMemory(y), incy, CUDAMemoryMutable(result)); 750 } 751 752 bool CUDABlas::DoBlasDot(Stream *stream, uint64 elem_count, 753 const DeviceMemory<double> &x, int incx, 754 const DeviceMemory<double> &y, int incy, 755 DeviceMemory<double> *result) { 756 return DoBlasInternal( 757 wrap::cublasDdot, stream, false /* = pointer_mode_host */, elem_count, 758 CUDAMemory(x), incx, CUDAMemory(y), incy, CUDAMemoryMutable(result)); 759 } 760 761 bool CUDABlas::DoBlasDotc(Stream *stream, uint64 elem_count, 762 const DeviceMemory<std::complex<float>> &x, int incx, 763 const DeviceMemory<std::complex<float>> &y, int incy, 764 DeviceMemory<std::complex<float>> *result) { 765 return DoBlasInternal( 766 wrap::cublasCdotc, stream, false /* = pointer_mode_host */, elem_count, 767 CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy, 768 CUDAComplex(CUDAMemoryMutable(result))); 769 } 770 771 bool CUDABlas::DoBlasDotc(Stream *stream, uint64 elem_count, 772 const DeviceMemory<std::complex<double>> &x, int incx, 773 const DeviceMemory<std::complex<double>> &y, int incy, 774 DeviceMemory<std::complex<double>> *result) { 775 return DoBlasInternal( 776 wrap::cublasZdotc, stream, false /* = pointer_mode_host */, elem_count, 777 CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy, 778 CUDAComplex(CUDAMemoryMutable(result))); 779 } 780 781 bool CUDABlas::DoBlasDotu(Stream *stream, uint64 elem_count, 782 const DeviceMemory<std::complex<float>> &x, int incx, 783 const DeviceMemory<std::complex<float>> &y, int incy, 784 DeviceMemory<std::complex<float>> *result) { 785 return DoBlasInternal( 786 wrap::cublasCdotu, stream, false /* = pointer_mode_host */, elem_count, 787 CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy, 788 CUDAComplex(CUDAMemoryMutable(result))); 789 } 790 791 bool CUDABlas::DoBlasDotu(Stream *stream, uint64 elem_count, 792 const DeviceMemory<std::complex<double>> &x, int incx, 793 const DeviceMemory<std::complex<double>> &y, int incy, 794 DeviceMemory<std::complex<double>> *result) { 795 return DoBlasInternal( 796 wrap::cublasZdotu, stream, false /* = pointer_mode_host */, elem_count, 797 CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy, 798 CUDAComplex(CUDAMemoryMutable(result))); 799 } 800 801 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count, 802 const DeviceMemory<float> &x, int incx, 803 DeviceMemory<float> *result) { 804 return DoBlasInternal(wrap::cublasSnrm2, stream, 805 false /* = pointer_mode_host */, elem_count, 806 CUDAMemory(x), incx, CUDAMemoryMutable(result)); 807 } 808 809 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count, 810 const DeviceMemory<double> &x, int incx, 811 DeviceMemory<double> *result) { 812 return DoBlasInternal(wrap::cublasDnrm2, stream, 813 false /* = pointer_mode_host */, elem_count, 814 CUDAMemory(x), incx, CUDAMemoryMutable(result)); 815 } 816 817 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count, 818 const DeviceMemory<std::complex<float>> &x, int incx, 819 DeviceMemory<float> *result) { 820 return DoBlasInternal( 821 wrap::cublasScnrm2, stream, false /* = pointer_mode_host */, elem_count, 822 CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result)); 823 } 824 825 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count, 826 const DeviceMemory<std::complex<double>> &x, int incx, 827 DeviceMemory<double> *result) { 828 return DoBlasInternal( 829 wrap::cublasDznrm2, stream, false /* = pointer_mode_host */, elem_count, 830 CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result)); 831 } 832 833 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count, 834 DeviceMemory<float> *x, int incx, 835 DeviceMemory<float> *y, int incy, float c, float s) { 836 return DoBlasInternal( 837 wrap::cublasSrot, stream, true /* = pointer_mode_host */, elem_count, 838 CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy, &c, &s); 839 } 840 841 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count, 842 DeviceMemory<double> *x, int incx, 843 DeviceMemory<double> *y, int incy, double c, 844 double s) { 845 return DoBlasInternal( 846 wrap::cublasDrot, stream, true /* = pointer_mode_host */, elem_count, 847 CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy, &c, &s); 848 } 849 850 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count, 851 DeviceMemory<std::complex<float>> *x, int incx, 852 DeviceMemory<std::complex<float>> *y, int incy, 853 float c, float s) { 854 return DoBlasInternal(wrap::cublasCsrot, stream, 855 true /* = pointer_mode_host */, elem_count, 856 CUDAComplex(CUDAMemoryMutable(x)), incx, 857 CUDAComplex(CUDAMemoryMutable(y)), incy, &c, &s); 858 } 859 860 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count, 861 DeviceMemory<std::complex<double>> *x, int incx, 862 DeviceMemory<std::complex<double>> *y, int incy, 863 double c, double s) { 864 return DoBlasInternal(wrap::cublasZdrot, stream, 865 true /* = pointer_mode_host */, elem_count, 866 CUDAComplex(CUDAMemoryMutable(x)), incx, 867 CUDAComplex(CUDAMemoryMutable(y)), incy, &c, &s); 868 } 869 870 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<float> *a, 871 DeviceMemory<float> *b, DeviceMemory<float> *c, 872 DeviceMemory<float> *s) { 873 return DoBlasInternal(wrap::cublasSrotg, stream, 874 false /* = pointer_mode_host */, CUDAMemoryMutable(a), 875 CUDAMemoryMutable(b), CUDAMemoryMutable(c), 876 CUDAMemoryMutable(s)); 877 } 878 879 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<double> *a, 880 DeviceMemory<double> *b, DeviceMemory<double> *c, 881 DeviceMemory<double> *s) { 882 return DoBlasInternal(wrap::cublasDrotg, stream, 883 false /* = pointer_mode_host */, 884 CUDAComplex(CUDAMemoryMutable(a)), CUDAMemoryMutable(b), 885 CUDAMemoryMutable(c), CUDAMemoryMutable(s)); 886 } 887 888 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a, 889 DeviceMemory<std::complex<float>> *b, 890 DeviceMemory<float> *c, 891 DeviceMemory<std::complex<float>> *s) { 892 return DoBlasInternal( 893 wrap::cublasCrotg, stream, false /* = pointer_mode_host */, 894 CUDAComplex(CUDAMemoryMutable(a)), CUDAComplex(CUDAMemoryMutable(b)), 895 CUDAComplex(CUDAMemoryMutable(c)), CUDAComplex(CUDAMemoryMutable(s))); 896 } 897 898 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a, 899 DeviceMemory<std::complex<double>> *b, 900 DeviceMemory<double> *c, 901 DeviceMemory<std::complex<double>> *s) { 902 return DoBlasInternal( 903 wrap::cublasZrotg, stream, false /* = pointer_mode_host */, 904 CUDAComplex(CUDAMemoryMutable(a)), CUDAComplex(CUDAMemoryMutable(b)), 905 CUDAComplex(CUDAMemoryMutable(c)), CUDAComplex(CUDAMemoryMutable(s))); 906 } 907 908 bool CUDABlas::DoBlasRotm(Stream *stream, uint64 elem_count, 909 DeviceMemory<float> *x, int incx, 910 DeviceMemory<float> *y, int incy, 911 const DeviceMemory<float> ¶m) { 912 return DoBlasInternal(wrap::cublasSrotm, stream, 913 false /* = pointer_mode_host */, elem_count, 914 CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy, 915 CUDAMemory(param)); 916 } 917 918 bool CUDABlas::DoBlasRotm(Stream *stream, uint64 elem_count, 919 DeviceMemory<double> *x, int incx, 920 DeviceMemory<double> *y, int incy, 921 const DeviceMemory<double> ¶m) { 922 return DoBlasInternal(wrap::cublasDrotm, stream, 923 false /* = pointer_mode_host */, elem_count, 924 CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy, 925 CUDAMemory(param)); 926 } 927 928 bool CUDABlas::DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1, 929 DeviceMemory<float> *d2, DeviceMemory<float> *x1, 930 const DeviceMemory<float> &y1, 931 DeviceMemory<float> *param) { 932 return DoBlasInternal(wrap::cublasSrotmg, stream, 933 false /* = pointer_mode_host */, CUDAMemoryMutable(d1), 934 CUDAMemoryMutable(d2), CUDAMemoryMutable(x1), 935 CUDAMemory(y1), CUDAMemoryMutable(param)); 936 } 937 938 bool CUDABlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1, 939 DeviceMemory<double> *d2, DeviceMemory<double> *x1, 940 const DeviceMemory<double> &y1, 941 DeviceMemory<double> *param) { 942 return DoBlasInternal(wrap::cublasDrotmg, stream, 943 false /* = pointer_mode_host */, CUDAMemoryMutable(d1), 944 CUDAMemoryMutable(d2), CUDAMemoryMutable(x1), 945 CUDAMemory(y1), CUDAMemoryMutable(param)); 946 } 947 948 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha, 949 DeviceMemory<float> *x, int incx) { 950 return DoBlasInternal(wrap::cublasSscal, stream, 951 true /* = pointer_mode_host */, elem_count, &alpha, 952 CUDAMemoryMutable(x), incx); 953 } 954 955 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha, 956 DeviceMemory<double> *x, int incx) { 957 return DoBlasInternal(wrap::cublasDscal, stream, 958 true /* = pointer_mode_host */, elem_count, &alpha, 959 CUDAMemoryMutable(x), incx); 960 } 961 962 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha, 963 DeviceMemory<std::complex<float>> *x, int incx) { 964 return DoBlasInternal( 965 wrap::cublasCsscal, stream, true /* = pointer_mode_host */, elem_count, 966 CUDAComplex(&alpha), CUDAComplex(CUDAMemoryMutable(x)), incx); 967 } 968 969 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha, 970 DeviceMemory<std::complex<double>> *x, int incx) { 971 return DoBlasInternal( 972 wrap::cublasZdscal, stream, true /* = pointer_mode_host */, elem_count, 973 CUDAComplex(&alpha), CUDAComplex(CUDAMemoryMutable(x)), incx); 974 } 975 976 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, 977 std::complex<float> alpha, 978 DeviceMemory<std::complex<float>> *x, int incx) { 979 return DoBlasInternal( 980 wrap::cublasCscal, stream, true /* = pointer_mode_host */, elem_count, 981 CUDAComplex(&alpha), CUDAComplex(CUDAMemoryMutable(x)), incx); 982 } 983 984 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, 985 std::complex<double> alpha, 986 DeviceMemory<std::complex<double>> *x, int incx) { 987 return DoBlasInternal( 988 wrap::cublasZscal, stream, true /* = pointer_mode_host */, elem_count, 989 CUDAComplex(&alpha), CUDAComplex(CUDAMemoryMutable(x)), incx); 990 } 991 992 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count, 993 DeviceMemory<float> *x, int incx, 994 DeviceMemory<float> *y, int incy) { 995 return DoBlasInternal(wrap::cublasSswap, stream, 996 true /* = pointer_mode_host */, elem_count, 997 CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy); 998 } 999 1000 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count, 1001 DeviceMemory<double> *x, int incx, 1002 DeviceMemory<double> *y, int incy) { 1003 return DoBlasInternal(wrap::cublasDswap, stream, 1004 true /* = pointer_mode_host */, elem_count, 1005 CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy); 1006 } 1007 1008 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count, 1009 DeviceMemory<std::complex<float>> *x, int incx, 1010 DeviceMemory<std::complex<float>> *y, int incy) { 1011 return DoBlasInternal(wrap::cublasCswap, stream, 1012 true /* = pointer_mode_host */, elem_count, 1013 CUDAComplex(CUDAMemoryMutable(x)), incx, 1014 CUDAComplex(CUDAMemoryMutable(y)), incy); 1015 } 1016 1017 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count, 1018 DeviceMemory<std::complex<double>> *x, int incx, 1019 DeviceMemory<std::complex<double>> *y, int incy) { 1020 return DoBlasInternal(wrap::cublasZswap, stream, 1021 true /* = pointer_mode_host */, elem_count, 1022 CUDAComplex(CUDAMemoryMutable(x)), incx, 1023 CUDAComplex(CUDAMemoryMutable(y)), incy); 1024 } 1025 1026 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count, 1027 const DeviceMemory<float> &x, int incx, 1028 DeviceMemory<int> *result) { 1029 return DoBlasInternal(wrap::cublasIsamax, stream, 1030 false /* = pointer_mode_host */, elem_count, 1031 CUDAMemory(x), incx, CUDAMemoryMutable(result)); 1032 } 1033 1034 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count, 1035 const DeviceMemory<double> &x, int incx, 1036 DeviceMemory<int> *result) { 1037 return DoBlasInternal(wrap::cublasIdamax, stream, 1038 false /* = pointer_mode_host */, elem_count, 1039 CUDAMemory(x), incx, CUDAMemoryMutable(result)); 1040 } 1041 1042 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count, 1043 const DeviceMemory<std::complex<float>> &x, int incx, 1044 DeviceMemory<int> *result) { 1045 return DoBlasInternal( 1046 wrap::cublasIcamax, stream, false /* = pointer_mode_host */, elem_count, 1047 CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result)); 1048 } 1049 1050 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count, 1051 const DeviceMemory<std::complex<double>> &x, 1052 int incx, DeviceMemory<int> *result) { 1053 return DoBlasInternal( 1054 wrap::cublasIzamax, stream, false /* = pointer_mode_host */, elem_count, 1055 CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result)); 1056 } 1057 1058 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count, 1059 const DeviceMemory<float> &x, int incx, 1060 DeviceMemory<int> *result) { 1061 return DoBlasInternal( 1062 wrap::cublasIsamin, stream, false /* = pointer_mode_host */, elem_count, 1063 CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result)); 1064 } 1065 1066 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count, 1067 const DeviceMemory<double> &x, int incx, 1068 DeviceMemory<int> *result) { 1069 return DoBlasInternal( 1070 wrap::cublasIdamin, stream, false /* = pointer_mode_host */, elem_count, 1071 CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result)); 1072 } 1073 1074 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count, 1075 const DeviceMemory<std::complex<float>> &x, int incx, 1076 DeviceMemory<int> *result) { 1077 return DoBlasInternal( 1078 wrap::cublasIcamin, stream, false /* = pointer_mode_host */, elem_count, 1079 CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result)); 1080 } 1081 1082 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count, 1083 const DeviceMemory<std::complex<double>> &x, 1084 int incx, DeviceMemory<int> *result) { 1085 return DoBlasInternal( 1086 wrap::cublasIzamin, stream, false /* = pointer_mode_host */, elem_count, 1087 CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result)); 1088 } 1089 1090 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 1091 uint64 n, uint64 kl, uint64 ku, float alpha, 1092 const DeviceMemory<float> &a, int lda, 1093 const DeviceMemory<float> &x, int incx, float beta, 1094 DeviceMemory<float> *y, int incy) { 1095 return DoBlasInternal( 1096 wrap::cublasSgbmv, stream, true /* = pointer_mode_host */, 1097 CUDABlasTranspose(trans), m, n, kl, ku, &alpha, CUDAMemory(a), lda, 1098 CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy); 1099 } 1100 1101 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 1102 uint64 n, uint64 kl, uint64 ku, double alpha, 1103 const DeviceMemory<double> &a, int lda, 1104 const DeviceMemory<double> &x, int incx, double beta, 1105 DeviceMemory<double> *y, int incy) { 1106 return DoBlasInternal( 1107 wrap::cublasDgbmv, stream, true /* = pointer_mode_host */, 1108 CUDABlasTranspose(trans), m, n, kl, ku, &alpha, CUDAMemory(a), lda, 1109 CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy); 1110 } 1111 1112 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 1113 uint64 n, uint64 kl, uint64 ku, 1114 std::complex<float> alpha, 1115 const DeviceMemory<std::complex<float>> &a, int lda, 1116 const DeviceMemory<std::complex<float>> &x, int incx, 1117 std::complex<float> beta, 1118 DeviceMemory<std::complex<float>> *y, int incy) { 1119 return DoBlasInternal( 1120 wrap::cublasCgbmv, stream, true /* = pointer_mode_host */, 1121 CUDABlasTranspose(trans), m, n, kl, ku, CUDAComplex(&alpha), 1122 CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx, 1123 CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy); 1124 } 1125 1126 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 1127 uint64 n, uint64 kl, uint64 ku, 1128 std::complex<double> alpha, 1129 const DeviceMemory<std::complex<double>> &a, int lda, 1130 const DeviceMemory<std::complex<double>> &x, int incx, 1131 std::complex<double> beta, 1132 DeviceMemory<std::complex<double>> *y, int incy) { 1133 return DoBlasInternal( 1134 wrap::cublasZgbmv, stream, true /* = pointer_mode_host */, 1135 CUDABlasTranspose(trans), m, n, kl, ku, CUDAComplex(&alpha), 1136 CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx, 1137 CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy); 1138 } 1139 1140 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 1141 uint64 n, float alpha, const DeviceMemory<float> &a, 1142 int lda, const DeviceMemory<float> &x, int incx, 1143 float beta, DeviceMemory<float> *y, int incy) { 1144 return DoBlasInternal( 1145 wrap::cublasSgemv, stream, true /* = pointer_mode_host */, 1146 CUDABlasTranspose(trans), m, n, &alpha, CUDAMemory(a), lda, CUDAMemory(x), 1147 incx, &beta, CUDAMemoryMutable(y), incy); 1148 } 1149 1150 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 1151 uint64 n, double alpha, const DeviceMemory<double> &a, 1152 int lda, const DeviceMemory<double> &x, int incx, 1153 double beta, DeviceMemory<double> *y, int incy) { 1154 return DoBlasInternal( 1155 wrap::cublasDgemv, stream, true /* = pointer_mode_host */, 1156 CUDABlasTranspose(trans), m, n, &alpha, CUDAMemory(a), lda, CUDAMemory(x), 1157 incx, &beta, CUDAMemoryMutable(y), incy); 1158 } 1159 1160 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 1161 uint64 n, std::complex<float> alpha, 1162 const DeviceMemory<std::complex<float>> &a, int lda, 1163 const DeviceMemory<std::complex<float>> &x, int incx, 1164 std::complex<float> beta, 1165 DeviceMemory<std::complex<float>> *y, int incy) { 1166 return DoBlasInternal( 1167 wrap::cublasCgemv, stream, true /* = pointer_mode_host */, 1168 CUDABlasTranspose(trans), m, n, CUDAComplex(&alpha), 1169 CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx, 1170 CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy); 1171 } 1172 1173 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 1174 uint64 n, std::complex<double> alpha, 1175 const DeviceMemory<std::complex<double>> &a, int lda, 1176 const DeviceMemory<std::complex<double>> &x, int incx, 1177 std::complex<double> beta, 1178 DeviceMemory<std::complex<double>> *y, int incy) { 1179 return DoBlasInternal( 1180 wrap::cublasZgemv, stream, true /* = pointer_mode_host */, 1181 CUDABlasTranspose(trans), m, n, CUDAComplex(&alpha), 1182 CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx, 1183 CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy); 1184 } 1185 1186 bool CUDABlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, 1187 const DeviceMemory<float> &x, int incx, 1188 const DeviceMemory<float> &y, int incy, 1189 DeviceMemory<float> *a, int lda) { 1190 return DoBlasInternal( 1191 wrap::cublasSger, stream, true /* = pointer_mode_host */, m, n, &alpha, 1192 CUDAMemory(x), incx, CUDAMemory(y), incy, CUDAMemoryMutable(a), lda); 1193 } 1194 1195 bool CUDABlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha, 1196 const DeviceMemory<double> &x, int incx, 1197 const DeviceMemory<double> &y, int incy, 1198 DeviceMemory<double> *a, int lda) { 1199 return DoBlasInternal( 1200 wrap::cublasDger, stream, true /* = pointer_mode_host */, m, n, &alpha, 1201 CUDAMemory(x), incx, CUDAMemory(y), incy, CUDAMemoryMutable(a), lda); 1202 } 1203 1204 bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n, 1205 std::complex<float> alpha, 1206 const DeviceMemory<std::complex<float>> &x, int incx, 1207 const DeviceMemory<std::complex<float>> &y, int incy, 1208 DeviceMemory<std::complex<float>> *a, int lda) { 1209 return DoBlasInternal( 1210 wrap::cublasCgerc, stream, true /* = pointer_mode_host */, m, n, 1211 CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx, 1212 CUDAComplex(CUDAMemory(y)), incy, CUDAComplex(CUDAMemoryMutable(a)), lda); 1213 } 1214 1215 bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n, 1216 std::complex<double> alpha, 1217 const DeviceMemory<std::complex<double>> &x, int incx, 1218 const DeviceMemory<std::complex<double>> &y, int incy, 1219 DeviceMemory<std::complex<double>> *a, int lda) { 1220 return DoBlasInternal( 1221 wrap::cublasZgerc, stream, true /* = pointer_mode_host */, m, n, 1222 CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx, 1223 CUDAComplex(CUDAMemory(y)), incy, CUDAComplex(CUDAMemoryMutable(a)), lda); 1224 } 1225 1226 bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n, 1227 std::complex<float> alpha, 1228 const DeviceMemory<std::complex<float>> &x, int incx, 1229 const DeviceMemory<std::complex<float>> &y, int incy, 1230 DeviceMemory<std::complex<float>> *a, int lda) { 1231 return DoBlasInternal( 1232 wrap::cublasCgeru, stream, true /* = pointer_mode_host */, m, n, 1233 CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx, 1234 CUDAComplex(CUDAMemory(y)), incy, CUDAComplex(CUDAMemoryMutable(a)), lda); 1235 } 1236 1237 bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n, 1238 std::complex<double> alpha, 1239 const DeviceMemory<std::complex<double>> &x, int incx, 1240 const DeviceMemory<std::complex<double>> &y, int incy, 1241 DeviceMemory<std::complex<double>> *a, int lda) { 1242 return DoBlasInternal( 1243 wrap::cublasZgeru, stream, true /* = pointer_mode_host */, m, n, 1244 CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx, 1245 CUDAComplex(CUDAMemory(y)), incy, CUDAComplex(CUDAMemoryMutable(a)), lda); 1246 } 1247 1248 bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 1249 uint64 k, std::complex<float> alpha, 1250 const DeviceMemory<std::complex<float>> &a, int lda, 1251 const DeviceMemory<std::complex<float>> &x, int incx, 1252 std::complex<float> beta, 1253 DeviceMemory<std::complex<float>> *y, int incy) { 1254 return DoBlasInternal( 1255 wrap::cublasChbmv, stream, true /* = pointer_mode_host */, 1256 CUDABlasUpperLower(uplo), n, k, CUDAComplex(&alpha), 1257 CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx, 1258 CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy); 1259 } 1260 1261 bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 1262 uint64 k, std::complex<double> alpha, 1263 const DeviceMemory<std::complex<double>> &a, int lda, 1264 const DeviceMemory<std::complex<double>> &x, int incx, 1265 std::complex<double> beta, 1266 DeviceMemory<std::complex<double>> *y, int incy) { 1267 return DoBlasInternal( 1268 wrap::cublasZhbmv, stream, true /* = pointer_mode_host */, 1269 CUDABlasUpperLower(uplo), n, k, CUDAComplex(&alpha), 1270 CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx, 1271 CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy); 1272 } 1273 1274 bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, 1275 std::complex<float> alpha, 1276 const DeviceMemory<std::complex<float>> &a, int lda, 1277 const DeviceMemory<std::complex<float>> &x, int incx, 1278 std::complex<float> beta, 1279 DeviceMemory<std::complex<float>> *y, int incy) { 1280 return DoBlasInternal( 1281 wrap::cublasChemv, stream, true /* = pointer_mode_host */, 1282 CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha), 1283 CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx, 1284 CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy); 1285 } 1286 1287 bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, 1288 std::complex<double> alpha, 1289 const DeviceMemory<std::complex<double>> &a, int lda, 1290 const DeviceMemory<std::complex<double>> &x, int incx, 1291 std::complex<double> beta, 1292 DeviceMemory<std::complex<double>> *y, int incy) { 1293 return DoBlasInternal( 1294 wrap::cublasZhemv, stream, true /* = pointer_mode_host */, 1295 CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha), 1296 CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx, 1297 CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy); 1298 } 1299 1300 bool CUDABlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, 1301 float alpha, 1302 const DeviceMemory<std::complex<float>> &x, int incx, 1303 DeviceMemory<std::complex<float>> *a, int lda) { 1304 return DoBlasInternal( 1305 wrap::cublasCher, stream, true /* = pointer_mode_host */, 1306 CUDABlasUpperLower(uplo), n, &alpha, CUDAComplex(CUDAMemory(x)), incx, 1307 CUDAComplex(CUDAMemoryMutable(a)), lda); 1308 } 1309 1310 bool CUDABlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, 1311 double alpha, 1312 const DeviceMemory<std::complex<double>> &x, int incx, 1313 DeviceMemory<std::complex<double>> *a, int lda) { 1314 return DoBlasInternal( 1315 wrap::cublasZher, stream, true /* = pointer_mode_host */, 1316 CUDABlasUpperLower(uplo), n, &alpha, CUDAComplex(CUDAMemory(x)), incx, 1317 CUDAComplex(CUDAMemoryMutable(a)), lda); 1318 } 1319 1320 bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, 1321 std::complex<float> alpha, 1322 const DeviceMemory<std::complex<float>> &x, int incx, 1323 const DeviceMemory<std::complex<float>> &y, int incy, 1324 DeviceMemory<std::complex<float>> *a, int lda) { 1325 return DoBlasInternal( 1326 wrap::cublasCher2, stream, true /* = pointer_mode_host */, 1327 CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha), 1328 CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy, 1329 CUDAComplex(CUDAMemoryMutable(a)), lda); 1330 } 1331 1332 bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, 1333 std::complex<double> alpha, 1334 const DeviceMemory<std::complex<double>> &x, int incx, 1335 const DeviceMemory<std::complex<double>> &y, int incy, 1336 DeviceMemory<std::complex<double>> *a, int lda) { 1337 return DoBlasInternal( 1338 wrap::cublasZher2, stream, true /* = pointer_mode_host */, 1339 CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha), 1340 CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy, 1341 CUDAComplex(CUDAMemoryMutable(a)), lda); 1342 } 1343 1344 bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 1345 std::complex<float> alpha, 1346 const DeviceMemory<std::complex<float>> &ap, 1347 const DeviceMemory<std::complex<float>> &x, int incx, 1348 std::complex<float> beta, 1349 DeviceMemory<std::complex<float>> *y, int incy) { 1350 return DoBlasInternal( 1351 wrap::cublasChpmv, stream, true /* = pointer_mode_host */, 1352 CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha), 1353 CUDAComplex(CUDAMemory(ap)), CUDAComplex(CUDAMemory(x)), incx, 1354 CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy); 1355 } 1356 1357 bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 1358 std::complex<double> alpha, 1359 const DeviceMemory<std::complex<double>> &ap, 1360 const DeviceMemory<std::complex<double>> &x, int incx, 1361 std::complex<double> beta, 1362 DeviceMemory<std::complex<double>> *y, int incy) { 1363 return DoBlasInternal( 1364 wrap::cublasZhpmv, stream, true /* = pointer_mode_host */, 1365 CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha), 1366 CUDAComplex(CUDAMemory(ap)), CUDAComplex(CUDAMemory(x)), incx, 1367 CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy); 1368 } 1369 1370 bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, 1371 float alpha, 1372 const DeviceMemory<std::complex<float>> &x, int incx, 1373 DeviceMemory<std::complex<float>> *ap) { 1374 return DoBlasInternal( 1375 wrap::cublasChpr, stream, true /* = pointer_mode_host */, 1376 CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha), 1377 CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemoryMutable(ap))); 1378 } 1379 1380 bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, 1381 double alpha, 1382 const DeviceMemory<std::complex<double>> &x, int incx, 1383 DeviceMemory<std::complex<double>> *ap) { 1384 return DoBlasInternal( 1385 wrap::cublasZhpr, stream, true /* = pointer_mode_host */, 1386 CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha), 1387 CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemoryMutable(ap))); 1388 } 1389 1390 bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 1391 std::complex<float> alpha, 1392 const DeviceMemory<std::complex<float>> &x, int incx, 1393 const DeviceMemory<std::complex<float>> &y, int incy, 1394 DeviceMemory<std::complex<float>> *ap) { 1395 return DoBlasInternal( 1396 wrap::cublasChpr2, stream, true /* = pointer_mode_host */, 1397 CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha), 1398 CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy, 1399 CUDAComplex(CUDAMemoryMutable(ap))); 1400 } 1401 1402 bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 1403 std::complex<double> alpha, 1404 const DeviceMemory<std::complex<double>> &x, int incx, 1405 const DeviceMemory<std::complex<double>> &y, int incy, 1406 DeviceMemory<std::complex<double>> *ap) { 1407 return DoBlasInternal( 1408 wrap::cublasZhpr2, stream, true /* = pointer_mode_host */, 1409 CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha), 1410 CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy, 1411 CUDAComplex(CUDAMemoryMutable(ap))); 1412 } 1413 1414 bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 1415 uint64 k, float alpha, const DeviceMemory<float> &a, 1416 int lda, const DeviceMemory<float> &x, int incx, 1417 float beta, DeviceMemory<float> *y, int incy) { 1418 return DoBlasInternal( 1419 wrap::cublasSsbmv, stream, true /* = pointer_mode_host */, 1420 CUDABlasUpperLower(uplo), n, k, &alpha, CUDAMemory(a), lda, CUDAMemory(x), 1421 incx, &beta, CUDAMemoryMutable(y), incy); 1422 } 1423 1424 bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 1425 uint64 k, double alpha, const DeviceMemory<double> &a, 1426 int lda, const DeviceMemory<double> &x, int incx, 1427 double beta, DeviceMemory<double> *y, int incy) { 1428 return DoBlasInternal( 1429 wrap::cublasDsbmv, stream, true /* = pointer_mode_host */, 1430 CUDABlasUpperLower(uplo), n, k, &alpha, CUDAMemory(a), lda, CUDAMemory(x), 1431 incx, &beta, CUDAMemoryMutable(y), incy); 1432 } 1433 1434 bool CUDABlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 1435 float alpha, const DeviceMemory<float> &ap, 1436 const DeviceMemory<float> &x, int incx, float beta, 1437 DeviceMemory<float> *y, int incy) { 1438 return DoBlasInternal(wrap::cublasSspmv, stream, 1439 true /* = pointer_mode_host */, 1440 CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(ap), 1441 CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy); 1442 } 1443 1444 bool CUDABlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 1445 double alpha, const DeviceMemory<double> &ap, 1446 const DeviceMemory<double> &x, int incx, double beta, 1447 DeviceMemory<double> *y, int incy) { 1448 return DoBlasInternal(wrap::cublasDspmv, stream, 1449 true /* = pointer_mode_host */, 1450 CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(ap), 1451 CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy); 1452 } 1453 1454 bool CUDABlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, 1455 float alpha, const DeviceMemory<float> &x, int incx, 1456 DeviceMemory<float> *ap) { 1457 return DoBlasInternal(wrap::cublasSspr, stream, 1458 true /* = pointer_mode_host */, 1459 CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x), 1460 incx, CUDAMemoryMutable(ap)); 1461 } 1462 1463 bool CUDABlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, 1464 double alpha, const DeviceMemory<double> &x, int incx, 1465 DeviceMemory<double> *ap) { 1466 return DoBlasInternal(wrap::cublasDspr, stream, 1467 true /* = pointer_mode_host */, 1468 CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x), 1469 incx, CUDAMemoryMutable(ap)); 1470 } 1471 1472 bool CUDABlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 1473 float alpha, const DeviceMemory<float> &x, int incx, 1474 const DeviceMemory<float> &y, int incy, 1475 DeviceMemory<float> *ap) { 1476 return DoBlasInternal(wrap::cublasSspr2, stream, 1477 true /* = pointer_mode_host */, 1478 CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x), 1479 incx, CUDAMemory(y), incy, CUDAMemoryMutable(ap)); 1480 } 1481 1482 bool CUDABlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 1483 double alpha, const DeviceMemory<double> &x, int incx, 1484 const DeviceMemory<double> &y, int incy, 1485 DeviceMemory<double> *ap) { 1486 return DoBlasInternal(wrap::cublasDspr2, stream, 1487 true /* = pointer_mode_host */, 1488 CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x), 1489 incx, CUDAMemory(y), incy, CUDAMemoryMutable(ap)); 1490 } 1491 1492 bool CUDABlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, 1493 float alpha, const DeviceMemory<float> &a, int lda, 1494 const DeviceMemory<float> &x, int incx, float beta, 1495 DeviceMemory<float> *y, int incy) { 1496 return DoBlasInternal(wrap::cublasSsymv, stream, 1497 true /* = pointer_mode_host */, 1498 CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(a), lda, 1499 CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy); 1500 } 1501 1502 bool CUDABlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, 1503 double alpha, const DeviceMemory<double> &a, int lda, 1504 const DeviceMemory<double> &x, int incx, double beta, 1505 DeviceMemory<double> *y, int incy) { 1506 return DoBlasInternal(wrap::cublasDsymv, stream, 1507 true /* = pointer_mode_host */, 1508 CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(a), lda, 1509 CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy); 1510 } 1511 1512 bool CUDABlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, 1513 float alpha, const DeviceMemory<float> &x, int incx, 1514 DeviceMemory<float> *a, int lda) { 1515 return DoBlasInternal(wrap::cublasSsyr, stream, 1516 true /* = pointer_mode_host */, 1517 CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x), 1518 incx, CUDAMemoryMutable(a), lda); 1519 } 1520 1521 bool CUDABlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, 1522 double alpha, const DeviceMemory<double> &x, int incx, 1523 DeviceMemory<double> *a, int lda) { 1524 return DoBlasInternal(wrap::cublasDsyr, stream, 1525 true /* = pointer_mode_host */, 1526 CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x), 1527 incx, CUDAMemoryMutable(a), lda); 1528 } 1529 1530 bool CUDABlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, 1531 float alpha, const DeviceMemory<float> &x, int incx, 1532 const DeviceMemory<float> &y, int incy, 1533 DeviceMemory<float> *a, int lda) { 1534 return DoBlasInternal(wrap::cublasSsyr2, stream, 1535 true /* = pointer_mode_host */, 1536 CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x), 1537 incx, CUDAMemory(y), incy, CUDAMemoryMutable(a), lda); 1538 } 1539 1540 bool CUDABlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, 1541 double alpha, const DeviceMemory<double> &x, int incx, 1542 const DeviceMemory<double> &y, int incy, 1543 DeviceMemory<double> *a, int lda) { 1544 return DoBlasInternal(wrap::cublasDsyr2, stream, 1545 true /* = pointer_mode_host */, 1546 CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x), 1547 incx, CUDAMemory(y), incy, CUDAMemoryMutable(a), lda); 1548 } 1549 1550 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 1551 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1552 uint64 k, const DeviceMemory<float> &a, int lda, 1553 DeviceMemory<float> *x, int incx) { 1554 return DoBlasInternal(wrap::cublasStbmv, stream, 1555 true /* = pointer_mode_host */, 1556 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1557 CUDABlasDiagonal(diag), n, k, CUDAMemory(a), lda, 1558 CUDAMemoryMutable(x), incx); 1559 } 1560 1561 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 1562 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1563 uint64 k, const DeviceMemory<double> &a, int lda, 1564 DeviceMemory<double> *x, int incx) { 1565 return DoBlasInternal(wrap::cublasDtbmv, stream, 1566 true /* = pointer_mode_host */, 1567 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1568 CUDABlasDiagonal(diag), n, k, CUDAMemory(a), lda, 1569 CUDAMemoryMutable(x), incx); 1570 } 1571 1572 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 1573 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1574 uint64 k, const DeviceMemory<std::complex<float>> &a, 1575 int lda, DeviceMemory<std::complex<float>> *x, 1576 int incx) { 1577 return DoBlasInternal( 1578 wrap::cublasCtbmv, stream, true /* = pointer_mode_host */, 1579 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1580 CUDABlasDiagonal(diag), n, k, CUDAComplex(CUDAMemory(a)), lda, 1581 CUDAComplex(CUDAMemoryMutable(x)), incx); 1582 } 1583 1584 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 1585 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1586 uint64 k, const DeviceMemory<std::complex<double>> &a, 1587 int lda, DeviceMemory<std::complex<double>> *x, 1588 int incx) { 1589 return DoBlasInternal( 1590 wrap::cublasZtbmv, stream, true /* = pointer_mode_host */, 1591 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1592 CUDABlasDiagonal(diag), n, k, CUDAComplex(CUDAMemory(a)), lda, 1593 CUDAComplex(CUDAMemoryMutable(x)), incx); 1594 } 1595 1596 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 1597 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1598 uint64 k, const DeviceMemory<float> &a, int lda, 1599 DeviceMemory<float> *x, int incx) { 1600 return DoBlasInternal(wrap::cublasStbsv, stream, 1601 true /* = pointer_mode_host */, 1602 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1603 CUDABlasDiagonal(diag), n, k, CUDAMemory(a), lda, 1604 CUDAMemoryMutable(x), incx); 1605 } 1606 1607 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 1608 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1609 uint64 k, const DeviceMemory<double> &a, int lda, 1610 DeviceMemory<double> *x, int incx) { 1611 return DoBlasInternal(wrap::cublasDtbsv, stream, 1612 true /* = pointer_mode_host */, 1613 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1614 CUDABlasDiagonal(diag), n, k, CUDAMemory(a), lda, 1615 CUDAMemoryMutable(x), incx); 1616 } 1617 1618 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 1619 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1620 uint64 k, const DeviceMemory<std::complex<float>> &a, 1621 int lda, DeviceMemory<std::complex<float>> *x, 1622 int incx) { 1623 return DoBlasInternal( 1624 wrap::cublasCtbsv, stream, true /* = pointer_mode_host */, 1625 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1626 CUDABlasDiagonal(diag), n, k, CUDAComplex(CUDAMemory(a)), lda, 1627 CUDAComplex(CUDAMemoryMutable(x)), incx); 1628 } 1629 1630 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 1631 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1632 uint64 k, const DeviceMemory<std::complex<double>> &a, 1633 int lda, DeviceMemory<std::complex<double>> *x, 1634 int incx) { 1635 return DoBlasInternal( 1636 wrap::cublasZtbsv, stream, true /* = pointer_mode_host */, 1637 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1638 CUDABlasDiagonal(diag), n, k, CUDAComplex(CUDAMemory(a)), lda, 1639 CUDAComplex(CUDAMemoryMutable(x)), incx); 1640 } 1641 1642 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 1643 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1644 const DeviceMemory<float> &ap, DeviceMemory<float> *x, 1645 int incx) { 1646 return DoBlasInternal( 1647 wrap::cublasStpmv, stream, true /* = pointer_mode_host */, 1648 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1649 CUDABlasDiagonal(diag), n, CUDAMemory(ap), CUDAMemoryMutable(x), incx); 1650 } 1651 1652 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 1653 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1654 const DeviceMemory<double> &ap, 1655 DeviceMemory<double> *x, int incx) { 1656 return DoBlasInternal( 1657 wrap::cublasDtpmv, stream, true /* = pointer_mode_host */, 1658 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1659 CUDABlasDiagonal(diag), n, CUDAMemory(ap), CUDAMemoryMutable(x), incx); 1660 } 1661 1662 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 1663 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1664 const DeviceMemory<std::complex<float>> &ap, 1665 DeviceMemory<std::complex<float>> *x, int incx) { 1666 return DoBlasInternal(wrap::cublasCtpmv, stream, 1667 true /* = pointer_mode_host */, 1668 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1669 CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(ap)), 1670 CUDAComplex(CUDAMemoryMutable(x)), incx); 1671 } 1672 1673 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 1674 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1675 const DeviceMemory<std::complex<double>> &ap, 1676 DeviceMemory<std::complex<double>> *x, int incx) { 1677 return DoBlasInternal(wrap::cublasZtpmv, stream, 1678 true /* = pointer_mode_host */, 1679 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1680 CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(ap)), 1681 CUDAComplex(CUDAMemoryMutable(x)), incx); 1682 } 1683 1684 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 1685 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1686 const DeviceMemory<float> &ap, DeviceMemory<float> *x, 1687 int incx) { 1688 return DoBlasInternal( 1689 wrap::cublasStpsv, stream, true /* = pointer_mode_host */, 1690 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1691 CUDABlasDiagonal(diag), n, CUDAMemory(ap), CUDAMemoryMutable(x), incx); 1692 } 1693 1694 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 1695 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1696 const DeviceMemory<double> &ap, 1697 DeviceMemory<double> *x, int incx) { 1698 return DoBlasInternal( 1699 wrap::cublasDtpsv, stream, true /* = pointer_mode_host */, 1700 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1701 CUDABlasDiagonal(diag), n, CUDAMemory(ap), CUDAMemoryMutable(x), incx); 1702 } 1703 1704 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 1705 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1706 const DeviceMemory<std::complex<float>> &ap, 1707 DeviceMemory<std::complex<float>> *x, int incx) { 1708 return DoBlasInternal(wrap::cublasCtpsv, stream, 1709 true /* = pointer_mode_host */, 1710 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1711 CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(ap)), 1712 CUDAComplex(CUDAMemoryMutable(x)), incx); 1713 } 1714 1715 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 1716 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1717 const DeviceMemory<std::complex<double>> &ap, 1718 DeviceMemory<std::complex<double>> *x, int incx) { 1719 return DoBlasInternal(wrap::cublasZtpsv, stream, 1720 true /* = pointer_mode_host */, 1721 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1722 CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(ap)), 1723 CUDAComplex(CUDAMemoryMutable(x)), incx); 1724 } 1725 1726 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 1727 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1728 const DeviceMemory<float> &a, int lda, 1729 DeviceMemory<float> *x, int incx) { 1730 return DoBlasInternal(wrap::cublasStrmv, stream, 1731 true /* = pointer_mode_host */, 1732 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1733 CUDABlasDiagonal(diag), n, CUDAMemory(a), lda, 1734 CUDAMemoryMutable(x), incx); 1735 } 1736 1737 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 1738 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1739 const DeviceMemory<double> &a, int lda, 1740 DeviceMemory<double> *x, int incx) { 1741 return DoBlasInternal(wrap::cublasDtrmv, stream, 1742 true /* = pointer_mode_host */, 1743 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1744 CUDABlasDiagonal(diag), n, CUDAMemory(a), lda, 1745 CUDAMemoryMutable(x), incx); 1746 } 1747 1748 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 1749 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1750 const DeviceMemory<std::complex<float>> &a, int lda, 1751 DeviceMemory<std::complex<float>> *x, int incx) { 1752 return DoBlasInternal(wrap::cublasCtrmv, stream, 1753 true /* = pointer_mode_host */, 1754 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1755 CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(a)), 1756 lda, CUDAComplex(CUDAMemoryMutable(x)), incx); 1757 } 1758 1759 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 1760 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1761 const DeviceMemory<std::complex<double>> &a, int lda, 1762 DeviceMemory<std::complex<double>> *x, int incx) { 1763 return DoBlasInternal(wrap::cublasZtrmv, stream, 1764 true /* = pointer_mode_host */, 1765 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1766 CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(a)), 1767 lda, CUDAComplex(CUDAMemoryMutable(x)), incx); 1768 } 1769 1770 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 1771 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1772 const DeviceMemory<float> &a, int lda, 1773 DeviceMemory<float> *x, int incx) { 1774 return DoBlasInternal(wrap::cublasStrsv, stream, 1775 true /* = pointer_mode_host */, 1776 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1777 CUDABlasDiagonal(diag), n, CUDAMemory(a), lda, 1778 CUDAMemoryMutable(x), incx); 1779 } 1780 1781 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 1782 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1783 const DeviceMemory<double> &a, int lda, 1784 DeviceMemory<double> *x, int incx) { 1785 return DoBlasInternal(wrap::cublasDtrsv, stream, 1786 true /* = pointer_mode_host */, 1787 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1788 CUDABlasDiagonal(diag), n, CUDAMemory(a), lda, 1789 CUDAMemoryMutable(x), incx); 1790 } 1791 1792 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 1793 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1794 const DeviceMemory<std::complex<float>> &a, int lda, 1795 DeviceMemory<std::complex<float>> *x, int incx) { 1796 return DoBlasInternal(wrap::cublasCtrsv, stream, 1797 true /* = pointer_mode_host */, 1798 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1799 CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(a)), 1800 lda, CUDAComplex(CUDAMemoryMutable(x)), incx); 1801 } 1802 1803 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 1804 blas::Transpose trans, blas::Diagonal diag, uint64 n, 1805 const DeviceMemory<std::complex<double>> &a, int lda, 1806 DeviceMemory<std::complex<double>> *x, int incx) { 1807 return DoBlasInternal(wrap::cublasZtrsv, stream, 1808 true /* = pointer_mode_host */, 1809 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), 1810 CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(a)), 1811 lda, CUDAComplex(CUDAMemoryMutable(x)), incx); 1812 } 1813 1814 bool CUDABlas::DoBlasGemm( 1815 Stream *stream, blas::Transpose transa, 1816 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 1817 float alpha, const DeviceMemory<Eigen::half> &a, int lda, 1818 const DeviceMemory<Eigen::half> &b, int ldb, float beta, 1819 DeviceMemory<Eigen::half> *c, int ldc) { 1820 #if CUDA_VERSION >= 7050 1821 VLOG(1) << port::Printf( 1822 "doing cuBLAS SGEMM: at=%d bt=%d m=%llu n=%llu " 1823 "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f " 1824 "c=%p ldc=%d", 1825 static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha, 1826 a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc); 1827 if (transa == blas::Transpose::kNoTranspose) { 1828 if (lda < static_cast<int64>(m)) { 1829 LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); " 1830 "precondition violation"; 1831 } 1832 } else { 1833 if (lda < static_cast<int64>(k)) { 1834 LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k 1835 << ") (transpose case); precondition violation"; 1836 } 1837 } 1838 if (transb == blas::Transpose::kNoTranspose) { 1839 if (ldb < static_cast<int64>(k)) { 1840 LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k 1841 << ") (no transpose case); precondition violation"; 1842 } 1843 } else { 1844 if (ldb < static_cast<int64>(n)) { 1845 LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); " 1846 "precondition violation"; 1847 } 1848 } 1849 1850 bool use_tensor_ops = false; 1851 #if CUDA_VERSION >= 9000 1852 int cc_major, cc_minor; 1853 stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, 1854 &cc_minor); 1855 1856 // GPUs < sm_70 don't support Volta hardware. 1857 if (cc_major >= 7 && TensorOpMathEnabled()) { 1858 use_tensor_ops = true; 1859 } 1860 #endif 1861 1862 return DoBlasInternalImpl( 1863 wrap::cublasSgemmEx, stream, true /* = pointer_mode_host */, 1864 true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa), 1865 CUDABlasTranspose(transb), m, n, k, &alpha, CUDAMemory(a), 1866 SE_CUDA_DATA_HALF, lda, CUDAMemory(b), SE_CUDA_DATA_HALF, ldb, &beta, 1867 CUDAMemoryMutable(c), SE_CUDA_DATA_HALF, ldc); 1868 1869 #else 1870 LOG(ERROR) << "fp16 sgemm is not implemented in this cuBLAS version " 1871 << "(need at least CUDA 7.5)"; 1872 return false; 1873 #endif 1874 } 1875 1876 bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa, 1877 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 1878 float alpha, const DeviceMemory<float> &a, int lda, 1879 const DeviceMemory<float> &b, int ldb, float beta, 1880 DeviceMemory<float> *c, int ldc) { 1881 VLOG(1) << port::Printf( 1882 "doing cuBLAS SGEMM: at=%d bt=%d m=%llu n=%llu " 1883 "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f " 1884 "c=%p ldc=%d", 1885 static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha, 1886 a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc); 1887 if (transa == blas::Transpose::kNoTranspose) { 1888 if (lda < static_cast<int64>(m)) { 1889 LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); " 1890 "precondition violation"; 1891 } 1892 } else { 1893 if (lda < static_cast<int64>(k)) { 1894 LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k 1895 << ") (transpose case); precondition violation"; 1896 } 1897 } 1898 if (transb == blas::Transpose::kNoTranspose) { 1899 if (ldb < static_cast<int64>(k)) { 1900 LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k 1901 << ") (no transpose case); precondition violation"; 1902 } 1903 } else { 1904 if (ldb < static_cast<int64>(n)) { 1905 LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); " 1906 "precondition violation"; 1907 } 1908 } 1909 return DoBlasInternal( 1910 wrap::cublasSgemm, stream, true /* = pointer_mode_host */, 1911 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, 1912 CUDAMemory(a), lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc); 1913 } 1914 1915 bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa, 1916 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 1917 double alpha, const DeviceMemory<double> &a, int lda, 1918 const DeviceMemory<double> &b, int ldb, double beta, 1919 DeviceMemory<double> *c, int ldc) { 1920 return DoBlasInternal( 1921 wrap::cublasDgemm, stream, true /* = pointer_mode_host */, 1922 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, 1923 CUDAMemory(a), lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc); 1924 } 1925 1926 bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa, 1927 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 1928 std::complex<float> alpha, 1929 const DeviceMemory<std::complex<float>> &a, int lda, 1930 const DeviceMemory<std::complex<float>> &b, int ldb, 1931 std::complex<float> beta, 1932 DeviceMemory<std::complex<float>> *c, int ldc) { 1933 return DoBlasInternal( 1934 wrap::cublasCgemm, stream, true /* = pointer_mode_host */, 1935 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, 1936 CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, 1937 CUDAComplex(CUDAMemory(b)), ldb, CUDAComplex(&beta), 1938 CUDAComplex(CUDAMemoryMutable(c)), ldc); 1939 } 1940 1941 bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa, 1942 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 1943 std::complex<double> alpha, 1944 const DeviceMemory<std::complex<double>> &a, int lda, 1945 const DeviceMemory<std::complex<double>> &b, int ldb, 1946 std::complex<double> beta, 1947 DeviceMemory<std::complex<double>> *c, int ldc) { 1948 return DoBlasInternal( 1949 wrap::cublasZgemm, stream, true /* = pointer_mode_host */, 1950 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, 1951 CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, 1952 CUDAComplex(CUDAMemory(b)), ldb, CUDAComplex(&beta), 1953 CUDAComplex(CUDAMemoryMutable(c)), ldc); 1954 } 1955 1956 bool CUDABlas::DoBlasGemvWithProfiling( 1957 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, 1958 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, 1959 int incx, float beta, DeviceMemory<float> *y, int incy, 1960 blas::ProfileResult *output_profile_result) { 1961 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x, 1962 incx, beta, y, incy, 1963 output_profile_result); 1964 } 1965 1966 bool CUDABlas::DoBlasGemvWithProfiling( 1967 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, 1968 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, 1969 int incx, double beta, DeviceMemory<double> *y, int incy, 1970 blas::ProfileResult *output_profile_result) { 1971 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x, 1972 incx, beta, y, incy, 1973 output_profile_result); 1974 } 1975 1976 bool CUDABlas::DoBlasGemvWithProfiling( 1977 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, 1978 std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, 1979 int lda, const DeviceMemory<std::complex<float>> &x, int incx, 1980 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy, 1981 blas::ProfileResult *output_profile_result) { 1982 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x, 1983 incx, beta, y, incy, 1984 output_profile_result); 1985 } 1986 1987 bool CUDABlas::DoBlasGemvWithProfiling( 1988 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, 1989 std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, 1990 int lda, const DeviceMemory<std::complex<double>> &x, int incx, 1991 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy, 1992 blas::ProfileResult *output_profile_result) { 1993 return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x, 1994 incx, beta, y, incy, 1995 output_profile_result); 1996 } 1997 1998 bool CUDABlas::DoBlasGemmWithProfiling( 1999 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 2000 uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, 2001 int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta, 2002 DeviceMemory<Eigen::half> *c, int ldc, 2003 blas::ProfileResult *output_profile_result) { 2004 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a, 2005 lda, b, ldb, beta, c, ldc, 2006 output_profile_result); 2007 } 2008 2009 bool CUDABlas::DoBlasGemmWithProfiling( 2010 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 2011 uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, 2012 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, 2013 int ldc, blas::ProfileResult *output_profile_result) { 2014 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a, 2015 lda, b, ldb, beta, c, ldc, 2016 output_profile_result); 2017 } 2018 2019 bool CUDABlas::DoBlasGemmWithProfiling( 2020 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 2021 uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, 2022 const DeviceMemory<double> &b, int ldb, double beta, 2023 DeviceMemory<double> *c, int ldc, 2024 blas::ProfileResult *output_profile_result) { 2025 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a, 2026 lda, b, ldb, beta, c, ldc, 2027 output_profile_result); 2028 } 2029 2030 bool CUDABlas::DoBlasGemmWithProfiling( 2031 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 2032 uint64 n, uint64 k, std::complex<float> alpha, 2033 const DeviceMemory<std::complex<float>> &a, int lda, 2034 const DeviceMemory<std::complex<float>> &b, int ldb, 2035 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, 2036 blas::ProfileResult *output_profile_result) { 2037 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a, 2038 lda, b, ldb, beta, c, ldc, 2039 output_profile_result); 2040 } 2041 2042 bool CUDABlas::DoBlasGemmWithProfiling( 2043 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 2044 uint64 n, uint64 k, std::complex<double> alpha, 2045 const DeviceMemory<std::complex<double>> &a, int lda, 2046 const DeviceMemory<std::complex<double>> &b, int ldb, 2047 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, 2048 blas::ProfileResult *output_profile_result) { 2049 return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a, 2050 lda, b, ldb, beta, c, ldc, 2051 output_profile_result); 2052 } 2053 2054 template <typename T> 2055 bool CUDABlas::DoBlasGemvWithProfilingImpl( 2056 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, const T &alpha, 2057 const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx, 2058 const T &beta, DeviceMemory<T> *y, int incy, 2059 blas::ProfileResult *output_profile_result) { 2060 struct TimerDeleter { 2061 void operator()(CUDATimer *t) { 2062 t->Destroy(); 2063 delete t; 2064 } 2065 }; 2066 std::unique_ptr<CUDATimer, TimerDeleter> timer; 2067 if (output_profile_result != nullptr) { 2068 timer.reset(new CUDATimer(parent_)); 2069 if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { 2070 return false; 2071 } 2072 } 2073 2074 // Call blasGemm 2075 bool result = 2076 DoBlasGemv(stream, trans, m, n, alpha, a, lda, x, incx, beta, y, incy); 2077 2078 if (timer != nullptr && result) { 2079 // CUDATimer will CHECK-fail if we Stop() it while the stream is in an error 2080 // state. 2081 if (!timer->Stop(AsCUDAStream(stream))) { 2082 return false; 2083 } 2084 output_profile_result->set_is_valid(true); 2085 output_profile_result->set_algorithm(blas::kDefaultBlasGemv); 2086 output_profile_result->set_elapsed_time_in_ms( 2087 timer->GetElapsedMilliseconds()); 2088 } 2089 return result; 2090 } 2091 2092 template <typename T, typename ParamType> 2093 bool CUDABlas::DoBlasGemmWithProfilingImpl( 2094 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 2095 uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a, 2096 int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta, 2097 DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) { 2098 struct TimerDeleter { 2099 void operator()(CUDATimer *t) { 2100 t->Destroy(); 2101 delete t; 2102 } 2103 }; 2104 std::unique_ptr<CUDATimer, TimerDeleter> timer; 2105 if (output_profile_result != nullptr) { 2106 timer.reset(new CUDATimer(parent_)); 2107 if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { 2108 return false; 2109 } 2110 } 2111 2112 // Call blasGemm 2113 bool result = DoBlasGemm(stream, transa, transb, m, n, k, alpha, a, lda, b, 2114 ldb, beta, c, ldc); 2115 2116 if (timer != nullptr && result) { 2117 // CUDATimer will CHECK-fail if we Stop() it while the stream is in an error 2118 // state. 2119 if (!timer->Stop(AsCUDAStream(stream))) { 2120 return false; 2121 } 2122 output_profile_result->set_is_valid(true); 2123 output_profile_result->set_algorithm(blas::kDefaultBlasGemm); 2124 output_profile_result->set_elapsed_time_in_ms( 2125 timer->GetElapsedMilliseconds()); 2126 } 2127 return result; 2128 } 2129 2130 static bool UsesTensorOps(blas::AlgorithmType algo) { 2131 #if CUDA_VERSION >= 9000 2132 cublasGemmAlgo_t cublas_algo = static_cast<cublasGemmAlgo_t>(algo); 2133 return cublas_algo >= CUBLAS_GEMM_DEFAULT_TENSOR_OP; 2134 #else 2135 return false; 2136 #endif 2137 } 2138 2139 template <typename InType> 2140 static bool TensorOpsAvailable(int cc_major) { 2141 #if CUDA_VERSION >= 9000 2142 if (cc_major >= 7 && TensorOpMathEnabled() && 2143 std::is_same<InType, Eigen::half>::value) { 2144 return true; 2145 } 2146 #endif 2147 return false; 2148 } 2149 2150 template <typename InT, typename OutT, typename CompT> 2151 bool CUDABlas::DoBlasGemmWithAlgorithmImpl( 2152 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 2153 uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, int lda, 2154 const DeviceMemory<InT> &b, int ldb, const CompT &beta, 2155 DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type, 2156 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { 2157 // CUDA < version 8 and GPUs < sm_50 don't support cublasGemmEx. 2158 #if CUDA_VERSION < 8000 2159 return false; 2160 #else 2161 int cc_major, cc_minor; 2162 if (stream->parent()->GetDeviceDescription().cuda_compute_capability( 2163 &cc_major, &cc_minor) && 2164 cc_major < 5) { 2165 return false; 2166 } 2167 2168 if (UsesTensorOps(algorithm) && !TensorOpsAvailable<InT>(cc_major)) { 2169 return false; 2170 } 2171 2172 struct TimerDeleter { 2173 void operator()(CUDATimer *t) { 2174 t->Destroy(); 2175 delete t; 2176 } 2177 }; 2178 std::unique_ptr<CUDATimer, TimerDeleter> timer; 2179 if (output_profile_result != nullptr) { 2180 timer.reset(new CUDATimer(parent_)); 2181 if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) { 2182 return false; 2183 } 2184 } 2185 2186 cudaDataType_t cuda_in_type = CUDADataType<InT>::type; 2187 // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast, 2188 // we do the following compile-time check on the default value: 2189 static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, ""); 2190 bool result = DoBlasInternalFailureOK( 2191 wrap::cublasGemmEx, stream, /* pointer_mode_host = */ true, 2192 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, 2193 CUDAMemory(a), cuda_in_type, lda, CUDAMemory(b), cuda_in_type, ldb, &beta, 2194 CUDAMemoryMutable(c), CUDADataType<OutT>::type, ldc, 2195 CUDAComputationType(computation_type), 2196 static_cast<cublasGemmAlgo_t>(algorithm)); 2197 2198 if (timer != nullptr && result) { 2199 // CUDATimer will CHECK-fail if we Stop() it while the stream is in an error 2200 // state. 2201 if (!timer->Stop(AsCUDAStream(stream))) { 2202 return false; 2203 } 2204 output_profile_result->set_is_valid(true); 2205 output_profile_result->set_algorithm(algorithm); 2206 output_profile_result->set_elapsed_time_in_ms( 2207 timer->GetElapsedMilliseconds()); 2208 } 2209 return result; 2210 #endif 2211 } 2212 2213 bool CUDABlas::GetBlasGemmAlgorithms( 2214 std::vector<blas::AlgorithmType> *out_algorithms) { 2215 // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx) 2216 // were first introduced in CUDA 8. 2217 // Note that when CUDA version and compute capability is not sufficient, we 2218 // still return the out_algorithms. Caller needs to make sure that in this case, 2219 // the returned vector is empty. 2220 #if CUDA_VERSION >= 8000 2221 for (cublasGemmAlgo_t algo : { 2222 CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1, 2223 CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4, 2224 CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7, 2225 #if CUDA_VERSION >= 9000 2226 CUBLAS_GEMM_ALGO8, CUBLAS_GEMM_ALGO9, CUBLAS_GEMM_ALGO10, 2227 CUBLAS_GEMM_ALGO11, CUBLAS_GEMM_ALGO12, CUBLAS_GEMM_ALGO13, 2228 CUBLAS_GEMM_ALGO14, CUBLAS_GEMM_ALGO15, CUBLAS_GEMM_ALGO16, 2229 CUBLAS_GEMM_ALGO17, CUBLAS_GEMM_DFALT_TENSOR_OP, 2230 CUBLAS_GEMM_ALGO0_TENSOR_OP, CUBLAS_GEMM_ALGO1_TENSOR_OP, 2231 CUBLAS_GEMM_ALGO2_TENSOR_OP 2232 #endif 2233 }) { 2234 out_algorithms->push_back(algo); 2235 } 2236 #endif 2237 return true; 2238 } 2239 2240 bool CUDABlas::DoBlasGemmWithAlgorithm( 2241 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 2242 uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, int lda, 2243 const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c, 2244 int ldc, blas::ComputationType computation_type, 2245 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { 2246 return DoBlasGemmWithAlgorithmImpl( 2247 stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, 2248 computation_type, algorithm, output_profile_result); 2249 } 2250 2251 bool CUDABlas::DoBlasGemmWithAlgorithm( 2252 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 2253 uint64 n, uint64 k, const Eigen::half &alpha, 2254 const DeviceMemory<Eigen::half> &a, int lda, 2255 const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta, 2256 DeviceMemory<Eigen::half> *c, int ldc, 2257 blas::ComputationType computation_type, blas::AlgorithmType algorithm, 2258 blas::ProfileResult *output_profile_result) { 2259 return DoBlasGemmWithAlgorithmImpl( 2260 stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, 2261 computation_type, algorithm, output_profile_result); 2262 } 2263 2264 bool CUDABlas::DoBlasGemmWithAlgorithm( 2265 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 2266 uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, 2267 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, 2268 int ldc, blas::ComputationType computation_type, 2269 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { 2270 return DoBlasGemmWithAlgorithmImpl( 2271 stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, 2272 computation_type, algorithm, output_profile_result); 2273 } 2274 2275 bool CUDABlas::DoBlasGemmWithAlgorithm( 2276 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 2277 uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, 2278 const DeviceMemory<double> &b, int ldb, double beta, 2279 DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type, 2280 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { 2281 return DoBlasGemmWithAlgorithmImpl( 2282 stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, 2283 computation_type, algorithm, output_profile_result); 2284 } 2285 2286 bool CUDABlas::DoBlasGemmWithAlgorithm( 2287 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 2288 uint64 n, uint64 k, std::complex<float> alpha, 2289 const DeviceMemory<std::complex<float>> &a, int lda, 2290 const DeviceMemory<std::complex<float>> &b, int ldb, 2291 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, 2292 blas::ComputationType computation_type, blas::AlgorithmType algorithm, 2293 blas::ProfileResult *output_profile_result) { 2294 return DoBlasGemmWithAlgorithmImpl( 2295 stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, 2296 computation_type, algorithm, output_profile_result); 2297 } 2298 2299 bool CUDABlas::DoBlasGemmWithAlgorithm( 2300 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 2301 uint64 n, uint64 k, std::complex<double> alpha, 2302 const DeviceMemory<std::complex<double>> &a, int lda, 2303 const DeviceMemory<std::complex<double>> &b, int ldb, 2304 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, 2305 blas::ComputationType computation_type, blas::AlgorithmType algorithm, 2306 blas::ProfileResult *output_profile_result) { 2307 return DoBlasGemmWithAlgorithmImpl( 2308 stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, 2309 computation_type, algorithm, output_profile_result); 2310 } 2311 2312 template <typename T, typename FuncT> 2313 port::Status CUDABlas::DoBlasGemmBatchedInternal( 2314 FuncT cublas_func, Stream *stream, blas::Transpose transa, 2315 blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha, 2316 const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda, 2317 const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb, 2318 T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers, 2319 int ldc, int batch_count, ScratchAllocator *scratch_allocator) { 2320 std::vector<T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs; 2321 for (int i = 0; i < batch_count; ++i) { 2322 a_raw_ptrs.push_back(static_cast<T *>(a_ptrs_to_wrappers[i]->opaque())); 2323 b_raw_ptrs.push_back(static_cast<T *>(b_ptrs_to_wrappers[i]->opaque())); 2324 c_raw_ptrs.push_back(static_cast<T *>(c_ptrs_to_wrappers[i]->opaque())); 2325 } 2326 2327 typedef typename CUDAComplexT<T>::type CUDA_T; 2328 2329 const size_t size = batch_count * sizeof(CUDA_T *); 2330 2331 // Device-side copy of pointers to matrices. 2332 DeviceMemory<CUDA_T *> a; 2333 DeviceMemory<CUDA_T *> b; 2334 DeviceMemory<CUDA_T *> c; 2335 2336 // If temporary space is allocated for device-side copies of pointers to 2337 // matrices, that temporary space should not be freed until this function 2338 // returns. Although the values for these unique_ptrs are not set here, they 2339 // are declared at this scope so they will be destroyed when the function 2340 // returns. 2341 // 2342 // If a scratch allocator is provided, these pointers will not be used at all. 2343 std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> a_temporary; 2344 std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> b_temporary; 2345 std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> c_temporary; 2346 2347 // Decide how to allocate device-side copy of pointers to matrices based on 2348 // whether a scratch allocator was passed. 2349 if (scratch_allocator != nullptr) { 2350 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> a_bytes, 2351 scratch_allocator->AllocateBytes(stream, size)); 2352 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> b_bytes, 2353 scratch_allocator->AllocateBytes(stream, size)); 2354 SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> c_bytes, 2355 scratch_allocator->AllocateBytes(stream, size)); 2356 a = DeviceMemory<CUDA_T *>(a_bytes); 2357 b = DeviceMemory<CUDA_T *>(b_bytes); 2358 c = DeviceMemory<CUDA_T *>(c_bytes); 2359 } else { 2360 SE_ASSIGN_OR_RETURN(a_temporary, 2361 stream->AllocateTemporaryArray<CUDA_T *>(batch_count)); 2362 SE_ASSIGN_OR_RETURN(b_temporary, 2363 stream->AllocateTemporaryArray<CUDA_T *>(batch_count)); 2364 SE_ASSIGN_OR_RETURN(c_temporary, 2365 stream->AllocateTemporaryArray<CUDA_T *>(batch_count)); 2366 a = DeviceMemory<CUDA_T *>(*a_temporary->mutable_device_memory()); 2367 b = DeviceMemory<CUDA_T *>(*b_temporary->mutable_device_memory()); 2368 c = DeviceMemory<CUDA_T *>(*c_temporary->mutable_device_memory()); 2369 } 2370 2371 if (!stream->ThenMemcpy(&a, a_raw_ptrs.data(), size).ok() || 2372 !stream->ThenMemcpy(&b, b_raw_ptrs.data(), size).ok() || 2373 !stream->ThenMemcpy(&c, c_raw_ptrs.data(), size).ok()) { 2374 return port::Status(port::error::INTERNAL, 2375 "failed to copy memory from host to device in " 2376 "CUDABlas::DoBlasGemmBatched"); 2377 } 2378 2379 bool ok = DoBlasInternal( 2380 cublas_func, stream, true /* = pointer_mode_host */, 2381 CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, 2382 CUDAComplex(&alpha), const_cast<const CUDA_T **>(CUDAMemory(a)), lda, 2383 const_cast<const CUDA_T **>(CUDAMemory(b)), ldb, CUDAComplex(&beta), 2384 const_cast<CUDA_T **>(CUDAMemory(c)), ldc, batch_count); 2385 2386 if (ok) { 2387 return port::Status::OK(); 2388 } 2389 return port::Status(port::error::INTERNAL, 2390 "failed BLAS call, see log for details"); 2391 } 2392 2393 bool CUDABlas::DoBlasGemmBatched( 2394 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 2395 uint64 n, uint64 k, float alpha, 2396 const port::ArraySlice<DeviceMemory<float> *> &a_array, int lda, 2397 const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta, 2398 const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc, 2399 int batch_count, ScratchAllocator *scratch_allocator) { 2400 port::Status status = DoBlasGemmBatchedInternal( 2401 wrap::cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, 2402 lda, b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator); 2403 if (!status.ok()) { 2404 LOG(ERROR) << status; 2405 } 2406 return status.ok(); 2407 } 2408 2409 bool CUDABlas::DoBlasGemmBatched( 2410 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 2411 uint64 n, uint64 k, double alpha, 2412 const port::ArraySlice<DeviceMemory<double> *> &a_array, int lda, 2413 const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb, 2414 double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array, 2415 int ldc, int batch_count, ScratchAllocator *scratch_allocator) { 2416 port::Status status = DoBlasGemmBatchedInternal( 2417 wrap::cublasDgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, 2418 lda, b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator); 2419 if (!status.ok()) { 2420 LOG(ERROR) << status; 2421 } 2422 return status.ok(); 2423 } 2424 2425 bool CUDABlas::DoBlasGemmBatched( 2426 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 2427 uint64 n, uint64 k, std::complex<float> alpha, 2428 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a_array, 2429 int lda, 2430 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b_array, 2431 int ldb, std::complex<float> beta, 2432 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array, 2433 int ldc, int batch_count, ScratchAllocator *scratch_allocator) { 2434 port::Status status = DoBlasGemmBatchedInternal( 2435 wrap::cublasCgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, 2436 lda, b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator); 2437 if (!status.ok()) { 2438 LOG(ERROR) << status; 2439 } 2440 return status.ok(); 2441 } 2442 2443 bool CUDABlas::DoBlasGemmBatched( 2444 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 2445 uint64 n, uint64 k, std::complex<double> alpha, 2446 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a_array, 2447 int lda, 2448 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b_array, 2449 int ldb, std::complex<double> beta, 2450 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array, 2451 int ldc, int batch_count, ScratchAllocator *scratch_allocator) { 2452 port::Status status = DoBlasGemmBatchedInternal( 2453 wrap::cublasZgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, 2454 lda, b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator); 2455 if (!status.ok()) { 2456 LOG(ERROR) << status; 2457 } 2458 return status.ok(); 2459 } 2460 2461 bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side, 2462 blas::UpperLower uplo, uint64 m, uint64 n, 2463 std::complex<float> alpha, 2464 const DeviceMemory<std::complex<float>> &a, int lda, 2465 const DeviceMemory<std::complex<float>> &b, int ldb, 2466 std::complex<float> beta, 2467 DeviceMemory<std::complex<float>> *c, int ldc) { 2468 return DoBlasInternal( 2469 wrap::cublasChemm, stream, true /* = pointer_mode_host */, 2470 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, CUDAComplex(&alpha), 2471 CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(b)), ldb, 2472 CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(c)), ldc); 2473 } 2474 2475 bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side, 2476 blas::UpperLower uplo, uint64 m, uint64 n, 2477 std::complex<double> alpha, 2478 const DeviceMemory<std::complex<double>> &a, int lda, 2479 const DeviceMemory<std::complex<double>> &b, int ldb, 2480 std::complex<double> beta, 2481 DeviceMemory<std::complex<double>> *c, int ldc) { 2482 return DoBlasInternal( 2483 wrap::cublasZhemm, stream, true /* = pointer_mode_host */, 2484 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, CUDAComplex(&alpha), 2485 CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(b)), ldb, 2486 CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(c)), ldc); 2487 } 2488 2489 bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo, 2490 blas::Transpose trans, uint64 n, uint64 k, 2491 float alpha, 2492 const DeviceMemory<std::complex<float>> &a, int lda, 2493 float beta, DeviceMemory<std::complex<float>> *c, 2494 int ldc) { 2495 return DoBlasInternal(wrap::cublasCherk, stream, 2496 true /* = pointer_mode_host */, 2497 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, 2498 k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, 2499 &beta, CUDAComplex(CUDAMemoryMutable(c)), ldc); 2500 } 2501 2502 bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo, 2503 blas::Transpose trans, uint64 n, uint64 k, 2504 double alpha, 2505 const DeviceMemory<std::complex<double>> &a, int lda, 2506 double beta, DeviceMemory<std::complex<double>> *c, 2507 int ldc) { 2508 return DoBlasInternal(wrap::cublasZherk, stream, 2509 true /* = pointer_mode_host */, 2510 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, 2511 k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, 2512 &beta, CUDAComplex(CUDAMemoryMutable(c)), ldc); 2513 } 2514 2515 bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo, 2516 blas::Transpose trans, uint64 n, uint64 k, 2517 std::complex<float> alpha, 2518 const DeviceMemory<std::complex<float>> &a, int lda, 2519 const DeviceMemory<std::complex<float>> &b, int ldb, 2520 float beta, DeviceMemory<std::complex<float>> *c, 2521 int ldc) { 2522 return DoBlasInternal(wrap::cublasCher2k, stream, 2523 true /* = pointer_mode_host */, 2524 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, 2525 k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, 2526 CUDAComplex(CUDAMemory(b)), ldb, &beta, 2527 CUDAComplex(CUDAMemoryMutable(c)), ldc); 2528 } 2529 2530 bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo, 2531 blas::Transpose trans, uint64 n, uint64 k, 2532 std::complex<double> alpha, 2533 const DeviceMemory<std::complex<double>> &a, int lda, 2534 const DeviceMemory<std::complex<double>> &b, int ldb, 2535 double beta, DeviceMemory<std::complex<double>> *c, 2536 int ldc) { 2537 return DoBlasInternal(wrap::cublasZher2k, stream, 2538 true /* = pointer_mode_host */, 2539 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, 2540 k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, 2541 CUDAComplex(CUDAMemory(b)), ldb, &beta, 2542 CUDAComplex(CUDAMemoryMutable(c)), ldc); 2543 } 2544 2545 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side, 2546 blas::UpperLower uplo, uint64 m, uint64 n, 2547 float alpha, const DeviceMemory<float> &a, int lda, 2548 const DeviceMemory<float> &b, int ldb, float beta, 2549 DeviceMemory<float> *c, int ldc) { 2550 return DoBlasInternal( 2551 wrap::cublasSsymm, stream, true /* = pointer_mode_host */, 2552 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, &alpha, CUDAMemory(a), 2553 lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc); 2554 } 2555 2556 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side, 2557 blas::UpperLower uplo, uint64 m, uint64 n, 2558 double alpha, const DeviceMemory<double> &a, int lda, 2559 const DeviceMemory<double> &b, int ldb, double beta, 2560 DeviceMemory<double> *c, int ldc) { 2561 return DoBlasInternal( 2562 wrap::cublasDsymm, stream, true /* = pointer_mode_host */, 2563 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, &alpha, CUDAMemory(a), 2564 lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc); 2565 } 2566 2567 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side, 2568 blas::UpperLower uplo, uint64 m, uint64 n, 2569 std::complex<float> alpha, 2570 const DeviceMemory<std::complex<float>> &a, int lda, 2571 const DeviceMemory<std::complex<float>> &b, int ldb, 2572 std::complex<float> beta, 2573 DeviceMemory<std::complex<float>> *c, int ldc) { 2574 return DoBlasInternal( 2575 wrap::cublasCsymm, stream, true /* = pointer_mode_host */, 2576 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, CUDAComplex(&alpha), 2577 CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(b)), ldb, 2578 CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(c)), ldc); 2579 } 2580 2581 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side, 2582 blas::UpperLower uplo, uint64 m, uint64 n, 2583 std::complex<double> alpha, 2584 const DeviceMemory<std::complex<double>> &a, int lda, 2585 const DeviceMemory<std::complex<double>> &b, int ldb, 2586 std::complex<double> beta, 2587 DeviceMemory<std::complex<double>> *c, int ldc) { 2588 return DoBlasInternal( 2589 wrap::cublasZsymm, stream, true /* = pointer_mode_host */, 2590 CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, CUDAComplex(&alpha), 2591 CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(b)), ldb, 2592 CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(c)), ldc); 2593 } 2594 2595 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 2596 blas::Transpose trans, uint64 n, uint64 k, 2597 float alpha, const DeviceMemory<float> &a, int lda, 2598 float beta, DeviceMemory<float> *c, int ldc) { 2599 return DoBlasInternal( 2600 wrap::cublasSsyrk, stream, true /* = pointer_mode_host */, 2601 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k, &alpha, 2602 CUDAMemory(a), lda, &beta, CUDAMemoryMutable(c), ldc); 2603 } 2604 2605 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 2606 blas::Transpose trans, uint64 n, uint64 k, 2607 double alpha, const DeviceMemory<double> &a, int lda, 2608 double beta, DeviceMemory<double> *c, int ldc) { 2609 return DoBlasInternal( 2610 wrap::cublasDsyrk, stream, true /* = pointer_mode_host */, 2611 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k, &alpha, 2612 CUDAMemory(a), lda, &beta, CUDAMemoryMutable(c), ldc); 2613 } 2614 2615 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 2616 blas::Transpose trans, uint64 n, uint64 k, 2617 std::complex<float> alpha, 2618 const DeviceMemory<std::complex<float>> &a, int lda, 2619 std::complex<float> beta, 2620 DeviceMemory<std::complex<float>> *c, int ldc) { 2621 return DoBlasInternal( 2622 wrap::cublasCsyrk, stream, true /* = pointer_mode_host */, 2623 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k, 2624 CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(&beta), 2625 CUDAComplex(CUDAMemoryMutable(c)), ldc); 2626 } 2627 2628 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 2629 blas::Transpose trans, uint64 n, uint64 k, 2630 std::complex<double> alpha, 2631 const DeviceMemory<std::complex<double>> &a, int lda, 2632 std::complex<double> beta, 2633 DeviceMemory<std::complex<double>> *c, int ldc) { 2634 return DoBlasInternal( 2635 wrap::cublasZsyrk, stream, true /* = pointer_mode_host */, 2636 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k, 2637 CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(&beta), 2638 CUDAComplex(CUDAMemoryMutable(c)), ldc); 2639 } 2640 2641 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 2642 blas::Transpose trans, uint64 n, uint64 k, 2643 float alpha, const DeviceMemory<float> &a, int lda, 2644 const DeviceMemory<float> &b, int ldb, float beta, 2645 DeviceMemory<float> *c, int ldc) { 2646 return DoBlasInternal( 2647 wrap::cublasSsyr2k, stream, true /* = pointer_mode_host */, 2648 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k, &alpha, 2649 CUDAMemory(a), lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc); 2650 } 2651 2652 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 2653 blas::Transpose trans, uint64 n, uint64 k, 2654 double alpha, const DeviceMemory<double> &a, int lda, 2655 const DeviceMemory<double> &b, int ldb, double beta, 2656 DeviceMemory<double> *c, int ldc) { 2657 return DoBlasInternal( 2658 wrap::cublasDsyr2k, stream, true /* = pointer_mode_host */, 2659 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k, &alpha, 2660 CUDAMemory(a), lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc); 2661 } 2662 2663 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 2664 blas::Transpose trans, uint64 n, uint64 k, 2665 std::complex<float> alpha, 2666 const DeviceMemory<std::complex<float>> &a, int lda, 2667 const DeviceMemory<std::complex<float>> &b, int ldb, 2668 std::complex<float> beta, 2669 DeviceMemory<std::complex<float>> *c, int ldc) { 2670 return DoBlasInternal(wrap::cublasCsyr2k, stream, 2671 true /* = pointer_mode_host */, 2672 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, 2673 k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, 2674 CUDAComplex(CUDAMemory(b)), ldb, CUDAComplex(&beta), 2675 CUDAComplex(CUDAMemoryMutable(c)), ldc); 2676 } 2677 2678 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 2679 blas::Transpose trans, uint64 n, uint64 k, 2680 std::complex<double> alpha, 2681 const DeviceMemory<std::complex<double>> &a, int lda, 2682 const DeviceMemory<std::complex<double>> &b, int ldb, 2683 std::complex<double> beta, 2684 DeviceMemory<std::complex<double>> *c, int ldc) { 2685 return DoBlasInternal(wrap::cublasZsyr2k, stream, 2686 true /* = pointer_mode_host */, 2687 CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, 2688 k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, 2689 CUDAComplex(CUDAMemory(b)), ldb, CUDAComplex(&beta), 2690 CUDAComplex(CUDAMemoryMutable(c)), ldc); 2691 } 2692 2693 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side, 2694 blas::UpperLower uplo, blas::Transpose transa, 2695 blas::Diagonal diag, uint64 m, uint64 n, float alpha, 2696 const DeviceMemory<float> &a, int lda, 2697 DeviceMemory<float> *b, int ldb) { 2698 return DoBlasInternal( 2699 wrap::cublasStrmm, stream, true /* = pointer_mode_host */, 2700 CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa), 2701 CUDABlasDiagonal(diag), m, n, &alpha, CUDAMemory(a), lda, 2702 CUDAMemoryMutable(b), ldb, CUDAMemoryMutable(b), ldb); 2703 } 2704 2705 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side, 2706 blas::UpperLower uplo, blas::Transpose transa, 2707 blas::Diagonal diag, uint64 m, uint64 n, double alpha, 2708 const DeviceMemory<double> &a, int lda, 2709 DeviceMemory<double> *b, int ldb) { 2710 return DoBlasInternal( 2711 wrap::cublasDtrmm, stream, true /* = pointer_mode_host */, 2712 CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa), 2713 CUDABlasDiagonal(diag), m, n, &alpha, CUDAMemory(a), lda, 2714 CUDAMemoryMutable(b), ldb, CUDAMemoryMutable(b), ldb); 2715 } 2716 2717 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side, 2718 blas::UpperLower uplo, blas::Transpose transa, 2719 blas::Diagonal diag, uint64 m, uint64 n, 2720 std::complex<float> alpha, 2721 const DeviceMemory<std::complex<float>> &a, int lda, 2722 DeviceMemory<std::complex<float>> *b, int ldb) { 2723 return DoBlasInternal( 2724 wrap::cublasCtrmm, stream, true /* = pointer_mode_host */, 2725 CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa), 2726 CUDABlasDiagonal(diag), m, n, CUDAComplex(&alpha), 2727 CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemoryMutable(b)), ldb, 2728 CUDAComplex(CUDAMemoryMutable(b)), ldb); 2729 } 2730 2731 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side, 2732 blas::UpperLower uplo, blas::Transpose transa, 2733 blas::Diagonal diag, uint64 m, uint64 n, 2734 std::complex<double> alpha, 2735 const DeviceMemory<std::complex<double>> &a, int lda, 2736 DeviceMemory<std::complex<double>> *b, int ldb) { 2737 return DoBlasInternal( 2738 wrap::cublasZtrmm, stream, true /* = pointer_mode_host */, 2739 CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa), 2740 CUDABlasDiagonal(diag), m, n, CUDAComplex(&alpha), 2741 CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemoryMutable(b)), ldb, 2742 CUDAComplex(CUDAMemoryMutable(b)), ldb); 2743 } 2744 2745 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side, 2746 blas::UpperLower uplo, blas::Transpose transa, 2747 blas::Diagonal diag, uint64 m, uint64 n, float alpha, 2748 const DeviceMemory<float> &a, int lda, 2749 DeviceMemory<float> *b, int ldb) { 2750 return DoBlasInternal(wrap::cublasStrsm, stream, 2751 true /* = pointer_mode_host */, CUDABlasSide(side), 2752 CUDABlasUpperLower(uplo), CUDABlasTranspose(transa), 2753 CUDABlasDiagonal(diag), m, n, &alpha, CUDAMemory(a), 2754 lda, CUDAMemoryMutable(b), ldb); 2755 } 2756 2757 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side, 2758 blas::UpperLower uplo, blas::Transpose transa, 2759 blas::Diagonal diag, uint64 m, uint64 n, double alpha, 2760 const DeviceMemory<double> &a, int lda, 2761 DeviceMemory<double> *b, int ldb) { 2762 return DoBlasInternal(wrap::cublasDtrsm, stream, 2763 true /* = pointer_mode_host */, CUDABlasSide(side), 2764 CUDABlasUpperLower(uplo), CUDABlasTranspose(transa), 2765 CUDABlasDiagonal(diag), m, n, &alpha, CUDAMemory(a), 2766 lda, CUDAMemoryMutable(b), ldb); 2767 } 2768 2769 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side, 2770 blas::UpperLower uplo, blas::Transpose transa, 2771 blas::Diagonal diag, uint64 m, uint64 n, 2772 std::complex<float> alpha, 2773 const DeviceMemory<std::complex<float>> &a, int lda, 2774 DeviceMemory<std::complex<float>> *b, int ldb) { 2775 return DoBlasInternal( 2776 wrap::cublasCtrsm, stream, true /* = pointer_mode_host */, 2777 CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa), 2778 CUDABlasDiagonal(diag), m, n, CUDAComplex(&alpha), 2779 CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemoryMutable(b)), ldb); 2780 } 2781 2782 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side, 2783 blas::UpperLower uplo, blas::Transpose transa, 2784 blas::Diagonal diag, uint64 m, uint64 n, 2785 std::complex<double> alpha, 2786 const DeviceMemory<std::complex<double>> &a, int lda, 2787 DeviceMemory<std::complex<double>> *b, int ldb) { 2788 return DoBlasInternal( 2789 wrap::cublasZtrsm, stream, true /* = pointer_mode_host */, 2790 CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa), 2791 CUDABlasDiagonal(diag), m, n, CUDAComplex(&alpha), 2792 CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemoryMutable(b)), ldb); 2793 } 2794 2795 } // namespace cuda 2796 2797 namespace gpu = ::perftools::gputools; 2798 2799 void initialize_cublas() { 2800 gpu::port::Status status = 2801 gpu::PluginRegistry::Instance() 2802 ->RegisterFactory<gpu::PluginRegistry::BlasFactory>( 2803 gpu::cuda::kCudaPlatformId, gpu::cuda::kCuBlasPlugin, "cuBLAS", 2804 [](gpu::internal::StreamExecutorInterface 2805 *parent) -> gpu::blas::BlasSupport * { 2806 gpu::cuda::CUDAExecutor *cuda_executor = 2807 dynamic_cast<gpu::cuda::CUDAExecutor *>(parent); 2808 if (cuda_executor == nullptr) { 2809 LOG(ERROR) 2810 << "Attempting to initialize an instance of the cuBLAS " 2811 << "support library with a non-CUDA StreamExecutor"; 2812 return nullptr; 2813 } 2814 2815 gpu::cuda::CUDABlas *blas = 2816 new gpu::cuda::CUDABlas(cuda_executor); 2817 if (!blas->Init()) { 2818 // Note: Init() will log a more specific error. 2819 delete blas; 2820 return nullptr; 2821 } 2822 return blas; 2823 }); 2824 2825 if (!status.ok()) { 2826 LOG(ERROR) << "Unable to register cuBLAS factory: " 2827 << status.error_message(); 2828 } 2829 2830 gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId, 2831 gpu::PluginKind::kBlas, 2832 gpu::cuda::kCuBlasPlugin); 2833 } 2834 2835 } // namespace gputools 2836 } // namespace perftools 2837 2838 REGISTER_MODULE_INITIALIZER(register_cublas, 2839 { perftools::gputools::initialize_cublas(); }); 2840