1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // CUDA-specific support for BLAS functionality -- this wraps the cuBLAS library 17 // capabilities, and is only included into CUDA implementation code -- it will 18 // not introduce cuda headers into other code. 19 20 #ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ 21 #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ 22 23 #include "tensorflow/stream_executor/blas.h" 24 #include "tensorflow/stream_executor/lib/stringpiece.h" 25 #include "tensorflow/stream_executor/platform/mutex.h" 26 #include "tensorflow/stream_executor/platform/port.h" 27 #include "tensorflow/stream_executor/platform/thread_annotations.h" 28 #include "tensorflow/stream_executor/plugin_registry.h" 29 30 typedef struct cublasContext *cublasHandle_t; 31 32 namespace perftools { 33 namespace gputools { 34 35 class Stream; 36 37 namespace cuda { 38 39 // Opaque and unique identifier for the cuBLAS plugin. 40 extern const PluginId kCuBlasPlugin; 41 42 class CUDAExecutor; 43 44 // BLAS plugin for CUDA platform via cuBLAS library. 45 // 46 // This satisfies the platform-agnostic BlasSupport interface. 47 // 48 // Note that the cuBLAS handle that this encapsulates is implicitly tied to the 49 // context (and, as a result, the device) that the parent CUDAExecutor is tied 50 // to. This simply happens as an artifact of creating the cuBLAS handle when a 51 // CUDA context is active. 52 // 53 // Thread-safe post-initialization. 54 class CUDABlas : public blas::BlasSupport { 55 public: 56 explicit CUDABlas(CUDAExecutor *parent); 57 58 // Allocates a cuBLAS handle. 59 bool Init(); 60 61 // Releases the cuBLAS handle, if present. 62 ~CUDABlas() override; 63 64 TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES 65 66 private: 67 // Tells cuBLAS to enqueue the BLAS operation onto a particular Stream. 68 // 69 // cuBLAS is stateful, and only be associated with one stream (in order to 70 // enqueue dispatch) at a given time. As a result, this generally must be 71 // invoked before calling into cuBLAS. 72 bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_); 73 74 // A helper function that calls the real cuBLAS function together with error 75 // handling. 76 // 77 // cublas_func: cuBLAS function pointer. 78 // cublas_name: cuBLAS function name. 79 // stream: Stream to enqueue the BLAS operation onto. 80 // pointer_mode_host: Indicate if the pointer to a scalar value is from host 81 // (true) or device (false). 82 // err_on_failure: Whether to print an error if the cublas function fails. 83 // args: Arguments of cuBLAS function. 84 template <typename FuncT, typename... Args> 85 bool DoBlasInternalImpl(FuncT cublas_func, Stream *stream, 86 bool pointer_mode_host, bool err_on_failure, 87 bool use_tensor_op_math, Args... args); 88 89 // Convenience functions that call DoBlasInternalImpl with different values 90 // for err_on_failure. 91 template <typename FuncT, typename... Args> 92 bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host, 93 Args... args) { 94 return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host, 95 /*err_on_failure=*/true, /*use_tensor_ops=*/false, 96 args...); 97 } 98 template <typename FuncT, typename... Args> 99 bool DoBlasInternalFailureOK(FuncT cublas_func, Stream *stream, 100 bool pointer_mode_host, Args... args) { 101 // Tensor ops are hard-coded off in this path, but can still be enabled with 102 // a specific algorithm choice as in DoBlasGemmWithAlgorithmImpl(). 103 return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host, 104 /*err_on_failure=*/false, 105 /*use_tensor_ops=*/false, args...); 106 } 107 108 // A helper function to implement DoBlasGemmBatched interfaces for generic 109 // types. 110 template <typename T, typename FuncT> 111 port::Status DoBlasGemmBatchedInternal( 112 FuncT cublas_func, Stream *stream, blas::Transpose transa, 113 blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha, 114 const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda, 115 const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta, 116 const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc, 117 int batch_count, ScratchAllocator *scratch_allocator); 118 119 // Helper function for implementing DoBlasGemmWithAlgorithm. 120 // 121 // We take alpha and beta by const reference because T might be Eigen::half, 122 // and we want to avoid pulling in a dependency on Eigen. When we pass the 123 // references to cublas, we essentially reinterpret_cast to __half, which is 124 // safe because Eigen::half inherits from __half. 125 template <typename InT, typename OutT, typename CompT> 126 bool DoBlasGemmWithAlgorithmImpl( 127 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 128 uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, 129 int lda, const DeviceMemory<InT> &b, int ldb, const CompT &beta, 130 DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type, 131 blas::AlgorithmType algorithm, 132 blas::ProfileResult *output_profile_result); 133 134 // Helper function for implementing DoBlasGemmWithProfiling. 135 template <typename T, typename ParamType> 136 bool DoBlasGemmWithProfilingImpl( 137 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 138 uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a, 139 int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta, 140 DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result); 141 142 // Helper function for implementing DoBlasGemvWithProfiling. 143 template <typename T> 144 bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans, 145 uint64 m, uint64 n, const T &alpha, 146 const DeviceMemory<T> &a, int lda, 147 const DeviceMemory<T> &x, int incx, 148 const T &beta, DeviceMemory<T> *y, int incy, 149 blas::ProfileResult *output_profile_result); 150 151 // mutex that guards the cuBLAS handle for this device. 152 mutex mu_; 153 154 // CUDAExecutor which instantiated this CUDABlas. 155 // Immutable post-initialization. 156 CUDAExecutor *parent_; 157 158 // cuBLAS library handle on the device. 159 cublasHandle_t blas_ GUARDED_BY(mu_); 160 161 SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas); 162 }; 163 164 } // namespace cuda 165 } // namespace gputools 166 } // namespace perftools 167 168 #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ 169