Home | History | Annotate | Download | only in cuda
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // CUDA-specific support for BLAS functionality -- this wraps the cuBLAS library
     17 // capabilities, and is only included into CUDA implementation code -- it will
     18 // not introduce cuda headers into other code.
     19 
     20 #ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
     21 #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
     22 
     23 #include "tensorflow/stream_executor/blas.h"
     24 #include "tensorflow/stream_executor/lib/stringpiece.h"
     25 #include "tensorflow/stream_executor/platform/mutex.h"
     26 #include "tensorflow/stream_executor/platform/port.h"
     27 #include "tensorflow/stream_executor/platform/thread_annotations.h"
     28 #include "tensorflow/stream_executor/plugin_registry.h"
     29 
     30 typedef struct cublasContext *cublasHandle_t;
     31 
     32 namespace perftools {
     33 namespace gputools {
     34 
     35 class Stream;
     36 
     37 namespace cuda {
     38 
     39 // Opaque and unique identifier for the cuBLAS plugin.
     40 extern const PluginId kCuBlasPlugin;
     41 
     42 class CUDAExecutor;
     43 
     44 // BLAS plugin for CUDA platform via cuBLAS library.
     45 //
     46 // This satisfies the platform-agnostic BlasSupport interface.
     47 //
     48 // Note that the cuBLAS handle that this encapsulates is implicitly tied to the
     49 // context (and, as a result, the device) that the parent CUDAExecutor is tied
     50 // to. This simply happens as an artifact of creating the cuBLAS handle when a
     51 // CUDA context is active.
     52 //
     53 // Thread-safe post-initialization.
     54 class CUDABlas : public blas::BlasSupport {
     55  public:
     56   explicit CUDABlas(CUDAExecutor *parent);
     57 
     58   // Allocates a cuBLAS handle.
     59   bool Init();
     60 
     61   // Releases the cuBLAS handle, if present.
     62   ~CUDABlas() override;
     63 
     64   TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES
     65 
     66  private:
     67   // Tells cuBLAS to enqueue the BLAS operation onto a particular Stream.
     68   //
     69   // cuBLAS is stateful, and only be associated with one stream (in order to
     70   // enqueue dispatch) at a given time. As a result, this generally must be
     71   // invoked before calling into cuBLAS.
     72   bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_);
     73 
     74   // A helper function that calls the real cuBLAS function together with error
     75   // handling.
     76   //
     77   // cublas_func:        cuBLAS function pointer.
     78   // cublas_name:        cuBLAS function name.
     79   // stream:             Stream to enqueue the BLAS operation onto.
     80   // pointer_mode_host:  Indicate if the pointer to a scalar value is from host
     81   //                     (true) or device (false).
     82   // err_on_failure:     Whether to print an error if the cublas function fails.
     83   // args:               Arguments of cuBLAS function.
     84   template <typename FuncT, typename... Args>
     85   bool DoBlasInternalImpl(FuncT cublas_func, Stream *stream,
     86                           bool pointer_mode_host, bool err_on_failure,
     87                           bool use_tensor_op_math, Args... args);
     88 
     89   // Convenience functions that call DoBlasInternalImpl with different values
     90   // for err_on_failure.
     91   template <typename FuncT, typename... Args>
     92   bool DoBlasInternal(FuncT cublas_func, Stream *stream, bool pointer_mode_host,
     93                       Args... args) {
     94     return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
     95                               /*err_on_failure=*/true, /*use_tensor_ops=*/false,
     96                               args...);
     97   }
     98   template <typename FuncT, typename... Args>
     99   bool DoBlasInternalFailureOK(FuncT cublas_func, Stream *stream,
    100                                bool pointer_mode_host, Args... args) {
    101     // Tensor ops are hard-coded off in this path, but can still be enabled with
    102     // a specific algorithm choice as in DoBlasGemmWithAlgorithmImpl().
    103     return DoBlasInternalImpl(cublas_func, stream, pointer_mode_host,
    104                               /*err_on_failure=*/false,
    105                               /*use_tensor_ops=*/false, args...);
    106   }
    107 
    108   // A helper function to implement DoBlasGemmBatched interfaces for generic
    109   // types.
    110   template <typename T, typename FuncT>
    111   port::Status DoBlasGemmBatchedInternal(
    112       FuncT cublas_func, Stream *stream, blas::Transpose transa,
    113       blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
    114       const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda,
    115       const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta,
    116       const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
    117       int batch_count, ScratchAllocator *scratch_allocator);
    118 
    119   // Helper function for implementing DoBlasGemmWithAlgorithm.
    120   //
    121   // We take alpha and beta by const reference because T might be Eigen::half,
    122   // and we want to avoid pulling in a dependency on Eigen.  When we pass the
    123   // references to cublas, we essentially reinterpret_cast to __half, which is
    124   // safe because Eigen::half inherits from __half.
    125   template <typename InT, typename OutT, typename CompT>
    126   bool DoBlasGemmWithAlgorithmImpl(
    127       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
    128       uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a,
    129       int lda, const DeviceMemory<InT> &b, int ldb, const CompT &beta,
    130       DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
    131       blas::AlgorithmType algorithm,
    132       blas::ProfileResult *output_profile_result);
    133 
    134   // Helper function for implementing DoBlasGemmWithProfiling.
    135   template <typename T, typename ParamType>
    136   bool DoBlasGemmWithProfilingImpl(
    137       Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
    138       uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
    139       int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
    140       DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result);
    141 
    142   // Helper function for implementing DoBlasGemvWithProfiling.
    143   template <typename T>
    144   bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans,
    145                                    uint64 m, uint64 n, const T &alpha,
    146                                    const DeviceMemory<T> &a, int lda,
    147                                    const DeviceMemory<T> &x, int incx,
    148                                    const T &beta, DeviceMemory<T> *y, int incy,
    149                                    blas::ProfileResult *output_profile_result);
    150 
    151   // mutex that guards the cuBLAS handle for this device.
    152   mutex mu_;
    153 
    154   // CUDAExecutor which instantiated this CUDABlas.
    155   // Immutable post-initialization.
    156   CUDAExecutor *parent_;
    157 
    158   // cuBLAS library handle on the device.
    159   cublasHandle_t blas_ GUARDED_BY(mu_);
    160 
    161   SE_DISALLOW_COPY_AND_ASSIGN(CUDABlas);
    162 };
    163 
    164 }  // namespace cuda
    165 }  // namespace gputools
    166 }  // namespace perftools
    167 
    168 #endif  // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
    169