Home | History | Annotate | Download | only in cuda
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Include cuBLAS headers early, and then set EIGEN_HAS_CUDA_FP16
     17 // if we have new enough CUDA (which we will only know after including
     18 // cuda.h). This ensures that Eigen's Half.h does not attempt to make its own
     19 // __half typedef if CUDA has already defined one (and conversely, that we do
     20 // not include <cuda_fp16.h> after Half.h has made its typedef).
     21 #include "cuda/include/cuda.h"
     22 #include "cuda/include/cublas_v2.h"
     23 
     24 #if CUDA_VERSION >= 7050
     25 #define EIGEN_HAS_CUDA_FP16
     26 #endif
     27 
     28 #if CUDA_VERSION >= 8000
     29 #define SE_CUDA_DATA_HALF CUDA_R_16F
     30 #else
     31 #define SE_CUDA_DATA_HALF CUBLAS_DATA_HALF
     32 #endif
     33 
     34 #include "tensorflow/stream_executor/cuda/cuda_blas.h"
     35 
     36 #include <assert.h>
     37 #include <complex>
     38 
     39 #include "tensorflow/core/util/env_var.h"
     40 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
     41 #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
     42 #include "tensorflow/stream_executor/cuda/cuda_helpers.h"
     43 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
     44 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
     45 #include "tensorflow/stream_executor/cuda/cuda_timer.h"
     46 #include "tensorflow/stream_executor/device_memory.h"
     47 #include "tensorflow/stream_executor/lib/env.h"
     48 #include "tensorflow/stream_executor/lib/initialize.h"
     49 #include "tensorflow/stream_executor/lib/status.h"
     50 #include "tensorflow/stream_executor/lib/status_macros.h"
     51 #include "tensorflow/stream_executor/lib/strcat.h"
     52 #include "tensorflow/stream_executor/lib/stringprintf.h"
     53 #include "tensorflow/stream_executor/platform/logging.h"
     54 #include "tensorflow/stream_executor/platform/port.h"
     55 #include "tensorflow/stream_executor/plugin_registry.h"
     56 #include "tensorflow/stream_executor/scratch_allocator.h"
     57 #include "tensorflow/stream_executor/stream_executor.h"
     58 
     59 namespace perftools {
     60 namespace gputools {
     61 namespace cuda {
     62 
     63 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuBlasPlugin);
     64 
     65 namespace wrap {
     66 
     67 #define PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(__name)                      \
     68   struct WrapperShim__##__name {                                    \
     69     static const char *kName;                                       \
     70     template <typename... Args>                                     \
     71     cublasStatus_t operator()(CUDAExecutor *parent, Args... args) { \
     72       cuda::ScopedActivateExecutorContext sac{parent};              \
     73       return ::__name(args...);                                     \
     74     }                                                               \
     75   } __name;                                                         \
     76   const char *WrapperShim__##__name::kName = #__name;
     77 
     78 #define PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(__name) \
     79   PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(__name)
     80 
     81 #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
     82   __macro(cublasSnrm2)                    \
     83   __macro(cublasDnrm2)                    \
     84   __macro(cublasScnrm2)                   \
     85   __macro(cublasDznrm2)                   \
     86   __macro(cublasSdot)                     \
     87   __macro(cublasDdot)                     \
     88   __macro(cublasCdotu)                    \
     89   __macro(cublasCdotc)                    \
     90   __macro(cublasZdotu)                    \
     91   __macro(cublasZdotc)                    \
     92   __macro(cublasSscal)                    \
     93   __macro(cublasDscal)                    \
     94   __macro(cublasCscal)                    \
     95   __macro(cublasCsscal)                   \
     96   __macro(cublasZscal)                    \
     97   __macro(cublasZdscal)                   \
     98   __macro(cublasSaxpy)                    \
     99   __macro(cublasDaxpy)                    \
    100   __macro(cublasCaxpy)                    \
    101   __macro(cublasZaxpy)                    \
    102   __macro(cublasScopy)                    \
    103   __macro(cublasDcopy)                    \
    104   __macro(cublasCcopy)                    \
    105   __macro(cublasZcopy)                    \
    106   __macro(cublasSswap)                    \
    107   __macro(cublasDswap)                    \
    108   __macro(cublasCswap)                    \
    109   __macro(cublasZswap)                    \
    110   __macro(cublasIsamax)                   \
    111   __macro(cublasIdamax)                   \
    112   __macro(cublasIcamax)                   \
    113   __macro(cublasIzamax)                   \
    114   __macro(cublasIsamin)                   \
    115   __macro(cublasIdamin)                   \
    116   __macro(cublasIcamin)                   \
    117   __macro(cublasIzamin)                   \
    118   __macro(cublasSasum)                    \
    119   __macro(cublasDasum)                    \
    120   __macro(cublasScasum)                   \
    121   __macro(cublasDzasum)                   \
    122   __macro(cublasSrot)                     \
    123   __macro(cublasDrot)                     \
    124   __macro(cublasCrot)                     \
    125   __macro(cublasCsrot)                    \
    126   __macro(cublasZrot)                     \
    127   __macro(cublasZdrot)                    \
    128   __macro(cublasSrotg)                    \
    129   __macro(cublasDrotg)                    \
    130   __macro(cublasCrotg)                    \
    131   __macro(cublasZrotg)                    \
    132   __macro(cublasSrotm)                    \
    133   __macro(cublasDrotm)                    \
    134   __macro(cublasSrotmg)                   \
    135   __macro(cublasDrotmg)                   \
    136   __macro(cublasSgemv)                    \
    137   __macro(cublasDgemv)                    \
    138   __macro(cublasCgemv)                    \
    139   __macro(cublasZgemv)                    \
    140   __macro(cublasSgbmv)                    \
    141   __macro(cublasDgbmv)                    \
    142   __macro(cublasCgbmv)                    \
    143   __macro(cublasZgbmv)                    \
    144   __macro(cublasStrmv)                    \
    145   __macro(cublasDtrmv)                    \
    146   __macro(cublasCtrmv)                    \
    147   __macro(cublasZtrmv)                    \
    148   __macro(cublasStbmv)                    \
    149   __macro(cublasDtbmv)                    \
    150   __macro(cublasCtbmv)                    \
    151   __macro(cublasZtbmv)                    \
    152   __macro(cublasStpmv)                    \
    153   __macro(cublasDtpmv)                    \
    154   __macro(cublasCtpmv)                    \
    155   __macro(cublasZtpmv)                    \
    156   __macro(cublasStrsv)                    \
    157   __macro(cublasDtrsv)                    \
    158   __macro(cublasCtrsv)                    \
    159   __macro(cublasZtrsv)                    \
    160   __macro(cublasStpsv)                    \
    161   __macro(cublasDtpsv)                    \
    162   __macro(cublasCtpsv)                    \
    163   __macro(cublasZtpsv)                    \
    164   __macro(cublasStbsv)                    \
    165   __macro(cublasDtbsv)                    \
    166   __macro(cublasCtbsv)                    \
    167   __macro(cublasZtbsv)                    \
    168   __macro(cublasSsymv)                    \
    169   __macro(cublasDsymv)                    \
    170   __macro(cublasCsymv)                    \
    171   __macro(cublasZsymv)                    \
    172   __macro(cublasChemv)                    \
    173   __macro(cublasZhemv)                    \
    174   __macro(cublasSsbmv)                    \
    175   __macro(cublasDsbmv)                    \
    176   __macro(cublasChbmv)                    \
    177   __macro(cublasZhbmv)                    \
    178   __macro(cublasSspmv)                    \
    179   __macro(cublasDspmv)                    \
    180   __macro(cublasChpmv)                    \
    181   __macro(cublasZhpmv)                    \
    182   __macro(cublasSger)                     \
    183   __macro(cublasDger)                     \
    184   __macro(cublasCgeru)                    \
    185   __macro(cublasCgerc)                    \
    186   __macro(cublasZgeru)                    \
    187   __macro(cublasZgerc)                    \
    188   __macro(cublasSsyr)                     \
    189   __macro(cublasDsyr)                     \
    190   __macro(cublasCsyr)                     \
    191   __macro(cublasZsyr)                     \
    192   __macro(cublasCher)                     \
    193   __macro(cublasZher)                     \
    194   __macro(cublasSspr)                     \
    195   __macro(cublasDspr)                     \
    196   __macro(cublasChpr)                     \
    197   __macro(cublasZhpr)                     \
    198   __macro(cublasSsyr2)                    \
    199   __macro(cublasDsyr2)                    \
    200   __macro(cublasCsyr2)                    \
    201   __macro(cublasZsyr2)                    \
    202   __macro(cublasCher2)                    \
    203   __macro(cublasZher2)                    \
    204   __macro(cublasSspr2)                    \
    205   __macro(cublasDspr2)                    \
    206   __macro(cublasChpr2)                    \
    207   __macro(cublasZhpr2)                    \
    208   __macro(cublasSgemm)                    \
    209   __macro(cublasDgemm)                    \
    210   __macro(cublasCgemm)                    \
    211   __macro(cublasZgemm)                    \
    212   __macro(cublasSsyrk)                    \
    213   __macro(cublasDsyrk)                    \
    214   __macro(cublasCsyrk)                    \
    215   __macro(cublasZsyrk)                    \
    216   __macro(cublasCherk)                    \
    217   __macro(cublasZherk)                    \
    218   __macro(cublasSsyr2k)                   \
    219   __macro(cublasDsyr2k)                   \
    220   __macro(cublasCsyr2k)                   \
    221   __macro(cublasZsyr2k)                   \
    222   __macro(cublasCher2k)                   \
    223   __macro(cublasZher2k)                   \
    224   __macro(cublasSsyrkx)                   \
    225   __macro(cublasDsyrkx)                   \
    226   __macro(cublasCsyrkx)                   \
    227   __macro(cublasZsyrkx)                   \
    228   __macro(cublasCherkx)                   \
    229   __macro(cublasZherkx)                   \
    230   __macro(cublasSsymm)                    \
    231   __macro(cublasDsymm)                    \
    232   __macro(cublasCsymm)                    \
    233   __macro(cublasZsymm)                    \
    234   __macro(cublasChemm)                    \
    235   __macro(cublasZhemm)                    \
    236   __macro(cublasStrsm)                    \
    237   __macro(cublasDtrsm)                    \
    238   __macro(cublasCtrsm)                    \
    239   __macro(cublasZtrsm)                    \
    240   __macro(cublasStrmm)                    \
    241   __macro(cublasDtrmm)                    \
    242   __macro(cublasCtrmm)                    \
    243   __macro(cublasZtrmm)                    \
    244   __macro(cublasSgeam)                    \
    245   __macro(cublasDgeam)                    \
    246   __macro(cublasCgeam)                    \
    247   __macro(cublasZgeam)                    \
    248   __macro(cublasSdgmm)                    \
    249   __macro(cublasDdgmm)                    \
    250   __macro(cublasCdgmm)                    \
    251   __macro(cublasZdgmm)
    252 
    253 PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasCreate)
    254 PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasDestroy)
    255 PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasSetStream)
    256 PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasSetPointerMode)
    257 PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasGetPointerMode)
    258 PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSgemmBatched)
    259 PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasDgemmBatched)
    260 PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasCgemmBatched)
    261 PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasZgemmBatched)
    262 CUBLAS_BLAS_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP)
    263 
    264 #if CUDA_VERSION >= 7050
    265 PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSgemmEx)
    266 #endif
    267 
    268 #if CUDA_VERSION >= 8000
    269 PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGemmEx)
    270 #endif
    271 
    272 #if CUDA_VERSION >= 9000
    273 PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGetMathMode)
    274 PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSetMathMode)
    275 #endif
    276 
    277 }  // namespace wrap
    278 
    279 static string ToString(cublasStatus_t status) {
    280   switch (status) {
    281     case CUBLAS_STATUS_SUCCESS:
    282       return "CUBLAS_STATUS_SUCCESS";
    283     case CUBLAS_STATUS_NOT_INITIALIZED:
    284       return "CUBLAS_STATUS_NOT_INITIALIZED";
    285     case CUBLAS_STATUS_ALLOC_FAILED:
    286       return "CUBLAS_STATUS_ALLOC_FAILED";
    287     case CUBLAS_STATUS_INVALID_VALUE:
    288       return "CUBLAS_STATUS_INVALID_VALUE";
    289     case CUBLAS_STATUS_ARCH_MISMATCH:
    290       return "CUBLAS_STATUS_ARCH_MISMATCH";
    291     case CUBLAS_STATUS_MAPPING_ERROR:
    292       return "CUBLAS_STATUS_MAPPING_ERROR";
    293     case CUBLAS_STATUS_EXECUTION_FAILED:
    294       return "CUBLAS_STATUS_EXECUTION_FAILED";
    295     case CUBLAS_STATUS_INTERNAL_ERROR:
    296       return "CUBLAS_STATUS_INTERNAL_ERROR";
    297 #if CUDA_VERSION >= 8000
    298     case CUBLAS_STATUS_NOT_SUPPORTED:
    299       return "CUBLAS_STATUS_NOT_SUPPORTED";
    300     case CUBLAS_STATUS_LICENSE_ERROR:
    301       return "CUBLAS_STATUS_LICENSE_ERROR";
    302 #endif
    303     default:
    304       return port::StrCat("<invalid cublas status: ", status, ">");
    305   }
    306 }
    307 
    308 // Decide whether to enable TENSOR_OP_MATH
    309 static bool TensorOpMathEnabled() {
    310   static bool is_enabled = [] {
    311     bool is_disabled;
    312     TF_CHECK_OK(
    313         tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUBLAS_TENSOR_OP_MATH",
    314                                        /*default_val=*/false, &is_disabled));
    315     return !is_disabled;
    316   }();
    317   return is_enabled;
    318 }
    319 
    320 // cuBLAS has interfaces that permit pointers to be passed from either the host
    321 // memory space or the device memory space; however, you must instruct it as to
    322 // which address space those pointers are in with cublasSetPointerMode.
    323 //
    324 // This helper sets the cuBLAS pointer mode to a desired value for a cuBLAS call
    325 // you are about to perform in a given scope.
    326 //
    327 // The prior cuBLAS pointer mode is retained and restored when this object goes
    328 // out of scope.
    329 class ScopedCublasPointerMode {
    330  public:
    331   // Note that, because the setting of the cublas pointer mode is fallible,
    332   // construction of this scoped datatype must be paired with a call to
    333   // Init().
    334   //
    335   // Parameters:
    336   //  handle: The cublas library handle to act upon in setting the pointer mode.
    337   explicit ScopedCublasPointerMode(CUDAExecutor *parent, cublasHandle_t handle)
    338       : parent_(parent), handle_(handle), ok_(false) {}
    339 
    340   // Attempts the switch to the requested scoped pointer mode, new_mode.
    341   //
    342   // Note that when false is returned, an appropriate error has already been
    343   // logged.
    344   bool Init(cublasPointerMode_t new_mode) {
    345     cublasStatus_t ret =
    346         wrap::cublasGetPointerMode(parent_, handle_, &old_mode_);
    347     if (ret != CUBLAS_STATUS_SUCCESS) {
    348       LOG(ERROR) << "failed to get old cublas pointer mode: " << ToString(ret);
    349       return ok_ = false;
    350     }
    351 
    352     ret = wrap::cublasSetPointerMode(parent_, handle_, new_mode);
    353     if (ret != CUBLAS_STATUS_SUCCESS) {
    354       LOG(ERROR) << "failed to set new cublas pointer mode: " << ToString(ret);
    355       return ok_ = false;
    356     }
    357 
    358     return ok_ = true;
    359   }
    360 
    361   // Switches back to the prior pointer mode, if the switch operation was
    362   // successful in the first place.
    363   ~ScopedCublasPointerMode() {
    364     if (ok_) {
    365       cublasStatus_t ret =
    366           wrap::cublasSetPointerMode(parent_, handle_, old_mode_);
    367       if (ret != CUBLAS_STATUS_SUCCESS) {
    368         LOG(ERROR) << "failed to set former cublas pointer mode: "
    369                    << ToString(ret);
    370       }
    371     }
    372   }
    373 
    374  private:
    375   CUDAExecutor *parent_;   // Executor establishing this pointer mode for.
    376   cublasHandle_t handle_;  // Handle to the cuBLAS instance of interest.
    377   cublasPointerMode_t old_mode_;  // Prior cuBLAS pointer mode, to be restored.
    378   bool ok_;                       // Whether the change was successful.
    379 };
    380 
    381 #if CUDA_VERSION >= 9000
    382 // cuBLAS has interfaces that permit computations to use the Volta hardware.
    383 // This must be enabled via the cublasGet/SetMathMode APIs.
    384 //
    385 // This helper sets the cuBLAS math mode to a desired value for a cuBLAS call
    386 // you are about to perform in a given scope.
    387 //
    388 // The prior cuBLAS math mode is retained and restored when this object goes
    389 // out of scope.
    390 class ScopedCublasMathMode {
    391  public:
    392   // Note that, because the setting of the cublas math mode is fallible,
    393   // construction of this scoped datatype must be paired with a call to
    394   // Init().
    395   //
    396   // Parameters:
    397   //  handle: The cublas library handle to act upon in setting the math mode.
    398   explicit ScopedCublasMathMode(CUDAExecutor *parent, cublasHandle_t handle)
    399       : parent_(parent), handle_(handle), ok_(false) {}
    400 
    401   // Attempts the switch to the requested scoped math mode, new_mode.
    402   //
    403   // Note that when false is returned, an appropriate error has already been
    404   // logged.
    405   bool Init(cublasMath_t new_mode) {
    406     cublasStatus_t ret = wrap::cublasGetMathMode(parent_, handle_, &old_mode_);
    407     if (ret != CUBLAS_STATUS_SUCCESS) {
    408       LOG(ERROR) << "failed to get old cublas math mode: " << ToString(ret);
    409       return ok_ = false;
    410     }
    411 
    412     ret = wrap::cublasSetMathMode(parent_, handle_, new_mode);
    413     if (ret != CUBLAS_STATUS_SUCCESS) {
    414       LOG(ERROR) << "failed to set new cublas math mode: " << ToString(ret);
    415       return ok_ = false;
    416     }
    417     return ok_ = true;
    418   }
    419 
    420   // Switches back to the prior math mode, if the switch operation was
    421   // successful in the first place.
    422   ~ScopedCublasMathMode() {
    423     if (ok_) {
    424       cublasStatus_t ret = wrap::cublasSetMathMode(parent_, handle_, old_mode_);
    425       if (ret != CUBLAS_STATUS_SUCCESS) {
    426         LOG(ERROR) << "failed to set former cublas math mode: "
    427                    << ToString(ret);
    428       }
    429     }
    430   }
    431 
    432  private:
    433   CUDAExecutor *parent_;   // Executor establishing this math mode for.
    434   cublasHandle_t handle_;  // Handle to the cuBLAS instance of interest.
    435   cublasMath_t old_mode_;  // Prior cuBLAS math mode, to be restored.
    436   bool ok_;                // Whether the change was successful.
    437 };
    438 #endif  // CUDA_VERSION >= 9000
    439 
    440 bool CUDABlas::Init() {
    441   cublasStatus_t ret = wrap::cublasCreate(parent_, &blas_);
    442   if (ret != CUBLAS_STATUS_SUCCESS) {
    443     LOG(ERROR) << "failed to create cublas handle: " << ToString(ret);
    444     return false;
    445   }
    446 
    447   return true;
    448 }
    449 
    450 CUDABlas::CUDABlas(cuda::CUDAExecutor *parent)
    451     : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) {}
    452 
    453 CUDABlas::~CUDABlas() {
    454   if (blas_ != nullptr) {
    455     wrap::cublasDestroy(parent_, blas_);
    456   }
    457 }
    458 
    459 bool CUDABlas::SetStream(Stream *stream) {
    460   CHECK(stream != nullptr);
    461   CHECK(AsCUDAStreamValue(stream) != nullptr);
    462   CHECK(blas_ != nullptr);
    463   cublasStatus_t ret =
    464       wrap::cublasSetStream(parent_, blas_, AsCUDAStreamValue(stream));
    465   if (ret != CUBLAS_STATUS_SUCCESS) {
    466     LOG(ERROR) << "failed to set stream for cuBLAS calls: " << ToString(ret);
    467     return false;
    468   }
    469 
    470   return true;
    471 }
    472 
    473 namespace {
    474 
    475 // Helper functions transforming blas arguments into cuBLAS arguments.
    476 
    477 cublasOperation_t CUDABlasTranspose(blas::Transpose trans) {
    478   switch (trans) {
    479     case blas::Transpose::kNoTranspose:
    480       return CUBLAS_OP_N;
    481     case blas::Transpose::kTranspose:
    482       return CUBLAS_OP_T;
    483     case blas::Transpose::kConjugateTranspose:
    484       return CUBLAS_OP_C;
    485     default:
    486       LOG(FATAL) << "Invalid value of blas::Transpose.";
    487   }
    488 }
    489 
    490 cublasFillMode_t CUDABlasUpperLower(blas::UpperLower uplo) {
    491   switch (uplo) {
    492     case blas::UpperLower::kUpper:
    493       return CUBLAS_FILL_MODE_UPPER;
    494     case blas::UpperLower::kLower:
    495       return CUBLAS_FILL_MODE_LOWER;
    496     default:
    497       LOG(FATAL) << "Invalid value of blas::UpperLower.";
    498   }
    499 }
    500 
    501 cublasDiagType_t CUDABlasDiagonal(blas::Diagonal diag) {
    502   switch (diag) {
    503     case blas::Diagonal::kUnit:
    504       return CUBLAS_DIAG_UNIT;
    505     case blas::Diagonal::kNonUnit:
    506       return CUBLAS_DIAG_NON_UNIT;
    507     default:
    508       LOG(FATAL) << "Invalid value of blas::Diagonal.";
    509   }
    510 }
    511 
    512 cublasSideMode_t CUDABlasSide(blas::Side side) {
    513   switch (side) {
    514     case blas::Side::kLeft:
    515       return CUBLAS_SIDE_LEFT;
    516     case blas::Side::kRight:
    517       return CUBLAS_SIDE_RIGHT;
    518     default:
    519       LOG(FATAL) << "Invalid value of blas::Side.";
    520   }
    521 }
    522 
    523 // CUDADataType<T>::type translates from a C++ type (e.g. float) to a
    524 // cudaDataType_t (e.g. CUDA_R_32F).  CUDAComputationType(ty) translates from a
    525 // blas::ComputationType to a cudaDataType_t.
    526 //
    527 // These are used to build the argument type and computation type args to
    528 // cublasGemmEx.  cublasGemmEx and cudaDataType_t are available only on
    529 // CUDA >= 8.0.
    530 #if CUDA_VERSION >= 8000
    531 template <typename T>
    532 struct CUDADataType;
    533 
    534 template <>
    535 struct CUDADataType<Eigen::half> {
    536   static constexpr cudaDataType_t type = SE_CUDA_DATA_HALF;
    537 };
    538 
    539 template <>
    540 struct CUDADataType<std::complex<Eigen::half>> {
    541   static constexpr cudaDataType_t type = CUDA_C_16F;
    542 };
    543 
    544 template <>
    545 struct CUDADataType<float> {
    546   static constexpr cudaDataType_t type = CUDA_R_32F;
    547 };
    548 
    549 template <>
    550 struct CUDADataType<std::complex<float>> {
    551   static constexpr cudaDataType_t type = CUDA_C_32F;
    552 };
    553 
    554 template <>
    555 struct CUDADataType<double> {
    556   static constexpr cudaDataType_t type = CUDA_R_64F;
    557 };
    558 
    559 template <>
    560 struct CUDADataType<std::complex<double>> {
    561   static constexpr cudaDataType_t type = CUDA_C_64F;
    562 };
    563 
    564 template <>
    565 struct CUDADataType<int> {
    566   static constexpr cudaDataType_t type = CUDA_R_32I;
    567 };
    568 
    569 template <>
    570 struct CUDADataType<int8> {
    571   static constexpr cudaDataType_t type = CUDA_R_8I;
    572 };
    573 
    574 template <>
    575 struct CUDADataType<std::complex<int8>> {
    576   static constexpr cudaDataType_t type = CUDA_C_8I;
    577 };
    578 
    579 template <>
    580 struct CUDADataType<uint8> {
    581   static constexpr cudaDataType_t type = CUDA_R_8U;
    582 };
    583 
    584 template <>
    585 struct CUDADataType<std::complex<uint8>> {
    586   static constexpr cudaDataType_t type = CUDA_C_8U;
    587 };
    588 
    589 cudaDataType_t CUDAComputationType(blas::ComputationType ty) {
    590   switch (ty) {
    591     case blas::ComputationType::kF16:
    592       return CUDA_R_16F;
    593     case blas::ComputationType::kF32:
    594       return CUDA_R_32F;
    595     case blas::ComputationType::kF64:
    596       return CUDA_R_64F;
    597     case blas::ComputationType::kI32:
    598       return CUDA_R_32I;
    599     case blas::ComputationType::kComplexF32:
    600       return CUDA_C_32F;
    601     case blas::ComputationType::kComplexF64:
    602       return CUDA_C_64F;
    603   }
    604 }
    605 #endif
    606 
    607 }  // namespace
    608 
    609 template <typename FuncT, typename... Args>
    610 bool CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
    611                                   bool pointer_mode_host, bool err_on_failure,
    612                                   bool use_tensor_op_math, Args... args) {
    613   mutex_lock lock{mu_};
    614 
    615   CHECK(blas_ != nullptr);
    616   if (!SetStream(stream)) {
    617     return false;
    618   }
    619 
    620   ScopedCublasPointerMode pointer_mode{parent_, blas_};
    621   if (!pointer_mode.Init(pointer_mode_host ? CUBLAS_POINTER_MODE_HOST
    622                                            : CUBLAS_POINTER_MODE_DEVICE)) {
    623     return false;
    624   }
    625 #if CUDA_VERSION >= 9000
    626   ScopedCublasMathMode math_mode{parent_, blas_};
    627   if (use_tensor_op_math) {
    628     if (!math_mode.Init(CUBLAS_TENSOR_OP_MATH)) {
    629       return false;
    630     }
    631   }
    632 #endif
    633   cublasStatus_t ret = cublas_func(parent_, blas_, args...);
    634   if (err_on_failure && ret != CUBLAS_STATUS_SUCCESS) {
    635     LOG(ERROR) << "failed to run cuBLAS routine " << cublas_func.kName << ": "
    636                << ToString(ret);
    637   }
    638   return ret == CUBLAS_STATUS_SUCCESS;
    639 }
    640 
    641 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
    642                           const DeviceMemory<float> &x, int incx,
    643                           DeviceMemory<float> *result) {
    644   return DoBlasInternal(wrap::cublasSasum, stream,
    645                         false /* = pointer_mode_host */, elem_count,
    646                         CUDAMemory(x), incx, CUDAMemoryMutable(result));
    647 }
    648 
    649 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
    650                           const DeviceMemory<double> &x, int incx,
    651                           DeviceMemory<double> *result) {
    652   return DoBlasInternal(wrap::cublasDasum, stream,
    653                         false /* = pointer_mode_host */, elem_count,
    654                         CUDAMemory(x), incx, CUDAMemoryMutable(result));
    655 }
    656 
    657 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
    658                           const DeviceMemory<std::complex<float>> &x, int incx,
    659                           DeviceMemory<float> *result) {
    660   return DoBlasInternal(
    661       wrap::cublasScasum, stream, false /* = pointer_mode_host */, elem_count,
    662       CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
    663 }
    664 
    665 bool CUDABlas::DoBlasAsum(Stream *stream, uint64 elem_count,
    666                           const DeviceMemory<std::complex<double>> &x, int incx,
    667                           DeviceMemory<double> *result) {
    668   return DoBlasInternal(
    669       wrap::cublasDzasum, stream, false /* = pointer_mode_host */, elem_count,
    670       CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
    671 }
    672 
    673 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha,
    674                           const DeviceMemory<float> &x, int incx,
    675                           DeviceMemory<float> *y, int incy) {
    676   return DoBlasInternal(wrap::cublasSaxpy, stream,
    677                         true /* = pointer_mode_host */, elem_count, &alpha,
    678                         CUDAMemory(x), incx, CUDAMemoryMutable(y), incy);
    679 }
    680 
    681 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha,
    682                           const DeviceMemory<double> &x, int incx,
    683                           DeviceMemory<double> *y, int incy) {
    684   return DoBlasInternal(wrap::cublasDaxpy, stream,
    685                         true /* = pointer_mode_host */, elem_count, &alpha,
    686                         CUDAMemory(x), incx, CUDAMemoryMutable(y), incy);
    687 }
    688 
    689 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
    690                           std::complex<float> alpha,
    691                           const DeviceMemory<std::complex<float>> &x, int incx,
    692                           DeviceMemory<std::complex<float>> *y, int incy) {
    693   return DoBlasInternal(wrap::cublasCaxpy, stream,
    694                         true /* = pointer_mode_host */, elem_count,
    695                         CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx,
    696                         CUDAComplex(CUDAMemoryMutable(y)), incy);
    697 }
    698 
    699 bool CUDABlas::DoBlasAxpy(Stream *stream, uint64 elem_count,
    700                           std::complex<double> alpha,
    701                           const DeviceMemory<std::complex<double>> &x, int incx,
    702                           DeviceMemory<std::complex<double>> *y, int incy) {
    703   return DoBlasInternal(wrap::cublasZaxpy, stream,
    704                         true /* = pointer_mode_host */, elem_count,
    705                         CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx,
    706                         CUDAComplex(CUDAMemoryMutable(y)), incy);
    707 }
    708 
    709 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
    710                           const DeviceMemory<float> &x, int incx,
    711                           DeviceMemory<float> *y, int incy) {
    712   return DoBlasInternal(wrap::cublasScopy, stream,
    713                         true /* = pointer_mode_host */, elem_count,
    714                         CUDAMemory(x), incx, CUDAMemoryMutable(y), incy);
    715 }
    716 
    717 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
    718                           const DeviceMemory<double> &x, int incx,
    719                           DeviceMemory<double> *y, int incy) {
    720   return DoBlasInternal(wrap::cublasDcopy, stream,
    721                         true /* = pointer_mode_host */, elem_count,
    722                         CUDAMemory(x), incx, CUDAMemoryMutable(y), incy);
    723 }
    724 
    725 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
    726                           const DeviceMemory<std::complex<float>> &x, int incx,
    727                           DeviceMemory<std::complex<float>> *y, int incy) {
    728   return DoBlasInternal(wrap::cublasCcopy, stream,
    729                         true /* = pointer_mode_host */, elem_count,
    730                         CUDAComplex(CUDAMemory(x)), incx,
    731                         CUDAComplex(CUDAMemoryMutable(y)), incy);
    732 }
    733 
    734 bool CUDABlas::DoBlasCopy(Stream *stream, uint64 elem_count,
    735                           const DeviceMemory<std::complex<double>> &x, int incx,
    736                           DeviceMemory<std::complex<double>> *y, int incy) {
    737   return DoBlasInternal(wrap::cublasZcopy, stream,
    738                         true /* = pointer_mode_host */, elem_count,
    739                         CUDAComplex(CUDAMemory(x)), incx,
    740                         CUDAComplex(CUDAMemoryMutable(y)), incy);
    741 }
    742 
    743 bool CUDABlas::DoBlasDot(Stream *stream, uint64 elem_count,
    744                          const DeviceMemory<float> &x, int incx,
    745                          const DeviceMemory<float> &y, int incy,
    746                          DeviceMemory<float> *result) {
    747   return DoBlasInternal(
    748       wrap::cublasSdot, stream, false /* = pointer_mode_host */, elem_count,
    749       CUDAMemory(x), incx, CUDAMemory(y), incy, CUDAMemoryMutable(result));
    750 }
    751 
    752 bool CUDABlas::DoBlasDot(Stream *stream, uint64 elem_count,
    753                          const DeviceMemory<double> &x, int incx,
    754                          const DeviceMemory<double> &y, int incy,
    755                          DeviceMemory<double> *result) {
    756   return DoBlasInternal(
    757       wrap::cublasDdot, stream, false /* = pointer_mode_host */, elem_count,
    758       CUDAMemory(x), incx, CUDAMemory(y), incy, CUDAMemoryMutable(result));
    759 }
    760 
    761 bool CUDABlas::DoBlasDotc(Stream *stream, uint64 elem_count,
    762                           const DeviceMemory<std::complex<float>> &x, int incx,
    763                           const DeviceMemory<std::complex<float>> &y, int incy,
    764                           DeviceMemory<std::complex<float>> *result) {
    765   return DoBlasInternal(
    766       wrap::cublasCdotc, stream, false /* = pointer_mode_host */, elem_count,
    767       CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy,
    768       CUDAComplex(CUDAMemoryMutable(result)));
    769 }
    770 
    771 bool CUDABlas::DoBlasDotc(Stream *stream, uint64 elem_count,
    772                           const DeviceMemory<std::complex<double>> &x, int incx,
    773                           const DeviceMemory<std::complex<double>> &y, int incy,
    774                           DeviceMemory<std::complex<double>> *result) {
    775   return DoBlasInternal(
    776       wrap::cublasZdotc, stream, false /* = pointer_mode_host */, elem_count,
    777       CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy,
    778       CUDAComplex(CUDAMemoryMutable(result)));
    779 }
    780 
    781 bool CUDABlas::DoBlasDotu(Stream *stream, uint64 elem_count,
    782                           const DeviceMemory<std::complex<float>> &x, int incx,
    783                           const DeviceMemory<std::complex<float>> &y, int incy,
    784                           DeviceMemory<std::complex<float>> *result) {
    785   return DoBlasInternal(
    786       wrap::cublasCdotu, stream, false /* = pointer_mode_host */, elem_count,
    787       CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy,
    788       CUDAComplex(CUDAMemoryMutable(result)));
    789 }
    790 
    791 bool CUDABlas::DoBlasDotu(Stream *stream, uint64 elem_count,
    792                           const DeviceMemory<std::complex<double>> &x, int incx,
    793                           const DeviceMemory<std::complex<double>> &y, int incy,
    794                           DeviceMemory<std::complex<double>> *result) {
    795   return DoBlasInternal(
    796       wrap::cublasZdotu, stream, false /* = pointer_mode_host */, elem_count,
    797       CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy,
    798       CUDAComplex(CUDAMemoryMutable(result)));
    799 }
    800 
    801 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
    802                           const DeviceMemory<float> &x, int incx,
    803                           DeviceMemory<float> *result) {
    804   return DoBlasInternal(wrap::cublasSnrm2, stream,
    805                         false /* = pointer_mode_host */, elem_count,
    806                         CUDAMemory(x), incx, CUDAMemoryMutable(result));
    807 }
    808 
    809 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
    810                           const DeviceMemory<double> &x, int incx,
    811                           DeviceMemory<double> *result) {
    812   return DoBlasInternal(wrap::cublasDnrm2, stream,
    813                         false /* = pointer_mode_host */, elem_count,
    814                         CUDAMemory(x), incx, CUDAMemoryMutable(result));
    815 }
    816 
    817 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
    818                           const DeviceMemory<std::complex<float>> &x, int incx,
    819                           DeviceMemory<float> *result) {
    820   return DoBlasInternal(
    821       wrap::cublasScnrm2, stream, false /* = pointer_mode_host */, elem_count,
    822       CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
    823 }
    824 
    825 bool CUDABlas::DoBlasNrm2(Stream *stream, uint64 elem_count,
    826                           const DeviceMemory<std::complex<double>> &x, int incx,
    827                           DeviceMemory<double> *result) {
    828   return DoBlasInternal(
    829       wrap::cublasDznrm2, stream, false /* = pointer_mode_host */, elem_count,
    830       CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
    831 }
    832 
    833 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
    834                          DeviceMemory<float> *x, int incx,
    835                          DeviceMemory<float> *y, int incy, float c, float s) {
    836   return DoBlasInternal(
    837       wrap::cublasSrot, stream, true /* = pointer_mode_host */, elem_count,
    838       CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy, &c, &s);
    839 }
    840 
    841 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
    842                          DeviceMemory<double> *x, int incx,
    843                          DeviceMemory<double> *y, int incy, double c,
    844                          double s) {
    845   return DoBlasInternal(
    846       wrap::cublasDrot, stream, true /* = pointer_mode_host */, elem_count,
    847       CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy, &c, &s);
    848 }
    849 
    850 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
    851                          DeviceMemory<std::complex<float>> *x, int incx,
    852                          DeviceMemory<std::complex<float>> *y, int incy,
    853                          float c, float s) {
    854   return DoBlasInternal(wrap::cublasCsrot, stream,
    855                         true /* = pointer_mode_host */, elem_count,
    856                         CUDAComplex(CUDAMemoryMutable(x)), incx,
    857                         CUDAComplex(CUDAMemoryMutable(y)), incy, &c, &s);
    858 }
    859 
    860 bool CUDABlas::DoBlasRot(Stream *stream, uint64 elem_count,
    861                          DeviceMemory<std::complex<double>> *x, int incx,
    862                          DeviceMemory<std::complex<double>> *y, int incy,
    863                          double c, double s) {
    864   return DoBlasInternal(wrap::cublasZdrot, stream,
    865                         true /* = pointer_mode_host */, elem_count,
    866                         CUDAComplex(CUDAMemoryMutable(x)), incx,
    867                         CUDAComplex(CUDAMemoryMutable(y)), incy, &c, &s);
    868 }
    869 
    870 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<float> *a,
    871                           DeviceMemory<float> *b, DeviceMemory<float> *c,
    872                           DeviceMemory<float> *s) {
    873   return DoBlasInternal(wrap::cublasSrotg, stream,
    874                         false /* = pointer_mode_host */, CUDAMemoryMutable(a),
    875                         CUDAMemoryMutable(b), CUDAMemoryMutable(c),
    876                         CUDAMemoryMutable(s));
    877 }
    878 
    879 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<double> *a,
    880                           DeviceMemory<double> *b, DeviceMemory<double> *c,
    881                           DeviceMemory<double> *s) {
    882   return DoBlasInternal(wrap::cublasDrotg, stream,
    883                         false /* = pointer_mode_host */,
    884                         CUDAComplex(CUDAMemoryMutable(a)), CUDAMemoryMutable(b),
    885                         CUDAMemoryMutable(c), CUDAMemoryMutable(s));
    886 }
    887 
    888 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a,
    889                           DeviceMemory<std::complex<float>> *b,
    890                           DeviceMemory<float> *c,
    891                           DeviceMemory<std::complex<float>> *s) {
    892   return DoBlasInternal(
    893       wrap::cublasCrotg, stream, false /* = pointer_mode_host */,
    894       CUDAComplex(CUDAMemoryMutable(a)), CUDAComplex(CUDAMemoryMutable(b)),
    895       CUDAComplex(CUDAMemoryMutable(c)), CUDAComplex(CUDAMemoryMutable(s)));
    896 }
    897 
    898 bool CUDABlas::DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a,
    899                           DeviceMemory<std::complex<double>> *b,
    900                           DeviceMemory<double> *c,
    901                           DeviceMemory<std::complex<double>> *s) {
    902   return DoBlasInternal(
    903       wrap::cublasZrotg, stream, false /* = pointer_mode_host */,
    904       CUDAComplex(CUDAMemoryMutable(a)), CUDAComplex(CUDAMemoryMutable(b)),
    905       CUDAComplex(CUDAMemoryMutable(c)), CUDAComplex(CUDAMemoryMutable(s)));
    906 }
    907 
    908 bool CUDABlas::DoBlasRotm(Stream *stream, uint64 elem_count,
    909                           DeviceMemory<float> *x, int incx,
    910                           DeviceMemory<float> *y, int incy,
    911                           const DeviceMemory<float> &param) {
    912   return DoBlasInternal(wrap::cublasSrotm, stream,
    913                         false /* = pointer_mode_host */, elem_count,
    914                         CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy,
    915                         CUDAMemory(param));
    916 }
    917 
    918 bool CUDABlas::DoBlasRotm(Stream *stream, uint64 elem_count,
    919                           DeviceMemory<double> *x, int incx,
    920                           DeviceMemory<double> *y, int incy,
    921                           const DeviceMemory<double> &param) {
    922   return DoBlasInternal(wrap::cublasDrotm, stream,
    923                         false /* = pointer_mode_host */, elem_count,
    924                         CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy,
    925                         CUDAMemory(param));
    926 }
    927 
    928 bool CUDABlas::DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1,
    929                            DeviceMemory<float> *d2, DeviceMemory<float> *x1,
    930                            const DeviceMemory<float> &y1,
    931                            DeviceMemory<float> *param) {
    932   return DoBlasInternal(wrap::cublasSrotmg, stream,
    933                         false /* = pointer_mode_host */, CUDAMemoryMutable(d1),
    934                         CUDAMemoryMutable(d2), CUDAMemoryMutable(x1),
    935                         CUDAMemory(y1), CUDAMemoryMutable(param));
    936 }
    937 
    938 bool CUDABlas::DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1,
    939                            DeviceMemory<double> *d2, DeviceMemory<double> *x1,
    940                            const DeviceMemory<double> &y1,
    941                            DeviceMemory<double> *param) {
    942   return DoBlasInternal(wrap::cublasDrotmg, stream,
    943                         false /* = pointer_mode_host */, CUDAMemoryMutable(d1),
    944                         CUDAMemoryMutable(d2), CUDAMemoryMutable(x1),
    945                         CUDAMemory(y1), CUDAMemoryMutable(param));
    946 }
    947 
    948 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
    949                           DeviceMemory<float> *x, int incx) {
    950   return DoBlasInternal(wrap::cublasSscal, stream,
    951                         true /* = pointer_mode_host */, elem_count, &alpha,
    952                         CUDAMemoryMutable(x), incx);
    953 }
    954 
    955 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
    956                           DeviceMemory<double> *x, int incx) {
    957   return DoBlasInternal(wrap::cublasDscal, stream,
    958                         true /* = pointer_mode_host */, elem_count, &alpha,
    959                         CUDAMemoryMutable(x), incx);
    960 }
    961 
    962 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha,
    963                           DeviceMemory<std::complex<float>> *x, int incx) {
    964   return DoBlasInternal(
    965       wrap::cublasCsscal, stream, true /* = pointer_mode_host */, elem_count,
    966       CUDAComplex(&alpha), CUDAComplex(CUDAMemoryMutable(x)), incx);
    967 }
    968 
    969 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha,
    970                           DeviceMemory<std::complex<double>> *x, int incx) {
    971   return DoBlasInternal(
    972       wrap::cublasZdscal, stream, true /* = pointer_mode_host */, elem_count,
    973       CUDAComplex(&alpha), CUDAComplex(CUDAMemoryMutable(x)), incx);
    974 }
    975 
    976 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count,
    977                           std::complex<float> alpha,
    978                           DeviceMemory<std::complex<float>> *x, int incx) {
    979   return DoBlasInternal(
    980       wrap::cublasCscal, stream, true /* = pointer_mode_host */, elem_count,
    981       CUDAComplex(&alpha), CUDAComplex(CUDAMemoryMutable(x)), incx);
    982 }
    983 
    984 bool CUDABlas::DoBlasScal(Stream *stream, uint64 elem_count,
    985                           std::complex<double> alpha,
    986                           DeviceMemory<std::complex<double>> *x, int incx) {
    987   return DoBlasInternal(
    988       wrap::cublasZscal, stream, true /* = pointer_mode_host */, elem_count,
    989       CUDAComplex(&alpha), CUDAComplex(CUDAMemoryMutable(x)), incx);
    990 }
    991 
    992 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
    993                           DeviceMemory<float> *x, int incx,
    994                           DeviceMemory<float> *y, int incy) {
    995   return DoBlasInternal(wrap::cublasSswap, stream,
    996                         true /* = pointer_mode_host */, elem_count,
    997                         CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy);
    998 }
    999 
   1000 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
   1001                           DeviceMemory<double> *x, int incx,
   1002                           DeviceMemory<double> *y, int incy) {
   1003   return DoBlasInternal(wrap::cublasDswap, stream,
   1004                         true /* = pointer_mode_host */, elem_count,
   1005                         CUDAMemoryMutable(x), incx, CUDAMemoryMutable(y), incy);
   1006 }
   1007 
   1008 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
   1009                           DeviceMemory<std::complex<float>> *x, int incx,
   1010                           DeviceMemory<std::complex<float>> *y, int incy) {
   1011   return DoBlasInternal(wrap::cublasCswap, stream,
   1012                         true /* = pointer_mode_host */, elem_count,
   1013                         CUDAComplex(CUDAMemoryMutable(x)), incx,
   1014                         CUDAComplex(CUDAMemoryMutable(y)), incy);
   1015 }
   1016 
   1017 bool CUDABlas::DoBlasSwap(Stream *stream, uint64 elem_count,
   1018                           DeviceMemory<std::complex<double>> *x, int incx,
   1019                           DeviceMemory<std::complex<double>> *y, int incy) {
   1020   return DoBlasInternal(wrap::cublasZswap, stream,
   1021                         true /* = pointer_mode_host */, elem_count,
   1022                         CUDAComplex(CUDAMemoryMutable(x)), incx,
   1023                         CUDAComplex(CUDAMemoryMutable(y)), incy);
   1024 }
   1025 
   1026 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
   1027                            const DeviceMemory<float> &x, int incx,
   1028                            DeviceMemory<int> *result) {
   1029   return DoBlasInternal(wrap::cublasIsamax, stream,
   1030                         false /* = pointer_mode_host */, elem_count,
   1031                         CUDAMemory(x), incx, CUDAMemoryMutable(result));
   1032 }
   1033 
   1034 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
   1035                            const DeviceMemory<double> &x, int incx,
   1036                            DeviceMemory<int> *result) {
   1037   return DoBlasInternal(wrap::cublasIdamax, stream,
   1038                         false /* = pointer_mode_host */, elem_count,
   1039                         CUDAMemory(x), incx, CUDAMemoryMutable(result));
   1040 }
   1041 
   1042 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
   1043                            const DeviceMemory<std::complex<float>> &x, int incx,
   1044                            DeviceMemory<int> *result) {
   1045   return DoBlasInternal(
   1046       wrap::cublasIcamax, stream, false /* = pointer_mode_host */, elem_count,
   1047       CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
   1048 }
   1049 
   1050 bool CUDABlas::DoBlasIamax(Stream *stream, uint64 elem_count,
   1051                            const DeviceMemory<std::complex<double>> &x,
   1052                            int incx, DeviceMemory<int> *result) {
   1053   return DoBlasInternal(
   1054       wrap::cublasIzamax, stream, false /* = pointer_mode_host */, elem_count,
   1055       CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
   1056 }
   1057 
   1058 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
   1059                            const DeviceMemory<float> &x, int incx,
   1060                            DeviceMemory<int> *result) {
   1061   return DoBlasInternal(
   1062       wrap::cublasIsamin, stream, false /* = pointer_mode_host */, elem_count,
   1063       CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
   1064 }
   1065 
   1066 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
   1067                            const DeviceMemory<double> &x, int incx,
   1068                            DeviceMemory<int> *result) {
   1069   return DoBlasInternal(
   1070       wrap::cublasIdamin, stream, false /* = pointer_mode_host */, elem_count,
   1071       CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
   1072 }
   1073 
   1074 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
   1075                            const DeviceMemory<std::complex<float>> &x, int incx,
   1076                            DeviceMemory<int> *result) {
   1077   return DoBlasInternal(
   1078       wrap::cublasIcamin, stream, false /* = pointer_mode_host */, elem_count,
   1079       CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
   1080 }
   1081 
   1082 bool CUDABlas::DoBlasIamin(Stream *stream, uint64 elem_count,
   1083                            const DeviceMemory<std::complex<double>> &x,
   1084                            int incx, DeviceMemory<int> *result) {
   1085   return DoBlasInternal(
   1086       wrap::cublasIzamin, stream, false /* = pointer_mode_host */, elem_count,
   1087       CUDAComplex(CUDAMemory(x)), incx, CUDAMemoryMutable(result));
   1088 }
   1089 
   1090 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
   1091                           uint64 n, uint64 kl, uint64 ku, float alpha,
   1092                           const DeviceMemory<float> &a, int lda,
   1093                           const DeviceMemory<float> &x, int incx, float beta,
   1094                           DeviceMemory<float> *y, int incy) {
   1095   return DoBlasInternal(
   1096       wrap::cublasSgbmv, stream, true /* = pointer_mode_host */,
   1097       CUDABlasTranspose(trans), m, n, kl, ku, &alpha, CUDAMemory(a), lda,
   1098       CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy);
   1099 }
   1100 
   1101 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
   1102                           uint64 n, uint64 kl, uint64 ku, double alpha,
   1103                           const DeviceMemory<double> &a, int lda,
   1104                           const DeviceMemory<double> &x, int incx, double beta,
   1105                           DeviceMemory<double> *y, int incy) {
   1106   return DoBlasInternal(
   1107       wrap::cublasDgbmv, stream, true /* = pointer_mode_host */,
   1108       CUDABlasTranspose(trans), m, n, kl, ku, &alpha, CUDAMemory(a), lda,
   1109       CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy);
   1110 }
   1111 
   1112 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
   1113                           uint64 n, uint64 kl, uint64 ku,
   1114                           std::complex<float> alpha,
   1115                           const DeviceMemory<std::complex<float>> &a, int lda,
   1116                           const DeviceMemory<std::complex<float>> &x, int incx,
   1117                           std::complex<float> beta,
   1118                           DeviceMemory<std::complex<float>> *y, int incy) {
   1119   return DoBlasInternal(
   1120       wrap::cublasCgbmv, stream, true /* = pointer_mode_host */,
   1121       CUDABlasTranspose(trans), m, n, kl, ku, CUDAComplex(&alpha),
   1122       CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx,
   1123       CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
   1124 }
   1125 
   1126 bool CUDABlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m,
   1127                           uint64 n, uint64 kl, uint64 ku,
   1128                           std::complex<double> alpha,
   1129                           const DeviceMemory<std::complex<double>> &a, int lda,
   1130                           const DeviceMemory<std::complex<double>> &x, int incx,
   1131                           std::complex<double> beta,
   1132                           DeviceMemory<std::complex<double>> *y, int incy) {
   1133   return DoBlasInternal(
   1134       wrap::cublasZgbmv, stream, true /* = pointer_mode_host */,
   1135       CUDABlasTranspose(trans), m, n, kl, ku, CUDAComplex(&alpha),
   1136       CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx,
   1137       CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
   1138 }
   1139 
   1140 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
   1141                           uint64 n, float alpha, const DeviceMemory<float> &a,
   1142                           int lda, const DeviceMemory<float> &x, int incx,
   1143                           float beta, DeviceMemory<float> *y, int incy) {
   1144   return DoBlasInternal(
   1145       wrap::cublasSgemv, stream, true /* = pointer_mode_host */,
   1146       CUDABlasTranspose(trans), m, n, &alpha, CUDAMemory(a), lda, CUDAMemory(x),
   1147       incx, &beta, CUDAMemoryMutable(y), incy);
   1148 }
   1149 
   1150 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
   1151                           uint64 n, double alpha, const DeviceMemory<double> &a,
   1152                           int lda, const DeviceMemory<double> &x, int incx,
   1153                           double beta, DeviceMemory<double> *y, int incy) {
   1154   return DoBlasInternal(
   1155       wrap::cublasDgemv, stream, true /* = pointer_mode_host */,
   1156       CUDABlasTranspose(trans), m, n, &alpha, CUDAMemory(a), lda, CUDAMemory(x),
   1157       incx, &beta, CUDAMemoryMutable(y), incy);
   1158 }
   1159 
   1160 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
   1161                           uint64 n, std::complex<float> alpha,
   1162                           const DeviceMemory<std::complex<float>> &a, int lda,
   1163                           const DeviceMemory<std::complex<float>> &x, int incx,
   1164                           std::complex<float> beta,
   1165                           DeviceMemory<std::complex<float>> *y, int incy) {
   1166   return DoBlasInternal(
   1167       wrap::cublasCgemv, stream, true /* = pointer_mode_host */,
   1168       CUDABlasTranspose(trans), m, n, CUDAComplex(&alpha),
   1169       CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx,
   1170       CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
   1171 }
   1172 
   1173 bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m,
   1174                           uint64 n, std::complex<double> alpha,
   1175                           const DeviceMemory<std::complex<double>> &a, int lda,
   1176                           const DeviceMemory<std::complex<double>> &x, int incx,
   1177                           std::complex<double> beta,
   1178                           DeviceMemory<std::complex<double>> *y, int incy) {
   1179   return DoBlasInternal(
   1180       wrap::cublasZgemv, stream, true /* = pointer_mode_host */,
   1181       CUDABlasTranspose(trans), m, n, CUDAComplex(&alpha),
   1182       CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx,
   1183       CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
   1184 }
   1185 
   1186 bool CUDABlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha,
   1187                          const DeviceMemory<float> &x, int incx,
   1188                          const DeviceMemory<float> &y, int incy,
   1189                          DeviceMemory<float> *a, int lda) {
   1190   return DoBlasInternal(
   1191       wrap::cublasSger, stream, true /* = pointer_mode_host */, m, n, &alpha,
   1192       CUDAMemory(x), incx, CUDAMemory(y), incy, CUDAMemoryMutable(a), lda);
   1193 }
   1194 
   1195 bool CUDABlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha,
   1196                          const DeviceMemory<double> &x, int incx,
   1197                          const DeviceMemory<double> &y, int incy,
   1198                          DeviceMemory<double> *a, int lda) {
   1199   return DoBlasInternal(
   1200       wrap::cublasDger, stream, true /* = pointer_mode_host */, m, n, &alpha,
   1201       CUDAMemory(x), incx, CUDAMemory(y), incy, CUDAMemoryMutable(a), lda);
   1202 }
   1203 
   1204 bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
   1205                           std::complex<float> alpha,
   1206                           const DeviceMemory<std::complex<float>> &x, int incx,
   1207                           const DeviceMemory<std::complex<float>> &y, int incy,
   1208                           DeviceMemory<std::complex<float>> *a, int lda) {
   1209   return DoBlasInternal(
   1210       wrap::cublasCgerc, stream, true /* = pointer_mode_host */, m, n,
   1211       CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx,
   1212       CUDAComplex(CUDAMemory(y)), incy, CUDAComplex(CUDAMemoryMutable(a)), lda);
   1213 }
   1214 
   1215 bool CUDABlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n,
   1216                           std::complex<double> alpha,
   1217                           const DeviceMemory<std::complex<double>> &x, int incx,
   1218                           const DeviceMemory<std::complex<double>> &y, int incy,
   1219                           DeviceMemory<std::complex<double>> *a, int lda) {
   1220   return DoBlasInternal(
   1221       wrap::cublasZgerc, stream, true /* = pointer_mode_host */, m, n,
   1222       CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx,
   1223       CUDAComplex(CUDAMemory(y)), incy, CUDAComplex(CUDAMemoryMutable(a)), lda);
   1224 }
   1225 
   1226 bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
   1227                           std::complex<float> alpha,
   1228                           const DeviceMemory<std::complex<float>> &x, int incx,
   1229                           const DeviceMemory<std::complex<float>> &y, int incy,
   1230                           DeviceMemory<std::complex<float>> *a, int lda) {
   1231   return DoBlasInternal(
   1232       wrap::cublasCgeru, stream, true /* = pointer_mode_host */, m, n,
   1233       CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx,
   1234       CUDAComplex(CUDAMemory(y)), incy, CUDAComplex(CUDAMemoryMutable(a)), lda);
   1235 }
   1236 
   1237 bool CUDABlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n,
   1238                           std::complex<double> alpha,
   1239                           const DeviceMemory<std::complex<double>> &x, int incx,
   1240                           const DeviceMemory<std::complex<double>> &y, int incy,
   1241                           DeviceMemory<std::complex<double>> *a, int lda) {
   1242   return DoBlasInternal(
   1243       wrap::cublasZgeru, stream, true /* = pointer_mode_host */, m, n,
   1244       CUDAComplex(&alpha), CUDAComplex(CUDAMemory(x)), incx,
   1245       CUDAComplex(CUDAMemory(y)), incy, CUDAComplex(CUDAMemoryMutable(a)), lda);
   1246 }
   1247 
   1248 bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
   1249                           uint64 k, std::complex<float> alpha,
   1250                           const DeviceMemory<std::complex<float>> &a, int lda,
   1251                           const DeviceMemory<std::complex<float>> &x, int incx,
   1252                           std::complex<float> beta,
   1253                           DeviceMemory<std::complex<float>> *y, int incy) {
   1254   return DoBlasInternal(
   1255       wrap::cublasChbmv, stream, true /* = pointer_mode_host */,
   1256       CUDABlasUpperLower(uplo), n, k, CUDAComplex(&alpha),
   1257       CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx,
   1258       CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
   1259 }
   1260 
   1261 bool CUDABlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
   1262                           uint64 k, std::complex<double> alpha,
   1263                           const DeviceMemory<std::complex<double>> &a, int lda,
   1264                           const DeviceMemory<std::complex<double>> &x, int incx,
   1265                           std::complex<double> beta,
   1266                           DeviceMemory<std::complex<double>> *y, int incy) {
   1267   return DoBlasInternal(
   1268       wrap::cublasZhbmv, stream, true /* = pointer_mode_host */,
   1269       CUDABlasUpperLower(uplo), n, k, CUDAComplex(&alpha),
   1270       CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx,
   1271       CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
   1272 }
   1273 
   1274 bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
   1275                           std::complex<float> alpha,
   1276                           const DeviceMemory<std::complex<float>> &a, int lda,
   1277                           const DeviceMemory<std::complex<float>> &x, int incx,
   1278                           std::complex<float> beta,
   1279                           DeviceMemory<std::complex<float>> *y, int incy) {
   1280   return DoBlasInternal(
   1281       wrap::cublasChemv, stream, true /* = pointer_mode_host */,
   1282       CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
   1283       CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx,
   1284       CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
   1285 }
   1286 
   1287 bool CUDABlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n,
   1288                           std::complex<double> alpha,
   1289                           const DeviceMemory<std::complex<double>> &a, int lda,
   1290                           const DeviceMemory<std::complex<double>> &x, int incx,
   1291                           std::complex<double> beta,
   1292                           DeviceMemory<std::complex<double>> *y, int incy) {
   1293   return DoBlasInternal(
   1294       wrap::cublasZhemv, stream, true /* = pointer_mode_host */,
   1295       CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
   1296       CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(x)), incx,
   1297       CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
   1298 }
   1299 
   1300 bool CUDABlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
   1301                          float alpha,
   1302                          const DeviceMemory<std::complex<float>> &x, int incx,
   1303                          DeviceMemory<std::complex<float>> *a, int lda) {
   1304   return DoBlasInternal(
   1305       wrap::cublasCher, stream, true /* = pointer_mode_host */,
   1306       CUDABlasUpperLower(uplo), n, &alpha, CUDAComplex(CUDAMemory(x)), incx,
   1307       CUDAComplex(CUDAMemoryMutable(a)), lda);
   1308 }
   1309 
   1310 bool CUDABlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n,
   1311                          double alpha,
   1312                          const DeviceMemory<std::complex<double>> &x, int incx,
   1313                          DeviceMemory<std::complex<double>> *a, int lda) {
   1314   return DoBlasInternal(
   1315       wrap::cublasZher, stream, true /* = pointer_mode_host */,
   1316       CUDABlasUpperLower(uplo), n, &alpha, CUDAComplex(CUDAMemory(x)), incx,
   1317       CUDAComplex(CUDAMemoryMutable(a)), lda);
   1318 }
   1319 
   1320 bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
   1321                           std::complex<float> alpha,
   1322                           const DeviceMemory<std::complex<float>> &x, int incx,
   1323                           const DeviceMemory<std::complex<float>> &y, int incy,
   1324                           DeviceMemory<std::complex<float>> *a, int lda) {
   1325   return DoBlasInternal(
   1326       wrap::cublasCher2, stream, true /* = pointer_mode_host */,
   1327       CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
   1328       CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy,
   1329       CUDAComplex(CUDAMemoryMutable(a)), lda);
   1330 }
   1331 
   1332 bool CUDABlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n,
   1333                           std::complex<double> alpha,
   1334                           const DeviceMemory<std::complex<double>> &x, int incx,
   1335                           const DeviceMemory<std::complex<double>> &y, int incy,
   1336                           DeviceMemory<std::complex<double>> *a, int lda) {
   1337   return DoBlasInternal(
   1338       wrap::cublasZher2, stream, true /* = pointer_mode_host */,
   1339       CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
   1340       CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy,
   1341       CUDAComplex(CUDAMemoryMutable(a)), lda);
   1342 }
   1343 
   1344 bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
   1345                           std::complex<float> alpha,
   1346                           const DeviceMemory<std::complex<float>> &ap,
   1347                           const DeviceMemory<std::complex<float>> &x, int incx,
   1348                           std::complex<float> beta,
   1349                           DeviceMemory<std::complex<float>> *y, int incy) {
   1350   return DoBlasInternal(
   1351       wrap::cublasChpmv, stream, true /* = pointer_mode_host */,
   1352       CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
   1353       CUDAComplex(CUDAMemory(ap)), CUDAComplex(CUDAMemory(x)), incx,
   1354       CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
   1355 }
   1356 
   1357 bool CUDABlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
   1358                           std::complex<double> alpha,
   1359                           const DeviceMemory<std::complex<double>> &ap,
   1360                           const DeviceMemory<std::complex<double>> &x, int incx,
   1361                           std::complex<double> beta,
   1362                           DeviceMemory<std::complex<double>> *y, int incy) {
   1363   return DoBlasInternal(
   1364       wrap::cublasZhpmv, stream, true /* = pointer_mode_host */,
   1365       CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
   1366       CUDAComplex(CUDAMemory(ap)), CUDAComplex(CUDAMemory(x)), incx,
   1367       CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(y)), incy);
   1368 }
   1369 
   1370 bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
   1371                          float alpha,
   1372                          const DeviceMemory<std::complex<float>> &x, int incx,
   1373                          DeviceMemory<std::complex<float>> *ap) {
   1374   return DoBlasInternal(
   1375       wrap::cublasChpr, stream, true /* = pointer_mode_host */,
   1376       CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
   1377       CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemoryMutable(ap)));
   1378 }
   1379 
   1380 bool CUDABlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n,
   1381                          double alpha,
   1382                          const DeviceMemory<std::complex<double>> &x, int incx,
   1383                          DeviceMemory<std::complex<double>> *ap) {
   1384   return DoBlasInternal(
   1385       wrap::cublasZhpr, stream, true /* = pointer_mode_host */,
   1386       CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
   1387       CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemoryMutable(ap)));
   1388 }
   1389 
   1390 bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
   1391                           std::complex<float> alpha,
   1392                           const DeviceMemory<std::complex<float>> &x, int incx,
   1393                           const DeviceMemory<std::complex<float>> &y, int incy,
   1394                           DeviceMemory<std::complex<float>> *ap) {
   1395   return DoBlasInternal(
   1396       wrap::cublasChpr2, stream, true /* = pointer_mode_host */,
   1397       CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
   1398       CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy,
   1399       CUDAComplex(CUDAMemoryMutable(ap)));
   1400 }
   1401 
   1402 bool CUDABlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
   1403                           std::complex<double> alpha,
   1404                           const DeviceMemory<std::complex<double>> &x, int incx,
   1405                           const DeviceMemory<std::complex<double>> &y, int incy,
   1406                           DeviceMemory<std::complex<double>> *ap) {
   1407   return DoBlasInternal(
   1408       wrap::cublasZhpr2, stream, true /* = pointer_mode_host */,
   1409       CUDABlasUpperLower(uplo), n, CUDAComplex(&alpha),
   1410       CUDAComplex(CUDAMemory(x)), incx, CUDAComplex(CUDAMemory(y)), incy,
   1411       CUDAComplex(CUDAMemoryMutable(ap)));
   1412 }
   1413 
   1414 bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
   1415                           uint64 k, float alpha, const DeviceMemory<float> &a,
   1416                           int lda, const DeviceMemory<float> &x, int incx,
   1417                           float beta, DeviceMemory<float> *y, int incy) {
   1418   return DoBlasInternal(
   1419       wrap::cublasSsbmv, stream, true /* = pointer_mode_host */,
   1420       CUDABlasUpperLower(uplo), n, k, &alpha, CUDAMemory(a), lda, CUDAMemory(x),
   1421       incx, &beta, CUDAMemoryMutable(y), incy);
   1422 }
   1423 
   1424 bool CUDABlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n,
   1425                           uint64 k, double alpha, const DeviceMemory<double> &a,
   1426                           int lda, const DeviceMemory<double> &x, int incx,
   1427                           double beta, DeviceMemory<double> *y, int incy) {
   1428   return DoBlasInternal(
   1429       wrap::cublasDsbmv, stream, true /* = pointer_mode_host */,
   1430       CUDABlasUpperLower(uplo), n, k, &alpha, CUDAMemory(a), lda, CUDAMemory(x),
   1431       incx, &beta, CUDAMemoryMutable(y), incy);
   1432 }
   1433 
   1434 bool CUDABlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
   1435                           float alpha, const DeviceMemory<float> &ap,
   1436                           const DeviceMemory<float> &x, int incx, float beta,
   1437                           DeviceMemory<float> *y, int incy) {
   1438   return DoBlasInternal(wrap::cublasSspmv, stream,
   1439                         true /* = pointer_mode_host */,
   1440                         CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(ap),
   1441                         CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy);
   1442 }
   1443 
   1444 bool CUDABlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n,
   1445                           double alpha, const DeviceMemory<double> &ap,
   1446                           const DeviceMemory<double> &x, int incx, double beta,
   1447                           DeviceMemory<double> *y, int incy) {
   1448   return DoBlasInternal(wrap::cublasDspmv, stream,
   1449                         true /* = pointer_mode_host */,
   1450                         CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(ap),
   1451                         CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy);
   1452 }
   1453 
   1454 bool CUDABlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
   1455                          float alpha, const DeviceMemory<float> &x, int incx,
   1456                          DeviceMemory<float> *ap) {
   1457   return DoBlasInternal(wrap::cublasSspr, stream,
   1458                         true /* = pointer_mode_host */,
   1459                         CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x),
   1460                         incx, CUDAMemoryMutable(ap));
   1461 }
   1462 
   1463 bool CUDABlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n,
   1464                          double alpha, const DeviceMemory<double> &x, int incx,
   1465                          DeviceMemory<double> *ap) {
   1466   return DoBlasInternal(wrap::cublasDspr, stream,
   1467                         true /* = pointer_mode_host */,
   1468                         CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x),
   1469                         incx, CUDAMemoryMutable(ap));
   1470 }
   1471 
   1472 bool CUDABlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
   1473                           float alpha, const DeviceMemory<float> &x, int incx,
   1474                           const DeviceMemory<float> &y, int incy,
   1475                           DeviceMemory<float> *ap) {
   1476   return DoBlasInternal(wrap::cublasSspr2, stream,
   1477                         true /* = pointer_mode_host */,
   1478                         CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x),
   1479                         incx, CUDAMemory(y), incy, CUDAMemoryMutable(ap));
   1480 }
   1481 
   1482 bool CUDABlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n,
   1483                           double alpha, const DeviceMemory<double> &x, int incx,
   1484                           const DeviceMemory<double> &y, int incy,
   1485                           DeviceMemory<double> *ap) {
   1486   return DoBlasInternal(wrap::cublasDspr2, stream,
   1487                         true /* = pointer_mode_host */,
   1488                         CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x),
   1489                         incx, CUDAMemory(y), incy, CUDAMemoryMutable(ap));
   1490 }
   1491 
   1492 bool CUDABlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
   1493                           float alpha, const DeviceMemory<float> &a, int lda,
   1494                           const DeviceMemory<float> &x, int incx, float beta,
   1495                           DeviceMemory<float> *y, int incy) {
   1496   return DoBlasInternal(wrap::cublasSsymv, stream,
   1497                         true /* = pointer_mode_host */,
   1498                         CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(a), lda,
   1499                         CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy);
   1500 }
   1501 
   1502 bool CUDABlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n,
   1503                           double alpha, const DeviceMemory<double> &a, int lda,
   1504                           const DeviceMemory<double> &x, int incx, double beta,
   1505                           DeviceMemory<double> *y, int incy) {
   1506   return DoBlasInternal(wrap::cublasDsymv, stream,
   1507                         true /* = pointer_mode_host */,
   1508                         CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(a), lda,
   1509                         CUDAMemory(x), incx, &beta, CUDAMemoryMutable(y), incy);
   1510 }
   1511 
   1512 bool CUDABlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
   1513                          float alpha, const DeviceMemory<float> &x, int incx,
   1514                          DeviceMemory<float> *a, int lda) {
   1515   return DoBlasInternal(wrap::cublasSsyr, stream,
   1516                         true /* = pointer_mode_host */,
   1517                         CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x),
   1518                         incx, CUDAMemoryMutable(a), lda);
   1519 }
   1520 
   1521 bool CUDABlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n,
   1522                          double alpha, const DeviceMemory<double> &x, int incx,
   1523                          DeviceMemory<double> *a, int lda) {
   1524   return DoBlasInternal(wrap::cublasDsyr, stream,
   1525                         true /* = pointer_mode_host */,
   1526                         CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x),
   1527                         incx, CUDAMemoryMutable(a), lda);
   1528 }
   1529 
   1530 bool CUDABlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
   1531                           float alpha, const DeviceMemory<float> &x, int incx,
   1532                           const DeviceMemory<float> &y, int incy,
   1533                           DeviceMemory<float> *a, int lda) {
   1534   return DoBlasInternal(wrap::cublasSsyr2, stream,
   1535                         true /* = pointer_mode_host */,
   1536                         CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x),
   1537                         incx, CUDAMemory(y), incy, CUDAMemoryMutable(a), lda);
   1538 }
   1539 
   1540 bool CUDABlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n,
   1541                           double alpha, const DeviceMemory<double> &x, int incx,
   1542                           const DeviceMemory<double> &y, int incy,
   1543                           DeviceMemory<double> *a, int lda) {
   1544   return DoBlasInternal(wrap::cublasDsyr2, stream,
   1545                         true /* = pointer_mode_host */,
   1546                         CUDABlasUpperLower(uplo), n, &alpha, CUDAMemory(x),
   1547                         incx, CUDAMemory(y), incy, CUDAMemoryMutable(a), lda);
   1548 }
   1549 
   1550 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
   1551                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1552                           uint64 k, const DeviceMemory<float> &a, int lda,
   1553                           DeviceMemory<float> *x, int incx) {
   1554   return DoBlasInternal(wrap::cublasStbmv, stream,
   1555                         true /* = pointer_mode_host */,
   1556                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1557                         CUDABlasDiagonal(diag), n, k, CUDAMemory(a), lda,
   1558                         CUDAMemoryMutable(x), incx);
   1559 }
   1560 
   1561 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
   1562                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1563                           uint64 k, const DeviceMemory<double> &a, int lda,
   1564                           DeviceMemory<double> *x, int incx) {
   1565   return DoBlasInternal(wrap::cublasDtbmv, stream,
   1566                         true /* = pointer_mode_host */,
   1567                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1568                         CUDABlasDiagonal(diag), n, k, CUDAMemory(a), lda,
   1569                         CUDAMemoryMutable(x), incx);
   1570 }
   1571 
   1572 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
   1573                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1574                           uint64 k, const DeviceMemory<std::complex<float>> &a,
   1575                           int lda, DeviceMemory<std::complex<float>> *x,
   1576                           int incx) {
   1577   return DoBlasInternal(
   1578       wrap::cublasCtbmv, stream, true /* = pointer_mode_host */,
   1579       CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1580       CUDABlasDiagonal(diag), n, k, CUDAComplex(CUDAMemory(a)), lda,
   1581       CUDAComplex(CUDAMemoryMutable(x)), incx);
   1582 }
   1583 
   1584 bool CUDABlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo,
   1585                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1586                           uint64 k, const DeviceMemory<std::complex<double>> &a,
   1587                           int lda, DeviceMemory<std::complex<double>> *x,
   1588                           int incx) {
   1589   return DoBlasInternal(
   1590       wrap::cublasZtbmv, stream, true /* = pointer_mode_host */,
   1591       CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1592       CUDABlasDiagonal(diag), n, k, CUDAComplex(CUDAMemory(a)), lda,
   1593       CUDAComplex(CUDAMemoryMutable(x)), incx);
   1594 }
   1595 
   1596 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
   1597                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1598                           uint64 k, const DeviceMemory<float> &a, int lda,
   1599                           DeviceMemory<float> *x, int incx) {
   1600   return DoBlasInternal(wrap::cublasStbsv, stream,
   1601                         true /* = pointer_mode_host */,
   1602                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1603                         CUDABlasDiagonal(diag), n, k, CUDAMemory(a), lda,
   1604                         CUDAMemoryMutable(x), incx);
   1605 }
   1606 
   1607 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
   1608                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1609                           uint64 k, const DeviceMemory<double> &a, int lda,
   1610                           DeviceMemory<double> *x, int incx) {
   1611   return DoBlasInternal(wrap::cublasDtbsv, stream,
   1612                         true /* = pointer_mode_host */,
   1613                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1614                         CUDABlasDiagonal(diag), n, k, CUDAMemory(a), lda,
   1615                         CUDAMemoryMutable(x), incx);
   1616 }
   1617 
   1618 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
   1619                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1620                           uint64 k, const DeviceMemory<std::complex<float>> &a,
   1621                           int lda, DeviceMemory<std::complex<float>> *x,
   1622                           int incx) {
   1623   return DoBlasInternal(
   1624       wrap::cublasCtbsv, stream, true /* = pointer_mode_host */,
   1625       CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1626       CUDABlasDiagonal(diag), n, k, CUDAComplex(CUDAMemory(a)), lda,
   1627       CUDAComplex(CUDAMemoryMutable(x)), incx);
   1628 }
   1629 
   1630 bool CUDABlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo,
   1631                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1632                           uint64 k, const DeviceMemory<std::complex<double>> &a,
   1633                           int lda, DeviceMemory<std::complex<double>> *x,
   1634                           int incx) {
   1635   return DoBlasInternal(
   1636       wrap::cublasZtbsv, stream, true /* = pointer_mode_host */,
   1637       CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1638       CUDABlasDiagonal(diag), n, k, CUDAComplex(CUDAMemory(a)), lda,
   1639       CUDAComplex(CUDAMemoryMutable(x)), incx);
   1640 }
   1641 
   1642 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
   1643                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1644                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
   1645                           int incx) {
   1646   return DoBlasInternal(
   1647       wrap::cublasStpmv, stream, true /* = pointer_mode_host */,
   1648       CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1649       CUDABlasDiagonal(diag), n, CUDAMemory(ap), CUDAMemoryMutable(x), incx);
   1650 }
   1651 
   1652 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
   1653                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1654                           const DeviceMemory<double> &ap,
   1655                           DeviceMemory<double> *x, int incx) {
   1656   return DoBlasInternal(
   1657       wrap::cublasDtpmv, stream, true /* = pointer_mode_host */,
   1658       CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1659       CUDABlasDiagonal(diag), n, CUDAMemory(ap), CUDAMemoryMutable(x), incx);
   1660 }
   1661 
   1662 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
   1663                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1664                           const DeviceMemory<std::complex<float>> &ap,
   1665                           DeviceMemory<std::complex<float>> *x, int incx) {
   1666   return DoBlasInternal(wrap::cublasCtpmv, stream,
   1667                         true /* = pointer_mode_host */,
   1668                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1669                         CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(ap)),
   1670                         CUDAComplex(CUDAMemoryMutable(x)), incx);
   1671 }
   1672 
   1673 bool CUDABlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo,
   1674                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1675                           const DeviceMemory<std::complex<double>> &ap,
   1676                           DeviceMemory<std::complex<double>> *x, int incx) {
   1677   return DoBlasInternal(wrap::cublasZtpmv, stream,
   1678                         true /* = pointer_mode_host */,
   1679                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1680                         CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(ap)),
   1681                         CUDAComplex(CUDAMemoryMutable(x)), incx);
   1682 }
   1683 
   1684 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
   1685                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1686                           const DeviceMemory<float> &ap, DeviceMemory<float> *x,
   1687                           int incx) {
   1688   return DoBlasInternal(
   1689       wrap::cublasStpsv, stream, true /* = pointer_mode_host */,
   1690       CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1691       CUDABlasDiagonal(diag), n, CUDAMemory(ap), CUDAMemoryMutable(x), incx);
   1692 }
   1693 
   1694 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
   1695                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1696                           const DeviceMemory<double> &ap,
   1697                           DeviceMemory<double> *x, int incx) {
   1698   return DoBlasInternal(
   1699       wrap::cublasDtpsv, stream, true /* = pointer_mode_host */,
   1700       CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1701       CUDABlasDiagonal(diag), n, CUDAMemory(ap), CUDAMemoryMutable(x), incx);
   1702 }
   1703 
   1704 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
   1705                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1706                           const DeviceMemory<std::complex<float>> &ap,
   1707                           DeviceMemory<std::complex<float>> *x, int incx) {
   1708   return DoBlasInternal(wrap::cublasCtpsv, stream,
   1709                         true /* = pointer_mode_host */,
   1710                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1711                         CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(ap)),
   1712                         CUDAComplex(CUDAMemoryMutable(x)), incx);
   1713 }
   1714 
   1715 bool CUDABlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo,
   1716                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1717                           const DeviceMemory<std::complex<double>> &ap,
   1718                           DeviceMemory<std::complex<double>> *x, int incx) {
   1719   return DoBlasInternal(wrap::cublasZtpsv, stream,
   1720                         true /* = pointer_mode_host */,
   1721                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1722                         CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(ap)),
   1723                         CUDAComplex(CUDAMemoryMutable(x)), incx);
   1724 }
   1725 
   1726 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
   1727                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1728                           const DeviceMemory<float> &a, int lda,
   1729                           DeviceMemory<float> *x, int incx) {
   1730   return DoBlasInternal(wrap::cublasStrmv, stream,
   1731                         true /* = pointer_mode_host */,
   1732                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1733                         CUDABlasDiagonal(diag), n, CUDAMemory(a), lda,
   1734                         CUDAMemoryMutable(x), incx);
   1735 }
   1736 
   1737 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
   1738                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1739                           const DeviceMemory<double> &a, int lda,
   1740                           DeviceMemory<double> *x, int incx) {
   1741   return DoBlasInternal(wrap::cublasDtrmv, stream,
   1742                         true /* = pointer_mode_host */,
   1743                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1744                         CUDABlasDiagonal(diag), n, CUDAMemory(a), lda,
   1745                         CUDAMemoryMutable(x), incx);
   1746 }
   1747 
   1748 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
   1749                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1750                           const DeviceMemory<std::complex<float>> &a, int lda,
   1751                           DeviceMemory<std::complex<float>> *x, int incx) {
   1752   return DoBlasInternal(wrap::cublasCtrmv, stream,
   1753                         true /* = pointer_mode_host */,
   1754                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1755                         CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(a)),
   1756                         lda, CUDAComplex(CUDAMemoryMutable(x)), incx);
   1757 }
   1758 
   1759 bool CUDABlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo,
   1760                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1761                           const DeviceMemory<std::complex<double>> &a, int lda,
   1762                           DeviceMemory<std::complex<double>> *x, int incx) {
   1763   return DoBlasInternal(wrap::cublasZtrmv, stream,
   1764                         true /* = pointer_mode_host */,
   1765                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1766                         CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(a)),
   1767                         lda, CUDAComplex(CUDAMemoryMutable(x)), incx);
   1768 }
   1769 
   1770 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
   1771                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1772                           const DeviceMemory<float> &a, int lda,
   1773                           DeviceMemory<float> *x, int incx) {
   1774   return DoBlasInternal(wrap::cublasStrsv, stream,
   1775                         true /* = pointer_mode_host */,
   1776                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1777                         CUDABlasDiagonal(diag), n, CUDAMemory(a), lda,
   1778                         CUDAMemoryMutable(x), incx);
   1779 }
   1780 
   1781 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
   1782                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1783                           const DeviceMemory<double> &a, int lda,
   1784                           DeviceMemory<double> *x, int incx) {
   1785   return DoBlasInternal(wrap::cublasDtrsv, stream,
   1786                         true /* = pointer_mode_host */,
   1787                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1788                         CUDABlasDiagonal(diag), n, CUDAMemory(a), lda,
   1789                         CUDAMemoryMutable(x), incx);
   1790 }
   1791 
   1792 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
   1793                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1794                           const DeviceMemory<std::complex<float>> &a, int lda,
   1795                           DeviceMemory<std::complex<float>> *x, int incx) {
   1796   return DoBlasInternal(wrap::cublasCtrsv, stream,
   1797                         true /* = pointer_mode_host */,
   1798                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1799                         CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(a)),
   1800                         lda, CUDAComplex(CUDAMemoryMutable(x)), incx);
   1801 }
   1802 
   1803 bool CUDABlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo,
   1804                           blas::Transpose trans, blas::Diagonal diag, uint64 n,
   1805                           const DeviceMemory<std::complex<double>> &a, int lda,
   1806                           DeviceMemory<std::complex<double>> *x, int incx) {
   1807   return DoBlasInternal(wrap::cublasZtrsv, stream,
   1808                         true /* = pointer_mode_host */,
   1809                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans),
   1810                         CUDABlasDiagonal(diag), n, CUDAComplex(CUDAMemory(a)),
   1811                         lda, CUDAComplex(CUDAMemoryMutable(x)), incx);
   1812 }
   1813 
   1814 bool CUDABlas::DoBlasGemm(
   1815     Stream *stream, blas::Transpose transa,
   1816     blas::Transpose transb, uint64 m, uint64 n, uint64 k,
   1817     float alpha, const DeviceMemory<Eigen::half> &a, int lda,
   1818     const DeviceMemory<Eigen::half> &b, int ldb, float beta,
   1819     DeviceMemory<Eigen::half> *c, int ldc) {
   1820 #if CUDA_VERSION >= 7050
   1821   VLOG(1) << port::Printf(
   1822       "doing cuBLAS SGEMM: at=%d bt=%d m=%llu n=%llu "
   1823       "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
   1824       "c=%p ldc=%d",
   1825       static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
   1826       a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
   1827   if (transa == blas::Transpose::kNoTranspose) {
   1828     if (lda < static_cast<int64>(m)) {
   1829       LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
   1830                       "precondition violation";
   1831     }
   1832   } else {
   1833     if (lda < static_cast<int64>(k)) {
   1834       LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
   1835                    << ") (transpose case); precondition violation";
   1836     }
   1837   }
   1838   if (transb == blas::Transpose::kNoTranspose) {
   1839     if (ldb < static_cast<int64>(k)) {
   1840       LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
   1841                    << ") (no transpose case); precondition violation";
   1842     }
   1843   } else {
   1844     if (ldb < static_cast<int64>(n)) {
   1845       LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
   1846                       "precondition violation";
   1847     }
   1848   }
   1849 
   1850   bool use_tensor_ops = false;
   1851 #if CUDA_VERSION >= 9000
   1852   int cc_major, cc_minor;
   1853   stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major,
   1854                                                                    &cc_minor);
   1855 
   1856   // GPUs < sm_70 don't support Volta hardware.
   1857   if (cc_major >= 7 && TensorOpMathEnabled()) {
   1858     use_tensor_ops = true;
   1859   }
   1860 #endif
   1861 
   1862   return DoBlasInternalImpl(
   1863       wrap::cublasSgemmEx, stream, true /* = pointer_mode_host */,
   1864       true /* = err_on_failure= */, use_tensor_ops, CUDABlasTranspose(transa),
   1865       CUDABlasTranspose(transb), m, n, k, &alpha, CUDAMemory(a),
   1866       SE_CUDA_DATA_HALF, lda, CUDAMemory(b), SE_CUDA_DATA_HALF, ldb, &beta,
   1867       CUDAMemoryMutable(c), SE_CUDA_DATA_HALF, ldc);
   1868 
   1869 #else
   1870   LOG(ERROR) << "fp16 sgemm is not implemented in this cuBLAS version "
   1871              << "(need at least CUDA 7.5)";
   1872   return false;
   1873 #endif
   1874 }
   1875 
   1876 bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
   1877                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
   1878                           float alpha, const DeviceMemory<float> &a, int lda,
   1879                           const DeviceMemory<float> &b, int ldb, float beta,
   1880                           DeviceMemory<float> *c, int ldc) {
   1881   VLOG(1) << port::Printf(
   1882       "doing cuBLAS SGEMM: at=%d bt=%d m=%llu n=%llu "
   1883       "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f "
   1884       "c=%p ldc=%d",
   1885       static_cast<int>(transa), static_cast<int>(transb), m, n, k, alpha,
   1886       a.opaque(), lda, b.opaque(), ldb, beta, c->opaque(), ldc);
   1887   if (transa == blas::Transpose::kNoTranspose) {
   1888     if (lda < static_cast<int64>(m)) {
   1889       LOG(WARNING) << "GEMM lda was smaller than m (no transpose case); "
   1890                       "precondition violation";
   1891     }
   1892   } else {
   1893     if (lda < static_cast<int64>(k)) {
   1894       LOG(WARNING) << "GEMM lda (" << lda << ") was smaller than k (" << k
   1895                    << ") (transpose case); precondition violation";
   1896     }
   1897   }
   1898   if (transb == blas::Transpose::kNoTranspose) {
   1899     if (ldb < static_cast<int64>(k)) {
   1900       LOG(WARNING) << "GEMM ldb (" << ldb << ") was smaller than k (" << k
   1901                    << ") (no transpose case); precondition violation";
   1902     }
   1903   } else {
   1904     if (ldb < static_cast<int64>(n)) {
   1905       LOG(WARNING) << "GEMM ldb was smaller than n (transpose case); "
   1906                       "precondition violation";
   1907     }
   1908   }
   1909   return DoBlasInternal(
   1910       wrap::cublasSgemm, stream, true /* = pointer_mode_host */,
   1911       CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
   1912       CUDAMemory(a), lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc);
   1913 }
   1914 
   1915 bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
   1916                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
   1917                           double alpha, const DeviceMemory<double> &a, int lda,
   1918                           const DeviceMemory<double> &b, int ldb, double beta,
   1919                           DeviceMemory<double> *c, int ldc) {
   1920   return DoBlasInternal(
   1921       wrap::cublasDgemm, stream, true /* = pointer_mode_host */,
   1922       CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
   1923       CUDAMemory(a), lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc);
   1924 }
   1925 
   1926 bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
   1927                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
   1928                           std::complex<float> alpha,
   1929                           const DeviceMemory<std::complex<float>> &a, int lda,
   1930                           const DeviceMemory<std::complex<float>> &b, int ldb,
   1931                           std::complex<float> beta,
   1932                           DeviceMemory<std::complex<float>> *c, int ldc) {
   1933   return DoBlasInternal(
   1934       wrap::cublasCgemm, stream, true /* = pointer_mode_host */,
   1935       CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
   1936       CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda,
   1937       CUDAComplex(CUDAMemory(b)), ldb, CUDAComplex(&beta),
   1938       CUDAComplex(CUDAMemoryMutable(c)), ldc);
   1939 }
   1940 
   1941 bool CUDABlas::DoBlasGemm(Stream *stream, blas::Transpose transa,
   1942                           blas::Transpose transb, uint64 m, uint64 n, uint64 k,
   1943                           std::complex<double> alpha,
   1944                           const DeviceMemory<std::complex<double>> &a, int lda,
   1945                           const DeviceMemory<std::complex<double>> &b, int ldb,
   1946                           std::complex<double> beta,
   1947                           DeviceMemory<std::complex<double>> *c, int ldc) {
   1948   return DoBlasInternal(
   1949       wrap::cublasZgemm, stream, true /* = pointer_mode_host */,
   1950       CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
   1951       CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda,
   1952       CUDAComplex(CUDAMemory(b)), ldb, CUDAComplex(&beta),
   1953       CUDAComplex(CUDAMemoryMutable(c)), ldc);
   1954 }
   1955 
   1956 bool CUDABlas::DoBlasGemvWithProfiling(
   1957     Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha,
   1958     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
   1959     int incx, float beta, DeviceMemory<float> *y, int incy,
   1960     blas::ProfileResult *output_profile_result) {
   1961   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
   1962                                      incx, beta, y, incy,
   1963                                      output_profile_result);
   1964 }
   1965 
   1966 bool CUDABlas::DoBlasGemvWithProfiling(
   1967     Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha,
   1968     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
   1969     int incx, double beta, DeviceMemory<double> *y, int incy,
   1970     blas::ProfileResult *output_profile_result) {
   1971   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
   1972                                      incx, beta, y, incy,
   1973                                      output_profile_result);
   1974 }
   1975 
   1976 bool CUDABlas::DoBlasGemvWithProfiling(
   1977     Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
   1978     std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a,
   1979     int lda, const DeviceMemory<std::complex<float>> &x, int incx,
   1980     std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
   1981     blas::ProfileResult *output_profile_result) {
   1982   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
   1983                                      incx, beta, y, incy,
   1984                                      output_profile_result);
   1985 }
   1986 
   1987 bool CUDABlas::DoBlasGemvWithProfiling(
   1988     Stream *stream, blas::Transpose trans, uint64 m, uint64 n,
   1989     std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a,
   1990     int lda, const DeviceMemory<std::complex<double>> &x, int incx,
   1991     std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
   1992     blas::ProfileResult *output_profile_result) {
   1993   return DoBlasGemvWithProfilingImpl(stream, trans, m, n, alpha, a, lda, x,
   1994                                      incx, beta, y, incy,
   1995                                      output_profile_result);
   1996 }
   1997 
   1998 bool CUDABlas::DoBlasGemmWithProfiling(
   1999     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
   2000     uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a,
   2001     int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta,
   2002     DeviceMemory<Eigen::half> *c, int ldc,
   2003     blas::ProfileResult *output_profile_result) {
   2004   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
   2005                                      lda, b, ldb, beta, c, ldc,
   2006                                      output_profile_result);
   2007 }
   2008 
   2009 bool CUDABlas::DoBlasGemmWithProfiling(
   2010     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
   2011     uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
   2012     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
   2013     int ldc, blas::ProfileResult *output_profile_result) {
   2014   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
   2015                                      lda, b, ldb, beta, c, ldc,
   2016                                      output_profile_result);
   2017 }
   2018 
   2019 bool CUDABlas::DoBlasGemmWithProfiling(
   2020     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
   2021     uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
   2022     const DeviceMemory<double> &b, int ldb, double beta,
   2023     DeviceMemory<double> *c, int ldc,
   2024     blas::ProfileResult *output_profile_result) {
   2025   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
   2026                                      lda, b, ldb, beta, c, ldc,
   2027                                      output_profile_result);
   2028 }
   2029 
   2030 bool CUDABlas::DoBlasGemmWithProfiling(
   2031     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
   2032     uint64 n, uint64 k, std::complex<float> alpha,
   2033     const DeviceMemory<std::complex<float>> &a, int lda,
   2034     const DeviceMemory<std::complex<float>> &b, int ldb,
   2035     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
   2036     blas::ProfileResult *output_profile_result) {
   2037   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
   2038                                      lda, b, ldb, beta, c, ldc,
   2039                                      output_profile_result);
   2040 }
   2041 
   2042 bool CUDABlas::DoBlasGemmWithProfiling(
   2043     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
   2044     uint64 n, uint64 k, std::complex<double> alpha,
   2045     const DeviceMemory<std::complex<double>> &a, int lda,
   2046     const DeviceMemory<std::complex<double>> &b, int ldb,
   2047     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
   2048     blas::ProfileResult *output_profile_result) {
   2049   return DoBlasGemmWithProfilingImpl(stream, transa, transb, m, n, k, alpha, a,
   2050                                      lda, b, ldb, beta, c, ldc,
   2051                                      output_profile_result);
   2052 }
   2053 
   2054 template <typename T>
   2055 bool CUDABlas::DoBlasGemvWithProfilingImpl(
   2056     Stream *stream, blas::Transpose trans, uint64 m, uint64 n, const T &alpha,
   2057     const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
   2058     const T &beta, DeviceMemory<T> *y, int incy,
   2059     blas::ProfileResult *output_profile_result) {
   2060   struct TimerDeleter {
   2061     void operator()(CUDATimer *t) {
   2062       t->Destroy();
   2063       delete t;
   2064     }
   2065   };
   2066   std::unique_ptr<CUDATimer, TimerDeleter> timer;
   2067   if (output_profile_result != nullptr) {
   2068     timer.reset(new CUDATimer(parent_));
   2069     if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
   2070       return false;
   2071     }
   2072   }
   2073 
   2074   // Call blasGemm
   2075   bool result =
   2076       DoBlasGemv(stream, trans, m, n, alpha, a, lda, x, incx, beta, y, incy);
   2077 
   2078   if (timer != nullptr && result) {
   2079     // CUDATimer will CHECK-fail if we Stop() it while the stream is in an error
   2080     // state.
   2081     if (!timer->Stop(AsCUDAStream(stream))) {
   2082       return false;
   2083     }
   2084     output_profile_result->set_is_valid(true);
   2085     output_profile_result->set_algorithm(blas::kDefaultBlasGemv);
   2086     output_profile_result->set_elapsed_time_in_ms(
   2087         timer->GetElapsedMilliseconds());
   2088   }
   2089   return result;
   2090 }
   2091 
   2092 template <typename T, typename ParamType>
   2093 bool CUDABlas::DoBlasGemmWithProfilingImpl(
   2094     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
   2095     uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
   2096     int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
   2097     DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
   2098   struct TimerDeleter {
   2099     void operator()(CUDATimer *t) {
   2100       t->Destroy();
   2101       delete t;
   2102     }
   2103   };
   2104   std::unique_ptr<CUDATimer, TimerDeleter> timer;
   2105   if (output_profile_result != nullptr) {
   2106     timer.reset(new CUDATimer(parent_));
   2107     if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
   2108       return false;
   2109     }
   2110   }
   2111 
   2112   // Call blasGemm
   2113   bool result = DoBlasGemm(stream, transa, transb, m, n, k, alpha, a, lda, b,
   2114                            ldb, beta, c, ldc);
   2115 
   2116   if (timer != nullptr && result) {
   2117     // CUDATimer will CHECK-fail if we Stop() it while the stream is in an error
   2118     // state.
   2119     if (!timer->Stop(AsCUDAStream(stream))) {
   2120       return false;
   2121     }
   2122     output_profile_result->set_is_valid(true);
   2123     output_profile_result->set_algorithm(blas::kDefaultBlasGemm);
   2124     output_profile_result->set_elapsed_time_in_ms(
   2125         timer->GetElapsedMilliseconds());
   2126   }
   2127   return result;
   2128 }
   2129 
   2130 static bool UsesTensorOps(blas::AlgorithmType algo) {
   2131 #if CUDA_VERSION >= 9000
   2132   cublasGemmAlgo_t cublas_algo = static_cast<cublasGemmAlgo_t>(algo);
   2133   return cublas_algo >= CUBLAS_GEMM_DEFAULT_TENSOR_OP;
   2134 #else
   2135   return false;
   2136 #endif
   2137 }
   2138 
   2139 template <typename InType>
   2140 static bool TensorOpsAvailable(int cc_major) {
   2141 #if CUDA_VERSION >= 9000
   2142   if (cc_major >= 7 && TensorOpMathEnabled() &&
   2143       std::is_same<InType, Eigen::half>::value) {
   2144     return true;
   2145   }
   2146 #endif
   2147   return false;
   2148 }
   2149 
   2150 template <typename InT, typename OutT, typename CompT>
   2151 bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
   2152     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
   2153     uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, int lda,
   2154     const DeviceMemory<InT> &b, int ldb, const CompT &beta,
   2155     DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
   2156     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
   2157 // CUDA < version 8 and GPUs < sm_50 don't support cublasGemmEx.
   2158 #if CUDA_VERSION < 8000
   2159   return false;
   2160 #else
   2161   int cc_major, cc_minor;
   2162   if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
   2163           &cc_major, &cc_minor) &&
   2164       cc_major < 5) {
   2165     return false;
   2166   }
   2167 
   2168   if (UsesTensorOps(algorithm) && !TensorOpsAvailable<InT>(cc_major)) {
   2169     return false;
   2170   }
   2171 
   2172   struct TimerDeleter {
   2173     void operator()(CUDATimer *t) {
   2174       t->Destroy();
   2175       delete t;
   2176     }
   2177   };
   2178   std::unique_ptr<CUDATimer, TimerDeleter> timer;
   2179   if (output_profile_result != nullptr) {
   2180     timer.reset(new CUDATimer(parent_));
   2181     if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
   2182       return false;
   2183     }
   2184   }
   2185 
   2186   cudaDataType_t cuda_in_type = CUDADataType<InT>::type;
   2187   // Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
   2188   // we do the following compile-time check on the default value:
   2189   static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, "");
   2190   bool result = DoBlasInternalFailureOK(
   2191       wrap::cublasGemmEx, stream, /* pointer_mode_host = */ true,
   2192       CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
   2193       CUDAMemory(a), cuda_in_type, lda, CUDAMemory(b), cuda_in_type, ldb, &beta,
   2194       CUDAMemoryMutable(c), CUDADataType<OutT>::type, ldc,
   2195       CUDAComputationType(computation_type),
   2196       static_cast<cublasGemmAlgo_t>(algorithm));
   2197 
   2198   if (timer != nullptr && result) {
   2199     // CUDATimer will CHECK-fail if we Stop() it while the stream is in an error
   2200     // state.
   2201     if (!timer->Stop(AsCUDAStream(stream))) {
   2202       return false;
   2203     }
   2204     output_profile_result->set_is_valid(true);
   2205     output_profile_result->set_algorithm(algorithm);
   2206     output_profile_result->set_elapsed_time_in_ms(
   2207         timer->GetElapsedMilliseconds());
   2208   }
   2209   return result;
   2210 #endif
   2211 }
   2212 
   2213 bool CUDABlas::GetBlasGemmAlgorithms(
   2214     std::vector<blas::AlgorithmType> *out_algorithms) {
   2215 // cublasGemmAlgo_t (and the function that accepts this type, cublasGemmEx)
   2216 // were first introduced in CUDA 8.
   2217 // Note that when CUDA version and compute capability is not sufficient, we
   2218 // still return the out_algorithms. Caller needs to make sure that in this case,
   2219 // the returned vector is empty.
   2220 #if CUDA_VERSION >= 8000
   2221   for (cublasGemmAlgo_t algo : {
   2222          CUBLAS_GEMM_DFALT, CUBLAS_GEMM_ALGO0, CUBLAS_GEMM_ALGO1,
   2223              CUBLAS_GEMM_ALGO2, CUBLAS_GEMM_ALGO3, CUBLAS_GEMM_ALGO4,
   2224              CUBLAS_GEMM_ALGO5, CUBLAS_GEMM_ALGO6, CUBLAS_GEMM_ALGO7,
   2225 #if CUDA_VERSION >= 9000
   2226              CUBLAS_GEMM_ALGO8, CUBLAS_GEMM_ALGO9, CUBLAS_GEMM_ALGO10,
   2227              CUBLAS_GEMM_ALGO11, CUBLAS_GEMM_ALGO12, CUBLAS_GEMM_ALGO13,
   2228              CUBLAS_GEMM_ALGO14, CUBLAS_GEMM_ALGO15, CUBLAS_GEMM_ALGO16,
   2229              CUBLAS_GEMM_ALGO17, CUBLAS_GEMM_DFALT_TENSOR_OP,
   2230              CUBLAS_GEMM_ALGO0_TENSOR_OP, CUBLAS_GEMM_ALGO1_TENSOR_OP,
   2231              CUBLAS_GEMM_ALGO2_TENSOR_OP
   2232 #endif
   2233        }) {
   2234     out_algorithms->push_back(algo);
   2235   }
   2236 #endif
   2237   return true;
   2238 }
   2239 
   2240 bool CUDABlas::DoBlasGemmWithAlgorithm(
   2241     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
   2242     uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
   2243     const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c,
   2244     int ldc, blas::ComputationType computation_type,
   2245     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
   2246   return DoBlasGemmWithAlgorithmImpl(
   2247       stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
   2248       computation_type, algorithm, output_profile_result);
   2249 }
   2250 
   2251 bool CUDABlas::DoBlasGemmWithAlgorithm(
   2252     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
   2253     uint64 n, uint64 k, const Eigen::half &alpha,
   2254     const DeviceMemory<Eigen::half> &a, int lda,
   2255     const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta,
   2256     DeviceMemory<Eigen::half> *c, int ldc,
   2257     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
   2258     blas::ProfileResult *output_profile_result) {
   2259   return DoBlasGemmWithAlgorithmImpl(
   2260       stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
   2261       computation_type, algorithm, output_profile_result);
   2262 }
   2263 
   2264 bool CUDABlas::DoBlasGemmWithAlgorithm(
   2265     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
   2266     uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
   2267     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
   2268     int ldc, blas::ComputationType computation_type,
   2269     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
   2270   return DoBlasGemmWithAlgorithmImpl(
   2271       stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
   2272       computation_type, algorithm, output_profile_result);
   2273 }
   2274 
   2275 bool CUDABlas::DoBlasGemmWithAlgorithm(
   2276     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
   2277     uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
   2278     const DeviceMemory<double> &b, int ldb, double beta,
   2279     DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type,
   2280     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
   2281   return DoBlasGemmWithAlgorithmImpl(
   2282       stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
   2283       computation_type, algorithm, output_profile_result);
   2284 }
   2285 
   2286 bool CUDABlas::DoBlasGemmWithAlgorithm(
   2287     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
   2288     uint64 n, uint64 k, std::complex<float> alpha,
   2289     const DeviceMemory<std::complex<float>> &a, int lda,
   2290     const DeviceMemory<std::complex<float>> &b, int ldb,
   2291     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
   2292     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
   2293     blas::ProfileResult *output_profile_result) {
   2294   return DoBlasGemmWithAlgorithmImpl(
   2295       stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
   2296       computation_type, algorithm, output_profile_result);
   2297 }
   2298 
   2299 bool CUDABlas::DoBlasGemmWithAlgorithm(
   2300     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
   2301     uint64 n, uint64 k, std::complex<double> alpha,
   2302     const DeviceMemory<std::complex<double>> &a, int lda,
   2303     const DeviceMemory<std::complex<double>> &b, int ldb,
   2304     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
   2305     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
   2306     blas::ProfileResult *output_profile_result) {
   2307   return DoBlasGemmWithAlgorithmImpl(
   2308       stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
   2309       computation_type, algorithm, output_profile_result);
   2310 }
   2311 
   2312 template <typename T, typename FuncT>
   2313 port::Status CUDABlas::DoBlasGemmBatchedInternal(
   2314     FuncT cublas_func, Stream *stream, blas::Transpose transa,
   2315     blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
   2316     const port::ArraySlice<DeviceMemory<T> *> &a_ptrs_to_wrappers, int lda,
   2317     const port::ArraySlice<DeviceMemory<T> *> &b_ptrs_to_wrappers, int ldb,
   2318     T beta, const port::ArraySlice<DeviceMemory<T> *> &c_ptrs_to_wrappers,
   2319     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
   2320   std::vector<T *> a_raw_ptrs, b_raw_ptrs, c_raw_ptrs;
   2321   for (int i = 0; i < batch_count; ++i) {
   2322     a_raw_ptrs.push_back(static_cast<T *>(a_ptrs_to_wrappers[i]->opaque()));
   2323     b_raw_ptrs.push_back(static_cast<T *>(b_ptrs_to_wrappers[i]->opaque()));
   2324     c_raw_ptrs.push_back(static_cast<T *>(c_ptrs_to_wrappers[i]->opaque()));
   2325   }
   2326 
   2327   typedef typename CUDAComplexT<T>::type CUDA_T;
   2328 
   2329   const size_t size = batch_count * sizeof(CUDA_T *);
   2330 
   2331   // Device-side copy of pointers to matrices.
   2332   DeviceMemory<CUDA_T *> a;
   2333   DeviceMemory<CUDA_T *> b;
   2334   DeviceMemory<CUDA_T *> c;
   2335 
   2336   // If temporary space is allocated for device-side copies of pointers to
   2337   // matrices, that temporary space should not be freed until this function
   2338   // returns. Although the values for these unique_ptrs are not set here, they
   2339   // are declared at this scope so they will be destroyed when the function
   2340   // returns.
   2341   //
   2342   // If a scratch allocator is provided, these pointers will not be used at all.
   2343   std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> a_temporary;
   2344   std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> b_temporary;
   2345   std::unique_ptr<TemporaryDeviceMemory<CUDA_T *>> c_temporary;
   2346 
   2347   // Decide how to allocate device-side copy of pointers to matrices based on
   2348   // whether a scratch allocator was passed.
   2349   if (scratch_allocator != nullptr) {
   2350     SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> a_bytes,
   2351                         scratch_allocator->AllocateBytes(stream, size));
   2352     SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> b_bytes,
   2353                         scratch_allocator->AllocateBytes(stream, size));
   2354     SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> c_bytes,
   2355                         scratch_allocator->AllocateBytes(stream, size));
   2356     a = DeviceMemory<CUDA_T *>(a_bytes);
   2357     b = DeviceMemory<CUDA_T *>(b_bytes);
   2358     c = DeviceMemory<CUDA_T *>(c_bytes);
   2359   } else {
   2360     SE_ASSIGN_OR_RETURN(a_temporary,
   2361                         stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
   2362     SE_ASSIGN_OR_RETURN(b_temporary,
   2363                         stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
   2364     SE_ASSIGN_OR_RETURN(c_temporary,
   2365                         stream->AllocateTemporaryArray<CUDA_T *>(batch_count));
   2366     a = DeviceMemory<CUDA_T *>(*a_temporary->mutable_device_memory());
   2367     b = DeviceMemory<CUDA_T *>(*b_temporary->mutable_device_memory());
   2368     c = DeviceMemory<CUDA_T *>(*c_temporary->mutable_device_memory());
   2369   }
   2370 
   2371   if (!stream->ThenMemcpy(&a, a_raw_ptrs.data(), size).ok() ||
   2372       !stream->ThenMemcpy(&b, b_raw_ptrs.data(), size).ok() ||
   2373       !stream->ThenMemcpy(&c, c_raw_ptrs.data(), size).ok()) {
   2374     return port::Status(port::error::INTERNAL,
   2375                         "failed to copy memory from host to device in "
   2376                         "CUDABlas::DoBlasGemmBatched");
   2377   }
   2378 
   2379   bool ok = DoBlasInternal(
   2380       cublas_func, stream, true /* = pointer_mode_host */,
   2381       CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
   2382       CUDAComplex(&alpha), const_cast<const CUDA_T **>(CUDAMemory(a)), lda,
   2383       const_cast<const CUDA_T **>(CUDAMemory(b)), ldb, CUDAComplex(&beta),
   2384       const_cast<CUDA_T **>(CUDAMemory(c)), ldc, batch_count);
   2385 
   2386   if (ok) {
   2387     return port::Status::OK();
   2388   }
   2389   return port::Status(port::error::INTERNAL,
   2390                       "failed BLAS call, see log for details");
   2391 }
   2392 
   2393 bool CUDABlas::DoBlasGemmBatched(
   2394     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
   2395     uint64 n, uint64 k, float alpha,
   2396     const port::ArraySlice<DeviceMemory<float> *> &a_array, int lda,
   2397     const port::ArraySlice<DeviceMemory<float> *> &b_array, int ldb, float beta,
   2398     const port::ArraySlice<DeviceMemory<float> *> &c_array, int ldc,
   2399     int batch_count, ScratchAllocator *scratch_allocator) {
   2400   port::Status status = DoBlasGemmBatchedInternal(
   2401       wrap::cublasSgemmBatched, stream, transa, transb, m, n, k, alpha, a_array,
   2402       lda, b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
   2403   if (!status.ok()) {
   2404     LOG(ERROR) << status;
   2405   }
   2406   return status.ok();
   2407 }
   2408 
   2409 bool CUDABlas::DoBlasGemmBatched(
   2410     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
   2411     uint64 n, uint64 k, double alpha,
   2412     const port::ArraySlice<DeviceMemory<double> *> &a_array, int lda,
   2413     const port::ArraySlice<DeviceMemory<double> *> &b_array, int ldb,
   2414     double beta, const port::ArraySlice<DeviceMemory<double> *> &c_array,
   2415     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
   2416   port::Status status = DoBlasGemmBatchedInternal(
   2417       wrap::cublasDgemmBatched, stream, transa, transb, m, n, k, alpha, a_array,
   2418       lda, b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
   2419   if (!status.ok()) {
   2420     LOG(ERROR) << status;
   2421   }
   2422   return status.ok();
   2423 }
   2424 
   2425 bool CUDABlas::DoBlasGemmBatched(
   2426     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
   2427     uint64 n, uint64 k, std::complex<float> alpha,
   2428     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a_array,
   2429     int lda,
   2430     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b_array,
   2431     int ldb, std::complex<float> beta,
   2432     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c_array,
   2433     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
   2434   port::Status status = DoBlasGemmBatchedInternal(
   2435       wrap::cublasCgemmBatched, stream, transa, transb, m, n, k, alpha, a_array,
   2436       lda, b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
   2437   if (!status.ok()) {
   2438     LOG(ERROR) << status;
   2439   }
   2440   return status.ok();
   2441 }
   2442 
   2443 bool CUDABlas::DoBlasGemmBatched(
   2444     Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
   2445     uint64 n, uint64 k, std::complex<double> alpha,
   2446     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a_array,
   2447     int lda,
   2448     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b_array,
   2449     int ldb, std::complex<double> beta,
   2450     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c_array,
   2451     int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
   2452   port::Status status = DoBlasGemmBatchedInternal(
   2453       wrap::cublasZgemmBatched, stream, transa, transb, m, n, k, alpha, a_array,
   2454       lda, b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator);
   2455   if (!status.ok()) {
   2456     LOG(ERROR) << status;
   2457   }
   2458   return status.ok();
   2459 }
   2460 
   2461 bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
   2462                           blas::UpperLower uplo, uint64 m, uint64 n,
   2463                           std::complex<float> alpha,
   2464                           const DeviceMemory<std::complex<float>> &a, int lda,
   2465                           const DeviceMemory<std::complex<float>> &b, int ldb,
   2466                           std::complex<float> beta,
   2467                           DeviceMemory<std::complex<float>> *c, int ldc) {
   2468   return DoBlasInternal(
   2469       wrap::cublasChemm, stream, true /* = pointer_mode_host */,
   2470       CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, CUDAComplex(&alpha),
   2471       CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(b)), ldb,
   2472       CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(c)), ldc);
   2473 }
   2474 
   2475 bool CUDABlas::DoBlasHemm(Stream *stream, blas::Side side,
   2476                           blas::UpperLower uplo, uint64 m, uint64 n,
   2477                           std::complex<double> alpha,
   2478                           const DeviceMemory<std::complex<double>> &a, int lda,
   2479                           const DeviceMemory<std::complex<double>> &b, int ldb,
   2480                           std::complex<double> beta,
   2481                           DeviceMemory<std::complex<double>> *c, int ldc) {
   2482   return DoBlasInternal(
   2483       wrap::cublasZhemm, stream, true /* = pointer_mode_host */,
   2484       CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, CUDAComplex(&alpha),
   2485       CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(b)), ldb,
   2486       CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(c)), ldc);
   2487 }
   2488 
   2489 bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
   2490                           blas::Transpose trans, uint64 n, uint64 k,
   2491                           float alpha,
   2492                           const DeviceMemory<std::complex<float>> &a, int lda,
   2493                           float beta, DeviceMemory<std::complex<float>> *c,
   2494                           int ldc) {
   2495   return DoBlasInternal(wrap::cublasCherk, stream,
   2496                         true /* = pointer_mode_host */,
   2497                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
   2498                         k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda,
   2499                         &beta, CUDAComplex(CUDAMemoryMutable(c)), ldc);
   2500 }
   2501 
   2502 bool CUDABlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo,
   2503                           blas::Transpose trans, uint64 n, uint64 k,
   2504                           double alpha,
   2505                           const DeviceMemory<std::complex<double>> &a, int lda,
   2506                           double beta, DeviceMemory<std::complex<double>> *c,
   2507                           int ldc) {
   2508   return DoBlasInternal(wrap::cublasZherk, stream,
   2509                         true /* = pointer_mode_host */,
   2510                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
   2511                         k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda,
   2512                         &beta, CUDAComplex(CUDAMemoryMutable(c)), ldc);
   2513 }
   2514 
   2515 bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
   2516                            blas::Transpose trans, uint64 n, uint64 k,
   2517                            std::complex<float> alpha,
   2518                            const DeviceMemory<std::complex<float>> &a, int lda,
   2519                            const DeviceMemory<std::complex<float>> &b, int ldb,
   2520                            float beta, DeviceMemory<std::complex<float>> *c,
   2521                            int ldc) {
   2522   return DoBlasInternal(wrap::cublasCher2k, stream,
   2523                         true /* = pointer_mode_host */,
   2524                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
   2525                         k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda,
   2526                         CUDAComplex(CUDAMemory(b)), ldb, &beta,
   2527                         CUDAComplex(CUDAMemoryMutable(c)), ldc);
   2528 }
   2529 
   2530 bool CUDABlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo,
   2531                            blas::Transpose trans, uint64 n, uint64 k,
   2532                            std::complex<double> alpha,
   2533                            const DeviceMemory<std::complex<double>> &a, int lda,
   2534                            const DeviceMemory<std::complex<double>> &b, int ldb,
   2535                            double beta, DeviceMemory<std::complex<double>> *c,
   2536                            int ldc) {
   2537   return DoBlasInternal(wrap::cublasZher2k, stream,
   2538                         true /* = pointer_mode_host */,
   2539                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
   2540                         k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda,
   2541                         CUDAComplex(CUDAMemory(b)), ldb, &beta,
   2542                         CUDAComplex(CUDAMemoryMutable(c)), ldc);
   2543 }
   2544 
   2545 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
   2546                           blas::UpperLower uplo, uint64 m, uint64 n,
   2547                           float alpha, const DeviceMemory<float> &a, int lda,
   2548                           const DeviceMemory<float> &b, int ldb, float beta,
   2549                           DeviceMemory<float> *c, int ldc) {
   2550   return DoBlasInternal(
   2551       wrap::cublasSsymm, stream, true /* = pointer_mode_host */,
   2552       CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, &alpha, CUDAMemory(a),
   2553       lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc);
   2554 }
   2555 
   2556 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
   2557                           blas::UpperLower uplo, uint64 m, uint64 n,
   2558                           double alpha, const DeviceMemory<double> &a, int lda,
   2559                           const DeviceMemory<double> &b, int ldb, double beta,
   2560                           DeviceMemory<double> *c, int ldc) {
   2561   return DoBlasInternal(
   2562       wrap::cublasDsymm, stream, true /* = pointer_mode_host */,
   2563       CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, &alpha, CUDAMemory(a),
   2564       lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc);
   2565 }
   2566 
   2567 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
   2568                           blas::UpperLower uplo, uint64 m, uint64 n,
   2569                           std::complex<float> alpha,
   2570                           const DeviceMemory<std::complex<float>> &a, int lda,
   2571                           const DeviceMemory<std::complex<float>> &b, int ldb,
   2572                           std::complex<float> beta,
   2573                           DeviceMemory<std::complex<float>> *c, int ldc) {
   2574   return DoBlasInternal(
   2575       wrap::cublasCsymm, stream, true /* = pointer_mode_host */,
   2576       CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, CUDAComplex(&alpha),
   2577       CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(b)), ldb,
   2578       CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(c)), ldc);
   2579 }
   2580 
   2581 bool CUDABlas::DoBlasSymm(Stream *stream, blas::Side side,
   2582                           blas::UpperLower uplo, uint64 m, uint64 n,
   2583                           std::complex<double> alpha,
   2584                           const DeviceMemory<std::complex<double>> &a, int lda,
   2585                           const DeviceMemory<std::complex<double>> &b, int ldb,
   2586                           std::complex<double> beta,
   2587                           DeviceMemory<std::complex<double>> *c, int ldc) {
   2588   return DoBlasInternal(
   2589       wrap::cublasZsymm, stream, true /* = pointer_mode_host */,
   2590       CUDABlasSide(side), CUDABlasUpperLower(uplo), m, n, CUDAComplex(&alpha),
   2591       CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemory(b)), ldb,
   2592       CUDAComplex(&beta), CUDAComplex(CUDAMemoryMutable(c)), ldc);
   2593 }
   2594 
   2595 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
   2596                           blas::Transpose trans, uint64 n, uint64 k,
   2597                           float alpha, const DeviceMemory<float> &a, int lda,
   2598                           float beta, DeviceMemory<float> *c, int ldc) {
   2599   return DoBlasInternal(
   2600       wrap::cublasSsyrk, stream, true /* = pointer_mode_host */,
   2601       CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k, &alpha,
   2602       CUDAMemory(a), lda, &beta, CUDAMemoryMutable(c), ldc);
   2603 }
   2604 
   2605 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
   2606                           blas::Transpose trans, uint64 n, uint64 k,
   2607                           double alpha, const DeviceMemory<double> &a, int lda,
   2608                           double beta, DeviceMemory<double> *c, int ldc) {
   2609   return DoBlasInternal(
   2610       wrap::cublasDsyrk, stream, true /* = pointer_mode_host */,
   2611       CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k, &alpha,
   2612       CUDAMemory(a), lda, &beta, CUDAMemoryMutable(c), ldc);
   2613 }
   2614 
   2615 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
   2616                           blas::Transpose trans, uint64 n, uint64 k,
   2617                           std::complex<float> alpha,
   2618                           const DeviceMemory<std::complex<float>> &a, int lda,
   2619                           std::complex<float> beta,
   2620                           DeviceMemory<std::complex<float>> *c, int ldc) {
   2621   return DoBlasInternal(
   2622       wrap::cublasCsyrk, stream, true /* = pointer_mode_host */,
   2623       CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k,
   2624       CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(&beta),
   2625       CUDAComplex(CUDAMemoryMutable(c)), ldc);
   2626 }
   2627 
   2628 bool CUDABlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo,
   2629                           blas::Transpose trans, uint64 n, uint64 k,
   2630                           std::complex<double> alpha,
   2631                           const DeviceMemory<std::complex<double>> &a, int lda,
   2632                           std::complex<double> beta,
   2633                           DeviceMemory<std::complex<double>> *c, int ldc) {
   2634   return DoBlasInternal(
   2635       wrap::cublasZsyrk, stream, true /* = pointer_mode_host */,
   2636       CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k,
   2637       CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(&beta),
   2638       CUDAComplex(CUDAMemoryMutable(c)), ldc);
   2639 }
   2640 
   2641 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
   2642                            blas::Transpose trans, uint64 n, uint64 k,
   2643                            float alpha, const DeviceMemory<float> &a, int lda,
   2644                            const DeviceMemory<float> &b, int ldb, float beta,
   2645                            DeviceMemory<float> *c, int ldc) {
   2646   return DoBlasInternal(
   2647       wrap::cublasSsyr2k, stream, true /* = pointer_mode_host */,
   2648       CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k, &alpha,
   2649       CUDAMemory(a), lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc);
   2650 }
   2651 
   2652 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
   2653                            blas::Transpose trans, uint64 n, uint64 k,
   2654                            double alpha, const DeviceMemory<double> &a, int lda,
   2655                            const DeviceMemory<double> &b, int ldb, double beta,
   2656                            DeviceMemory<double> *c, int ldc) {
   2657   return DoBlasInternal(
   2658       wrap::cublasDsyr2k, stream, true /* = pointer_mode_host */,
   2659       CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n, k, &alpha,
   2660       CUDAMemory(a), lda, CUDAMemory(b), ldb, &beta, CUDAMemoryMutable(c), ldc);
   2661 }
   2662 
   2663 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
   2664                            blas::Transpose trans, uint64 n, uint64 k,
   2665                            std::complex<float> alpha,
   2666                            const DeviceMemory<std::complex<float>> &a, int lda,
   2667                            const DeviceMemory<std::complex<float>> &b, int ldb,
   2668                            std::complex<float> beta,
   2669                            DeviceMemory<std::complex<float>> *c, int ldc) {
   2670   return DoBlasInternal(wrap::cublasCsyr2k, stream,
   2671                         true /* = pointer_mode_host */,
   2672                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
   2673                         k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda,
   2674                         CUDAComplex(CUDAMemory(b)), ldb, CUDAComplex(&beta),
   2675                         CUDAComplex(CUDAMemoryMutable(c)), ldc);
   2676 }
   2677 
   2678 bool CUDABlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo,
   2679                            blas::Transpose trans, uint64 n, uint64 k,
   2680                            std::complex<double> alpha,
   2681                            const DeviceMemory<std::complex<double>> &a, int lda,
   2682                            const DeviceMemory<std::complex<double>> &b, int ldb,
   2683                            std::complex<double> beta,
   2684                            DeviceMemory<std::complex<double>> *c, int ldc) {
   2685   return DoBlasInternal(wrap::cublasZsyr2k, stream,
   2686                         true /* = pointer_mode_host */,
   2687                         CUDABlasUpperLower(uplo), CUDABlasTranspose(trans), n,
   2688                         k, CUDAComplex(&alpha), CUDAComplex(CUDAMemory(a)), lda,
   2689                         CUDAComplex(CUDAMemory(b)), ldb, CUDAComplex(&beta),
   2690                         CUDAComplex(CUDAMemoryMutable(c)), ldc);
   2691 }
   2692 
   2693 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
   2694                           blas::UpperLower uplo, blas::Transpose transa,
   2695                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
   2696                           const DeviceMemory<float> &a, int lda,
   2697                           DeviceMemory<float> *b, int ldb) {
   2698   return DoBlasInternal(
   2699       wrap::cublasStrmm, stream, true /* = pointer_mode_host */,
   2700       CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa),
   2701       CUDABlasDiagonal(diag), m, n, &alpha, CUDAMemory(a), lda,
   2702       CUDAMemoryMutable(b), ldb, CUDAMemoryMutable(b), ldb);
   2703 }
   2704 
   2705 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
   2706                           blas::UpperLower uplo, blas::Transpose transa,
   2707                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
   2708                           const DeviceMemory<double> &a, int lda,
   2709                           DeviceMemory<double> *b, int ldb) {
   2710   return DoBlasInternal(
   2711       wrap::cublasDtrmm, stream, true /* = pointer_mode_host */,
   2712       CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa),
   2713       CUDABlasDiagonal(diag), m, n, &alpha, CUDAMemory(a), lda,
   2714       CUDAMemoryMutable(b), ldb, CUDAMemoryMutable(b), ldb);
   2715 }
   2716 
   2717 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
   2718                           blas::UpperLower uplo, blas::Transpose transa,
   2719                           blas::Diagonal diag, uint64 m, uint64 n,
   2720                           std::complex<float> alpha,
   2721                           const DeviceMemory<std::complex<float>> &a, int lda,
   2722                           DeviceMemory<std::complex<float>> *b, int ldb) {
   2723   return DoBlasInternal(
   2724       wrap::cublasCtrmm, stream, true /* = pointer_mode_host */,
   2725       CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa),
   2726       CUDABlasDiagonal(diag), m, n, CUDAComplex(&alpha),
   2727       CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemoryMutable(b)), ldb,
   2728       CUDAComplex(CUDAMemoryMutable(b)), ldb);
   2729 }
   2730 
   2731 bool CUDABlas::DoBlasTrmm(Stream *stream, blas::Side side,
   2732                           blas::UpperLower uplo, blas::Transpose transa,
   2733                           blas::Diagonal diag, uint64 m, uint64 n,
   2734                           std::complex<double> alpha,
   2735                           const DeviceMemory<std::complex<double>> &a, int lda,
   2736                           DeviceMemory<std::complex<double>> *b, int ldb) {
   2737   return DoBlasInternal(
   2738       wrap::cublasZtrmm, stream, true /* = pointer_mode_host */,
   2739       CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa),
   2740       CUDABlasDiagonal(diag), m, n, CUDAComplex(&alpha),
   2741       CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemoryMutable(b)), ldb,
   2742       CUDAComplex(CUDAMemoryMutable(b)), ldb);
   2743 }
   2744 
   2745 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
   2746                           blas::UpperLower uplo, blas::Transpose transa,
   2747                           blas::Diagonal diag, uint64 m, uint64 n, float alpha,
   2748                           const DeviceMemory<float> &a, int lda,
   2749                           DeviceMemory<float> *b, int ldb) {
   2750   return DoBlasInternal(wrap::cublasStrsm, stream,
   2751                         true /* = pointer_mode_host */, CUDABlasSide(side),
   2752                         CUDABlasUpperLower(uplo), CUDABlasTranspose(transa),
   2753                         CUDABlasDiagonal(diag), m, n, &alpha, CUDAMemory(a),
   2754                         lda, CUDAMemoryMutable(b), ldb);
   2755 }
   2756 
   2757 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
   2758                           blas::UpperLower uplo, blas::Transpose transa,
   2759                           blas::Diagonal diag, uint64 m, uint64 n, double alpha,
   2760                           const DeviceMemory<double> &a, int lda,
   2761                           DeviceMemory<double> *b, int ldb) {
   2762   return DoBlasInternal(wrap::cublasDtrsm, stream,
   2763                         true /* = pointer_mode_host */, CUDABlasSide(side),
   2764                         CUDABlasUpperLower(uplo), CUDABlasTranspose(transa),
   2765                         CUDABlasDiagonal(diag), m, n, &alpha, CUDAMemory(a),
   2766                         lda, CUDAMemoryMutable(b), ldb);
   2767 }
   2768 
   2769 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
   2770                           blas::UpperLower uplo, blas::Transpose transa,
   2771                           blas::Diagonal diag, uint64 m, uint64 n,
   2772                           std::complex<float> alpha,
   2773                           const DeviceMemory<std::complex<float>> &a, int lda,
   2774                           DeviceMemory<std::complex<float>> *b, int ldb) {
   2775   return DoBlasInternal(
   2776       wrap::cublasCtrsm, stream, true /* = pointer_mode_host */,
   2777       CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa),
   2778       CUDABlasDiagonal(diag), m, n, CUDAComplex(&alpha),
   2779       CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemoryMutable(b)), ldb);
   2780 }
   2781 
   2782 bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
   2783                           blas::UpperLower uplo, blas::Transpose transa,
   2784                           blas::Diagonal diag, uint64 m, uint64 n,
   2785                           std::complex<double> alpha,
   2786                           const DeviceMemory<std::complex<double>> &a, int lda,
   2787                           DeviceMemory<std::complex<double>> *b, int ldb) {
   2788   return DoBlasInternal(
   2789       wrap::cublasZtrsm, stream, true /* = pointer_mode_host */,
   2790       CUDABlasSide(side), CUDABlasUpperLower(uplo), CUDABlasTranspose(transa),
   2791       CUDABlasDiagonal(diag), m, n, CUDAComplex(&alpha),
   2792       CUDAComplex(CUDAMemory(a)), lda, CUDAComplex(CUDAMemoryMutable(b)), ldb);
   2793 }
   2794 
   2795 }  // namespace cuda
   2796 
   2797 namespace gpu = ::perftools::gputools;
   2798 
   2799 void initialize_cublas() {
   2800   gpu::port::Status status =
   2801       gpu::PluginRegistry::Instance()
   2802           ->RegisterFactory<gpu::PluginRegistry::BlasFactory>(
   2803               gpu::cuda::kCudaPlatformId, gpu::cuda::kCuBlasPlugin, "cuBLAS",
   2804               [](gpu::internal::StreamExecutorInterface
   2805                      *parent) -> gpu::blas::BlasSupport * {
   2806                 gpu::cuda::CUDAExecutor *cuda_executor =
   2807                     dynamic_cast<gpu::cuda::CUDAExecutor *>(parent);
   2808                 if (cuda_executor == nullptr) {
   2809                   LOG(ERROR)
   2810                       << "Attempting to initialize an instance of the cuBLAS "
   2811                       << "support library with a non-CUDA StreamExecutor";
   2812                   return nullptr;
   2813                 }
   2814 
   2815                 gpu::cuda::CUDABlas *blas =
   2816                     new gpu::cuda::CUDABlas(cuda_executor);
   2817                 if (!blas->Init()) {
   2818                   // Note: Init() will log a more specific error.
   2819                   delete blas;
   2820                   return nullptr;
   2821                 }
   2822                 return blas;
   2823               });
   2824 
   2825   if (!status.ok()) {
   2826     LOG(ERROR) << "Unable to register cuBLAS factory: "
   2827                << status.error_message();
   2828   }
   2829 
   2830   gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId,
   2831                                                      gpu::PluginKind::kBlas,
   2832                                                      gpu::cuda::kCuBlasPlugin);
   2833 }
   2834 
   2835 }  // namespace gputools
   2836 }  // namespace perftools
   2837 
   2838 REGISTER_MODULE_INITIALIZER(register_cublas,
   2839                             { perftools::gputools::initialize_cublas(); });
   2840