Home | History | Annotate | Download | only in stream_executor
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Suite of datatypes to represent data-parallel kernel objects (code entities).
     17 // Kernel is the untyped variant, whereas TypedKernel takes a type signature
     18 // to do some template-based helper generation and give compile-time type
     19 // checking for kernel launch parameters.
     20 //
     21 // Users typically don't see KernelBase, they see typed kernels, analogous to a
     22 // typed function pointer. TypedKernels express their argument types via
     23 // template parameters like so:
     24 //
     25 //  TypedKernel<DeviceMemory<int>*, int>
     26 //
     27 // Which expresses a data parallel kernel signature for:
     28 //
     29 //  void(int*, int);
     30 //
     31 // And for a const memory region:
     32 //
     33 //  TypedKernel<const DeviceMemory<int>&, int>
     34 //
     35 // Corresponds to a data parallel kernel signature for:
     36 //
     37 //  void(const int*, int)
     38 //
     39 // Note that kernels always have a void return type, so results typically must
     40 // be memcpy'ied from device memory to the host.
     41 //
     42 // Also note that a scalar integer residing in device memory and an array of
     43 // integers residing in device memory have the same signature: DeviceMemory<T>.
     44 // However, in the future, checks may be added for additional safety that arrays
     45 // of minimum sizes are passed when those minimum sizes are contractually
     46 // expected by the kernel.
     47 //
     48 // For user-defined types whose definitions are appropriately shared between the
     49 // host code doing the launching and the kernel code being launched, the user
     50 // defined types are similarly permitted to be expressed as residing in device
     51 // memory:
     52 //
     53 //  TypedKernel<DeviceMemory<MyUserDefinedStructure>>
     54 //
     55 // And, when the alignment and padding are agreed upon, POD types will also be
     56 // able to be passed by value; for example, it is a common idiom to specify a
     57 // bunch of options simultaneously with a structure:
     58 //
     59 //  TypedKernel<MyOptionsStructurePassedByValue, DeviceMemory<float>>
     60 //
     61 // Which corresponds to a data parallel kernel signature like:
     62 //
     63 //  void(MyOptionsStructurePassedByValue value, float *result);
     64 //
     65 // Users typically won't need to type out the TypedKernel signature in full, it
     66 // will be typedef'd by automatically generated code; for example, see
     67 // perftools::gputools::executor_sample::VecReduceAddKernel.
     68 
     69 #ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
     70 #define TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
     71 
     72 #include <array>
     73 #include <memory>
     74 #include <tuple>
     75 #include <type_traits>
     76 #include <vector>
     77 
     78 #include "tensorflow/stream_executor/device_memory.h"
     79 #include "tensorflow/stream_executor/kernel_cache_config.h"
     80 #include "tensorflow/stream_executor/lib/array_slice.h"
     81 #include "tensorflow/stream_executor/lib/inlined_vector.h"
     82 #include "tensorflow/stream_executor/lib/stringpiece.h"
     83 #include "tensorflow/stream_executor/platform/port.h"
     84 
     85 namespace perftools {
     86 namespace gputools {
     87 
     88 class DeviceMemoryBase;
     89 template <typename ElemT>
     90 class DeviceMemory;
     91 class StreamExecutor;
     92 
     93 namespace internal {
     94 class KernelInterface;
     95 }  // namespace internal
     96 
     97 // KernelMetadata holds runtime-queryable attributes of a loaded kernel, such as
     98 // registers allocated, shared memory used, etc.
     99 // Not all platforms support reporting of all information, so each accessor
    100 // returns false if the associated field is not populated in the underlying
    101 // platform.
    102 class KernelMetadata {
    103  public:
    104   KernelMetadata()
    105       : has_registers_per_thread_(false), has_shared_memory_bytes_(false) {}
    106 
    107   // Returns the number of registers used per thread executing this kernel.
    108   bool registers_per_thread(int *registers_per_thread) const;
    109 
    110   // Sets the number of registers used per thread executing this kernel.
    111   void set_registers_per_thread(int registers_per_thread);
    112 
    113   // Returns the amount of [static] shared memory used per block executing this
    114   // kernel. Note that dynamic shared memory allocations are not (and can not)
    115   // be reported here (since they're not specified until kernel launch time).
    116   bool shared_memory_bytes(int *shared_memory_bytes) const;
    117 
    118   // Sets the amount of [static] shared memory used per block executing this
    119   // kernel.
    120   void set_shared_memory_bytes(int shared_memory_bytes);
    121 
    122  private:
    123   // Holds the value returned by registers_per_thread above.
    124   bool has_registers_per_thread_;
    125   int registers_per_thread_;
    126 
    127   // Holds the value returned by shared_memory_bytes above.
    128   bool has_shared_memory_bytes_;
    129   int64 shared_memory_bytes_;
    130 };
    131 
    132 // A data-parallel kernel (code entity) for launching via the StreamExecutor,
    133 // analogous to a void* device function pointer. See TypedKernel for the typed
    134 // variant.
    135 //
    136 // Thread-compatible.
    137 class KernelBase {
    138  public:
    139   KernelBase(KernelBase &&from);
    140 
    141   // Constructs an "empty" (not-yet-loaded) kernel instance.
    142   //
    143   // parent is the StreamExecutor that will be responsible for loading the
    144   // implementation of this kernel. It must not be null.
    145   explicit KernelBase(StreamExecutor *parent);
    146 
    147   // Test-only constructor that can take a mock KernelInterface implementation.
    148   KernelBase(StreamExecutor *parent, internal::KernelInterface *implementation);
    149 
    150   // Releases resources associated with the kernel instance (i.e.
    151   // platform-specific implementation).
    152   ~KernelBase();
    153 
    154   // Returns the number of parameters that this kernel accepts. (Arity refers to
    155   // nullary, unary, ...).
    156   unsigned Arity() const;
    157 
    158   // Returns the StreamExecutor that represents the platform this kernel
    159   // executes upon.
    160   StreamExecutor *parent() const { return parent_; }
    161 
    162   // Returns a const pointer to the (opaque) platform-dependent implementation.
    163   const internal::KernelInterface *implementation() const {
    164     return implementation_.get();
    165   }
    166 
    167   // Returns a non-const pointer to the (opaque) platform-dependent
    168   // implementation.
    169   internal::KernelInterface *implementation() { return implementation_.get(); }
    170 
    171   void set_metadata(const KernelMetadata &metadata) { metadata_ = metadata; }
    172 
    173   const KernelMetadata &metadata() const { return metadata_; }
    174 
    175   // Sets the preferred cache configuration for a kernel. This is just a
    176   // suggestion to the runtime, and may not be honored during execution.
    177   void SetPreferredCacheConfig(KernelCacheConfig config);
    178 
    179   // Gets the preferred cache configuration for a kernel.
    180   KernelCacheConfig GetPreferredCacheConfig() const;
    181 
    182   void set_name(port::StringPiece name);
    183   const string &name() const { return name_; }
    184   const string &demangled_name() const { return demangled_name_; }
    185 
    186  private:
    187   // The StreamExecutor that loads this kernel object.
    188   StreamExecutor *parent_;
    189 
    190   // Implementation delegated to for platform-specific functionality.
    191   std::unique_ptr<internal::KernelInterface> implementation_;
    192 
    193   string name_;
    194   string demangled_name_;
    195 
    196   KernelMetadata metadata_;
    197 
    198   SE_DISALLOW_COPY_AND_ASSIGN(KernelBase);
    199 };
    200 
    201 // Whether T is a DeviceMemory-family pointer.
    202 template <typename T>
    203 struct IsDeviceMemoryPointer {
    204   static constexpr bool value = false;
    205 };
    206 
    207 template <typename U>
    208 struct IsDeviceMemoryPointer<DeviceMemory<U> *> {
    209   static constexpr bool value = true;
    210 };
    211 
    212 template <>
    213 struct IsDeviceMemoryPointer<DeviceMemoryBase *> {
    214   static constexpr bool value = true;
    215 };
    216 
    217 // Whether T is a DeviceMemory-family value-like thing (which includes a
    218 // reference). This trait is useful because we pack values in the same manner as
    219 // references.
    220 template <typename T>
    221 struct IsDeviceMemoryValueLike {
    222   static constexpr bool value = false;
    223 };
    224 
    225 template <typename U>
    226 struct IsDeviceMemoryValueLike<DeviceMemory<U> &> {
    227   static constexpr bool value = true;
    228 };
    229 
    230 // We need to treat SharedDeviceMemory types differently than other DeviceMemory
    231 // types (since they maintain no allocations), hence these specializations.
    232 template <typename U>
    233 struct IsDeviceMemoryValueLike<SharedDeviceMemory<U> &> {
    234   static constexpr bool value = false;
    235 };
    236 
    237 template <>
    238 struct IsDeviceMemoryValueLike<DeviceMemoryBase &> {
    239   static constexpr bool value = true;
    240 };
    241 
    242 template <typename U>
    243 struct IsDeviceMemoryValueLike<DeviceMemory<U>> {
    244   static constexpr bool value = true;
    245 };
    246 
    247 template <typename U>
    248 struct IsDeviceMemoryValueLike<SharedDeviceMemory<U>> {
    249   static constexpr bool value = false;
    250 };
    251 
    252 template <>
    253 struct IsDeviceMemoryValueLike<DeviceMemoryBase> {
    254   static constexpr bool value = true;
    255 };
    256 
    257 template <typename U>
    258 struct IsSharedDeviceMemory {
    259   static constexpr bool value = false;
    260 };
    261 
    262 template <typename U>
    263 struct IsSharedDeviceMemory<SharedDeviceMemory<U> &> {
    264   static constexpr bool value = true;
    265 };
    266 
    267 template <typename U>
    268 struct IsSharedDeviceMemory<SharedDeviceMemory<U>> {
    269   static constexpr bool value = true;
    270 };
    271 
    272 // Basic data about a kernel argument.
    273 struct KernelArg {
    274   bool is_shared;
    275   const void *address;
    276   size_t size;
    277 };
    278 
    279 // An iterator for traversing all the arguments of a KernelArgsArray.
    280 class KernelArgIterator {
    281  public:
    282   KernelArgIterator(int number_of_argument_addresses,
    283                     int number_of_shared_memory_arguments,
    284                     const void *const *arg_addresses_data,
    285                     const size_t *arg_sizes_data,
    286                     const size_t *shmem_bytes_data,
    287                     const size_t *shmem_indices_data)
    288       : arg_index_(0),
    289         number_of_arguments_(number_of_argument_addresses +
    290                              number_of_shared_memory_arguments),
    291         arg_address_iter_(arg_addresses_data),
    292         arg_size_iter_(arg_sizes_data),
    293         shmem_bytes_iter_(shmem_bytes_data),
    294         shmem_indices_iter_(shmem_indices_data),
    295         shmem_indices_end_(shmem_indices_data +
    296                            number_of_shared_memory_arguments) {}
    297 
    298   // Returns true if another argument is present in the iterator.
    299   bool has_next() { return arg_index_ < number_of_arguments_; }
    300 
    301   // Returns the next argument in the iterator.
    302   //
    303   // Returns a default-constructed KernelArg if there is no next argument.
    304   KernelArg next() {
    305     KernelArg result = {};
    306     if (!has_next()) {
    307       return result;
    308     } else if ((shmem_indices_iter_ != shmem_indices_end_) &&
    309                (arg_index_ == *shmem_indices_iter_)) {
    310       result.is_shared = true;
    311       result.address = nullptr;
    312       result.size = *shmem_bytes_iter_;
    313       ++shmem_indices_iter_;
    314       ++shmem_bytes_iter_;
    315     } else {
    316       result.is_shared = false;
    317       result.address = *arg_address_iter_;
    318       result.size = *arg_size_iter_;
    319       ++arg_address_iter_;
    320       ++arg_size_iter_;
    321     }
    322     ++arg_index_;
    323     return result;
    324   }
    325 
    326  private:
    327   size_t arg_index_;
    328   size_t number_of_arguments_;
    329   const void *const *arg_address_iter_;
    330   const size_t *arg_size_iter_;
    331   const size_t *shmem_bytes_iter_;
    332   const size_t *shmem_indices_iter_;
    333   const size_t *const shmem_indices_end_;
    334 };
    335 
    336 // Base class for KernelArgsArray.
    337 //
    338 // Supports all the getter methods that do not depend on the compile-time number
    339 // of arguments template parameter.
    340 //
    341 // This class exists as a way to pass kernel arguments to
    342 // StreamExecutorInterface::Launch. That Launch method is virtual, so it can't
    343 // be templated to accept any KernelArgsArray type, therefore a reference to
    344 // this base type is passed instead.
    345 //
    346 // Performance is not a concern here because each of these methods will be
    347 // called at most once per kernel launch. Past performance concerns with
    348 // KernelArgsArray have been in reference to the argument packing routines which
    349 // are called once per kernel argument. Those packing routines are now handled
    350 // by the templated KernelArgsArray subclass of this class where they can take
    351 // advantage of compile-time knowledge of the number of arguments in order to be
    352 // very efficient.
    353 class KernelArgsArrayBase {
    354  public:
    355   virtual ~KernelArgsArrayBase() = default;
    356 
    357   // Gets the number of arguments added so far, including shared memory
    358   // arguments.
    359   virtual size_t number_of_arguments() const = 0;
    360 
    361   // Gets the total number of shared memory bytes added so far.
    362   virtual uint64 number_of_shared_bytes() const = 0;
    363 
    364   // Gets the list of argument addresses.
    365   virtual port::ArraySlice<const void *> argument_addresses() const = 0;
    366 
    367   // Gets an iterator to the arguments in the array.
    368   virtual KernelArgIterator arg_iterator() const = 0;
    369 };
    370 
    371 // A list of arguments for a kernel call.
    372 //
    373 // The template parameter kNumArgs is the maximum number of arguments which can
    374 // be stored in the list.
    375 //
    376 // Contains a list of addresses for non-shared-memory arguments and a list of
    377 // sizes for shared-memory arguments. Since the shared-memory arguments may be
    378 // interspersed with the non-shared-memory arguments, it also stores a list of
    379 // the indices at which the shared-memory arguments appeared.
    380 //
    381 // For example, if the argument address list contains {a, b, c, d, e}, the
    382 // shared-memory arguments list contains the sizes of {A, B, C}, and the
    383 // shared-memory indices list contains {0, 3, 5}, then the original list of
    384 // arguments was {A, a, b, B, c, C, d, e}.
    385 //
    386 // This way of storing the arguments makes CUDA kernel calls efficient because
    387 // they only require the argument address list and the total number of shared
    388 // bytes, but it also makes it possible for OpenCL kernel calls because they
    389 // depend on the location of each shared-memory argument and its size.
    390 //
    391 // Note that the code for adding arguments has been identified as a performance
    392 // hotspot in some real-world applications so this structure has been optimized
    393 // for the performance of argument adding.
    394 template <size_t kNumArgs>
    395 class KernelArgsArray : public KernelArgsArrayBase {
    396  public:
    397   explicit KernelArgsArray()
    398       : total_shared_memory_bytes_(0),
    399         number_of_argument_addresses_(0),
    400         number_of_shared_memory_arguments_(0) {}
    401 
    402   // Adds an argument to the list.
    403   //
    404   // Note that the address of the argument is stored, so the input must not go
    405   // out of scope before the instance of this class that calls this method does.
    406   template <typename T>
    407   void add_argument(const T &arg) {
    408     argument_addresses_[number_of_argument_addresses_] =
    409         static_cast<const void *>(&arg);
    410     argument_sizes_[number_of_argument_addresses_] = sizeof(arg);
    411     ++number_of_argument_addresses_;
    412   }
    413 
    414   // Adds a device memory argument to the list.
    415   void add_device_memory_argument(const DeviceMemoryBase &arg) {
    416     const void **copy_ptr =
    417         &device_memory_opaque_pointers_[number_of_argument_addresses_];
    418     *copy_ptr = arg.opaque();
    419     argument_addresses_[number_of_argument_addresses_] = copy_ptr;
    420     argument_sizes_[number_of_argument_addresses_] = sizeof(void *);
    421     ++number_of_argument_addresses_;
    422   }
    423 
    424   // Adds a shared memory argument to the list.
    425   //
    426   // The only significant information about a shared argument is its size, so
    427   // that is the only parameter in this function.
    428   void add_shared_bytes(size_t number_of_bytes) {
    429     shared_memory_indices_[number_of_shared_memory_arguments_] =
    430         number_of_argument_addresses_ + number_of_shared_memory_arguments_;
    431     shared_memory_bytes_[number_of_shared_memory_arguments_] = number_of_bytes;
    432     ++number_of_shared_memory_arguments_;
    433     total_shared_memory_bytes_ += number_of_bytes;
    434   }
    435 
    436   // Gets the number of arguments added so far, including shared memory
    437   // arguments.
    438   size_t number_of_arguments() const override {
    439     return number_of_argument_addresses_ + number_of_shared_memory_arguments_;
    440   }
    441 
    442   // Gets the total number of shared memory bytes added so far.
    443   uint64 number_of_shared_bytes() const override {
    444     return total_shared_memory_bytes_;
    445   }
    446 
    447   // Gets the list of argument addresses.
    448   port::ArraySlice<const void *> argument_addresses() const override {
    449     return port::ArraySlice<const void *>(argument_addresses_.data(),
    450                                           number_of_argument_addresses_);
    451   }
    452 
    453   // Gets an iterator to the arguments in the array.
    454   KernelArgIterator arg_iterator() const override {
    455     return KernelArgIterator(
    456         number_of_argument_addresses_, number_of_shared_memory_arguments_,
    457         argument_addresses_.data(), argument_sizes_.data(),
    458         shared_memory_bytes_.data(), shared_memory_indices_.data());
    459   }
    460 
    461  private:
    462   // A place to store copies of opaque pointers from device memory arguments.
    463   std::array<const void *, kNumArgs> device_memory_opaque_pointers_;
    464 
    465   // Addresses for non-shared-memory arguments.
    466   std::array<const void *, kNumArgs> argument_addresses_;
    467 
    468   // Sizes for non-shared-memory arguments.
    469   std::array<size_t, kNumArgs> argument_sizes_;
    470 
    471   // Size in bytes for each shared memory argument.
    472   std::array<size_t, kNumArgs> shared_memory_bytes_;
    473 
    474   // Indices in the arguments array for shared memory arguments.
    475   std::array<size_t, kNumArgs> shared_memory_indices_;
    476 
    477   // Total of all shared memory sizes.
    478   size_t total_shared_memory_bytes_;
    479 
    480   // Number of significant entries in argument_addresses_ and argument_sizes_.
    481   size_t number_of_argument_addresses_;
    482 
    483   // Number of significant entries in shared_memory_bytes_ and
    484   // shared_memory_indices_.
    485   size_t number_of_shared_memory_arguments_;
    486 };
    487 
    488 // Typed variant of KernelBase, like a typed device function pointer. See the
    489 // file comment for details and example usage.
    490 //
    491 // This class contains template metaprogramming magic to type check the
    492 // parameters passed to a kernel launch are acceptable, and subsequently pack
    493 // them into a form which can be used by the StreamExecutorInterface
    494 // implementation. (i.e.  CUDA and OpenCL both bind void*s with associated
    495 // sizes as kernel arguments.)
    496 //
    497 // Thread-compatible.
    498 template <typename... Params>
    499 class TypedKernel : public KernelBase {
    500  public:
    501   static constexpr size_t kNumberOfParameters = sizeof...(Params);
    502 
    503   // Delegates to KernelBase::KernelBase(), see that constructor.
    504   explicit TypedKernel(StreamExecutor *parent) : KernelBase(parent) {}
    505 
    506   // Test-only constructor that can take a mock KernelInterface implementation.
    507   // Takes ownership of implementation, it should not be null.
    508   TypedKernel(StreamExecutor *parent, internal::KernelInterface *implementation)
    509       : KernelBase(parent, implementation) {}
    510 
    511  private:
    512   // Stream needs access to the specific parameter-packing functionality that
    513   // the TypedKernel provides for its corresponding type signature (and no other
    514   // type signatures).
    515   friend class Stream;
    516 
    517   // This is the main entry point into the magic. Packs the parameters (which
    518   // must type check against the class template) into the args and sizes
    519   // arrays.
    520   //
    521   // Const refs are taken as parameters on all of the handlers to avoid
    522   // implicit type promotion of integers.
    523   //
    524   // WARNING: as a performance optimization this method may store pointers to
    525   // some of the input parameters in the kernel args structure, so any params
    526   // passed into this method must live at least as long as the kernel args
    527   // structure.
    528   void PackParams(KernelArgsArray<kNumberOfParameters> *args,
    529                   Params &... params) const {
    530     PackOneParam(args, params...);
    531   }
    532 
    533   template <typename T, typename... RestOfParams>
    534   void PackOneParam(KernelArgsArray<kNumberOfParameters> *args, const T &arg,
    535                     const RestOfParams &... rest) const {
    536     PackOneParam(args, arg);
    537     PackOneParam(args, rest...);
    538   }
    539 
    540   // Packs one (non-DeviceMemoryBase) parameter into the arg and sizes array.
    541   // The enable_if<> is for excluding DeviceMemoryBase args, which have a
    542   // separate implementation below.
    543   template <typename T>
    544   void PackOneParam(
    545       KernelArgsArray<kNumberOfParameters> *args, const T &arg,
    546       typename std::enable_if<!IsDeviceMemoryValueLike<T>::value &&
    547                               !IsDeviceMemoryPointer<T>::value &&
    548                               !IsSharedDeviceMemory<T>::value>::type * =
    549           nullptr) const {
    550     static_assert(!std::is_pointer<T>::value,
    551                   "cannot pass raw pointer to the device");
    552     static_assert(!std::is_convertible<T, DeviceMemoryBase>::value,
    553                   "cannot pass device memory as a normal value");
    554     args->add_argument(arg);
    555   }
    556 
    557   // DeviceMemoryBase family reference override.
    558   template <typename T>
    559   void PackOneParam(
    560       KernelArgsArray<kNumberOfParameters> *args, const T &arg,
    561       typename std::enable_if<IsDeviceMemoryValueLike<T>::value>::type * =
    562           nullptr) const {
    563     args->add_device_memory_argument(arg);
    564   }
    565 
    566   // DeviceMemoryBase family pointer override.
    567   template <typename T>
    568   void PackOneParam(
    569       KernelArgsArray<kNumberOfParameters> *args, T arg,
    570       typename std::enable_if<IsDeviceMemoryPointer<T>::value>::type * =
    571           nullptr) const {
    572     DeviceMemoryBase *ptr = static_cast<DeviceMemoryBase *>(arg);
    573     args->add_device_memory_argument(*ptr);
    574   }
    575 
    576   // Dynamic shared device memory has a size, but no associated allocation on
    577   // the host; internally, the device will allocate storage.
    578   template <typename T>
    579   void PackOneParam(
    580       KernelArgsArray<kNumberOfParameters> *args, T arg,
    581       typename std::enable_if<IsSharedDeviceMemory<T>::value>::type * =
    582           nullptr) const {
    583     args->add_shared_bytes(arg.size());
    584   }
    585 
    586   // Base case for variadic template expansion - nothing to do!
    587   void PackOneParam(KernelArgsArray<kNumberOfParameters> *args) const {}
    588 
    589   SE_DISALLOW_COPY_AND_ASSIGN(TypedKernel);
    590 };
    591 
    592 // Template metaprogramming helper type that helps us produce better error
    593 // messages at compile time when the are mismatches between the parameter
    594 // type list and the argument type list.
    595 template <typename ParamTuple, typename ArgTuple>
    596 struct KernelInvocationChecker {
    597   // Whether the parameter tuple and argument tuple match in length.
    598   static constexpr bool kLengthMatches =
    599       std::tuple_size<ParamTuple>::value == std::tuple_size<ArgTuple>::value;
    600 
    601   // The (matching) length of the parameters and arguments type lists.
    602   static constexpr int kTupleLength =
    603       static_cast<int>(std::tuple_size<ArgTuple>::value);
    604 
    605   // Helper trait to say whether the parameter wants a DeviceMemory-reference
    606   // compatible type. This is for inexact type matches, so that it doesn't have
    607   // to be precisely a const DeviceMemory<T>&, but can also be a value that
    608   // represents the same.
    609   template <typename ParamType, typename ArgType>
    610   struct IsCompatibleDeviceMemoryRef {
    611     static constexpr bool value = false;
    612   };
    613 
    614   // See type trait definition above.
    615   template <typename U>
    616   struct IsCompatibleDeviceMemoryRef<const DeviceMemory<U> &, DeviceMemory<U>> {
    617     static constexpr bool value = true;
    618   };
    619 
    620   // See type trait definition above.
    621   template <typename U>
    622   struct IsCompatibleDeviceMemoryRef<const SharedDeviceMemory<U> &,
    623                                      SharedDeviceMemory<U>> {
    624     static constexpr bool value = true;
    625   };
    626 
    627   // Returns whether ParamT and ArgT are compatible for data parallel kernel
    628   // parameter packing without any assert functionality.
    629   template <typename ParamT, typename ArgT>
    630   static constexpr bool CompatibleNoAssert() {
    631     return std::is_same<typename std::remove_const<ParamT>::type,
    632                         ArgT>::value ||
    633            IsCompatibleDeviceMemoryRef<ParamT, ArgT>::value;
    634   }
    635 
    636   // Checks whether ParamT and ArgT are compatible for data parallel kernel
    637   // parameter packing. kArgumentNumber is unused, it just for error display.
    638   //
    639   // NOTE: if you encounter an error here, you can see the mismatch by looking
    640   // at the end of the last error message, which will be of the form:
    641   //
    642   //    ...::Compatible<const perftools::gputools::DeviceMemory<OneThing> &,
    643   //                    perftools::gputools::DeviceMemory<AnotherThing>, true,
    644   //                    0>'
    645   //    requested here
    646   //
    647   // This means that the 0th argument you passed to the kernel invocation should
    648   // have been DeviceMemory<OneThing> but was observed to be
    649   // DeviceMemory<AnotherThing>.
    650   template <typename ParamT, typename ArgT, bool kShouldStaticAssert,
    651             int kArgumentNumber>
    652   static constexpr bool Compatible() {
    653     static_assert(
    654         kShouldStaticAssert ? CompatibleNoAssert<ParamT, ArgT>() : true,
    655         "parameter type (LHS) is not compatible with argument type (RHS)");
    656     return CompatibleNoAssert<ParamT, ArgT>();
    657   }
    658 
    659   // Checks the parameter/argument match at kArgumentNumber for an out of bounds
    660   // argument number.
    661   //
    662   // This is the base case: we've run out of argument to check, so we're all
    663   // good.
    664   template <int kArgumentNumber, bool kShouldStaticAssert>
    665   static constexpr bool CheckParam(
    666       typename std::enable_if<(kArgumentNumber < 0)>::type *dummy = nullptr) {
    667     return true;
    668   }
    669 
    670   // Checks the parameter/argument match at kArgumentNumber.
    671   // kShouldStaticAssert determines whether to assert out on a mismatch, or just
    672   // yield the constexpr boolean value.
    673   template <int kArgumentNumber, bool kShouldStaticAssert>
    674   static constexpr bool CheckParam(
    675       typename std::enable_if<kArgumentNumber >= 0>::type *dummy = nullptr) {
    676     typedef typename std::tuple_element<kArgumentNumber, ParamTuple>::type
    677         ParamT;
    678     typedef typename std::tuple_element<kArgumentNumber, ArgTuple>::type ArgT;
    679     return Compatible<ParamT, ArgT, kShouldStaticAssert, kArgumentNumber>() &&
    680            CheckParam<kArgumentNumber - 1, kShouldStaticAssert>();
    681   }
    682 
    683   // Checks the parameters/arguments for match, but doesn't static assert out.
    684   // This is useful for testing/inspecting whether a set of parameters match in
    685   // things like tests.
    686   static constexpr bool CheckAllNoStaticAssert() {
    687     return kLengthMatches && CheckParam<kTupleLength - 1, false>();
    688   }
    689 
    690   // Checks the parameters and static asserts out with a helpful error message
    691   // (and useful template parameters in the instantiation stack) if there is an
    692   // error.
    693   static constexpr bool CheckAllStaticAssert() {
    694     static_assert(kLengthMatches,
    695                   "argument length mismatched against typed kernel parameters");
    696     return kLengthMatches && CheckParam<kTupleLength - 1, true>();
    697   }
    698 };
    699 
    700 // This is a convenience type for checking whether a typed kernel matches
    701 // against a type list.
    702 template <typename KernelT, typename... Params>
    703 struct KernelParamsOk {
    704   static constexpr bool kResult = false;
    705 };
    706 
    707 // See above.
    708 template <typename... Params, typename... Args>
    709 struct KernelParamsOk<TypedKernel<Params...>, Args...> {
    710   static constexpr bool kResult = KernelInvocationChecker<
    711       std::tuple<Params...>, std::tuple<Args...>>::CheckAllNoStaticAssert();
    712 };
    713 
    714 }  // namespace gputools
    715 }  // namespace perftools
    716 
    717 #endif  // TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
    718